You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

48 lines
1.6 KiB

1 month ago
from typing import List
from langchain.chains import RetrievalQA
from langchain_ollama import OllamaLLM
from langchain.prompts import PromptTemplate
from app.services.vector_store_service import VectorStoreService
class RAGChain:
def __init__(self, vector_store_service: VectorStoreService):
self.vector_store = vector_store_service
self.llm = OllamaLLM(model="qwen2.5:latest")
self.qa_chain = self._create_qa_chain()
def _create_qa_chain(self) -> RetrievalQA:
"""创建问答链"""
prompt_template = """你是一个专业的问答助手。请基于以下上下文信息来回答问题。如果上下文中没有足够的信息来回答问题,请明确说明"根据提供的上下文,我无法回答这个问题"
请遵循以下规则
1. 只使用提供的上下文信息来回答问题
2. 如果上下文信息不足不要编造答案
3. 如果上下文信息有冲突请指出这一点
4. 回答要简洁准确专业
上下文信息:
{context}
问题: {question}
请提供回答:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
chain_type_kwargs = {"prompt": PROMPT}
return RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.vector_store.vector_store.as_retriever(
search_kwargs={"k": 6}
),
chain_type_kwargs=chain_type_kwargs,
return_source_documents=True
)
def query(self, question: str) -> dict:
"""查询问题"""
return self.qa_chain({"query": question})