You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
190 lines
6.8 KiB
190 lines
6.8 KiB
import asyncio
|
|
import time
|
|
from typing import Optional
|
|
import paramiko
|
|
from app.core.device.manager import device_manager
|
|
from app.schemas.ssh import SSHExecRequest, SSHExecResponse, SSHConnectionInfo
|
|
from app.utils.log import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
class SSHService:
|
|
"""SSH服务类 - 实现SSH协议相关操作"""
|
|
|
|
def __init__(self):
|
|
self.connections = {} # 存储SSH连接
|
|
|
|
async def _get_connection(self, device_id: str) -> paramiko.SSHClient:
|
|
"""获取SSH连接"""
|
|
if device_id in self.connections:
|
|
return self.connections[device_id]
|
|
|
|
# 从设备管理器获取连接信息
|
|
device = await device_manager.get_device(device_id)
|
|
if not device:
|
|
raise ValueError(f"设备 {device_id} 不存在")
|
|
|
|
connection_info = device.connection_info
|
|
|
|
# 创建SSH连接
|
|
ssh_client = paramiko.SSHClient()
|
|
ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
|
|
try:
|
|
# 连接参数
|
|
host = connection_info.get('host')
|
|
port = connection_info.get('port', 22)
|
|
username = connection_info.get('username')
|
|
password = connection_info.get('password')
|
|
key_file = connection_info.get('key_file')
|
|
timeout = connection_info.get('timeout', 30)
|
|
|
|
# 建立连接
|
|
if key_file:
|
|
ssh_client.connect(
|
|
hostname=host,
|
|
port=port,
|
|
username=username,
|
|
key_filename=key_file,
|
|
timeout=timeout
|
|
)
|
|
else:
|
|
ssh_client.connect(
|
|
hostname=host,
|
|
port=port,
|
|
username=username,
|
|
password=password,
|
|
timeout=timeout
|
|
)
|
|
|
|
# 存储连接
|
|
self.connections[device_id] = ssh_client
|
|
logger.info(f"SSH连接建立成功: {device_id} -> {host}:{port}")
|
|
|
|
return ssh_client
|
|
|
|
except Exception as e:
|
|
logger.error(f"SSH连接失败: {device_id}, 错误: {e}")
|
|
raise ValueError(f"SSH连接失败: {str(e)}")
|
|
|
|
async def exec_command(self, device_id: str, command: str, timeout: int = 30,
|
|
working_directory: Optional[str] = None) -> SSHExecResponse:
|
|
"""执行SSH命令"""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
ssh_client = await self._get_connection(device_id)
|
|
|
|
# 如果指定工作目录,先切换到该目录
|
|
if working_directory:
|
|
command = f"cd {working_directory} && {command}"
|
|
|
|
# 执行命令
|
|
stdin, stdout, stderr = ssh_client.exec_command(command, timeout=timeout)
|
|
|
|
# 获取输出
|
|
stdout_data = stdout.read().decode('utf-8', errors='ignore')
|
|
stderr_data = stderr.read().decode('utf-8', errors='ignore')
|
|
exit_code = stdout.channel.recv_exit_status()
|
|
|
|
execution_time = time.time() - start_time
|
|
|
|
success = exit_code == 0
|
|
|
|
logger.info(f"SSH命令执行完成: {device_id} -> {command}, 退出码: {exit_code}")
|
|
|
|
return SSHExecResponse(
|
|
success=success,
|
|
stdout=stdout_data,
|
|
stderr=stderr_data,
|
|
exit_code=exit_code,
|
|
execution_time=execution_time
|
|
)
|
|
|
|
except Exception as e:
|
|
execution_time = time.time() - start_time
|
|
logger.error(f"SSH命令执行失败: {device_id} -> {command}, 错误: {e}")
|
|
|
|
return SSHExecResponse(
|
|
success=False,
|
|
stdout="",
|
|
stderr=str(e),
|
|
exit_code=-1,
|
|
execution_time=execution_time
|
|
)
|
|
|
|
async def upload_file(self, device_id: str, local_path: str, remote_path: str) -> dict:
|
|
"""上传文件"""
|
|
try:
|
|
ssh_client = await self._get_connection(device_id)
|
|
sftp = ssh_client.open_sftp()
|
|
|
|
# 上传文件
|
|
sftp.put(local_path, remote_path)
|
|
sftp.close()
|
|
|
|
logger.info(f"SSH文件上传成功: {device_id} -> {local_path} -> {remote_path}")
|
|
return {
|
|
"success": True,
|
|
"message": "文件上传成功",
|
|
"local_path": local_path,
|
|
"remote_path": remote_path
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"SSH文件上传失败: {device_id}, 错误: {e}")
|
|
return {
|
|
"success": False,
|
|
"message": f"文件上传失败: {str(e)}"
|
|
}
|
|
|
|
async def download_file(self, device_id: str, remote_path: str, local_path: str) -> dict:
|
|
"""下载文件"""
|
|
try:
|
|
ssh_client = await self._get_connection(device_id)
|
|
sftp = ssh_client.open_sftp()
|
|
|
|
# 下载文件
|
|
sftp.get(remote_path, local_path)
|
|
sftp.close()
|
|
|
|
logger.info(f"SSH文件下载成功: {device_id} -> {remote_path} -> {local_path}")
|
|
return {
|
|
"success": True,
|
|
"message": "文件下载成功",
|
|
"remote_path": remote_path,
|
|
"local_path": local_path
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"SSH文件下载失败: {device_id}, 错误: {e}")
|
|
return {
|
|
"success": False,
|
|
"message": f"文件下载失败: {str(e)}"
|
|
}
|
|
|
|
async def close_connection(self, device_id: str) -> bool:
|
|
"""关闭SSH连接"""
|
|
try:
|
|
if device_id in self.connections:
|
|
ssh_client = self.connections[device_id]
|
|
ssh_client.close()
|
|
del self.connections[device_id]
|
|
logger.info(f"SSH连接已关闭: {device_id}")
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"关闭SSH连接失败: {device_id}, 错误: {e}")
|
|
return False
|
|
|
|
async def test_connection(self, device_id: str) -> bool:
|
|
"""测试SSH连接"""
|
|
try:
|
|
ssh_client = await self._get_connection(device_id)
|
|
# 执行简单命令测试连接
|
|
stdin, stdout, stderr = ssh_client.exec_command("echo 'test'", timeout=5)
|
|
result = stdout.read().decode('utf-8').strip()
|
|
return result == "test"
|
|
except Exception as e:
|
|
logger.error(f"SSH连接测试失败: {device_id}, 错误: {e}")
|
|
return False
|