2025-10-15 17:41:05 +02:00

231 lines
10 KiB
Python

import asyncio
import json
import uuid
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from channels.db import database_sync_to_async
from django.contrib.auth import get_user_model
from .models import RemoteHost, CommandLog, BatchScript, CommandTask
from .services.ssh_client import open_connection, run_command, SSHError
import shlex
User = get_user_model()
class SSHStreamConsumer(AsyncJsonWebsocketConsumer):
async def connect(self):
# ensure attributes exist even if auth fails
self.run_task = None
self.conn = None
self.cancel_event = asyncio.Event()
self.log_id = None
if not self.scope['user'].is_authenticated:
await self.close()
return
self.session_id = self.scope['url_route']['kwargs'].get('session_id')
self.group_name = f"ssh_session_{self.session_id}"
await self.channel_layer.group_add(self.group_name, self.channel_name)
await self.accept()
await self.send_json({'event':'connected','session':str(self.session_id)})
async def disconnect(self, close_code): # noqa: D401
run_task = getattr(self, 'run_task', None)
if run_task and not run_task.done():
self.cancel_event.set()
try:
await asyncio.wait_for(run_task, timeout=3)
except asyncio.TimeoutError:
pass
if getattr(self, 'conn', None):
self.conn.close()
if hasattr(self, 'group_name'):
await self.channel_layer.group_discard(self.group_name, self.channel_name)
async def receive_json(self, content, **kwargs): # noqa: D401
action = content.get('action')
if action == 'start':
if self.run_task and not self.run_task.done():
await self.send_json({'event':'error','type':'runtime','message':'Command already running'})
return
host_id = content.get('host_id')
command = content.get('command')
task_key = content.get('task_key')
if task_key and not command:
task = await self.get_db_task(task_key)
if not task:
await self.send_json({'event':'error','message':'Invalid task'})
return
command = task.command
if not command:
await self.send_json({'event':'error','type':'runtime','message':'No command provided'})
return
host = await self.get_host(host_id)
if not host:
await self.send_json({'event':'error','type':'runtime','message':'Host not found'})
return
try:
self.conn = await open_connection(host)
except SSHError as e:
await self.send_json({'event':'error','type':'ssh','message':str(e)})
return
log = await self.create_log(host, command, run_type='single')
self.log_id = log.id
self.cancel_event = asyncio.Event()
self.run_task = asyncio.create_task(self._run_and_stream(log.id, command))
await self.send_json({'event':'started','log_id':log.id,'command':command})
elif action == 'start_batch':
if self.run_task and not self.run_task.done():
await self.send_json({'event':'error','type':'runtime','message':'Command already running'})
return
host_id = content.get('host_id')
batch_id = content.get('batch_id')
host = await self.get_host(host_id)
if not host:
await self.send_json({'event':'error','type':'runtime','message':'Host not found'})
return
batch = await self.get_batch(batch_id)
if not batch:
await self.send_json({'event':'error','type':'runtime','message':'Batch not found'})
return
try:
self.conn = await open_connection(host)
except SSHError as e:
await self.send_json({'event':'error','type':'ssh','message':str(e)})
return
log = await self.create_log(host, f"BATCH:{batch.name}\n{batch.script}", run_type='batch')
self.log_id = log.id
self.cancel_event = asyncio.Event()
self.run_task = asyncio.create_task(self._run_batch_and_stream(log.id, batch))
await self.send_json({'event':'started','log_id':log.id,'command':f'BATCH {batch.name}'})
elif action == 'cancel':
if self.run_task and not self.run_task.done():
self.cancel_event.set()
await self.send_json({'event':'canceling'})
elif action == 'disconnect':
await self.close()
else:
await self.send_json({'event':'error','type':'runtime','message':'Unknown action'})
async def _run_and_stream(self, log_id: int, command: str):
tail_buf = {'data': ''}
async def on_chunk(stream_name, data):
tail_buf['data'] = (tail_buf['data'] + data)[-32768:]
await self.channel_layer.group_send(self.group_name, {
'type':'ssh.message',
'payload': {'event':'chunk','stream':stream_name,'data':data}
})
status = 'ok'
exit_code = None
try:
exit_code = await run_command(self.conn, command, on_chunk, self.cancel_event)
if self.cancel_event.is_set():
status = 'canceled'
elif exit_code != 0:
status = 'failed'
except SSHError as e:
status = 'error'
await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'error','type':'ssh','message':str(e)}})
finally:
await self.update_log(log_id, status, exit_code, tail_buf['data'])
await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'completed','status':status,'exit_code':exit_code}})
if self.conn:
self.conn.close()
async def _run_batch_and_stream(self, log_id: int, batch: BatchScript):
tail_buf = {'data': ''}
lines = [ln.rstrip('\n') for ln in batch.script.splitlines()]
raw_commands = [l for l in lines if l.strip() and not l.strip().startswith('#')]
steps = []
current_dir = None
for line in raw_commands:
stripped = line.strip()
if stripped.lower().startswith('cd '):
current_dir = stripped[3:].strip()
steps.append((stripped, None))
else:
cmd = stripped
if current_dir:
cmd = f"cd {shlex.quote(current_dir)} && {cmd}"
steps.append((stripped, cmd))
async def emit(stream, data):
tail_buf['data'] = (tail_buf['data'] + data)[-32768:]
await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'chunk','stream':stream,'data':data}})
status = 'ok'
overall_exit = 0
total = len(steps)
failed_step_index = None
try:
for idx, (display_cmd, exec_cmd) in enumerate(steps, start=1):
if self.cancel_event.is_set():
status = 'canceled'
break
# send progress event
await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'progress','current':idx,'total':total,'step':display_cmd}})
await emit('stdout', f"\n>>> [{idx}/{total}] {display_cmd}\n")
if exec_cmd:
try:
ec = await run_command(self.conn, exec_cmd, lambda s,d: emit(s,d), self.cancel_event)
except SSHError as e:
await emit('stderr', f"ERROR: {e}\n")
status = 'error'
failed_step_index = idx
break
if self.cancel_event.is_set():
status = 'canceled'
break
if ec != 0:
overall_exit = ec
status = 'failed'
failed_step_index = idx
await emit('stderr', f"Step failed with exit code {ec}, aborting batch.\n")
break
else:
overall_exit = 0
finally:
await self.update_log(log_id, status, overall_exit, tail_buf['data'], failed_step_index)
await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'completed','status':status,'exit_code':overall_exit}})
if self.conn:
self.conn.close()
async def ssh_message(self, event):
# Handler required for group_send events
await self.send_json(event['payload'])
@database_sync_to_async
def get_host(self, host_id):
if not host_id:
return None
try:
return RemoteHost.objects.get(id=host_id)
except RemoteHost.DoesNotExist:
return None
@database_sync_to_async
def create_log(self, host, command: str, run_type: str):
return CommandLog.objects.create(host=host, command=command, created_by=self.scope['user'], run_type=run_type)
@database_sync_to_async
def update_log(self, log_id: int, status: str, exit_code, tail: str, failed_step: int | None = None):
try:
log = CommandLog.objects.get(id=log_id)
except CommandLog.DoesNotExist: # pragma: no cover
return
log.mark_finished(status=status, exit_code=exit_code, tail=tail, failed_step=failed_step)
@database_sync_to_async
def get_batch(self, batch_id):
if not batch_id:
return None
try:
return BatchScript.objects.get(id=batch_id)
except BatchScript.DoesNotExist:
return None
@database_sync_to_async
def get_db_task(self, name: str):
try:
return CommandTask.objects.get(name=name)
except CommandTask.DoesNotExist:
return None