You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

190 lines
6.8 KiB

import asyncio
import time
from typing import Optional
import paramiko
from app.core.device.manager import device_manager
from app.schemas.ssh import SSHExecRequest, SSHExecResponse, SSHConnectionInfo
from app.utils.log import get_logger
logger = get_logger(__name__)
class SSHService:
"""SSH服务类 - 实现SSH协议相关操作"""
def __init__(self):
self.connections = {} # 存储SSH连接
async def _get_connection(self, device_id: str) -> paramiko.SSHClient:
"""获取SSH连接"""
if device_id in self.connections:
return self.connections[device_id]
# 从设备管理器获取连接信息
device = await device_manager.get_device(device_id)
if not device:
raise ValueError(f"设备 {device_id} 不存在")
connection_info = device.connection_info
# 创建SSH连接
ssh_client = paramiko.SSHClient()
ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
# 连接参数
host = connection_info.get('host')
port = connection_info.get('port', 22)
username = connection_info.get('username')
password = connection_info.get('password')
key_file = connection_info.get('key_file')
timeout = connection_info.get('timeout', 30)
# 建立连接
if key_file:
ssh_client.connect(
hostname=host,
port=port,
username=username,
key_filename=key_file,
timeout=timeout
)
else:
ssh_client.connect(
hostname=host,
port=port,
username=username,
password=password,
timeout=timeout
)
# 存储连接
self.connections[device_id] = ssh_client
logger.info(f"SSH连接建立成功: {device_id} -> {host}:{port}")
return ssh_client
except Exception as e:
logger.error(f"SSH连接失败: {device_id}, 错误: {e}")
raise ValueError(f"SSH连接失败: {str(e)}")
async def exec_command(self, device_id: str, command: str, timeout: int = 30,
working_directory: Optional[str] = None) -> SSHExecResponse:
"""执行SSH命令"""
start_time = time.time()
try:
ssh_client = await self._get_connection(device_id)
# 如果指定工作目录,先切换到该目录
if working_directory:
command = f"cd {working_directory} && {command}"
# 执行命令
stdin, stdout, stderr = ssh_client.exec_command(command, timeout=timeout)
# 获取输出
stdout_data = stdout.read().decode('utf-8', errors='ignore')
stderr_data = stderr.read().decode('utf-8', errors='ignore')
exit_code = stdout.channel.recv_exit_status()
execution_time = time.time() - start_time
success = exit_code == 0
logger.info(f"SSH命令执行完成: {device_id} -> {command}, 退出码: {exit_code}")
return SSHExecResponse(
success=success,
stdout=stdout_data,
stderr=stderr_data,
exit_code=exit_code,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
logger.error(f"SSH命令执行失败: {device_id} -> {command}, 错误: {e}")
return SSHExecResponse(
success=False,
stdout="",
stderr=str(e),
exit_code=-1,
execution_time=execution_time
)
async def upload_file(self, device_id: str, local_path: str, remote_path: str) -> dict:
"""上传文件"""
try:
ssh_client = await self._get_connection(device_id)
sftp = ssh_client.open_sftp()
# 上传文件
sftp.put(local_path, remote_path)
sftp.close()
logger.info(f"SSH文件上传成功: {device_id} -> {local_path} -> {remote_path}")
return {
"success": True,
"message": "文件上传成功",
"local_path": local_path,
"remote_path": remote_path
}
except Exception as e:
logger.error(f"SSH文件上传失败: {device_id}, 错误: {e}")
return {
"success": False,
"message": f"文件上传失败: {str(e)}"
}
async def download_file(self, device_id: str, remote_path: str, local_path: str) -> dict:
"""下载文件"""
try:
ssh_client = await self._get_connection(device_id)
sftp = ssh_client.open_sftp()
# 下载文件
sftp.get(remote_path, local_path)
sftp.close()
logger.info(f"SSH文件下载成功: {device_id} -> {remote_path} -> {local_path}")
return {
"success": True,
"message": "文件下载成功",
"remote_path": remote_path,
"local_path": local_path
}
except Exception as e:
logger.error(f"SSH文件下载失败: {device_id}, 错误: {e}")
return {
"success": False,
"message": f"文件下载失败: {str(e)}"
}
async def close_connection(self, device_id: str) -> bool:
"""关闭SSH连接"""
try:
if device_id in self.connections:
ssh_client = self.connections[device_id]
ssh_client.close()
del self.connections[device_id]
logger.info(f"SSH连接已关闭: {device_id}")
return True
return False
except Exception as e:
logger.error(f"关闭SSH连接失败: {device_id}, 错误: {e}")
return False
async def test_connection(self, device_id: str) -> bool:
"""测试SSH连接"""
try:
ssh_client = await self._get_connection(device_id)
# 执行简单命令测试连接
stdin, stdout, stderr = ssh_client.exec_command("echo 'test'", timeout=5)
result = stdout.read().decode('utf-8').strip()
return result == "test"
except Exception as e:
logger.error(f"SSH连接测试失败: {device_id}, 错误: {e}")
return False