You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
48 lines
1.6 KiB
48 lines
1.6 KiB
1 month ago
|
from typing import List
|
||
|
from langchain.chains import RetrievalQA
|
||
|
from langchain_ollama import OllamaLLM
|
||
|
from langchain.prompts import PromptTemplate
|
||
|
from app.services.vector_store_service import VectorStoreService
|
||
|
|
||
|
class RAGChain:
|
||
|
def __init__(self, vector_store_service: VectorStoreService):
|
||
|
self.vector_store = vector_store_service
|
||
|
self.llm = OllamaLLM(model="qwen2.5:latest")
|
||
|
self.qa_chain = self._create_qa_chain()
|
||
|
|
||
|
def _create_qa_chain(self) -> RetrievalQA:
|
||
|
"""创建问答链"""
|
||
|
prompt_template = """你是一个专业的问答助手。请基于以下上下文信息来回答问题。如果上下文中没有足够的信息来回答问题,请明确说明"根据提供的上下文,我无法回答这个问题"。
|
||
|
|
||
|
请遵循以下规则:
|
||
|
1. 只使用提供的上下文信息来回答问题
|
||
|
2. 如果上下文信息不足,不要编造答案
|
||
|
3. 如果上下文信息有冲突,请指出这一点
|
||
|
4. 回答要简洁、准确、专业
|
||
|
|
||
|
上下文信息:
|
||
|
{context}
|
||
|
|
||
|
问题: {question}
|
||
|
|
||
|
请提供回答:"""
|
||
|
|
||
|
PROMPT = PromptTemplate(
|
||
|
template=prompt_template, input_variables=["context", "question"]
|
||
|
)
|
||
|
|
||
|
chain_type_kwargs = {"prompt": PROMPT}
|
||
|
|
||
|
return RetrievalQA.from_chain_type(
|
||
|
llm=self.llm,
|
||
|
chain_type="stuff",
|
||
|
retriever=self.vector_store.vector_store.as_retriever(
|
||
|
search_kwargs={"k": 6}
|
||
|
),
|
||
|
chain_type_kwargs=chain_type_kwargs,
|
||
|
return_source_documents=True
|
||
|
)
|
||
|
|
||
|
def query(self, question: str) -> dict:
|
||
|
"""查询问题"""
|
||
|
return self.qa_chain({"query": question})
|