import asyncio import time from typing import Optional import paramiko from app.core.device.manager import device_manager from app.schemas.ssh import SSHExecRequest, SSHExecResponse, SSHConnectionInfo from app.utils.log import get_logger logger = get_logger(__name__) class SSHService: """SSH服务类 - 实现SSH协议相关操作""" def __init__(self): self.connections = {} # 存储SSH连接 async def _get_connection(self, device_id: str) -> paramiko.SSHClient: """获取SSH连接""" if device_id in self.connections: return self.connections[device_id] # 从设备管理器获取连接信息 device = await device_manager.get_device(device_id) if not device: raise ValueError(f"设备 {device_id} 不存在") connection_info = device.connection_info # 创建SSH连接 ssh_client = paramiko.SSHClient() ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) try: # 连接参数 host = connection_info.get('host') port = connection_info.get('port', 22) username = connection_info.get('username') password = connection_info.get('password') key_file = connection_info.get('key_file') timeout = connection_info.get('timeout', 30) # 建立连接 if key_file: ssh_client.connect( hostname=host, port=port, username=username, key_filename=key_file, timeout=timeout ) else: ssh_client.connect( hostname=host, port=port, username=username, password=password, timeout=timeout ) # 存储连接 self.connections[device_id] = ssh_client logger.info(f"SSH连接建立成功: {device_id} -> {host}:{port}") return ssh_client except Exception as e: logger.error(f"SSH连接失败: {device_id}, 错误: {e}") raise ValueError(f"SSH连接失败: {str(e)}") async def exec_command(self, device_id: str, command: str, timeout: int = 30, working_directory: Optional[str] = None) -> SSHExecResponse: """执行SSH命令""" start_time = time.time() try: ssh_client = await self._get_connection(device_id) # 如果指定工作目录,先切换到该目录 if working_directory: command = f"cd {working_directory} && {command}" # 执行命令 stdin, stdout, stderr = ssh_client.exec_command(command, timeout=timeout) # 获取输出 stdout_data = stdout.read().decode('utf-8', errors='ignore') stderr_data = stderr.read().decode('utf-8', errors='ignore') exit_code = stdout.channel.recv_exit_status() execution_time = time.time() - start_time success = exit_code == 0 logger.info(f"SSH命令执行完成: {device_id} -> {command}, 退出码: {exit_code}") return SSHExecResponse( success=success, stdout=stdout_data, stderr=stderr_data, exit_code=exit_code, execution_time=execution_time ) except Exception as e: execution_time = time.time() - start_time logger.error(f"SSH命令执行失败: {device_id} -> {command}, 错误: {e}") return SSHExecResponse( success=False, stdout="", stderr=str(e), exit_code=-1, execution_time=execution_time ) async def upload_file(self, device_id: str, local_path: str, remote_path: str) -> dict: """上传文件""" try: ssh_client = await self._get_connection(device_id) sftp = ssh_client.open_sftp() # 上传文件 sftp.put(local_path, remote_path) sftp.close() logger.info(f"SSH文件上传成功: {device_id} -> {local_path} -> {remote_path}") return { "success": True, "message": "文件上传成功", "local_path": local_path, "remote_path": remote_path } except Exception as e: logger.error(f"SSH文件上传失败: {device_id}, 错误: {e}") return { "success": False, "message": f"文件上传失败: {str(e)}" } async def download_file(self, device_id: str, remote_path: str, local_path: str) -> dict: """下载文件""" try: ssh_client = await self._get_connection(device_id) sftp = ssh_client.open_sftp() # 下载文件 sftp.get(remote_path, local_path) sftp.close() logger.info(f"SSH文件下载成功: {device_id} -> {remote_path} -> {local_path}") return { "success": True, "message": "文件下载成功", "remote_path": remote_path, "local_path": local_path } except Exception as e: logger.error(f"SSH文件下载失败: {device_id}, 错误: {e}") return { "success": False, "message": f"文件下载失败: {str(e)}" } async def close_connection(self, device_id: str) -> bool: """关闭SSH连接""" try: if device_id in self.connections: ssh_client = self.connections[device_id] ssh_client.close() del self.connections[device_id] logger.info(f"SSH连接已关闭: {device_id}") return True return False except Exception as e: logger.error(f"关闭SSH连接失败: {device_id}, 错误: {e}") return False async def test_connection(self, device_id: str) -> bool: """测试SSH连接""" try: ssh_client = await self._get_connection(device_id) # 执行简单命令测试连接 stdin, stdout, stderr = ssh_client.exec_command("echo 'test'", timeout=5) result = stdout.read().decode('utf-8').strip() return result == "test" except Exception as e: logger.error(f"SSH连接测试失败: {device_id}, 错误: {e}") return False