import asyncio import json import uuid from channels.generic.websocket import AsyncJsonWebsocketConsumer from channels.db import database_sync_to_async from django.contrib.auth import get_user_model from .models import RemoteHost, CommandLog, BatchScript, CommandTask from .services.ssh_client import open_connection, run_command, SSHError import shlex User = get_user_model() class SSHStreamConsumer(AsyncJsonWebsocketConsumer): async def connect(self): # ensure attributes exist even if auth fails self.run_task = None self.conn = None self.cancel_event = asyncio.Event() self.log_id = None if not self.scope['user'].is_authenticated: await self.close() return self.session_id = self.scope['url_route']['kwargs'].get('session_id') self.group_name = f"ssh_session_{self.session_id}" await self.channel_layer.group_add(self.group_name, self.channel_name) await self.accept() await self.send_json({'event':'connected','session':str(self.session_id)}) async def disconnect(self, close_code): # noqa: D401 run_task = getattr(self, 'run_task', None) if run_task and not run_task.done(): self.cancel_event.set() try: await asyncio.wait_for(run_task, timeout=3) except asyncio.TimeoutError: pass if getattr(self, 'conn', None): self.conn.close() if hasattr(self, 'group_name'): await self.channel_layer.group_discard(self.group_name, self.channel_name) async def receive_json(self, content, **kwargs): # noqa: D401 action = content.get('action') if action == 'start': if self.run_task and not self.run_task.done(): await self.send_json({'event':'error','type':'runtime','message':'Command already running'}) return host_id = content.get('host_id') command = content.get('command') task_key = content.get('task_key') if task_key and not command: task = await self.get_db_task(task_key) if not task: await self.send_json({'event':'error','message':'Invalid task'}) return command = task.command if not command: await self.send_json({'event':'error','type':'runtime','message':'No command provided'}) return host = await self.get_host(host_id) if not host: await self.send_json({'event':'error','type':'runtime','message':'Host not found'}) return try: self.conn = await open_connection(host) except SSHError as e: await self.send_json({'event':'error','type':'ssh','message':str(e)}) return log = await self.create_log(host, command, run_type='single') self.log_id = log.id self.cancel_event = asyncio.Event() self.run_task = asyncio.create_task(self._run_and_stream(log.id, command)) await self.send_json({'event':'started','log_id':log.id,'command':command}) elif action == 'start_batch': if self.run_task and not self.run_task.done(): await self.send_json({'event':'error','type':'runtime','message':'Command already running'}) return host_id = content.get('host_id') batch_id = content.get('batch_id') host = await self.get_host(host_id) if not host: await self.send_json({'event':'error','type':'runtime','message':'Host not found'}) return batch = await self.get_batch(batch_id) if not batch: await self.send_json({'event':'error','type':'runtime','message':'Batch not found'}) return try: self.conn = await open_connection(host) except SSHError as e: await self.send_json({'event':'error','type':'ssh','message':str(e)}) return log = await self.create_log(host, f"BATCH:{batch.name}\n{batch.script}", run_type='batch') self.log_id = log.id self.cancel_event = asyncio.Event() self.run_task = asyncio.create_task(self._run_batch_and_stream(log.id, batch)) await self.send_json({'event':'started','log_id':log.id,'command':f'BATCH {batch.name}'}) elif action == 'cancel': if self.run_task and not self.run_task.done(): self.cancel_event.set() await self.send_json({'event':'canceling'}) elif action == 'disconnect': await self.close() else: await self.send_json({'event':'error','type':'runtime','message':'Unknown action'}) async def _run_and_stream(self, log_id: int, command: str): tail_buf = {'data': ''} async def on_chunk(stream_name, data): tail_buf['data'] = (tail_buf['data'] + data)[-32768:] await self.channel_layer.group_send(self.group_name, { 'type':'ssh.message', 'payload': {'event':'chunk','stream':stream_name,'data':data} }) status = 'ok' exit_code = None try: exit_code = await run_command(self.conn, command, on_chunk, self.cancel_event) if self.cancel_event.is_set(): status = 'canceled' elif exit_code != 0: status = 'failed' except SSHError as e: status = 'error' await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'error','type':'ssh','message':str(e)}}) finally: await self.update_log(log_id, status, exit_code, tail_buf['data']) await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'completed','status':status,'exit_code':exit_code}}) if self.conn: self.conn.close() async def _run_batch_and_stream(self, log_id: int, batch: BatchScript): tail_buf = {'data': ''} lines = [ln.rstrip('\n') for ln in batch.script.splitlines()] raw_commands = [l for l in lines if l.strip() and not l.strip().startswith('#')] steps = [] current_dir = None for line in raw_commands: stripped = line.strip() if stripped.lower().startswith('cd '): current_dir = stripped[3:].strip() steps.append((stripped, None)) else: cmd = stripped if current_dir: cmd = f"cd {shlex.quote(current_dir)} && {cmd}" steps.append((stripped, cmd)) async def emit(stream, data): tail_buf['data'] = (tail_buf['data'] + data)[-32768:] await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'chunk','stream':stream,'data':data}}) status = 'ok' overall_exit = 0 total = len(steps) failed_step_index = None try: for idx, (display_cmd, exec_cmd) in enumerate(steps, start=1): if self.cancel_event.is_set(): status = 'canceled' break # send progress event await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'progress','current':idx,'total':total,'step':display_cmd}}) await emit('stdout', f"\n>>> [{idx}/{total}] {display_cmd}\n") if exec_cmd: try: ec = await run_command(self.conn, exec_cmd, lambda s,d: emit(s,d), self.cancel_event) except SSHError as e: await emit('stderr', f"ERROR: {e}\n") status = 'error' failed_step_index = idx break if self.cancel_event.is_set(): status = 'canceled' break if ec != 0: overall_exit = ec status = 'failed' failed_step_index = idx await emit('stderr', f"Step failed with exit code {ec}, aborting batch.\n") break else: overall_exit = 0 finally: await self.update_log(log_id, status, overall_exit, tail_buf['data'], failed_step_index) await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'completed','status':status,'exit_code':overall_exit}}) if self.conn: self.conn.close() async def ssh_message(self, event): # Handler required for group_send events await self.send_json(event['payload']) @database_sync_to_async def get_host(self, host_id): if not host_id: return None try: return RemoteHost.objects.get(id=host_id) except RemoteHost.DoesNotExist: return None @database_sync_to_async def create_log(self, host, command: str, run_type: str): return CommandLog.objects.create(host=host, command=command, created_by=self.scope['user'], run_type=run_type) @database_sync_to_async def update_log(self, log_id: int, status: str, exit_code, tail: str, failed_step: int | None = None): try: log = CommandLog.objects.get(id=log_id) except CommandLog.DoesNotExist: # pragma: no cover return log.mark_finished(status=status, exit_code=exit_code, tail=tail, failed_step=failed_step) @database_sync_to_async def get_batch(self, batch_id): if not batch_id: return None try: return BatchScript.objects.get(id=batch_id) except BatchScript.DoesNotExist: return None @database_sync_to_async def get_db_task(self, name: str): try: return CommandTask.objects.get(name=name) except CommandTask.DoesNotExist: return None