You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
72 lines
2.6 KiB
72 lines
2.6 KiB
1 month ago
|
from fastapi import APIRouter, HTTPException
|
||
|
from pydantic import BaseModel, validator
|
||
|
from app.services.document_service import DocumentService
|
||
|
from app.services.vector_store_service import VectorStoreService
|
||
|
from app.chains.rag_chain import RAGChain
|
||
|
from app.core.logger import logger
|
||
|
import os
|
||
|
from pathlib import Path
|
||
|
|
||
|
router = APIRouter(prefix="/rag", tags=["RAG"])
|
||
|
|
||
|
# 初始化服务
|
||
|
document_service = DocumentService()
|
||
|
vector_store_service = VectorStoreService()
|
||
|
rag_chain = RAGChain(vector_store_service)
|
||
|
|
||
|
class DirectoryRequest(BaseModel):
|
||
|
directory: str
|
||
|
|
||
|
@validator('directory')
|
||
|
def validate_directory(cls, v):
|
||
|
# 统一路径分隔符
|
||
|
path = Path(v)
|
||
|
if not path.exists():
|
||
|
raise ValueError(f"路径不存在: {v}")
|
||
|
if not path.is_file() and not path.is_dir():
|
||
|
raise ValueError(f"无效的路径: {v}")
|
||
|
return str(path.absolute())
|
||
|
|
||
|
class QuestionRequest(BaseModel):
|
||
|
question: str
|
||
|
|
||
|
@router.post("/upload")
|
||
|
async def upload_documents(request: DirectoryRequest):
|
||
|
"""上传并处理文档"""
|
||
|
try:
|
||
|
logger.info(f"开始处理文档: {request.directory}")
|
||
|
# 处理文档
|
||
|
documents = document_service.process_documents(request.directory)
|
||
|
if not documents:
|
||
|
return {"message": "没有新的文档需要处理"}
|
||
|
# 添加到向量存储
|
||
|
vector_store_service.add_documents(documents)
|
||
|
logger.info(f"成功处理 {len(documents)} 个文档块")
|
||
|
return {"message": f"成功处理 {len(documents)} 个文档块"}
|
||
|
except Exception as e:
|
||
|
logger.error(f"处理文档时发生错误: {str(e)}", exc_info=True)
|
||
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
||
|
@router.post("/query")
|
||
|
async def query(request: QuestionRequest):
|
||
|
"""查询问题"""
|
||
|
try:
|
||
|
result = rag_chain.query(request.question)
|
||
|
return {
|
||
|
"answer": result["result"],
|
||
|
"sources": [doc.page_content for doc in result["source_documents"]]
|
||
|
}
|
||
|
except Exception as e:
|
||
|
logger.error(f"查询问题时发生错误: {str(e)}", exc_info=True)
|
||
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
||
|
@router.post("/clear")
|
||
|
async def clear_vector_store():
|
||
|
"""清空向量存储"""
|
||
|
try:
|
||
|
vector_store_service.clear()
|
||
|
document_service.clear_processed_files() # 同时清空文件处理记录
|
||
|
return {"message": "向量存储和文件处理记录已清空"}
|
||
|
except Exception as e:
|
||
|
logger.error(f"清空向量存储时发生错误: {str(e)}", exc_info=True)
|
||
|
raise HTTPException(status_code=500, detail=str(e))
|