231 lines
10 KiB
Python
231 lines
10 KiB
Python
import asyncio
|
|
import json
|
|
import uuid
|
|
from channels.generic.websocket import AsyncJsonWebsocketConsumer
|
|
from channels.db import database_sync_to_async
|
|
from django.contrib.auth import get_user_model
|
|
from .models import RemoteHost, CommandLog, BatchScript, CommandTask
|
|
from .services.ssh_client import open_connection, run_command, SSHError
|
|
import shlex
|
|
|
|
User = get_user_model()
|
|
|
|
class SSHStreamConsumer(AsyncJsonWebsocketConsumer):
|
|
async def connect(self):
|
|
# ensure attributes exist even if auth fails
|
|
self.run_task = None
|
|
self.conn = None
|
|
self.cancel_event = asyncio.Event()
|
|
self.log_id = None
|
|
if not self.scope['user'].is_authenticated:
|
|
await self.close()
|
|
return
|
|
self.session_id = self.scope['url_route']['kwargs'].get('session_id')
|
|
self.group_name = f"ssh_session_{self.session_id}"
|
|
await self.channel_layer.group_add(self.group_name, self.channel_name)
|
|
await self.accept()
|
|
await self.send_json({'event':'connected','session':str(self.session_id)})
|
|
|
|
async def disconnect(self, close_code): # noqa: D401
|
|
run_task = getattr(self, 'run_task', None)
|
|
if run_task and not run_task.done():
|
|
self.cancel_event.set()
|
|
try:
|
|
await asyncio.wait_for(run_task, timeout=3)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
if getattr(self, 'conn', None):
|
|
self.conn.close()
|
|
if hasattr(self, 'group_name'):
|
|
await self.channel_layer.group_discard(self.group_name, self.channel_name)
|
|
|
|
async def receive_json(self, content, **kwargs): # noqa: D401
|
|
action = content.get('action')
|
|
if action == 'start':
|
|
if self.run_task and not self.run_task.done():
|
|
await self.send_json({'event':'error','type':'runtime','message':'Command already running'})
|
|
return
|
|
host_id = content.get('host_id')
|
|
command = content.get('command')
|
|
task_key = content.get('task_key')
|
|
if task_key and not command:
|
|
task = await self.get_db_task(task_key)
|
|
if not task:
|
|
await self.send_json({'event':'error','message':'Invalid task'})
|
|
return
|
|
command = task.command
|
|
if not command:
|
|
await self.send_json({'event':'error','type':'runtime','message':'No command provided'})
|
|
return
|
|
host = await self.get_host(host_id)
|
|
if not host:
|
|
await self.send_json({'event':'error','type':'runtime','message':'Host not found'})
|
|
return
|
|
try:
|
|
self.conn = await open_connection(host)
|
|
except SSHError as e:
|
|
await self.send_json({'event':'error','type':'ssh','message':str(e)})
|
|
return
|
|
log = await self.create_log(host, command, run_type='single')
|
|
self.log_id = log.id
|
|
self.cancel_event = asyncio.Event()
|
|
self.run_task = asyncio.create_task(self._run_and_stream(log.id, command))
|
|
await self.send_json({'event':'started','log_id':log.id,'command':command})
|
|
elif action == 'start_batch':
|
|
if self.run_task and not self.run_task.done():
|
|
await self.send_json({'event':'error','type':'runtime','message':'Command already running'})
|
|
return
|
|
host_id = content.get('host_id')
|
|
batch_id = content.get('batch_id')
|
|
host = await self.get_host(host_id)
|
|
if not host:
|
|
await self.send_json({'event':'error','type':'runtime','message':'Host not found'})
|
|
return
|
|
batch = await self.get_batch(batch_id)
|
|
if not batch:
|
|
await self.send_json({'event':'error','type':'runtime','message':'Batch not found'})
|
|
return
|
|
try:
|
|
self.conn = await open_connection(host)
|
|
except SSHError as e:
|
|
await self.send_json({'event':'error','type':'ssh','message':str(e)})
|
|
return
|
|
log = await self.create_log(host, f"BATCH:{batch.name}\n{batch.script}", run_type='batch')
|
|
self.log_id = log.id
|
|
self.cancel_event = asyncio.Event()
|
|
self.run_task = asyncio.create_task(self._run_batch_and_stream(log.id, batch))
|
|
await self.send_json({'event':'started','log_id':log.id,'command':f'BATCH {batch.name}'})
|
|
elif action == 'cancel':
|
|
if self.run_task and not self.run_task.done():
|
|
self.cancel_event.set()
|
|
await self.send_json({'event':'canceling'})
|
|
elif action == 'disconnect':
|
|
await self.close()
|
|
else:
|
|
await self.send_json({'event':'error','type':'runtime','message':'Unknown action'})
|
|
|
|
async def _run_and_stream(self, log_id: int, command: str):
|
|
tail_buf = {'data': ''}
|
|
|
|
async def on_chunk(stream_name, data):
|
|
tail_buf['data'] = (tail_buf['data'] + data)[-32768:]
|
|
await self.channel_layer.group_send(self.group_name, {
|
|
'type':'ssh.message',
|
|
'payload': {'event':'chunk','stream':stream_name,'data':data}
|
|
})
|
|
status = 'ok'
|
|
exit_code = None
|
|
try:
|
|
exit_code = await run_command(self.conn, command, on_chunk, self.cancel_event)
|
|
if self.cancel_event.is_set():
|
|
status = 'canceled'
|
|
elif exit_code != 0:
|
|
status = 'failed'
|
|
except SSHError as e:
|
|
status = 'error'
|
|
await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'error','type':'ssh','message':str(e)}})
|
|
finally:
|
|
await self.update_log(log_id, status, exit_code, tail_buf['data'])
|
|
await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'completed','status':status,'exit_code':exit_code}})
|
|
if self.conn:
|
|
self.conn.close()
|
|
|
|
async def _run_batch_and_stream(self, log_id: int, batch: BatchScript):
|
|
tail_buf = {'data': ''}
|
|
lines = [ln.rstrip('\n') for ln in batch.script.splitlines()]
|
|
raw_commands = [l for l in lines if l.strip() and not l.strip().startswith('#')]
|
|
steps = []
|
|
current_dir = None
|
|
for line in raw_commands:
|
|
stripped = line.strip()
|
|
if stripped.lower().startswith('cd '):
|
|
current_dir = stripped[3:].strip()
|
|
steps.append((stripped, None))
|
|
else:
|
|
cmd = stripped
|
|
if current_dir:
|
|
cmd = f"cd {shlex.quote(current_dir)} && {cmd}"
|
|
steps.append((stripped, cmd))
|
|
|
|
async def emit(stream, data):
|
|
tail_buf['data'] = (tail_buf['data'] + data)[-32768:]
|
|
await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'chunk','stream':stream,'data':data}})
|
|
|
|
status = 'ok'
|
|
overall_exit = 0
|
|
total = len(steps)
|
|
failed_step_index = None
|
|
try:
|
|
for idx, (display_cmd, exec_cmd) in enumerate(steps, start=1):
|
|
if self.cancel_event.is_set():
|
|
status = 'canceled'
|
|
break
|
|
# send progress event
|
|
await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'progress','current':idx,'total':total,'step':display_cmd}})
|
|
await emit('stdout', f"\n>>> [{idx}/{total}] {display_cmd}\n")
|
|
if exec_cmd:
|
|
try:
|
|
ec = await run_command(self.conn, exec_cmd, lambda s,d: emit(s,d), self.cancel_event)
|
|
except SSHError as e:
|
|
await emit('stderr', f"ERROR: {e}\n")
|
|
status = 'error'
|
|
failed_step_index = idx
|
|
break
|
|
if self.cancel_event.is_set():
|
|
status = 'canceled'
|
|
break
|
|
if ec != 0:
|
|
overall_exit = ec
|
|
status = 'failed'
|
|
failed_step_index = idx
|
|
await emit('stderr', f"Step failed with exit code {ec}, aborting batch.\n")
|
|
break
|
|
else:
|
|
overall_exit = 0
|
|
finally:
|
|
await self.update_log(log_id, status, overall_exit, tail_buf['data'], failed_step_index)
|
|
await self.channel_layer.group_send(self.group_name, {'type':'ssh.message','payload':{'event':'completed','status':status,'exit_code':overall_exit}})
|
|
if self.conn:
|
|
self.conn.close()
|
|
|
|
async def ssh_message(self, event):
|
|
# Handler required for group_send events
|
|
await self.send_json(event['payload'])
|
|
|
|
@database_sync_to_async
|
|
def get_host(self, host_id):
|
|
if not host_id:
|
|
return None
|
|
try:
|
|
return RemoteHost.objects.get(id=host_id)
|
|
except RemoteHost.DoesNotExist:
|
|
return None
|
|
|
|
@database_sync_to_async
|
|
def create_log(self, host, command: str, run_type: str):
|
|
return CommandLog.objects.create(host=host, command=command, created_by=self.scope['user'], run_type=run_type)
|
|
|
|
@database_sync_to_async
|
|
def update_log(self, log_id: int, status: str, exit_code, tail: str, failed_step: int | None = None):
|
|
try:
|
|
log = CommandLog.objects.get(id=log_id)
|
|
except CommandLog.DoesNotExist: # pragma: no cover
|
|
return
|
|
log.mark_finished(status=status, exit_code=exit_code, tail=tail, failed_step=failed_step)
|
|
|
|
@database_sync_to_async
|
|
def get_batch(self, batch_id):
|
|
if not batch_id:
|
|
return None
|
|
try:
|
|
return BatchScript.objects.get(id=batch_id)
|
|
except BatchScript.DoesNotExist:
|
|
return None
|
|
|
|
@database_sync_to_async
|
|
def get_db_task(self, name: str):
|
|
try:
|
|
return CommandTask.objects.get(name=name)
|
|
except CommandTask.DoesNotExist:
|
|
return None
|