You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
3.5 KiB
93 lines
3.5 KiB
from typing import List
|
|
from langchain.schema import Document
|
|
from langchain_chroma import Chroma
|
|
from langchain_ollama import OllamaEmbeddings
|
|
import chromadb
|
|
import os
|
|
import shutil
|
|
from app.core.logger import logger
|
|
|
|
class VectorStoreService:
|
|
def __init__(self, persist_directory: str = "data/chroma"):
|
|
self.persist_directory = persist_directory
|
|
self.embeddings = OllamaEmbeddings(model="qwen2.5:latest")
|
|
self.vector_store = None
|
|
self._initialize_vector_store()
|
|
|
|
def _initialize_vector_store(self):
|
|
"""初始化向量存储"""
|
|
try:
|
|
# 创建新目录
|
|
os.makedirs(self.persist_directory, exist_ok=True)
|
|
|
|
# 创建Chroma客户端
|
|
client = chromadb.PersistentClient(
|
|
path=self.persist_directory,
|
|
settings=chromadb.Settings(
|
|
anonymized_telemetry=False,
|
|
allow_reset=True
|
|
)
|
|
)
|
|
|
|
# 确保集合存在
|
|
try:
|
|
client.get_collection("langchain")
|
|
except Exception:
|
|
# 如果集合不存在,创建新集合
|
|
client.create_collection("langchain")
|
|
|
|
# 初始化向量存储
|
|
self.vector_store = Chroma(
|
|
client=client,
|
|
collection_name="langchain",
|
|
embedding_function=self.embeddings
|
|
)
|
|
logger.info("向量存储初始化成功")
|
|
except Exception as e:
|
|
logger.error(f"初始化向量存储时出错: {str(e)}", exc_info=True)
|
|
raise
|
|
|
|
def add_documents(self, documents: List[Document]):
|
|
"""添加文档到向量存储"""
|
|
if not documents:
|
|
return
|
|
try:
|
|
self.vector_store.add_documents(documents)
|
|
logger.info(f"成功添加 {len(documents)} 个文档到向量存储")
|
|
except Exception as e:
|
|
logger.error(f"添加文档到向量存储时出错: {str(e)}", exc_info=True)
|
|
raise
|
|
|
|
def similarity_search(self, query: str, k: int = 6, score_threshold: float = 0.7) -> List[Document]:
|
|
"""相似度搜索
|
|
|
|
Args:
|
|
query (str): 查询文本
|
|
k (int): 返回的文档数量
|
|
score_threshold (float): 相似度阈值,只返回相似度大于此值的文档
|
|
"""
|
|
try:
|
|
# 获取带分数的搜索结果
|
|
results = self.vector_store.similarity_search_with_score(query, k=k)
|
|
# 过滤相似度低于阈值的文档
|
|
filtered_results = [doc for doc, score in results if score >= score_threshold]
|
|
# 如果没有文档通过阈值过滤,返回相似度最高的文档
|
|
if not filtered_results and results:
|
|
return [results[0][0]]
|
|
return filtered_results
|
|
except Exception as e:
|
|
logger.error(f"执行相似度搜索时出错: {str(e)}", exc_info=True)
|
|
raise
|
|
|
|
def clear(self):
|
|
"""清空向量存储"""
|
|
try:
|
|
if self.vector_store:
|
|
# 删除集合
|
|
self.vector_store.delete_collection()
|
|
# 重新初始化向量存储
|
|
self._initialize_vector_store()
|
|
logger.info("向量存储已清空并重新初始化")
|
|
except Exception as e:
|
|
logger.error(f"清空向量存储时出错: {str(e)}", exc_info=True)
|
|
raise
|