1. 伪文档生成法(Query-to-Document) #
伪文档生成法(Query-to-Document)是一种在信息检索流程中提升检索结果相关性的重要技术。其主要思路是,用户输入查询后,先由大语言模型(LLM)基于该查询生成一段伪文档。这个伪文档可能包括对原始查询的进一步解释、相关背景知识、潜在答案片段,或者是更丰富的同义词与关键词补充。
伪文档不仅弥补了原始短查询可能缺失的信息,而且通过包含更多上下文,有助于覆盖更多相关文档,提高召回率。随后,将原始查询与伪文档拼接形成增强查询(augmented query),再送入向量化和检索流程。这样,检索系统能更好地理解和匹配用户意图,尤其在知识库规模大、查询本身语意不充分的场景下,效果尤为明显。
简单来说,伪文档生成法的关键步骤包括:
- 调用LLM,根据用户查询生成伪文档(包含所需知识、扩展表述、补充背景等)。
- 将伪文档与原始查询拼接形成增强型检索输入。
- 对增强后的查询进行向量化或关键词检索,提升原始检索效果。
应用场景:
- 用户输入模糊、信息不足的查询时,通过自动补全信息实现更优检索。
- 面向多领域或大规模知识库时,增加召回率与相关性。
- RAG(Retrieval-Augmented Generation)等检索增强生成任务。
小结:伪文档生成法是实现智能检索系统、提升用户查询体验的核心技术之一,在RAG等方案中有着非常广泛的应用和实用价值。
1.1 QueryToDocument.py #
# 引入类型注解相关的库
from typing import List, Dict, Any, Optional
# 引入基础检索器类
from langchain_core.retrievers import BaseRetriever
# 引入文档对象
from langchain_core.documents import Document
# 引入语言模型基类
from langchain_core.language_models import BaseLanguageModel
# 引入提示词模板类
from langchain_core.prompts import PromptTemplate
# 引入自定义llm对象
from llm import llm
# 引入自定义embedding对象
from embeddings import embeddings
# 引入获取向量库的函数
from vector_store import get_vector_store
# 定义QueryToDocumentRetriever检索器,继承自BaseRetriever
class QueryToDocumentRetriever(BaseRetriever):
# 定义向量库属性
vector_store: Any
# 定义LLM属性
llm: BaseLanguageModel
# 定义embedding属性
embeddings: Any
# k 默认检索文档数量
k: int = 4
# 可选的自定义伪文档提示词模板字符串
pseudo_document_prompt_template_str: Optional[str] = None
# 获取(或生成)伪文档的prompt模板
def _get_prompt_template(self) -> PromptTemplate:
# 如果定义了自定义模板则使用
if self.pseudo_document_prompt_template_str:
return PromptTemplate(
input_variables=["query"],
template=self.pseudo_document_prompt_template_str
)
# 否则使用内置模板
return PromptTemplate(
input_variables=["query"],
template="""请根据以下用户查询,生成一段伪文档,需包含关键信息、相关术语及必要背景,并自然流畅呈现。\n用户查询:{query}\n伪文档:"""
)
# 根据查询,通过LLM生成伪文档
def generate_pseudo_document(self, query: str) -> str:
# 获取提示词模板并填充
prompt = self._get_prompt_template().format(query=query)
# 用LLM生成伪文档
response = self.llm.invoke(prompt)
# 返回伪文档内容,去除首尾空格
return response.content.strip()
# 将原查询和伪文档拼接为增强查询
def enhance_query(self, query: str, pseudo_document: str) -> str:
return f"{query}\n\n{pseudo_document}"
# 检索相关文档
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
# 获取需检索的文档数k
k = kwargs.get("k", self.k)
# 生成伪文档
pseudo_document = self.generate_pseudo_document(query)
# 构建增强后的查询
enhanced_query = self.enhance_query(query, pseudo_document)
# 用增强查询向量库检索,获取top-k文档(包含分数)
docs = self.vector_store.similarity_search_with_score(enhanced_query, k=k)
# 构建Document对象并附加相关元信息
return [
Document(
page_content=doc.page_content,
metadata={
**doc.metadata,
"score": float(distance),
"retrieval_method": "query_to_document",
"original_query": query,
"pseudo_document_length": len(pseudo_document)
}
) for doc, distance in docs
]
# 定义RAG系统类,实现拼接、生成最终答案
class QueryToDocumentRAG:
# 构造方法,传入检索器、LLM和可选回答模板
def __init__(
self,
retriever: QueryToDocumentRetriever,
llm: BaseLanguageModel,
answer_prompt_template: Optional[str] = None
):
self.retriever = retriever
self.llm = llm
# 初始化答案生成的PromptTemplate
self.answer_prompt_template = PromptTemplate(
input_variables=["context", "query"],
template=answer_prompt_template or
"已知相关信息:\n{context}\n请基于上述信息回答用户问题,如信息不足请指出。\n用户问题:{query}\n答案:"
)
# 生成最终答案的方法
def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
# 检索相关文档
docs = self.retriever._get_relevant_documents(query, k=k)
# 整合文档内容用于llm生成答案
context = "\n\n".join([doc.page_content for doc in docs])
# 用prompt模板拼接最终提示词
prompt = self.answer_prompt_template.format(context=context, query=query)
# 调用llm生成答案
response = self.llm.invoke(prompt)
# 返回包含问题、答案、检索文档和数量的字典
return {
"query": query,
"answer": response.content.strip(),
"retrieved_documents": docs,
"num_documents": len(docs)
}
# 初始化向量库
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="query_to_document"
)
# 初始化QueryToDocumentRetriever对象
retriever = QueryToDocumentRetriever(
vector_store=vector_store,
llm=llm,
embeddings=embeddings,
k=4
)
# 初始化RAG系统对象
rag_system = QueryToDocumentRAG(
retriever=retriever,
llm=llm
)
# 构造示例文档数据
documents = [
"人工智能(AI)是计算机科学的分支,旨在使系统完成需人类智能的任务。机器学习让计算机从数据学习,无需明确编程。深度学习通过神经网络模拟人脑。自然语言处理(NLP)让计算机理解人类语言。",
"区块链是一种分布式账本技术,用密码学保障安全和不可篡改。比特币是区块链的首次大规模应用,解决双重支付难题。以太坊支持智能合约,便于去中心化应用开发。智能合约自动执行,无需中介。",
"量子计算利用量子力学的新型计算方式,极具潜力。量子比特可处于0和1的叠加态。量子纠缠是其关键特性。量子计算或将应用于密码学、药物发现与优化领域。"
]
# 向向量库中批量添加文档及对应元数据
vector_store.add_texts(
documents,
metadatas=[
{"topic": "人工智能", "category": "科技", "doc_id": "ai_1"},
{"topic": "区块链", "category": "科技", "doc_id": "blockchain_1"},
{"topic": "量子计算", "category": "科技", "doc_id": "quantum_1"},
]
)
# 设定一个示例用户查询
query = "什么是机器学习?它和深度学习有什么关系?"
# 执行完整RAG流程,检索3个相关文档
result = rag_system.generate_answer(query, k=3)
# 打印流程和最终输出结果的分隔线
print("\n" + "="*60)
print("最终结果")
print("="*60)
# 打印用户查询内容
print(f"\n用户查询: {result['query']}")
# 打印生成的答案
print(f"\n生成的答案:\n{result['answer']}")
# 打印检索到的文档数量
print(f"\n检索到的文档数量: {result['num_documents']}")
# 打印检索到的每个文档及相关信息
print(f"\n检索到的文档详情:")
for i, doc in enumerate(result['retrieved_documents'], 1):
print(f"\n文档 {i} (分数: {doc.metadata.get('score', 'N/A'):.4f}):")
print(f"内容: {doc.page_content[:200]}...")
print(f"元数据: {doc.metadata}")1.2 执行流程 #
1.2.1 核心思想 #
Query-to-Document 采用“查询扩展”策略:
- 根据用户查询,使用 LLM 生成一个伪文档(pseudo document)
- 伪文档包含关键信息、相关术语和必要背景
- 将原始查询和伪文档拼接成增强查询
- 使用增强查询在向量库中进行检索
- 返回检索结果并生成最终答案
1.2.2 执行流程 #
阶段一:初始化
# 1. 获取向量存储实例
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="query_to_document"
)
# 2. 创建QueryToDocumentRetriever检索器
retriever = QueryToDocumentRetriever(
vector_store=vector_store,
llm=llm,
embeddings=embeddings,
k=4
)
# 3. 创建RAG系统
rag_system = QueryToDocumentRAG(
retriever=retriever,
llm=llm
)初始化时:
- 创建向量存储实例
- 创建检索器,配置 LLM、embeddings 和默认 k 值
- 创建 RAG 系统,用于完整流程
阶段二:文档索引
# 添加文档到向量库
vector_store.add_texts(
documents,
metadatas=[...]
)索引过程:
- 将示例文档添加到向量库
- 每个文档附带元数据(topic、category、doc_id)
- 文档会被向量化并存储
阶段三:查询处理
query = "什么是机器学习?它和深度学习有什么关系?"
result = rag_system.generate_answer(query, k=3)完整流程:
- 用户提交查询
- 生成伪文档:
- 调用
generate_pseudo_document(query) - 使用提示词模板生成伪文档
- LLM 生成包含关键信息、术语和背景的文本
- 调用
- 增强查询:
- 调用
enhance_query(query, pseudo_document) - 将原始查询和伪文档拼接
- 调用
- 检索文档:
- 使用增强查询在向量库中检索
- 返回 top-k 个相关文档(带分数)
- 生成答案:
- 整合检索到的文档内容作为上下文
- 使用提示词模板构建最终 prompt
- LLM 生成答案
- 返回结果:
- 包含查询、答案、检索文档和数量
1.2.3 类图 #
1.2.4 时序图 #
1.2.4.1 完整RAG流程时序图 #
1.2.4.2 伪文档生成详细流程 #
1.2.5 关键设计要点 #
1. 查询增强流程
用户查询
↓
生成伪文档 (LLM)
↓
增强查询 = 原始查询 + 伪文档
↓
向量检索 (使用增强查询)
↓
返回相关文档
↓
生成最终答案2. 伪文档的作用
伪文档包含:
- 关键信息:查询相关的核心概念
- 相关术语:专业词汇和同义词
- 必要背景:上下文信息
- 自然流畅:连贯的文本表达
示例:
用户查询: "什么是机器学习?"
伪文档可能生成:
"机器学习是人工智能的一个分支,它使计算机能够从数据中学习模式,
而无需明确编程。机器学习涉及算法、统计模型和数据分析技术,
广泛应用于图像识别、自然语言处理、推荐系统等领域。"3. 增强查询的优势
- 语义扩展:伪文档提供更多相关语义信息
- 术语匹配:包含同义词和相关术语,提高匹配率
- 上下文丰富:提供背景信息,帮助理解查询意图
- 检索精度:增强查询能更准确地匹配相关文档
4. 元数据设计
检索返回的 Document 对象包含:
score:相似度分数(距离,越小越相似)retrieval_method: "query_to_document":检索方法标识original_query:原始用户查询pseudo_document_length:伪文档长度- 继承原始文档的元数据(topic、category、doc_id 等)
1.2.6 与其他方法的对比 #
| 特性 | 传统检索 | Query-to-Document |
|---|---|---|
| 查询处理 | 直接使用原始查询 | 生成伪文档并增强查询 |
| 语义扩展 | 无 | 通过伪文档扩展 |
| 术语匹配 | 依赖查询中的术语 | 伪文档包含相关术语 |
| 计算成本 | 低 | 中等(需要LLM生成) |
| 适用场景 | 简单查询 | 复杂、模糊查询 |
1.2.7 优势与应用场景 #
优势:
- 提高检索精度:增强查询包含更多语义信息
- 处理模糊查询:伪文档帮助理解查询意图
- 术语扩展:自动包含相关术语和同义词
- 上下文理解:提供背景信息,改善匹配
适用场景:
- 复杂查询:需要语义扩展的查询
- 模糊查询:用户表达不够精确
- 专业领域:需要术语扩展的场景
- 多语言场景:需要跨语言语义理解
1.2.8 提示词模板设计 #
默认伪文档生成模板:
请根据以下用户查询,生成一段伪文档,需包含关键信息、
相关术语及必要背景,并自然流畅呈现。
用户查询:{query}
伪文档:默认答案生成模板:
已知相关信息:
{context}
请基于上述信息回答用户问题,如信息不足请指出。
用户问题:{query}
答案:该设计通过查询扩展提升检索质量,适用于需要语义扩展和上下文理解的 RAG 应用。
2. 假设文档向量化(Assume Document Vectorization) #
在RAG(Retrieval-Augmented Generation)流程中,检索质量直接影响最终答案的准确性。传统做法通常是直接用用户查询来做向量化检索,但由于自然语言查询本身的简洁性和主观性,往往不能覆盖用户潜在的多角度信息需求,因此导致召回的文档有限或者语义相关性不足。假设文档向量化法正是对此问题的优化。
其核心流程如下:
假设文档生成
给定用户原始查询,利用LLM(大语言模型)“扩写”出多个假设文档。这些假设文档旨在“假装”是合理相关的答案片段,从多个视角、更细粒度或更全面的表述对原始查询进行“扩展”。这样不仅弥补查询语义的不全,也能覆盖更多可能的信息需求。独立向量化
将原始查询和所有假设文档分别输入嵌入模型,得到各自的向量表示。每个假设文档的向量都是在不同上下文下对原查询的“信息补全”。平均向量合成
计算所有这些向量的算术平均值,把它作为增强后的“集成查询向量”。均值在理论上可视为多角度信息的“中心向量”,这样可提高检索相关文档的概率。向量检索
使用该平均向量去向量库中检索,能够显著提升与真实高相关文档的召回率,尤其对复杂问题、语义表达多变的问题尤为有效。
这种方案不仅简单(只需增加一次假设文档生成和向量平均),无需模型微调,而且对标准向量数据库和检索管道兼容性好。
典型应用场景举例:
- 面对高度开放性问题(如“如何提高新能源电池效率?”),假设文档向量化可以自动从材料、制造工艺、充放电策略等多角度激活潜在有效的语义空间;
- 复杂信息需求聚合型检索(如“机器学习与深度学习的异同?”)时,通过多种说法的假设文档更容易接近教材、论文等结构化知识内容。
2.1 AssumeDocumentVectorization.py #
# 假设文档向量化(Assume Document Vectorization)
"""
通过LLM辅助生成多角度假设文档,对原始查询和假设文档向量化后求平均向量,用于更精准的知识库检索与RAG增强。
"""
# 导入类型注解相关依赖
from typing import List, Dict, Any
# 导入基础检索器类
from langchain_core.retrievers import BaseRetriever
# 导入文档对象
from langchain_core.documents import Document
# 导入语言模型基类
from langchain_core.language_models import BaseLanguageModel
# 导入提示词模板
from langchain_core.prompts import PromptTemplate
# 导入自定义llm对象
from llm import llm
# 导入自定义embedding对象
from embeddings import embeddings
# 导入获取向量库的函数
from vector_store import get_vector_store
# 导入numpy用于向量计算
import numpy as np
# 定义假设文档向量化检索器,继承自BaseRetriever
class AssumeDocumentVectorizationRetriever(BaseRetriever):
# 定义向量库属性
vector_store: Any
# 定义LLM属性
llm: BaseLanguageModel
# 定义embedding属性
embeddings: Any
# 默认检索文档数量
k: int = 4
# 生成假设文档的数量
num_hypothetical_documents: int = 3
# 获取不同角度的Prompt模板
def _get_prompt_templates(self) -> List[PromptTemplate]:
# 定义角度列表
perspectives = [
"学术研究角度", "实际应用角度", "基础概念角度"
]
# 针对不同角度生成PromptTemplate对象
return [
PromptTemplate(
input_variables=["query", "perspective"],
template=f"请从{perspectives[i]},针对如下问题生成回答:{{query}}\n详细叙述。"
) for i in range(self.num_hypothetical_documents)
]
# 通过LLM生成多角度假设文档
def generate_hypothetical_documents(self, query: str) -> List[str]:
# 存储生成的文档内容
docs = []
# 遍历每个Prompt模板
for t in self._get_prompt_templates():
# 获取当前角度描述
p = t.template.split("从")[1].split(",")[0] if "从" in t.template else ""
# 填充Prompt
prompt = t.format(query=query, perspective=p)
# 用LLM生成内容并去除首尾空白
content = self.llm.invoke(prompt).content.strip()
# 添加到文档列表
docs.append(content)
# 返回生成的假设文档列表
return docs
# 将查询和假设文档批量向量化
def vectorize_documents(self, query: str, docs: List[str]) -> List[List[float]]:
# 首先对原始查询进行向量化
vectors = [self.embeddings.embed_query(query)]
# 对每个假设文档进行向量化
for doc in docs:
vectors.append(self.embeddings.embed_query(doc))
# 返回得到的全部向量
return vectors
# 计算所有向量的平均向量
def calculate_average_vector(self, vectors: List[List[float]]) -> List[float]:
# 使用numpy计算均值
return np.mean(np.array(vectors), axis=0).tolist()
# 用指定向量进行向量库检索
def retrieve_with_vector(self, query_vector: List[float], k: int) -> List[tuple]:
# 直接用底层Chroma接口进行向量检索
r = self.vector_store._collection.query(query_embeddings=[query_vector], n_results=k)
# 获取检索到的文档内容
docs = r.get("documents", [[]])[0]
# 获取距离值
dists = r.get("distances", [[]])[0]
# 获取元数据
metas = r.get("metadatas", [[]])[0] if r.get("metadatas", [[]]) else [{}]*len(docs)
# 将文档内容、距离、元数据打包
return list(zip(docs, dists, metas))
# 供RAG主流程调用的相关文档检索方法
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
# 获取需检索的文档数k
k = kwargs.get("k", self.k)
# 生成多角度假设文档文本
hypos = self.generate_hypothetical_documents(query)
# 对原查询和假设文档全部向量化
vectors = self.vectorize_documents(query, hypos)
# 计算向量平均作为查询向量
avg_vec = self.calculate_average_vector(vectors)
# 使用平均向量进行检索
results = self.retrieve_with_vector(avg_vec, k)
# 构建Document对象并附加分数元信息
return [
Document(
page_content=txt,
metadata={**meta, "score": float(dist)}
)
for txt, dist, meta in results
]
# 定义假设文档向量化RAG系统
class AssumeDocumentVectorizationRAG:
# 构造方法,传入检索器和LLM
def __init__(self, retriever: AssumeDocumentVectorizationRetriever, llm: BaseLanguageModel):
# 存储检索器
self.retriever = retriever
# 存储llm
self.llm = llm
# 定义答案生成Prompt模板
self.answer_prompt_template = PromptTemplate(
input_variables=["context", "query"],
template="""已知信息:{context}\n请回答:{query}\n答案:"""
)
# 生成最终答案的方法
def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
# 检索相关文档
docs = self.retriever._get_relevant_documents(query, k=k)
# 整合检索到的文档内容
context = "\n".join([doc.page_content for doc in docs])
# 用Prompt模板生成回答用的提示词
prompt = self.answer_prompt_template.format(context=context, query=query)
# 调用LLM生成答案
answer = self.llm.invoke(prompt).content.strip()
# 返回结果字典
return {
"query": query,
"answer": answer,
"retrieved_documents": docs,
"num_documents": len(docs)
}
# 初始化向量库实例
vector_store = get_vector_store(persist_directory="chroma_db", collection_name="assume_document_vectorization")
# 初始化假设文档向量化检索器
retriever = AssumeDocumentVectorizationRetriever(
vector_store=vector_store, llm=llm, embeddings=embeddings, k=4, num_hypothetical_documents=3
)
# 初始化RAG主流程类
rag_system = AssumeDocumentVectorizationRAG(retriever=retriever, llm=llm)
# 构造示例文档数据
documents = [
"人工智能(AI)是计算机科学的分支,致力于让计算机完成类似人类智能的任务。机器学习让计算机从数据中学习,无需显式编程。深度学习用神经网络模拟人脑。NLP使计算机理解和生成自然语言。",
"区块链是一种分布式账本技术,以密码学保障安全与不可篡改。比特币是首个区块链加密货币,解决了双重支付难题。以太坊支持智能合约,可构建去中心化应用。智能合约自动执行,无需中介。",
"量子计算利用量子力学进行信息处理,具巨大潜力。量子比特可叠加多态,量子纠缠是其核心特性。量子计算或将在密码学、药物等领域带来变革。",
]
# 向向量库中添加文档及元数据
vector_store.add_texts(
documents,
metadatas=[
{"topic": "人工智能", "category": "科技"},
{"topic": "区块链", "category": "科技"},
{"topic": "量子计算", "category": "科技"},
]
)
# 设定用户查询问题
query = "什么是机器学习?它和深度学习有什么关系?"
# 运行RAG流程,检索3个相关文档并生成答案
result = rag_system.generate_answer(query, k=3)
# 打印用户查询
print("\n用户查询:", result['query'])
# 打印生成的答案
print("\n生成的答案:\n", result['answer'])
# 打印检索到的文档数量
print("\n检索到的文档数量:", result['num_documents'])
# 打印检索到的每个文档的部分内容及元数据
for i, doc in enumerate(result['retrieved_documents'], 1):
print(f"\n文档{i}: {doc.page_content[:80]}...")
print(f"元数据: {doc.metadata}")2.2 执行流程 #
2.2.1 核心思想 #
假设文档向量化采用“多角度假设 + 向量平均”策略:
- 根据用户查询,使用 LLM 从多个角度生成假设文档
- 对原始查询和所有假设文档进行向量化
- 计算所有向量的平均向量作为查询向量
- 使用平均向量在向量库中进行检索
- 返回检索结果并生成最终答案
2.2.2 执行流程 #
阶段一:初始化
# 1. 获取向量存储实例
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="assume_document_vectorization"
)
# 2. 创建假设文档向量化检索器
retriever = AssumeDocumentVectorizationRetriever(
vector_store=vector_store,
llm=llm,
embeddings=embeddings,
k=4,
num_hypothetical_documents=3 # 生成3个假设文档
)
# 3. 创建RAG系统
rag_system = AssumeDocumentVectorizationRAG(
retriever=retriever,
llm=llm
)初始化时:
- 创建向量存储实例
- 创建检索器,配置 LLM、embeddings、默认 k 值和假设文档数量
- 创建 RAG 系统,用于完整流程
阶段二:文档索引
# 添加文档到向量库
vector_store.add_texts(
documents,
metadatas=[...]
)索引过程:
- 将示例文档添加到向量库
- 每个文档附带元数据(topic、category)
- 文档会被向量化并存储
阶段三:查询处理
query = "什么是机器学习?它和深度学习有什么关系?"
result = rag_system.generate_answer(query, k=3)完整流程:
- 用户提交查询
- 生成多角度假设文档:
- 调用
generate_hypothetical_documents(query) - 从3个角度生成假设文档(学术研究、实际应用、基础概念)
- 每个角度使用不同的提示词模板
- 调用
- 向量化处理:
- 调用
vectorize_documents(query, hypos) - 对原始查询进行向量化
- 对每个假设文档进行向量化
- 调用
- 计算平均向量:
- 调用
calculate_average_vector(vectors) - 使用 numpy 计算所有向量的平均值
- 调用
- 向量检索:
- 调用
retrieve_with_vector(avg_vec, k) - 使用平均向量在向量库中检索
- 返回 top-k 个相关文档(带分数)
- 调用
- 生成答案:
- 整合检索到的文档内容作为上下文
- 使用提示词模板构建最终 prompt
- LLM 生成答案
- 返回结果:
- 包含查询、答案、检索文档和数量
2.2.3 类图 #
2.2.4 时序图 #
2.2.4.1 完整RAG流程时序图 #
2.2.4.2 多角度假设文档生成详细流程 #
2.2.4.3 向量平均与检索详细流程 #
2.2.5 关键设计要点 #
1. 多角度假设文档生成流程
用户查询
↓
从3个角度生成假设文档:
- 学术研究角度
- 实际应用角度
- 基础概念角度
↓
得到3个假设文档
↓
向量化: [查询向量, 假设文档1向量, 假设文档2向量, 假设文档3向量]
↓
计算平均向量
↓
使用平均向量检索
↓
返回相关文档2. 向量平均的优势
- 语义融合:融合原始查询和多个角度的语义信息
- 提高精度:平均向量能更好地代表查询意图
- 多角度覆盖:涵盖不同角度的语义表达
- 鲁棒性:减少单一角度可能带来的偏差
3. 角度设计
默认三个角度:
- 学术研究角度:理论、研究、学术视角
- 实际应用角度:实践、应用、案例视角
- 基础概念角度:概念、定义、基础视角
示例:
用户查询: "什么是机器学习?"
学术研究角度假设文档:
"机器学习是人工智能领域的重要研究方向,涉及统计学、
算法理论和计算复杂性等学科。研究者通过设计算法使计算
机能够从数据中发现模式,建立预测模型..."
实际应用角度假设文档:
"机器学习在现实生活中广泛应用,如推荐系统、图像识别、
语音助手等。企业利用机器学习技术分析用户行为,提供
个性化服务,提升用户体验和业务效率..."
基础概念角度假设文档:
"机器学习是一种让计算机从数据中学习的方法,无需明确
编程。它通过训练数据建立模型,能够对新数据进行预测
或分类。主要包括监督学习、无监督学习和强化学习..."4. 向量计算过程
数学表示:
设原始查询向量为 $\mathbf{q}$,假设文档向量为 $\mathbf{h}_1, \mathbf{h}_2, \ldots, \mathbf{h}_n$,则平均向量为:
$$ v_{\text{avg}} = \frac{1}{n+1}(q + \sum_{i=1}^{n} h_i) $$
其中 $n$ 为假设文档数量(默认3)。
2.2.6 与其他方法的对比 #
| 特性 | Query-to-Document | Assume Document Vectorization |
|---|---|---|
| 生成内容 | 单个伪文档 | 多个角度假设文档 |
| 向量处理 | 直接使用增强查询 | 计算平均向量 |
| 语义融合 | 文本拼接 | 向量平均 |
| 角度覆盖 | 单一视角 | 多角度视角 |
| 计算成本 | 中等 | 较高(多文档生成+向量计算) |
| 适用场景 | 简单扩展 | 复杂多角度查询 |
2.2.7. 优势与应用场景 #
优势:
- 多角度语义融合:从不同角度理解查询意图
- 提高检索精度:平均向量能更准确地匹配相关文档
- 鲁棒性强:减少单一角度的偏差
- 语义丰富:假设文档提供更多上下文信息
适用场景:
- 复杂查询:需要多角度理解的查询
- 专业领域:需要从不同视角分析的问题
- 知识检索:需要全面覆盖相关知识的场景
- 研究场景:需要学术和应用双重视角
2.2.8. 技术细节 #
- 向量维度:所有向量必须具有相同的维度
- 平均计算:使用 numpy 的
mean()函数,axis=0 表示按列平均 - 检索接口:直接使用 Chroma 的底层
_collection.query()方法 - 结果处理:将文档、距离、元数据打包为 Document 对象
该设计通过多角度假设文档和向量平均,在复杂查询场景下提供更准确的语义匹配,适用于需要多角度理解的 RAG 应用。
3. 问题分解策略(Sub-Question Decomposition) #
核心思想与流程
对于复杂的用户查询,直接用整体检索往往效果有限,因为原始问题结构复杂、包含若干子意图和层层递进的知识点。问题分解策略旨在用LLM先自动将复杂问题拆解(decompose)为多个更简单但相关的子问题,每个子问题都聚焦某一个点或层面。然后,分别用每个子问题对知识库进行检索,获得各自相关的文本块。最后,将各子问题检索得到的文档合并、去重,并输入LLM进行最终答案生成,确保全面覆盖原问题的所有细节。
整个流程如下图所示:
┌──────────┐
│ 复杂查询 │
└────┬─────┘
│(LLM分解)
┌───────────┼────────────┐
│ │ │
┌──────▼─────┐┌────▼────┐┌─────▼─────┐
│ 子问题1 ││ 子问题2 ││ 子问题3 │ ...
└────┬───────┘└───┬─────┘└────┬──────┘
│检索 │检索 │检索
┌────▼──┐ ┌────▼──┐ ┌───▼────┐
│ 文档A │ │ 文档B │ │ 文档C │ ...
└─────┬─┘ └──────┬┘ └─┬──────┘
└────┬─────┬───┴────┬──┘
│合并+去重 │
▼
┌────────────┐
│ 检索上下文 │
└────┬───────┘
│
┌──────▼─────┐
│ LLM生成答案 │
└──────┬─────┘
│
┌─────▼──────┐
│ 最终答案 │
└────────────┘RAG流程对比说明
- 传统RAG:用户原始大问题→embedding→检索→生成答案;对于多子意图大问题,检索效果易有限。
- 分解优化RAG:用户原始大问题→LLM分解为多个子问题→每个子问题分别embedding检索→合并所有文本→生成答案。这样可以保持对每个知识点的覆盖与更细粒度的检索,提升答案完整性和准确率。
适用场景
- 查询包含并列/递进的若干小问:如“请阐述A与B的区别以及各自应用”。
- 复杂链式推理问题:如“请解释X的原理,并列举三个实际场景”。
- 希望保证大问题的每个方面都被充分覆盖和回答。
关键技术点
- 问题拆解:调用LLM,给定复杂的问题,输出按顺序排列的子问题清单(如返回一个子问题列表)。
- 子问题检索:对每个子问题分别用embedding在向量库中相似度检索。
- 去重合并:将所有检索到的文本去重合并,防止重复内容。
- 回答生成:LLM以问题分解列表和全部检索材料为上下文,综合生成高质量长答案。
代码实现要点
- 封装SubQuestionDecompositionRetriever类,实现包含问题拆解、子问题批量检索与结果归并等流程的方法。
- 提供自动“问题-子问题生成”的Prompt设计、调用及结果解析。
- 检索时为每个文档标注其对应的子问题来源,便于后续分析和复现。
3.1 SubQuestionDecomposition.py #
# 导入类型注解所需的相关类型
from typing import List, Dict, Any, Optional
# 导入pydantic配置工具
from pydantic import ConfigDict
# 导入基础检索器基类
from langchain_core.retrievers import BaseRetriever
# 导入文档类型
from langchain_core.documents import Document
# 导入LLM基础类型
from langchain_core.language_models import BaseLanguageModel
# 导入提示词模板
from langchain_core.prompts import PromptTemplate
# 导入自定义llm对象
from llm import llm
# 导入获取向量库函数
from vector_store import get_vector_store
# 导入正则模块
import re
# 定义子问题分解检索器类,继承自BaseRetriever
class SubQuestionDecompositionRetriever(BaseRetriever):
# 向量库对象
vector_store: Any
# 大语言模型对象
llm: BaseLanguageModel
# 检索每个子问题的文档数量
k: int = 4
# 子问题最少数量
min_sub_questions: int = 3
# 子问题最多数量
max_sub_questions: int = 5
# 可选的分解prompt模板字符串
decomposition_prompt_template_str: Optional[str] = None
# 配置模型参数,允许任意类型
model_config = ConfigDict(arbitrary_types_allowed=True)
# 获取问题分解prompt模板的方法
def _get_decomposition_prompt_template(self) -> PromptTemplate:
# 获取或默认生成用于子问题分解的Prompt模板字符串
template = self.decomposition_prompt_template_str or (
"请将以下复杂问题拆解为3-5个相关的子问题。\n"
"每个子问题应该:\n"
"1. 独立且具体\n2. 能够帮助回答原始问题\n3. 覆盖原始问题的不同方面\n\n"
"原始问题:{query}\n请列出子问题(每行一个,以问号结尾):"
)
# 返回PromptTemplate对象
return PromptTemplate(input_variables=["query"], template=template)
# 调用LLM将查询分解为若干子问题
def decompose_question(self, query: str) -> List[str]:
# 构建分解prompt
prompt = self._get_decomposition_prompt_template().format(query=query)
# 调用LLM进行内容生成
resp = self.llm.invoke(prompt).content.strip()
sub_questions = []
# 遍历LLM返回的每一行,解析子问题文本
for line in resp.split('\n'):
# 去除前缀编号等符号
line = re.sub(r'^\d+[\.、]\s*|^[-•]\s*', '', line.strip())
# 只保留含中文或英文问号的行
if line and ('?' in line or '?' in line):
sub_questions.append(line)
# 限制子问题数量不超过最大值
sub_questions = sub_questions[:self.max_sub_questions]
# 若数量少于最小,回退为原始问题
if len(sub_questions) < self.min_sub_questions:
sub_questions = [query] if not sub_questions else sub_questions
# 返回子问题列表
return sub_questions
# 针对子问题列表依次检索相似文档
def retrieve_for_sub_questions(self, sub_questions: List[str], k: int) -> List[Document]:
all_docs = []
# 遍历每个子问题及下标
for i, sub_q in enumerate(sub_questions, 1):
# 对当前子问题进行向量相似性检索
docs = self.vector_store.similarity_search_with_score(sub_q, k=k)
# 将检索结果统一封装为Document对象,并增加子问题相关信息
for doc, distance in docs:
all_docs.append(Document(
page_content=doc.page_content,
metadata={**doc.metadata, "score": float(distance), "sub_question": sub_q, "sub_question_index": i}
))
# 返回所有结果列表
return all_docs
# 对检索到的文档内容进行去重与归并
def deduplicate_and_merge(self, all_docs: List[Document]) -> List[Document]:
unique = {}
# 遍历所有文档,使用内容做去重键,保留分数更小者
for doc in all_docs:
c = doc.page_content
s = doc.metadata.get("score", float('inf'))
if c not in unique or s < unique[c].metadata.get("score", float('inf')):
unique[c] = doc
# 对去重后的文档按分数升序排序
docs = sorted(unique.values(), key=lambda x: x.metadata.get("score", float('inf')))
# 返回排序后的文档列表
return docs
# 主入口:根据查询获取相关文档流程
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
# 获取本次检索的k值(可动态传递)
k = kwargs.get("k", self.k)
# 子问题分解
sqs = self.decompose_question(query)
# 针对子问题检索
docs = self.retrieve_for_sub_questions(sqs, k)
# 去重合并并返回
return self.deduplicate_and_merge(docs)
# 定义子问题分解RAG主流程类
class SubQuestionDecompositionRAG:
# 初始化,注入检索器和LLM,支持自定义答案生成Prompt
def __init__(self, retriever: SubQuestionDecompositionRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str]=None):
self.retriever = retriever
self.llm = llm
# 设置答案生成Prompt模板,支持自定义或默认
template = answer_prompt_template or (
"已知以下相关信息:\n\n{context}\n\n原始问题:{query}\n\n"
"该问题被分解为以下子问题:\n{sub_questions}\n\n"
"请根据上述信息,全面回答原始问题。确保覆盖所有子问题的答案。\n\n答案:"
)
# 构建最终答案生成的PromptTemplate对象
self.answer_prompt_template = PromptTemplate(
input_variables=["context", "query", "sub_questions"], template=template
)
# 输入查询及可选k,返回答案及检索相关信息
def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
# 检索相关文档
docs = self.retriever._get_relevant_documents(query, k=k)
# 汇总所有子问题(去重)
sub_questions = list(set([doc.metadata.get("sub_question","") for doc in docs if doc.metadata.get("sub_question")]))
# 子问题格式化
sub_questions_text = "\n".join([f"{i+1}. {sq}" for i, sq in enumerate(sub_questions)])
# 合并上下文内容
context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
# 构建答案生成Prompt
prompt = self.answer_prompt_template.format(context=context, query=query, sub_questions=sub_questions_text)
# 调用LLM生成答案
answer = self.llm.invoke(prompt).content.strip()
# 整理并返回完整流程结果
return {
"query": query,
"answer": answer,
"sub_questions": sub_questions,
"retrieved_documents": docs,
"num_documents": len(docs)
}
# 初始化向量库对象
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="sub_question_decomposition"
)
# 初始化子问题分解检索器
retriever = SubQuestionDecompositionRetriever(
vector_store=vector_store,
llm=llm,
k=3,
min_sub_questions=3,
max_sub_questions=5
)
# 初始化RAG系统
rag_system = SubQuestionDecompositionRAG(
retriever=retriever,
llm=llm
)
# 定义示例文档内容列表
documents = [
"人工智能(AI)是计算机科学的一个分支,旨在创建能够执行通常需要人类智能的任务的系统。"
"机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习,而无需明确编程。"
"深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。"
"自然语言处理(NLP)是AI的另一个重要领域,专注于使计算机能够理解和生成人类语言。",
"区块链技术是一种分布式账本技术,通过密码学方法确保数据的安全性和不可篡改性。"
"比特币是第一个成功应用区块链技术的加密货币,它解决了数字货币的双重支付问题。"
"以太坊是一个支持智能合约的区块链平台,允许开发者在其上构建去中心化应用(DApps)。"
"智能合约是自动执行的合约,其条款直接写入代码中,无需第三方中介。",
"量子计算是一种利用量子力学现象进行计算的新兴技术,具有巨大的计算潜力。"
"量子比特(qubit)是量子计算的基本单位,与经典比特不同,它可以同时处于0和1的叠加态。"
"量子纠缠是量子计算的关键特性,允许量子比特之间建立特殊的关联关系。"
"量子计算在密码学、药物发现和优化问题等领域具有潜在的应用前景。",
]
# 向向量库中添加文档及元数据
vector_store.add_texts(
documents,
metadatas=[
{"topic": "人工智能", "category": "科技"},
{"topic": "区块链", "category": "科技"},
{"topic": "量子计算", "category": "科技"},
]
)
# 指定一个复杂问题作为用户查询
query = "人工智能和机器学习的关系是什么?深度学习在其中的作用如何?它们在实际应用中有哪些典型案例?"
# 运行RAG流程并获取结果
result = rag_system.generate_answer(query, k=3)
# 打印分隔线
print("="*60)
print("最终结果")
print("="*60)
# 打印用户查询
print(f"\n用户查询: {result['query']}")
# 打印子问题列表及数量
print(f"\n分解的子问题 ({len(result['sub_questions'])} 个):")
for i, sq in enumerate(result['sub_questions'], 1):
print(f" {i}. {sq}")
# 打印生成的最终答案
print(f"\n生成的答案:\n{result['answer']}")
# 打印检索到的文档数量
print(f"\n检索到的文档数量: {result['num_documents']}")
# 打印部分检索到的文档的细节信息
print(f"\n检索到的文档详情:")
for i, doc in enumerate(result['retrieved_documents'][:3], 1):
print(f"\n文档 {i} (分数: {doc.metadata.get('score', 'N/A'):.4f}):")
print(f" 来源子问题: {doc.metadata.get('sub_question', 'N/A')[:60]}...")
print(f" 内容: {doc.page_content[:150]}...")
3.2 执行流程 #
3.2.1 核心思想 #
子问题分解采用“分而治之”策略:
- 将复杂问题拆解为多个独立的子问题
- 对每个子问题分别进行向量检索
- 合并所有子问题的检索结果
- 对结果去重并按分数排序
- 基于所有子问题的检索结果生成最终答案
3.2.2 执行流程 #
阶段一:初始化
# 1. 获取向量存储实例
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="sub_question_decomposition"
)
# 2. 创建子问题分解检索器
retriever = SubQuestionDecompositionRetriever(
vector_store=vector_store,
llm=llm,
k=3, # 每个子问题检索3个文档
min_sub_questions=3, # 最少3个子问题
max_sub_questions=5 # 最多5个子问题
)
# 3. 创建RAG系统
rag_system = SubQuestionDecompositionRAG(
retriever=retriever,
llm=llm
)初始化时:
- 创建向量存储实例
- 创建检索器,配置 LLM、向量库、k 值和子问题数量范围
- 创建 RAG 系统,用于完整流程
阶段二:文档索引
# 添加文档到向量库
vector_store.add_texts(
documents,
metadatas=[...]
)索引过程:
- 将示例文档添加到向量库
- 每个文档附带元数据(topic、category)
- 文档会被向量化并存储
阶段三:查询处理
query = "人工智能和机器学习的关系是什么?深度学习在其中的作用如何?它们在实际应用中有哪些典型案例?"
result = rag_system.generate_answer(query, k=3)完整流程:
- 用户提交复杂查询
- 子问题分解:
- 调用
decompose_question(query) - 使用 LLM 将复杂问题拆解为 3-5 个子问题
- 解析并验证子问题数量
- 调用
- 子问题检索:
- 调用
retrieve_for_sub_questions(sqs, k) - 对每个子问题分别进行向量检索
- 为每个检索结果添加子问题元数据
- 调用
- 去重合并:
- 调用
deduplicate_and_merge(all_docs) - 按文档内容去重,保留最佳分数
- 按分数排序
- 调用
- 生成答案:
- 整合所有检索文档作为上下文
- 提取所有子问题并格式化
- 使用提示词模板构建最终 prompt
- LLM 生成答案
- 返回结果:
- 包含查询、答案、子问题列表、检索文档和数量
3.3 类图 #
3.4 时序图 #
3.4.1 完整RAG流程时序图 #
3.4.2 子问题分解详细流程 #
3.4.3 子问题检索与去重详细流程 #
3.5 关键设计要点 #
1. 子问题分解流程
复杂查询
↓
LLM分解为子问题 (3-5个)
↓
解析并验证子问题
↓
对每个子问题分别检索
↓
合并所有检索结果
↓
去重并排序
↓
返回最终文档列表2. 子问题分解示例
原始查询: "人工智能和机器学习的关系是什么?深度学习在其中的作用如何?它们在实际应用中有哪些典型案例?"
分解后的子问题:
1. 人工智能和机器学习的关系是什么?
2. 深度学习在人工智能中的作用如何?
3. 机器学习和深度学习在实际应用中有哪些典型案例?
4. 人工智能、机器学习和深度学习之间的层次关系是什么?
5. 这些技术在实际应用中的具体案例有哪些?3. 去重策略
- 去重键:使用文档内容(
page_content)作为唯一标识 - 分数选择:如果同一文档被多个子问题检索到,保留最佳分数(距离最小)
- 排序:去重后按分数升序排序,最相关的文档在前
示例:
子问题1检索到: [docA(score=0.2), docB(score=0.3), docC(score=0.4)]
子问题2检索到: [docA(score=0.1), docD(score=0.3), docE(score=0.5)]
去重后: [docA(score=0.1), docB(score=0.3), docC(score=0.4), docD(score=0.3), docE(score=0.5)]4. 元数据设计
检索返回的 Document 对象包含:
score:相似度分数(距离,越小越相似)sub_question:来源子问题sub_question_index:子问题索引(1, 2, 3...)- 继承原始文档的元数据(topic、category 等)
5. 答案生成模板
默认答案生成模板:
已知以下相关信息:
{context}
原始问题:{query}
该问题被分解为以下子问题:
{sub_questions}
请根据上述信息,全面回答原始问题。确保覆盖所有子问题的答案。
答案:该模板:
- 包含所有检索文档作为上下文
- 明确列出所有子问题
- 要求覆盖所有子问题的答案
3.6. 与其他方法的对比 #
| 特性 | 传统检索 | 子问题分解 |
|---|---|---|
| 查询处理 | 直接使用原始查询 | 拆解为多个子问题 |
| 检索方式 | 单次检索 | 多次检索(每个子问题) |
| 结果处理 | 直接返回 | 去重合并 |
| 适用场景 | 简单查询 | 复杂多部分查询 |
| 覆盖度 | 可能遗漏 | 更全面覆盖 |
3.7. 优势与应用场景 #
优势:
- 全面覆盖:通过子问题确保覆盖查询的各个方面
- 提高精度:每个子问题独立检索,更精准匹配
- 处理复杂查询:适合多部分、多角度的复杂问题
- 结果去重:自动处理重复文档,保留最佳分数
适用场景:
- 复杂问题:包含多个子问题的复合查询
- 多角度查询:需要从不同角度回答的问题
- 综合分析:需要综合多个方面信息的查询
- 研究场景:需要全面覆盖相关知识的场景
3.8. 技术细节 #
- 子问题数量控制:
- 最少
min_sub_questions个(默认3) - 最多
max_sub_questions个(默认5) - 如果少于最小值,回退为原始查询
- 最少
- 子问题解析:
- 去除编号前缀(如 "1. "、"1、"、"- ")
- 只保留包含问号的行
- 去重算法:
- 使用字典以内容为键
- 保留最佳分数(距离最小)
该设计通过“分而治之”策略,在复杂查询场景下提供更全面的检索覆盖,适用于需要多角度、多部分回答的 RAG 应用。
4. 多角度查询重写(Query Rewriting) #
传统的RAG(Retrieval-Augmented Generation,检索增强生成)流程,直接用用户的自然语言查询去做embedding检索,其效果很依赖用户的表述方式:同一个查询如果表达方式不同,底层向量检索的命中内容可能有很大区别,导致召回信息不完整、答案片面甚至遗漏关键信息。尤其是
- 用户表达简略,未穷举出所有意图的措辞
- 查询存在同义/多义,或有不同专业表述(例如「人工智能」VS「AI」VS「智能系统」)
- 用户问题本身较模糊或抽象
为此,可以通过查询重写(Query Rewriting)技术,将用户原始问题自动改写为多个表述不同、风格多样的查询版本,从而用多角度、丰富的语义视角去覆盖底层知识库的潜在高相关片段,大幅提升检索的召回率和结果多样性,实现信息“组团”式召回与补全。
步骤说明与关键技术点如下:
多版本查询生成
- 调用LLM,输入一个用户查询,让大模型帮我们生成3-5个不同措辞、风格或视角的查询表达。
要求每个版本都“忠于核心意图”但具备表述差异,可以参考如下Prompt引导:
请将以下查询改写为3-5个不同表达方式的查询版本。 每个查询版本应该: 1. 保持原始查询的核心意图 2. 使用不同的措辞和表达方式 3. 可以从不同角度(正式、通俗、专业、简洁等)表达 4. 确保所有版本都能检索到相关信息 原始查询:{query} 请列出查询版本(每行一个):
多版本批量检索
- 针对上述每一个改写后的查询,分别进行embedding后在向量库中检索K个高相关文档。
- 可为每条检索结果标注其对应的“查询版本”来源,便于后续分析。
去重合并(Deduplicate & Merge)
- 很多时候,不同的查询版本可能检索到内容大致相同的文本,因此需要对结果按内容hash或embedding聚类做去重,只保留分数最高最相关的若干条。
答案生成
- 可以汇总所有检索到的文档内容,附带各自的来源查询版本,统一拼接成上下文,再用LLM生成高质量响应答案。
优势总结
- 极大提升了召回率和检索多样性,最大化覆盖用户表达中的隐含意图和边界情况
- 对于表达不精准、领域同义词多、知识表述丰富的任务有巨大效果(如医学、法务、科技百科等)
- 还能为后续调优改写prompt与分析检索“盲区”提供基础
典型适用场景:
- 用户写的问题比较通用/宽泛/有歧义时
- 知识库表述丰富,内容多样,传统直检检索经常漏掉重要信息时
- 希望提高长尾召回、提升答案全面性时
方法小结:
- Query Rewriting ≈ 多视角“扩展式”检索 ≈ 类似于“查询增强(Query Expansion)”+“多表达聚合”
- 自动化结合大模型和向量检索,可插拔适配各种RAG/检索框架
4.1 QueryRewriting.py #
# 导入类型注解所需的相关类型
from typing import List, Dict, Any, Optional
# 导入pydantic配置工具
from pydantic import ConfigDict
# 导入基础检索器基类
from langchain_core.retrievers import BaseRetriever
# 导入文档类型
from langchain_core.documents import Document
# 导入LLM基础类型
from langchain_core.language_models import BaseLanguageModel
# 导入提示词模板
from langchain_core.prompts import PromptTemplate
# 导入自定义llm对象
from llm import llm
# 导入获取向量库函数
from vector_store import get_vector_store
# 导入正则模块
import re
# 定义查询重写检索器类,继承自BaseRetriever
class QueryRewritingRetriever(BaseRetriever):
# 向量库对象
vector_store: Any
# 大语言模型对象
llm: BaseLanguageModel
# 检索每个查询版本的文档数量
k: int = 4
# 查询版本最少数量
min_versions: int = 3
# 查询版本最多数量
max_versions: int = 5
# 可选的查询重写prompt模板字符串
rewriting_prompt_template_str: Optional[str] = None
# 配置模型参数,允许任意类型
model_config = ConfigDict(arbitrary_types_allowed=True)
# 获取查询重写prompt模板的方法
def _get_rewriting_prompt_template(self) -> PromptTemplate:
# 获取或默认生成用于查询重写的Prompt模板字符串
template = self.rewriting_prompt_template_str or (
"请将以下查询改写为3-5个不同表达方式的查询版本。\n"
"每个查询版本应该:\n"
"1. 保持原始查询的核心意图\n"
"2. 使用不同的措辞和表达方式\n"
"3. 可以从不同角度(正式、通俗、专业、简洁等)表达\n"
"4. 确保所有版本都能检索到相关信息\n\n"
"原始查询:{query}\n请列出查询版本(每行一个):"
)
# 返回PromptTemplate对象
return PromptTemplate(input_variables=["query"], template=template)
# 调用LLM生成多个不同表达方式的查询版本
def rewrite_query(self, query: str) -> List[str]:
# 构建重写prompt
prompt = self._get_rewriting_prompt_template().format(query=query)
# 调用LLM进行内容生成
resp = self.llm.invoke(prompt).content.strip()
rewritten_queries = []
# 遍历LLM返回的每一行,解析查询版本文本
for line in resp.split('\n'):
# 去除前缀编号等符号
line = re.sub(r'^\d+[\.、]\s*|^[-•]\s*', '', line.strip())
# 保留非空行作为查询版本
if line:
rewritten_queries.append(line)
# 限制查询版本数量不超过最大值
rewritten_queries = rewritten_queries[:self.max_versions]
# 若数量少于最小,至少包含原始查询
if len(rewritten_queries) < self.min_versions:
if query not in rewritten_queries:
rewritten_queries.insert(0, query)
# 返回查询版本列表
return rewritten_queries
# 针对查询版本列表依次检索相似文档
def retrieve_for_queries(self, queries: List[str], k: int) -> List[Document]:
all_docs = []
# 遍历每个查询版本及下标
for i, query_version in enumerate(queries, 1):
# 对当前查询版本进行向量相似性检索
docs = self.vector_store.similarity_search_with_score(query_version, k=k)
# 将检索结果统一封装为Document对象,并增加查询版本相关信息
for doc, distance in docs:
all_docs.append(Document(
page_content=doc.page_content,
metadata={**doc.metadata, "score": float(distance), "query_version": query_version, "version_index": i}
))
# 返回所有结果列表
return all_docs
# 对检索到的文档内容进行去重与归并
def deduplicate_and_merge(self, all_docs: List[Document]) -> List[Document]:
unique = {}
# 遍历所有文档,使用内容做去重键,保留分数更小者
for doc in all_docs:
c = doc.page_content
s = doc.metadata.get("score", float('inf'))
if c not in unique or s < unique[c].metadata.get("score", float('inf')):
unique[c] = doc
# 对去重后的文档按分数升序排序
docs = sorted(unique.values(), key=lambda x: x.metadata.get("score", float('inf')))
# 返回排序后的文档列表
return docs
# 主入口:根据查询获取相关文档流程
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
# 获取本次检索的k值(可动态传递)
k = kwargs.get("k", self.k)
# 查询重写:生成多个查询版本
query_versions = self.rewrite_query(query)
# 针对查询版本检索
docs = self.retrieve_for_queries(query_versions, k)
# 去重合并并返回
return self.deduplicate_and_merge(docs)
# 定义查询重写RAG主流程类
class QueryRewritingRAG:
# 初始化,注入检索器和LLM,支持自定义答案生成Prompt
def __init__(self, retriever: QueryRewritingRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str]=None):
self.retriever = retriever
self.llm = llm
# 设置答案生成Prompt模板,支持自定义或默认
template = answer_prompt_template or (
"已知以下相关信息:\n\n{context}\n\n"
"原始查询:{query}\n\n"
"该查询被重写为以下版本:\n{query_versions}\n\n"
"请根据上述信息,准确全面地回答原始查询。\n\n答案:"
)
# 构建最终答案生成的PromptTemplate对象
self.answer_prompt_template = PromptTemplate(
input_variables=["context", "query", "query_versions"], template=template
)
# 输入查询及可选k,返回答案及检索相关信息
def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
# 检索相关文档
docs = self.retriever._get_relevant_documents(query, k=k)
# 汇总所有查询版本(去重)
query_versions = list(set([doc.metadata.get("query_version","") for doc in docs if doc.metadata.get("query_version")]))
# 查询版本格式化
query_versions_text = "\n".join([f"{i+1}. {qv}" for i, qv in enumerate(query_versions)])
# 合并上下文内容
context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
# 构建答案生成Prompt
prompt = self.answer_prompt_template.format(context=context, query=query, query_versions=query_versions_text)
# 调用LLM生成答案
answer = self.llm.invoke(prompt).content.strip()
# 整理并返回完整流程结果
return {
"query": query,
"answer": answer,
"query_versions": query_versions,
"retrieved_documents": docs,
"num_documents": len(docs)
}
# 初始化向量库对象
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="query_rewriting"
)
# 初始化查询重写检索器
retriever = QueryRewritingRetriever(
vector_store=vector_store,
llm=llm,
k=3,
min_versions=3,
max_versions=5
)
# 初始化RAG系统
rag_system = QueryRewritingRAG(
retriever=retriever,
llm=llm
)
# 定义示例文档内容列表
documents = [
"人工智能(AI)是计算机科学的一个分支,旨在创建能够执行通常需要人类智能的任务的系统。"
"机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习,而无需明确编程。"
"深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。"
"自然语言处理(NLP)是AI的另一个重要领域,专注于使计算机能够理解和生成人类语言。",
"区块链技术是一种分布式账本技术,通过密码学方法确保数据的安全性和不可篡改性。"
"比特币是第一个成功应用区块链技术的加密货币,它解决了数字货币的双重支付问题。"
"以太坊是一个支持智能合约的区块链平台,允许开发者在其上构建去中心化应用(DApps)。"
"智能合约是自动执行的合约,其条款直接写入代码中,无需第三方中介。",
"量子计算是一种利用量子力学现象进行计算的新兴技术,具有巨大的计算潜力。"
"量子比特(qubit)是量子计算的基本单位,与经典比特不同,它可以同时处于0和1的叠加态。"
"量子纠缠是量子计算的关键特性,允许量子比特之间建立特殊的关联关系。"
"量子计算在密码学、药物发现和优化问题等领域具有潜在的应用前景。",
]
# 向向量库中添加文档及元数据
vector_store.add_texts(
documents,
metadatas=[
{"topic": "人工智能", "category": "科技"},
{"topic": "区块链", "category": "科技"},
{"topic": "量子计算", "category": "科技"},
]
)
# 指定一个查询作为用户输入
query = "机器学习是什么?"
# 运行RAG流程并获取结果
result = rag_system.generate_answer(query, k=3)
# 打印分隔线
print("="*60)
print("最终结果")
print("="*60)
# 打印用户查询
print(f"\n用户查询: {result['query']}")
# 打印查询版本列表及数量
print(f"\n生成的查询版本 ({len(result['query_versions'])} 个):")
for i, qv in enumerate(result['query_versions'], 1):
print(f" {i}. {qv}")
# 打印生成的最终答案
print(f"\n生成的答案:\n{result['answer']}")
# 打印检索到的文档数量
print(f"\n检索到的文档数量: {result['num_documents']}")
# 打印部分检索到的文档的细节信息
print(f"\n检索到的文档详情:")
for i, doc in enumerate(result['retrieved_documents'][:3], 1):
print(f"\n文档 {i} (分数: {doc.metadata.get('score', 'N/A'):.4f}):")
print(f" 来源查询版本: {doc.metadata.get('query_version', 'N/A')[:60]}...")
print(f" 内容: {doc.page_content[:150]}...")4.2 执行过程 #
4.2.1 核心思想 #
查询重写采用“多版本查询”策略:
- 将用户查询改写为多个不同表达方式的查询版本
- 每个版本保持核心意图,但使用不同措辞和角度
- 对每个查询版本分别进行向量检索
- 合并所有查询版本的检索结果
- 对结果去重并按分数排序
- 基于所有查询版本的检索结果生成最终答案
4.2.2 执行流程 #
阶段一:初始化
# 1. 获取向量存储实例
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="query_rewriting"
)
# 2. 创建查询重写检索器
retriever = QueryRewritingRetriever(
vector_store=vector_store,
llm=llm,
k=3, # 每个查询版本检索3个文档
min_versions=3, # 最少3个查询版本
max_versions=5 # 最多5个查询版本
)
# 3. 创建RAG系统
rag_system = QueryRewritingRAG(
retriever=retriever,
llm=llm
)初始化时:
- 创建向量存储实例
- 创建检索器,配置 LLM、向量库、k 值和查询版本数量范围
- 创建 RAG 系统,用于完整流程
阶段二:文档索引
# 添加文档到向量库
vector_store.add_texts(
documents,
metadatas=[...]
)索引过程:
- 将示例文档添加到向量库
- 每个文档附带元数据(topic、category)
- 文档会被向量化并存储
阶段三:查询处理
query = "机器学习是什么?"
result = rag_system.generate_answer(query, k=3)完整流程:
- 用户提交查询
- 查询重写:
- 调用
rewrite_query(query) - 使用 LLM 将查询改写为 3-5 个不同表达方式的版本
- 解析并验证查询版本数量
- 调用
- 多版本检索:
- 调用
retrieve_for_queries(query_versions, k) - 对每个查询版本分别进行向量检索
- 为每个检索结果添加查询版本元数据
- 调用
- 去重合并:
- 调用
deduplicate_and_merge(all_docs) - 按文档内容去重,保留最佳分数
- 按分数排序
- 调用
- 生成答案:
- 整合所有检索文档作为上下文
- 提取所有查询版本并格式化
- 使用提示词模板构建最终 prompt
- LLM 生成答案
- 返回结果:
- 包含查询、答案、查询版本列表、检索文档和数量
4.2.3 类图 #
4.2.4 时序图 #
4.2.4.1 完整RAG流程时序图 #
4.2.4.2 查询重写详细流程 #
4.2.4.3 多版本检索与去重详细流程 #
4.2.5 关键设计要点 #
1. 查询重写流程
用户查询
↓
LLM改写为多个版本 (3-5个)
↓
解析并验证查询版本
↓
对每个查询版本分别检索
↓
合并所有检索结果
↓
去重并排序
↓
返回最终文档列表2. 查询重写示例
原始查询: "机器学习是什么?"
重写后的查询版本:
1. 机器学习的定义是什么?
2. 什么是机器学习技术?
3. 机器学习的概念和原理
4. 如何理解机器学习?
5. 机器学习的基本含义不同表达方式:
- 正式表达:"机器学习的定义是什么?"
- 通俗表达:"什么是机器学习技术?"
- 专业表达:"机器学习的概念和原理"
- 简洁表达:"如何理解机器学习?"
- 基础表达:"机器学习的基本含义"
3. 去重策略
- 去重键:使用文档内容(
page_content)作为唯一标识 - 分数选择:如果同一文档被多个查询版本检索到,保留最佳分数(距离最小)
- 排序:去重后按分数升序排序,最相关的文档在前
示例:
查询版本1检索到: [docA(score=0.2), docB(score=0.3), docC(score=0.4)]
查询版本2检索到: [docA(score=0.15), docD(score=0.3), docE(score=0.5)]
查询版本3检索到: [docB(score=0.25), docF(score=0.35), docG(score=0.45)]
去重后: [docA(score=0.15), docB(score=0.25), docC(score=0.4), docD(score=0.3), docE(score=0.5), docF(score=0.35), docG(score=0.45)]4. 元数据设计
检索返回的 Document 对象包含:
score:相似度分数(距离,越小越相似)query_version:来源查询版本version_index:查询版本索引(1, 2, 3...)- 继承原始文档的元数据(topic、category 等)
5. 答案生成模板
默认答案生成模板:
已知以下相关信息:
{context}
原始查询:{query}
该查询被重写为以下版本:
{query_versions}
请根据上述信息,准确全面地回答原始查询。
答案:该模板:
- 包含所有检索文档作为上下文
- 明确列出所有查询版本
- 要求准确全面地回答原始查询
4.2.6. 与子问题分解的对比 #
| 特性 | 子问题分解 | 查询重写 |
|---|---|---|
| 处理方式 | 拆解为子问题 | 改写为不同表达方式 |
| 核心意图 | 覆盖不同方面 | 保持核心意图不变 |
| 表达方式 | 不同角度的问题 | 不同措辞和角度 |
| 适用场景 | 复杂多部分查询 | 单一查询的不同表达 |
| 检索策略 | 每个子问题独立检索 | 每个版本独立检索 |
4.2.7. 优势与应用场景 #
优势:
- 提高召回率:不同表达方式可能匹配到不同文档
- 处理表达多样性:适应不同的查询表达习惯
- 增强鲁棒性:减少单一表达方式可能带来的遗漏
- 结果去重:自动处理重复文档,保留最佳分数
适用场景:
- 用户表达多样性:不同用户可能用不同方式表达相同意图
- 术语变体:同一概念的不同术语表达
- 语言风格差异:正式、通俗、专业等不同风格
- 跨语言场景:需要处理不同语言表达方式
4.2.8. 技术细节 #
- 查询版本数量控制:
- 最少
min_versions个(默认3) - 最多
max_versions个(默认5) - 如果少于最小值,确保包含原始查询
- 最少
- 查询版本解析:
- 去除编号前缀(如 "1. "、"1、"、"- ")
- 保留所有非空行作为查询版本
- 去重算法:
- 使用字典以内容为键
- 保留最佳分数(距离最小)
4.2.9. 查询重写的要求 #
根据提示词模板,每个查询版本应该:
- 保持原始查询的核心意图
- 使用不同的措辞和表达方式
- 可以从不同角度(正式、通俗、专业、简洁等)表达
- 确保所有版本都能检索到相关信息
该设计通过“多版本查询”策略,在单一查询场景下提供更全面的检索覆盖,适用于需要处理表达多样性的 RAG 应用。
5. 抽象化查询转换(Take a Step Back) #
抽象化查询转换(Take a Step Back)是一种通过引导大语言模型(LLM)将用户提出的具体问题转化为更高层次、更加抽象且通用的问题,从而实现“广谱”信息检索的技术手段。这种方法的核心思想是“退一步”,不拘泥于问题的细节表述,而是挖掘出用户背后更广泛的核心诉求,进而覆盖更多相关内容,避免仅因提问限制而遗漏潜在高相关知识。
主要动机与应用场景
- 用户问题常常包含大量细节或上下文假设,造成检索空间受限。
- 知识库中很多内容用更抽象或泛化的方式表达,直接用具体问题检索可能无法命中这些内容。
- 支持“举一反三”型场景,用户关切的并不仅仅是问题本身,更关心相关原理、通用方法、背景知识。
比如,用户问:
- “如何使用Python的scikit-learn库训练一个支持向量机模型来分类鸢尾花数据集?” 抽象化后可转换为:
- “如何利用机器学习算法进行分类任务?”
- “支持向量机模型在分类问题中的应用方法是什么?”
这样,可以检索到更全面的泛化知识、典型流程和相关概念,丰富后续答案的广度与深度。
技术流程拆解
抽象化转换
- 针对用户原始查询,构造Prompt引导大模型去“去细节、保主旨”,生成一个更高层次、覆盖更广的抽象化问题。
常用Prompt示例:
请将以下具体问题转化为更高层次的抽象问题。 抽象化问题应该: 1. 去除具体细节,保留核心概念和意图 2. 使用更通用的术语和表达方式 3. 能够匹配更广泛的相关文档 4. 保持与原始问题的语义关联 具体问题:{query} 抽象化问题:- 得到抽象查询后,通常与原始查询一同组合用于检索。
双路混合检索(Abstract + Concrete)
- 对抽象化查询检索较多候选文档(广覆盖)。
- 对原始具体查询检索精确相关文档(高匹配)。
- 合并两路检索结果,内容去重、按相关性打分筛选。
答案生成/合成
- 整理上述所有检索到的文档内容,合成rich context。
- LLM生成答案时,参考原始查询+抽象查询,提升准确性与覆盖面。
方法优势与适用性
- 提升召回能力: 避免“细节漏检”,显著增加知识检索广度。
- 支持模糊查询: 用户问题不清晰或缺失背景时,自动补全潜在上位主题。
- 增强泛化表达: 能发现教材/论文/百科等资料通用讲法,补足典型套路与场景案例。
- 极适用于: 教育问答、技术原理解读、流程方法通用性归纳、案例“举一反三”等领域。
方法小结
- Take a Step Back(抽象化查询) = “拿掉细节、提升视角、扩展召回”。
- 本质是引入更泛化的表达,使知识检索“拉宽一圈”,与具体查询并用,最大化信息检索和答案生成的全面性及鲁棒性。
5.1 TakeAStepBack.py #
# 导入类型注解所需的相关类型
from typing import List, Dict, Any, Optional
# 导入pydantic配置工具
from pydantic import ConfigDict
# 导入基础检索器基类
from langchain_core.retrievers import BaseRetriever
# 导入文档类型
from langchain_core.documents import Document
# 导入LLM基础类型
from langchain_core.language_models import BaseLanguageModel
# 导入提示词模板
from langchain_core.prompts import PromptTemplate
# 导入自定义llm对象
from llm import llm
# 导入获取向量库函数
from vector_store import get_vector_store
# 定义抽象化查询转换检索器类,继承自BaseRetriever
class TakeAStepBackRetriever(BaseRetriever):
# 向量库对象
vector_store: Any
# 大语言模型对象
llm: BaseLanguageModel
# 使用抽象化查询进行广泛检索的结果数量
abstract_k: int = 5
# 使用原始查询进行精确检索的结果数量
concrete_k: int = 3
# 可选的抽象化转换prompt模板字符串
abstraction_prompt_template_str: Optional[str] = None
# 配置模型参数,允许任意类型
model_config = ConfigDict(arbitrary_types_allowed=True)
# 获取抽象化转换prompt模板的方法
def _get_abstraction_prompt_template(self) -> PromptTemplate:
# 获取或默认生成用于抽象化转换的Prompt模板字符串
template = self.abstraction_prompt_template_str or (
"请将以下具体问题转化为更高层次的抽象问题。\n"
"抽象化问题应该:\n"
"1. 去除具体细节,保留核心概念和意图\n"
"2. 使用更通用的术语和表达方式\n"
"3. 能够匹配更广泛的相关文档\n"
"4. 保持与原始问题的语义关联\n\n"
"具体问题:{query}\n\n"
"抽象化问题:"
)
# 返回PromptTemplate对象
return PromptTemplate(input_variables=["query"], template=template)
# 调用LLM将具体问题转化为抽象问题
def abstract_query(self, query: str) -> str:
# 构建抽象化prompt
prompt = self._get_abstraction_prompt_template().format(query=query)
# 调用LLM进行内容生成
abstract_query = self.llm.invoke(prompt).content.strip()
# 返回抽象化查询
return abstract_query
# 使用抽象化查询进行广泛检索
def retrieve_with_abstract_query(self, abstract_query: str, k: Optional[int] = None) -> List[Document]:
# 如果未提供k参数,则使用实例默认值
k = k if k is not None else self.abstract_k
# 对抽象化查询进行向量相似性检索,获取更多结果
docs = self.vector_store.similarity_search_with_score(abstract_query, k=k)
# 将检索结果统一封装为Document对象,并标记为抽象化检索结果
abstract_docs = []
for doc, distance in docs:
abstract_docs.append(Document(
page_content=doc.page_content,
metadata={**doc.metadata, "score": float(distance), "retrieval_type": "abstract", "query": abstract_query}
))
# 返回抽象化检索结果
return abstract_docs
# 使用原始查询进行精确检索
def retrieve_with_concrete_query(self, concrete_query: str, k: Optional[int] = None) -> List[Document]:
# 如果未提供k参数,则使用实例默认值
k = k if k is not None else self.concrete_k
# 对原始查询进行向量相似性检索,获取精确结果
docs = self.vector_store.similarity_search_with_score(concrete_query, k=k)
# 将检索结果统一封装为Document对象,并标记为精确检索结果
concrete_docs = []
for doc, distance in docs:
concrete_docs.append(Document(
page_content=doc.page_content,
metadata={**doc.metadata, "score": float(distance), "retrieval_type": "concrete", "query": concrete_query}
))
# 返回精确检索结果
return concrete_docs
# 合并检索结果并去重
def merge_and_deduplicate(self, abstract_docs: List[Document], concrete_docs: List[Document]) -> List[Document]:
# 合并所有检索结果
all_docs = abstract_docs + concrete_docs
unique = {}
# 遍历所有文档,使用内容做去重键,保留分数更小者
for doc in all_docs:
c = doc.page_content
s = doc.metadata.get("score", float('inf'))
# 如果内容未见过,或当前分数更好,则更新
if c not in unique or s < unique[c].metadata.get("score", float('inf')):
unique[c] = doc
# 对去重后的文档按分数升序排序
docs = sorted(unique.values(), key=lambda x: x.metadata.get("score", float('inf')))
# 返回排序后的文档列表
return docs
# 主入口:根据查询获取相关文档流程
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
# 获取可选的检索数量参数(如果提供则覆盖默认值)
abstract_k = kwargs.get("abstract_k", self.abstract_k)
concrete_k = kwargs.get("concrete_k", self.concrete_k)
# 获取可选的抽象化查询(如果提供则直接使用,避免重复生成)
abstract_query = kwargs.get("abstract_query")
if abstract_query is None:
# 抽象化转换:将具体问题转化为抽象问题
abstract_query = self.abstract_query(query)
# 混合检索:使用抽象化查询进行广泛检索,使用原始查询进行精确检索
# 直接传递检索数量参数,无需临时修改实例属性
abstract_docs = self.retrieve_with_abstract_query(abstract_query, k=abstract_k)
concrete_docs = self.retrieve_with_concrete_query(query, k=concrete_k)
# 结果合并:合并检索结果并去重
final_docs = self.merge_and_deduplicate(abstract_docs, concrete_docs)
# 返回最终文档列表
return final_docs
# 定义抽象化查询转换RAG主流程类
class TakeAStepBackRAG:
# 初始化,注入检索器和LLM,支持自定义答案生成Prompt
def __init__(self, retriever: TakeAStepBackRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str]=None):
self.retriever = retriever
self.llm = llm
# 设置答案生成Prompt模板,支持自定义或默认
template = answer_prompt_template or (
"已知以下相关信息:\n\n{context}\n\n"
"原始查询:{query}\n"
"抽象化查询:{abstract_query}\n\n"
"请根据上述信息,准确全面地回答原始查询。\n\n答案:"
)
# 构建最终答案生成的PromptTemplate对象
self.answer_prompt_template = PromptTemplate(
input_variables=["context", "query", "abstract_query"], template=template
)
# 输入查询及可选检索数量,返回答案及检索相关信息
def generate_answer(self, query: str, abstract_k: int = 5, concrete_k: int = 3) -> Dict[str, Any]:
# 先生成抽象化查询
abstract_query = self.retriever.abstract_query(query)
# 检索相关文档(包含抽象化和精确检索),传入抽象化查询避免重复生成
docs = self.retriever._get_relevant_documents(query, abstract_k=abstract_k, concrete_k=concrete_k, abstract_query=abstract_query)
# 合并上下文内容
context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
# 构建答案生成Prompt
prompt = self.answer_prompt_template.format(context=context, query=query, abstract_query=abstract_query)
# 调用LLM生成答案
answer = self.llm.invoke(prompt).content.strip()
# 统计检索类型
abstract_count = sum(1 for doc in docs if doc.metadata.get("retrieval_type") == "abstract")
concrete_count = sum(1 for doc in docs if doc.metadata.get("retrieval_type") == "concrete")
# 整理并返回完整流程结果
return {
"query": query,
"abstract_query": abstract_query,
"answer": answer,
"retrieved_documents": docs,
"num_documents": len(docs),
"abstract_count": abstract_count,
"concrete_count": concrete_count
}
# 初始化向量库对象
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="take_a_step_back"
)
# 初始化抽象化查询转换检索器
retriever = TakeAStepBackRetriever(
vector_store=vector_store,
llm=llm,
abstract_k=5,
concrete_k=3
)
# 初始化RAG系统
rag_system = TakeAStepBackRAG(
retriever=retriever,
llm=llm
)
# 定义示例文档内容列表
documents = [
"人工智能(AI)是计算机科学的一个分支,旨在创建能够执行通常需要人类智能的任务的系统。"
"机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习,而无需明确编程。"
"深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。"
"自然语言处理(NLP)是AI的另一个重要领域,专注于使计算机能够理解和生成人类语言。",
"区块链技术是一种分布式账本技术,通过密码学方法确保数据的安全性和不可篡改性。"
"比特币是第一个成功应用区块链技术的加密货币,它解决了数字货币的双重支付问题。"
"以太坊是一个支持智能合约的区块链平台,允许开发者在其上构建去中心化应用(DApps)。"
"智能合约是自动执行的合约,其条款直接写入代码中,无需第三方中介。",
"量子计算是一种利用量子力学现象进行计算的新兴技术,具有巨大的计算潜力。"
"量子比特(qubit)是量子计算的基本单位,与经典比特不同,它可以同时处于0和1的叠加态。"
"量子纠缠是量子计算的关键特性,允许量子比特之间建立特殊的关联关系。"
"量子计算在密码学、药物发现和优化问题等领域具有潜在的应用前景。",
]
# 向向量库中添加文档及元数据
vector_store.add_texts(
documents,
metadatas=[
{"topic": "人工智能", "category": "科技"},
{"topic": "区块链", "category": "科技"},
{"topic": "量子计算", "category": "科技"},
]
)
# 指定一个具体查询作为用户输入
query = "如何使用Python的scikit-learn库训练一个支持向量机模型来分类鸢尾花数据集?"
# 运行RAG流程并获取结果
result = rag_system.generate_answer(query, abstract_k=5, concrete_k=3)
# 打印分隔线
print("="*60)
print("最终结果")
print("="*60)
# 打印用户查询
print(f"\n原始查询: {result['query']}")
# 打印抽象化查询
print(f"\n抽象化查询: {result['abstract_query']}")
# 打印生成的最终答案
print(f"\n生成的答案:\n{result['answer']}")
# 打印检索统计信息
print(f"\n检索统计:")
print(f" 总文档数: {result['num_documents']}")
print(f" 抽象化检索: {result['abstract_count']} 个")
print(f" 精确检索: {result['concrete_count']} 个")
# 打印部分检索到的文档的细节信息
print(f"\n检索到的文档详情:")
for i, doc in enumerate(result['retrieved_documents'][:5], 1):
retrieval_type = doc.metadata.get("retrieval_type", "unknown")
print(f"\n文档 {i} (分数: {doc.metadata.get('score', 'N/A'):.4f}, 类型: {retrieval_type}):")
print(f" 内容: {doc.page_content[:150]}...")5.2 执行过程 #
5.2.1 核心思想 #
Take a Step Back 采用“抽象+精确”的双重检索策略:
- 将具体问题转化为更高层次的抽象问题
- 使用抽象化查询进行广泛检索(获取更多相关文档)
- 使用原始查询进行精确检索(获取精确匹配的文档)
- 合并两种检索结果并去重
- 基于合并结果生成最终答案
5.2.2 执行流程 #
阶段一:初始化
# 1. 获取向量存储实例
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="take_a_step_back"
)
# 2. 创建Take a Step Back检索器
retriever = TakeAStepBackRetriever(
vector_store=vector_store,
llm=llm,
abstract_k=5, # 抽象化检索返回5个文档
concrete_k=3 # 精确检索返回3个文档
)
# 3. 创建RAG系统
rag_system = TakeAStepBackRAG(
retriever=retriever,
llm=llm
)初始化时:
- 创建向量存储实例
- 创建检索器,配置 LLM、向量库、抽象化检索数量(abstract_k)和精确检索数量(concrete_k)
- 创建 RAG 系统,用于完整流程
阶段二:文档索引
# 添加文档到向量库
vector_store.add_texts(
documents,
metadatas=[...]
)索引过程:
- 将示例文档添加到向量库
- 每个文档附带元数据(topic、category)
- 文档会被向量化并存储
阶段三:查询处理
query = "如何使用Python的scikit-learn库训练一个支持向量机模型来分类鸢尾花数据集?"
result = rag_system.generate_answer(query, abstract_k=5, concrete_k=3)完整流程:
- 用户提交具体查询
- 抽象化转换:
- 调用
abstract_query(query) - 使用 LLM 将具体问题转化为抽象问题
- 去除具体细节,保留核心概念
- 调用
- 双重检索:
- 使用抽象化查询进行广泛检索(abstract_k=5)
- 使用原始查询进行精确检索(concrete_k=3)
- 结果合并:
- 调用
merge_and_deduplicate(abstract_docs, concrete_docs) - 合并两种检索结果
- 按文档内容去重,保留最佳分数
- 按分数排序
- 调用
- 生成答案:
- 整合所有检索文档作为上下文
- 使用提示词模板构建最终 prompt(包含原始查询和抽象化查询)
- LLM 生成答案
- 返回结果:
- 包含查询、抽象化查询、答案、检索文档、统计信息等
5.2.3 类图 #
5.2.4 时序图 #
5.2.4.1 完整RAG流程时序图 #
5.2.4.2 抽象化转换详细流程 #
5.2.4.3 双重检索与合并详细流程 #
5.2.5 关键设计要点 #
1. 抽象化转换流程
具体查询
↓
LLM抽象化转换
↓
抽象化查询 (去除细节,保留核心概念)
↓
双重检索:
- 抽象化查询 → 广泛检索 (abstract_k=5)
- 原始查询 → 精确检索 (concrete_k=3)
↓
合并结果并去重
↓
返回最终文档列表2. 抽象化转换示例
原始查询: "如何使用Python的scikit-learn库训练一个支持向量机模型来分类鸢尾花数据集?"
抽象化查询: "如何使用机器学习算法进行分类任务?"
或
抽象化查询: "机器学习分类模型的训练方法"抽象化要求:
- 去除具体细节:Python、scikit-learn、支持向量机、鸢尾花数据集
- 保留核心概念:机器学习、分类、训练
- 使用通用术语:算法、模型、方法
- 匹配更广泛文档:能匹配到机器学习相关的所有文档
3. 双重检索策略
- 抽象化检索(广泛检索):
- 使用抽象化查询
- 返回更多文档(abstract_k=5)
- 匹配相关概念和原理
- 提供背景知识和理论基础
- 精确检索(精确匹配):
- 使用原始查询
- 返回精确文档(concrete_k=3)
- 匹配具体实现和细节
- 提供直接相关的答案
4. 去重策略
- 去重键:使用文档内容(
page_content)作为唯一标识 - 分数选择:如果同一文档被两种检索方式都检索到,保留最佳分数(距离最小)
- 排序:去重后按分数升序排序,最相关的文档在前
示例:
抽象化检索结果: [docA(score=0.2), docB(score=0.3), docC(score=0.4), docD(score=0.5), docE(score=0.6)]
精确检索结果: [docA(score=0.15), docF(score=0.25), docG(score=0.35)]
合并去重后: [docA(score=0.15), docF(score=0.25), docB(score=0.3), docG(score=0.35), docC(score=0.4), docD(score=0.5), docE(score=0.6)]5. 元数据设计
检索返回的 Document 对象包含:
score:相似度分数(距离,越小越相似)retrieval_type:检索类型("abstract" 或 "concrete")query:使用的查询(抽象化查询或原始查询)- 继承原始文档的元数据(topic、category 等)
6. 答案生成模板
默认答案生成模板:
已知以下相关信息:
{context}
原始查询:{query}
抽象化查询:{abstract_query}
请根据上述信息,准确全面地回答原始查询。
答案:该模板:
- 包含所有检索文档作为上下文
- 明确列出原始查询和抽象化查询
- 要求准确全面地回答原始查询
5.2.6 与其他方法的对比 #
| 特性 | 传统检索 | Take a Step Back |
|---|---|---|
| 查询处理 | 直接使用原始查询 | 抽象化转换 + 原始查询 |
| 检索方式 | 单次检索 | 双重检索(广泛+精确) |
| 覆盖范围 | 单一匹配 | 广泛+精确双重覆盖 |
| 适用场景 | 简单查询 | 具体但需要背景知识的查询 |
| 优势 | 简单直接 | 兼顾广泛性和精确性 |
5.2.7 优势与应用场景 #
优势:
- 广泛覆盖:抽象化检索提供背景知识和理论基础
- 精确匹配:精确检索提供直接相关的答案
- 平衡策略:兼顾广泛性和精确性
- 结果去重:自动处理重复文档,保留最佳分数
适用场景:
- 具体技术问题:需要具体实现和背景知识
- 学习场景:需要理论背景和实际应用
- 研究场景:需要概念理解和具体案例
- 复杂查询:需要多层次信息的查询
5.2.8 技术细节 #
- 检索数量控制:
- 抽象化检索:
abstract_k(默认5) - 精确检索:
concrete_k(默认3) - 可通过参数动态调整
- 抽象化检索:
- 抽象化查询缓存:
- 在
_get_relevant_documents中支持传入abstract_query - 避免重复生成抽象化查询
- 在
- 去重算法:
- 使用字典以内容为键
- 保留最佳分数(距离最小)
5.2.9 检索统计 #
返回结果包含统计信息:
abstract_count:来自抽象化检索的文档数量concrete_count:来自精确检索的文档数量num_documents:总文档数量(去重后)
该设计通过“抽象+精确”的双重检索策略,在具体查询场景下提供更全面的检索覆盖,适用于需要背景知识和具体答案的 RAG 应用。
6. 检索-生成一体化(Retrieval-generation integration) #
传统RAG(Retrieval-Augmented Generation)体系通常将“检索”与“生成”划分为独立模块:先对查询进行信息检索,再将检索到的文档输入大语言模型生成答案。这种串联模式虽然结构清晰,但也存在以下潜在不足:
- 检索与生成的分界较为生硬,检索阶段难以充分感知生成任务的具体要求。
- 生成模型无法灵活动态地影响检索策略与范围。
- 在复杂问题和多轮对话场景下,检索结果的相关性和准确性难以保障。
为了解决上述问题,近年来兴起了“检索-生成一体化”思想——即通过深度融合检索与生成过程,实现二者的协同优化。其核心目标在于让生成模型动态、主动参与到文档检索阶段,例如:
- 让大模型根据当前生成任务多轮拟定检索子查询,分阶段检索并综合更多相关信息。
- 检索结果与生成阶段通过奖励机制、指令调优(instruction tuning)、强化学习等方式联合训练,增强全流程对齐能力。
- 构建端到端可微分检索生成链路,实现前馈式联合优化。
典型的一体化实现包括:
- Query Routing: 生成模型自动识别查询意图,有针对性地选择检索策略(如结构化、非结构化、混合型文档库)。
- 多轮增量式检索-生成:生成模型先粗略生成答案→识别缺失信息→动态追加检索→迭代完善答案。
- 检索-生成共设计(Co-design):如指令调优的融合检索生成Agent,将信息查询和答案生成步骤交错建模。
这一趋势使得RAG系统更加智能、灵活,能更好地处理开放域、复杂任务与场景,如多跳推理、事实查证、科技问答及对话系统。
6.1 RetrievalGenerationIntegration.py #
# 导入List、Dict、Any、Optional类型用于类型注解
from typing import List, Dict, Any, Optional
# 导入pydantic用于配置数据模型
from pydantic import ConfigDict
# 导入基础检索器基类
from langchain_core.retrievers import BaseRetriever
# 导入文档对象类型
from langchain_core.documents import Document
# 导入基础LLM类型
from langchain_core.language_models import BaseLanguageModel
# 导入提示词模板类
from langchain_core.prompts import PromptTemplate
# 导入自定义llm对象
from llm import llm
# 导入获取向量存储的函数
from vector_store import get_vector_store
# 定义意图分类器类
class IntentClassifier:
"""意图分类器:识别用户问题的类型"""
# 构造函数,初始化llm对象和意图类别列表
def __init__(self, llm: BaseLanguageModel, intent_categories: List[str]):
self.llm = llm
self.intent_categories = intent_categories
# 定义意图分类任务的prompt模板
self.classification_template = PromptTemplate(
input_variables=["query", "categories"],
template="请判断以下用户查询属于哪个类别。\n"
"可用类别:{categories}\n\n"
"用户查询:{query}\n\n"
"请只返回类别名称,不要返回其他内容:"
)
# 对单条用户查询进行意图分类
def classify(self, query: str) -> str:
"""对查询进行意图分类"""
# 用中文顿号拼接所有意图类别
categories_str = "、".join(self.intent_categories)
# 填充prompt生成具体任务
prompt = self.classification_template.format(query=query, categories=categories_str)
# 调用llm得到结果
response = self.llm.invoke(prompt).content.strip()
# 检查返回的类别是否在目标类别中
for category in self.intent_categories:
if category in response:
return category
# 默认返回第一个类别
return self.intent_categories[0]
# 定义查询重写器类
class QueryRewriter:
"""查询重写器:将口语化问题转换为标准术语"""
# 构造,接收llm对象和意图到重写策略的映射
def __init__(self, llm: BaseLanguageModel, rewriting_strategies: Dict[str, str]):
self.llm = llm
self.rewriting_strategies = rewriting_strategies
# 根据意图重写查询内容
def rewrite(self, query: str, intent: str) -> str:
"""根据意图重写查询"""
# 获取指定意图的重写提示词,没有则用默认
strategy = self.rewriting_strategies.get(intent, "请将以下查询转换为标准术语进行检索:{query}")
# 根据策略格式化生成prompt
if "{query}" in strategy:
prompt = strategy.format(query=query)
else:
prompt = f"{strategy}\n\n用户查询:{query}\n\n标准术语查询:"
# 使用llm调用重写查询
rewritten = self.llm.invoke(prompt).content.strip()
# 对重写为空的情况做降级处理
return rewritten if rewritten else query
# 定义知识库管理器类
class KnowledgeBaseManager:
"""知识库管理器:管理多个意图对应的知识库"""
# 构造函数,初始化知识库存储路径和知识库字典
def __init__(self, persist_directory: str):
self.persist_directory = persist_directory
self.knowledge_bases: Dict[str, Any] = {}
# 获取指定意图和集合名的知识库实例(不存在则自动创建)
def get_knowledge_base(self, intent: str, collection_name: str) -> Any:
"""获取或创建指定意图的知识库"""
# 如果知识库不存在则创建
if intent not in self.knowledge_bases:
self.knowledge_bases[intent] = get_vector_store(
persist_directory=self.persist_directory,
collection_name=collection_name
)
return self.knowledge_bases[intent]
# 向指定知识库添加文档列表及元数据
def add_documents(self, intent: str, collection_name: str, texts: List[str], metadatas: List[Dict] = None):
"""向指定意图的知识库添加文档"""
kb = self.get_knowledge_base(intent, collection_name)
kb.add_texts(texts, metadatas=metadatas or [{}] * len(texts))
# 定义检索生成一体化检索器
class RetrievalGenerationIntegrationRetriever(BaseRetriever):
"""检索-生成一体化检索器"""
# 定义知识库管理器类型
kb_manager: KnowledgeBaseManager
# 定义意图分类器类型
intent_classifier: IntentClassifier
# 定义查询重写器类型
query_rewriter: QueryRewriter
# 意图到集合名的映射表
intent_to_collection: Dict[str, str]
# 默认检索文档数量
k: int = 4
# pydantic模型配置,允许任意类型字段
model_config = ConfigDict(arbitrary_types_allowed=True)
# 获取与查询相关的文档
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
"""一体化检索流程"""
# 优先使用外部传入的k,否则用默认
k = kwargs.get("k", self.k)
# 步骤1: 意图识别
intent = self.intent_classifier.classify(query)
# 步骤2: 查询重写
rewritten_query = self.query_rewriter.rewrite(query, intent)
# 步骤3: 检索对应知识库
collection_name = self.intent_to_collection.get(intent, "default")
kb = self.kb_manager.get_knowledge_base(intent, collection_name)
docs = kb.similarity_search_with_score(rewritten_query, k=k)
# 步骤4: 文档结果组装成Document对象,加入相关元数据
result_docs = []
for doc, distance in docs:
result_docs.append(Document(
page_content=doc.page_content,
metadata={
**doc.metadata,
"score": float(distance),
"intent": intent,
"original_query": query,
"rewritten_query": rewritten_query
}
))
return result_docs
# 定义检索生成一体化RAG系统
class RetrievalGenerationIntegrationRAG:
"""检索-生成一体化RAG系统"""
# 构造方法,初始化retriever与llm,并配置答案生成prompt
def __init__(self, retriever: RetrievalGenerationIntegrationRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str] = None):
self.retriever = retriever
self.llm = llm
# 若无自定义模板则使用默认
template = answer_prompt_template or (
"已知以下相关信息:\n\n{context}\n\n"
"原始查询:{query}\n"
"识别意图:{intent}\n"
"重写查询:{rewritten_query}\n\n"
"请根据上述信息,准确全面地回答原始查询。\n\n答案:"
)
# 初始化生成答案所需的prompt模板
self.answer_prompt_template = PromptTemplate(
input_variables=["context", "query", "intent", "rewritten_query"], template=template
)
# 对用户查询生成最终答案
def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
"""生成答案"""
# 步骤1: 检索相关文档
docs = self.retriever._get_relevant_documents(query, k=k)
# 步骤2: 提取相关元数据
intent = docs[0].metadata.get("intent", "unknown") if docs else "unknown"
rewritten_query = docs[0].metadata.get("rewritten_query", query) if docs else query
# 步骤3: 组装检索到的文档内容作为上下文
context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
# 步骤4: 构建答案生成prompt
prompt = self.answer_prompt_template.format(
context=context,
query=query,
intent=intent,
rewritten_query=rewritten_query
)
# 步骤5: 调用llm生成答案
answer = self.llm.invoke(prompt).content.strip()
# 汇总所有输出信息
return {
"query": query,
"intent": intent,
"rewritten_query": rewritten_query,
"answer": answer,
"retrieved_documents": docs,
"num_documents": len(docs)
}
# 初始化各种组件配置参数
# 定义意图类别(如人事政策场景)
intent_categories = ["假期政策", "调岗流程", "薪资福利", "其他"]
# 定义针对不同意图的查询重写策略
rewriting_strategies = {
"假期政策": "请从以下查询中提取假期类型(如年假、病假、产假、婚假等)的标准术语,然后生成检索查询。\n用户查询:{query}\n标准术语查询:",
"调岗流程": "请从以下查询中提取调岗相关的标准术语和流程关键词,然后生成检索查询。\n用户查询:{query}\n标准术语查询:",
"薪资福利": "请从以下查询中提取薪资福利相关的标准术语,然后生成检索查询。\n用户查询:{query}\n标准术语查询:",
"其他": "请将以下查询转换为标准术语进行检索:{query}"
}
# 定义意图到向量集合名的映射
intent_to_collection = {
"假期政策": "leave_policy",
"调岗流程": "transfer_process",
"薪资福利": "salary_benefits",
"其他": "default"
}
# 创建意图分类器对象
intent_classifier = IntentClassifier(llm, intent_categories)
# 创建查询重写器对象
query_rewriter = QueryRewriter(llm, rewriting_strategies)
# 创建知识库管理器对象
kb_manager = KnowledgeBaseManager("chroma_db")
# 按依赖注入创建一体化检索器
retriever = RetrievalGenerationIntegrationRetriever(
kb_manager=kb_manager,
intent_classifier=intent_classifier,
query_rewriter=query_rewriter,
intent_to_collection=intent_to_collection,
k=4
)
# 创建一体化RAG主系统
rag_system = RetrievalGenerationIntegrationRAG(retriever=retriever, llm=llm)
# 示例:准备“假期政策”知识库文档
leave_policy_docs = [
"年假政策:员工工作满一年后可享受年假,年假天数根据工作年限递增。",
"病假政策:员工因病需要请假时,需提供医院证明,病假期间工资按相关规定发放。",
"产假政策:女性员工生育可享受产假,产假期间工资照发,具体天数根据国家规定执行。",
"婚假政策:员工结婚可申请婚假,婚假天数根据公司规定执行,通常为3-7天。"
]
# 示例:准备“调岗流程”知识库文档
transfer_process_docs = [
"调岗申请流程:员工需填写调岗申请表,经部门主管审批后提交人力资源部。",
"调岗审批流程:人力资源部审核调岗申请,评估员工能力和岗位匹配度,做出审批决定。",
"调岗执行流程:调岗获批后,人力资源部通知相关部门,安排员工到新岗位报到。"
]
# 示例:准备“薪资福利”知识库文档
salary_benefits_docs = [
"薪资结构:员工薪资由基本工资、绩效奖金、津贴等部分组成。",
"福利待遇:公司为员工提供五险一金、带薪年假、节日福利等。",
"绩效奖金:根据员工工作表现和公司业绩,发放季度和年度绩效奖金。"
]
# 各意图知识库分别添加相关文档及元数据
print("正在初始化知识库...")
kb_manager.add_documents("假期政策", "leave_policy", leave_policy_docs,
[{"topic": "假期政策", "type": "年假"}] * 1 +
[{"topic": "假期政策", "type": "病假"}] * 1 +
[{"topic": "假期政策", "type": "产假"}] * 1 +
[{"topic": "假期政策", "type": "婚假"}] * 1)
kb_manager.add_documents("调岗流程", "transfer_process", transfer_process_docs,
[{"topic": "调岗流程"}] * 3)
kb_manager.add_documents("薪资福利", "salary_benefits", salary_benefits_docs,
[{"topic": "薪资福利"}] * 3)
print("知识库初始化完成!\n")
# 口语化用户查询示例列表
queries = [
"我下个月要结婚,能请几天假?", # 应该识别为"假期政策",重写为"婚假"
"家里人生病了,公司有相关假期吗?", # 应该识别为"假期政策",重写为"病假"
"我想换个部门工作,需要走什么流程?" # 应该识别为"调岗流程"
]
# 测试每个查询,打印RAG系统的完整回答和检索信息
for query in queries:
print("="*60)
print(f"用户查询: {query}")
print("="*60)
result = rag_system.generate_answer(query, k=3)
print(f"\n识别意图: {result['intent']}")
print(f"重写查询: {result['rewritten_query']}")
print(f"\n生成答案:\n{result['answer']}")
print(f"\n检索到 {result['num_documents']} 个文档")
print("\n" + "-"*60 + "\n")
6.2 执行过程 #
6.2.1 核心思想 #
检索-生成一体化采用“意图驱动”的多知识库策略:
- 意图识别:识别用户问题的类型(如假期政策、调岗流程、薪资福利等)
- 查询重写:根据意图将口语化问题转换为标准术语
- 知识库路由:根据意图选择对应的知识库
- 检索:在对应的知识库中检索相关文档
- 生成答案:基于检索结果生成最终答案
6.2.2 执行流程 #
阶段一:初始化
# 1. 定义意图类别
intent_categories = ["假期政策", "调岗流程", "薪资福利", "其他"]
# 2. 定义查询重写策略
rewriting_strategies = {
"假期政策": "...",
"调岗流程": "...",
"薪资福利": "...",
"其他": "..."
}
# 3. 定义意图到集合名的映射
intent_to_collection = {
"假期政策": "leave_policy",
"调岗流程": "transfer_process",
"薪资福利": "salary_benefits",
"其他": "default"
}
# 4. 创建各个组件
intent_classifier = IntentClassifier(llm, intent_categories)
query_rewriter = QueryRewriter(llm, rewriting_strategies)
kb_manager = KnowledgeBaseManager("chroma_db")
# 5. 创建一体化检索器
retriever = RetrievalGenerationIntegrationRetriever(...)
# 6. 创建RAG系统
rag_system = RetrievalGenerationIntegrationRAG(retriever=retriever, llm=llm)初始化时:
- 定义意图类别和重写策略
- 创建意图分类器、查询重写器、知识库管理器
- 创建一体化检索器和RAG系统
阶段二:知识库初始化
# 为不同意图的知识库添加文档
kb_manager.add_documents("假期政策", "leave_policy", leave_policy_docs, ...)
kb_manager.add_documents("调岗流程", "transfer_process", transfer_process_docs, ...)
kb_manager.add_documents("薪资福利", "salary_benefits", salary_benefits_docs, ...)知识库初始化:
- 为每个意图创建独立的知识库
- 添加对应的文档和元数据
- 每个知识库使用不同的集合名
阶段三:查询处理
query = "我下个月要结婚,能请几天假?"
result = rag_system.generate_answer(query, k=3)完整流程:
- 用户提交口语化查询
- 意图识别:
- 调用
intent_classifier.classify(query) - 使用 LLM 识别查询的意图类别
- 调用
- 查询重写:
- 调用
query_rewriter.rewrite(query, intent) - 根据意图使用对应的重写策略
- 将口语化问题转换为标准术语
- 调用
- 知识库路由:
- 根据意图获取对应的集合名
- 获取或创建对应的知识库实例
- 检索:
- 在对应的知识库中使用重写后的查询进行检索
- 返回 top-k 个相关文档
- 生成答案:
- 整合检索文档作为上下文
- 使用提示词模板构建最终 prompt(包含原始查询、意图、重写查询)
- LLM 生成答案
- 返回结果:
- 包含查询、意图、重写查询、答案、检索文档等
6.2.3 类图 #
6.2.4 时序图 #
6.2.4.1 完整RAG流程时序图 #
6.2.4.2 意图识别详细流程 #
6.2.4.3 查询重写详细流程 #
6.2.4.4 知识库路由与检索详细流程 #
6.2.5 关键设计要点 #
1. 一体化流程
用户查询 (口语化)
↓
意图识别 → 意图类别
↓
查询重写 → 标准术语查询
↓
知识库路由 → 选择对应知识库
↓
检索 → 在对应知识库中检索
↓
生成答案 → 基于检索结果生成2. 意图识别示例
用户查询: "我下个月要结婚,能请几天假?"
↓
意图识别: "假期政策"
↓
查询重写: "婚假"
↓
知识库路由: "leave_policy"
↓
检索结果: 婚假政策相关文档3. 多知识库架构
- 每个意图对应一个独立的知识库
- 使用不同的集合名(collection_name)区分
- 知识库管理器统一管理
- 按需创建,避免重复初始化
知识库结构:
chroma_db/
├── leave_policy/ (假期政策知识库)
├── transfer_process/ (调岗流程知识库)
├── salary_benefits/ (薪资福利知识库)
└── default/ (其他知识库)4. 查询重写策略
不同意图使用不同的重写策略:
- 假期政策:提取假期类型(年假、病假、产假、婚假等)
- 调岗流程:提取调岗相关的标准术语和流程关键词
- 薪资福利:提取薪资福利相关的标准术语
- 其他:通用转换策略
示例:
原始查询: "我下个月要结婚,能请几天假?"
意图: "假期政策"
重写策略: "请从以下查询中提取假期类型..."
重写结果: "婚假"5. 元数据设计
检索返回的 Document 对象包含:
score:相似度分数(距离,越小越相似)intent:识别的意图类别original_query:原始用户查询rewritten_query:重写后的查询- 继承原始文档的元数据(topic、type 等)
6. 答案生成模板
默认答案生成模板:
已知以下相关信息:
{context}
原始查询:{query}
识别意图:{intent}
重写查询:{rewritten_query}
请根据上述信息,准确全面地回答原始查询。
答案:该模板:
- 包含所有检索文档作为上下文
- 明确列出原始查询、意图和重写查询
- 要求准确全面地回答原始查询
6.2.6 优势与应用场景 #
优势:
- 意图驱动:根据意图选择对应的知识库,提高检索精度
- 查询优化:将口语化问题转换为标准术语,提高匹配率
- 知识库隔离:不同意图使用独立知识库,避免干扰
- 可扩展性:易于添加新的意图类别和知识库
适用场景:
- 企业知识库:HR政策、流程文档等
- 多领域问答:不同领域使用不同知识库
- 专业场景:需要精确匹配特定领域知识
- 口语化查询:需要将自然语言转换为标准术语
6.2.7 技术细节 #
- 意图识别:
- 使用 LLM 进行分类
- 检查返回文本是否包含目标类别
- 默认返回第一个类别
- 查询重写:
- 根据意图使用对应的重写策略
- 支持自定义策略模板
- 重写失败时降级为原始查询
- 知识库管理:
- 使用字典缓存已创建的知识库
- 按需创建,避免重复初始化
- 支持动态添加文档
该设计通过“意图驱动+多知识库”策略,在特定领域场景下提供更精确的检索和答案生成,适用于需要意图识别和知识库路由的 RAG 应用。
7. 上下文对话(Contextual Dialogue) #
在许多真实场景中,用户的查询并不是一次性的,而是涉及多轮、上下文相关的对话。例如,用户第一次询问“我手机坏了怎么办?”,接下来又问“保修期还没过,怎么处理?”——后者的理解依赖于前文信息。传统检索流程对每次输入都“独立处理”,难以充分利用上下文已知的信息来精准理解当前问题、缩小检索范围或者消歧。
上下文对话(Contextual Dialogue)系统的主要任务,是结合对话历史,智能解析和补全当前用户的意图。例如:
指代消解
当用户问“它还能返厂维修吗?”时,系统应自动分析出“它”指代前文提到的具体产品,并据此生成准确的检索或回答。缺省信息补全
如果用户前面已经说明产品型号、事件等,后面的问句出现省略(如直接问“换电池多少钱?”),系统可自动补充上下文条件,避免重复追问。多轮追问/比较
对于“那与隔壁品牌的比,价格贵吗?”、“两者维修周期哪个短?”等对话,需要系统抓取对比对象,理解推理链路,综合分析。
为提升RAG系统在多轮对话场景下的表现,常见技术实践包括:
- 对话历史记录与回溯:为每位用户维护一份完整的会话轮次记录,并将历史摘要与当前输入共同作为检索和生成的条件。
- 指代消解与实体跟踪:使用NLP技术自动判断“他/它/他们/这种/这里/那个”等指代词,结合历史找到确切指向。
- 上下文增强的检索重写:对当前用户输入,通过LLM或专用重写组件,将含糊/不全的提问扩展为明确、完整的检索查询。
- 多轮信息整合生成:生成阶段基于综合的历史+检索片段,回答不仅准确,还符合连续对话语境。
举例说明:
用户A:
- “我买了你们的Alpha智能手表,但好像坏了。”
- “是去年买的。”
- “保修期还有多久?”
上述第3问,需要算法自动整合用户第1轮的产品信息和第2轮的购买时间,形成完整的条件(如“2023年购买的Alpha智能手表保修期还有多久”),再进行知识库检索和答案生成。
用户B:
- “我在看Beta蓝牙耳机和Gamma智能音箱,想了解一下售后政策。”
- “主要是保修时间。”
- “哪个产品保修期更长?”
第3问属于“对比型”,系统要抓出比较对象并准确回答。
应用难点与要点:
- 上下文解析的准确性直接影响检索质量,特别是在指代、歧义、信息不全时。
- 在RAG流程中,应让检索步骤和生成步骤都能访问到“上下文融合”的综合查询。
- 实际工程中,LangChain等框架提供了类似“ConversationMemory”、“ChatPromptTemplate”等用于上下文对话处理的工具。
综上,支持上下文对话是RAG从“检索-生成”到“智能陪伴体”的关键一步,是企业知识助手、客户服务机器人等场景不可或缺的技术模块。下面将展示一个典型的多轮上下文QA对话的代码样例。
7.1 ContextualDialogue.py #
# 导入List、Dict、Any、Optional类型用于类型注解
from typing import List, Dict, Any, Optional
# 导入pydantic用于配置数据模型
from pydantic import ConfigDict
# 导入基础检索器基类
from langchain_core.retrievers import BaseRetriever
# 导入文档对象类型
from langchain_core.documents import Document
# 导入基础LLM类型
from langchain_core.language_models import BaseLanguageModel
# 导入提示词模板类
from langchain_core.prompts import PromptTemplate
# 导入自定义llm对象
from llm import llm
# 导入获取向量存储的函数
from vector_store import get_vector_store
# 导入正则模块用于解析query字符串
import re
# 定义对话历史记录管理器类
class DialogueHistory:
"""对话历史管理器:管理多轮对话记录"""
# 初始化函数,创建空的历史列表
def __init__(self):
self.history: List[Dict[str, str]] = []
# 添加一轮对话,参数为角色和内容
def add_turn(self, role: str, content: str):
"""添加一轮对话"""
self.history.append({"role": role, "content": content})
# 获取对话历史记录的文本串,格式为"角色:内容"
def get_history_text(self) -> str:
"""获取对话历史的文本表示"""
lines = []
# 遍历每一轮对话,拼接字符串
for turn in self.history:
lines.append(f"{turn['role']}:{turn['content']}")
# 用换行符拼接所有轮次
return "\n".join(lines)
# 清空对话历史
def clear(self):
"""清空对话历史"""
self.history = []
# 定义Query归纳Agent类
class QueryGenerationAgent:
"""Agent:分析对话历史和最新问题,生成检索query列表"""
# 初始化函数,配置llm和模板
def __init__(self, llm: BaseLanguageModel, query_generation_template: Optional[str] = None):
# 保存llm对象
self.llm = llm
# 若未指定模板则用默认模板
template = query_generation_template or (
"请根据以下多轮对话内容和用户最新问题,智能归纳出1到多条适合知识库检索的query。\n"
"要求:\n"
"1. 结合历史对话补全关键信息(如产品名称、时间等)\n"
"2. 识别模糊指代(如'它们'、'这个'等)的具体对象\n"
"3. 对于对比型问题,为每个对象生成独立的query\n"
"4. 对于多意图问题,为每个意图生成独立的query\n"
"5. 将反问型问题转换为正面询问\n"
"6. 将条件型问题转化为检索参数\n\n"
"历史对话:\n{history}\n\n"
"用户最新问题:{latest_question}\n\n"
"请输出所有应检索的query,每行一条,不要编号:"
)
# 创建提示词模板对象
self.query_generation_template = PromptTemplate(
input_variables=["history", "latest_question"], template=template
)
# 根据历史和最新提问生成query
def generate_queries(self, history: str, latest_question: str) -> List[str]:
"""根据对话历史和最新问题生成检索query列表"""
# 填充模板得到prompt
prompt = self.query_generation_template.format(history=history, latest_question=latest_question)
# 用llm生成结果,去除首尾空白
response = self.llm.invoke(prompt).content.strip()
# 解析生成的query字符串列表
queries = []
# 按行处理每个生成的查询
for line in response.split('\n'):
# 去除两端空白字符
line = line.strip()
# 移除行首编号(如"1. "、"1、"等)
line = re.sub(r'^\d+[\.、]\s*', '', line)
# 移除列表标记(如"- "、"• "等)
line = re.sub(r'^[-•]\s*', '', line)
# 若不为空则加入列表
if line:
queries.append(line)
# 如果没解析出结果,则返回原始问题
if not queries:
queries = [latest_question]
return queries
# 定义上下文对话检索器类
class ContextualDialogueRetriever(BaseRetriever):
"""上下文对话检索器:结合对话历史进行检索"""
# 定义成员:向量数据库对象
vector_store: Any
# 定义成员:Query归纳Agent
agent: QueryGenerationAgent
# 定义成员:默认返回文档数量
k: int = 4
# 配置pydantic模型,允许任意类型参数
model_config = ConfigDict(arbitrary_types_allowed=True)
# 根据输入问题和历史,返回相关文档列表
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
"""根据查询和对话历史检索相关文档"""
# 获取k值,若未给则用默认
k = kwargs.get("k", self.k)
# 获取对话历史文本
history = kwargs.get("history", "")
# 若有历史,则生成queries,否则直接用query
if history:
queries = self.agent.generate_queries(history, query)
else:
queries = [query]
# 保存所有检索到的文档
all_docs = []
# 遍历每个query进行向量检索
for q in queries:
# 用向量库检索相似文档和分数
docs = self.vector_store.similarity_search_with_score(q, k=k)
# 遍历每一个文档和分数,组装为Document对象并记录meta信息
for doc, distance in docs:
all_docs.append(Document(
page_content=doc.page_content,
metadata={
**doc.metadata,
"score": float(distance),
"generated_query": q,
"original_query": query
}
))
# 去重:同一文档内容只保留分数最低(最相关)的那条
unique_docs = {}
for doc in all_docs:
content = doc.page_content
score = doc.metadata.get("score", float('inf'))
# 如果文档没见过,或更优分数,更新
if content not in unique_docs or score < unique_docs[content].metadata.get("score", float('inf')):
unique_docs[content] = doc
# 按分数升序排序,分数越低越相关
result_docs = sorted(unique_docs.values(), key=lambda x: x.metadata.get("score", float('inf')))
return result_docs
# 定义上下文对话RAG主系统类
class ContextualDialogueRAG:
"""上下文对话RAG系统:支持多轮对话的智能问答"""
# 初始化函数,注入检索器和llm,并可选注入答案模板
def __init__(self, retriever: ContextualDialogueRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str] = None):
self.retriever = retriever
self.llm = llm
# 创建对话历史管理对象
self.dialogue_history = DialogueHistory()
# 设置回答prompt模板
template = answer_prompt_template or (
"已知以下相关信息:\n\n{context}\n\n"
"对话历史:\n{history}\n\n"
"用户最新问题:{query}\n\n"
"请根据上述信息,结合对话历史,准确全面地回答用户最新问题。\n\n答案:"
)
# 创建模板对象
self.answer_prompt_template = PromptTemplate(
input_variables=["context", "history", "query"], template=template
)
# 一轮对话主逻辑
def chat(self, user_query: str, k: int = 4) -> Dict[str, Any]:
"""处理一轮对话"""
# 获取历史对话文本
history_text = self.dialogue_history.get_history_text()
# 检索相关文档(结合历史)
docs = self.retriever._get_relevant_documents(user_query, k=k, history=history_text)
# 从文档里提取所有本轮生成的query(去重)
generated_queries = list(set([doc.metadata.get("generated_query", "") for doc in docs if doc.metadata.get("generated_query")]))
# 构建上下文字符串,每个文档单独展示
context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
# 用模板生成prompt,注意若历史为空显示“无历史对话”
prompt = self.answer_prompt_template.format(
context=context,
history=history_text if history_text else "(无历史对话)",
query=user_query
)
# 用llm得到答案,去尾部空白
answer = self.llm.invoke(prompt).content.strip()
# 对话历史添加用户提问
self.dialogue_history.add_turn("用户", user_query)
# 对话历史添加AI回答
self.dialogue_history.add_turn("AI", answer)
# 返回本轮的全部信息
return {
"user_query": user_query,
"answer": answer,
"generated_queries": generated_queries,
"retrieved_documents": docs,
"num_documents": len(docs)
}
# 清空历史
def clear_history(self):
"""清空对话历史"""
self.dialogue_history.clear()
# ---------------- 以下为组件初始化和样例 ----------------
# 创建Query归纳Agent
agent = QueryGenerationAgent(llm)
# 创建向量数据库对象
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="contextual_dialogue"
)
# 创建上下文对话检索器
retriever = ContextualDialogueRetriever(
vector_store=vector_store,
agent=agent,
k=3
)
# 创建上下文对话RAG系统
rag_system = ContextualDialogueRAG(retriever=retriever, llm=llm)
# 示例产品知识库文档
documents = [
"Alpha智能手表:2022年购买的产品,保修期为2年,支持无线充电,颜色有黑色和白色。",
"Beta蓝牙耳机:保修期为1年,支持无线充电,颜色有黑色、白色和红色。",
"Gamma智能音箱:保修期为3年,不支持无线充电,颜色有黑色和白色。",
"Delta笔记本电脑:保修期为2年,不支持无线充电,颜色有黑色和银色。",
"Epsilon平板电脑:保修期为1年,支持无线充电,颜色有黑色、白色和金色。",
"Zeta无线充电器:保修期为1年,颜色有黑色和白色,支持多种设备充电。"
]
# 索引文档到向量库
print("正在初始化知识库...")
vector_store.add_texts(
documents,
metadatas=[{"topic": "产品信息"}] * len(documents)
)
print("知识库初始化完成!\n")
# ------------------- 示例1:上下文依赖对话 -------------------
print("="*60)
print("示例1:上下文依赖型对话")
print("="*60)
# 清空历史
rag_system.clear_history()
# 第一轮---智能手表问题
result1 = rag_system.chat("我最近买了个智能手表。")
print(f"用户:{result1['user_query']}")
print(f"AI:{result1['answer']}\n")
# 第二轮---补充产品关键信息
result2 = rag_system.chat("是Alpha智能手表,去年买的。")
print(f"用户:{result2['user_query']}")
print(f"AI:{result2['answer']}\n")
# 第三轮---上下文依赖查询,提及保修期
result3 = rag_system.chat("保修期还有多久?")
print(f"用户:{result3['user_query']}")
print(f"生成的检索query:{result3['generated_queries']}")
print(f"AI:{result3['answer']}\n")
# ------------------- 示例2:对比型对话 -------------------
print("\n" + "="*60)
print("示例2:对比型对话")
print("="*60)
# 清空历史
rag_system.clear_history()
result1 = rag_system.chat("我在看Beta蓝牙耳机和Gamma智能音箱,想了解一下售后政策。")
print(f"用户:{result1['user_query']}")
print(f"AI:{result1['answer']}\n")
result2 = rag_system.chat("主要是保修时间。")
print(f"用户:{result2['user_query']}")
print(f"AI:{result2['answer']}\n")
result3 = rag_system.chat("哪个产品保修期更长?")
print(f"用户:{result3['user_query']}")
print(f"生成的检索query:{result3['generated_queries']}")
print(f"AI:{result3['answer']}\n")
# ------------------- 示例3:模糊指代对话 -------------------
print("\n" + "="*60)
print("示例3:模糊指代型对话")
print("="*60)
# 清空历史
rag_system.clear_history()
result1 = rag_system.chat("我买了你们家的笔记本电脑和平板电脑。")
print(f"用户:{result1['user_query']}")
print(f"AI:{result1['answer']}\n")
result2 = rag_system.chat("是Delta笔记本电脑和Epsilon平板电脑。")
print(f"用户:{result2['user_query']}")
print(f"AI:{result2['answer']}\n")
result3 = rag_system.chat("它们都支持无线充电吗?")
print(f"用户:{result3['user_query']}")
print(f"生成的检索query:{result3['generated_queries']}")
print(f"AI:{result3['answer']}\n")
7.2 执行过程 #
7.2.1 核心思想 #
上下文对话采用“历史感知+智能查询生成”策略:
- 对话历史管理:维护多轮对话记录
- 智能查询生成:根据对话历史和最新问题,使用 Agent 生成适合检索的 query 列表
- 多查询检索:对每个生成的 query 分别进行检索
- 结果合并:合并所有检索结果并去重
- 上下文生成:基于检索结果和对话历史生成答案
7.2.2 执行流程 #
阶段一:初始化
# 1. 创建Query归纳Agent
agent = QueryGenerationAgent(llm)
# 2. 创建向量存储实例
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="contextual_dialogue"
)
# 3. 创建上下文对话检索器
retriever = ContextualDialogueRetriever(
vector_store=vector_store,
agent=agent,
k=3
)
# 4. 创建上下文对话RAG系统
rag_system = ContextualDialogueRAG(retriever=retriever, llm=llm)初始化时:
- 创建 Query 归纳 Agent,用于生成检索 query
- 创建向量存储实例
- 创建检索器,配置 Agent 和默认 k 值
- 创建 RAG 系统,内部包含对话历史管理器
阶段二:知识库初始化
# 添加文档到向量库
vector_store.add_texts(
documents,
metadatas=[...]
)索引过程:
- 将示例文档添加到向量库
- 每个文档附带元数据(topic)
- 文档会被向量化并存储
阶段三:多轮对话处理
# 第一轮对话
result1 = rag_system.chat("我最近买了个智能手表。")
# 第二轮对话
result2 = rag_system.chat("是Alpha智能手表,去年买的。")
# 第三轮对话(上下文依赖)
result3 = rag_system.chat("保修期还有多久?")完整流程(每轮对话):
- 用户提交查询
- 获取对话历史:
- 从对话历史管理器获取历史文本
- 智能查询生成:
- 如果有历史,调用
agent.generate_queries(history, query) - 使用 LLM 根据历史和最新问题生成检索 query 列表
- 如果没有历史,直接使用原始查询
- 如果有历史,调用
- 多查询检索:
- 对每个生成的 query 分别进行向量检索
- 为每个检索结果添加元数据(generated_query、original_query)
- 结果合并:
- 按文档内容去重,保留最佳分数
- 按分数排序
- 生成答案:
- 整合检索文档作为上下文
- 使用提示词模板构建最终 prompt(包含上下文、历史、查询)
- LLM 生成答案
- 更新历史:
- 将用户查询添加到历史
- 将 AI 回答添加到历史
- 返回结果:
- 包含用户查询、答案、生成的 query 列表、检索文档等
7.2.3 类图 #
7.2.4 时序图 #
7.2.4.1 完整多轮对话流程时序图 #
7.2.4.2 智能查询生成详细流程 #
7.2.4.3 多查询检索与合并详细流程 #
7.2.5 关键设计要点 #
1. 上下文对话流程
用户查询
↓
获取对话历史
↓
智能查询生成 (Agent)
↓
多查询检索
↓
结果合并去重
↓
生成答案 (结合历史)
↓
更新对话历史2. 智能查询生成示例
场景1:上下文依赖
历史对话:
用户:我最近买了个智能手表。
AI:...
最新问题: "保修期还有多久?"
生成的query: ["Alpha智能手表保修期"]场景2:对比型问题
历史对话:
用户:我在看Beta蓝牙耳机和Gamma智能音箱,想了解一下售后政策。
AI:...
最新问题: "哪个产品保修期更长?"
生成的query: ["Beta蓝牙耳机保修期", "Gamma智能音箱保修期"]场景3:模糊指代
历史对话:
用户:我买了你们家的笔记本电脑和平板电脑。
用户:是Delta笔记本电脑和Epsilon平板电脑。
AI:...
最新问题: "它们都支持无线充电吗?"
生成的query: ["Delta笔记本电脑无线充电", "Epsilon平板电脑无线充电"]3. 对话历史管理
- 历史格式:每轮对话以"角色:内容"格式存储
- 历史文本:用于查询生成和答案生成
- 自动更新:每轮对话后自动添加用户查询和AI回答
- 可清空:支持清空历史重新开始对话
历史示例:
用户:我最近买了个智能手表。
AI:很高兴听到您购买了智能手表...
用户:是Alpha智能手表,去年买的。
AI:Alpha智能手表是一款不错的产品...4. 查询生成要求
根据提示词模板,Agent需要:
- 结合历史补全关键信息(如产品名称、时间等)
- 识别模糊指代(如"它们"、"这个"等)的具体对象
- 对于对比型问题,为每个对象生成独立的query
- 对于多意图问题,为每个意图生成独立的query
- 将反问型问题转换为正面询问
- 将条件型问题转化为检索参数
5. 元数据设计
检索返回的 Document 对象包含:
score:相似度分数(距离,越小越相似)generated_query:生成的检索queryoriginal_query:原始用户查询- 继承原始文档的元数据(topic 等)
6. 答案生成模板
默认答案生成模板:
已知以下相关信息:
{context}
对话历史:
{history}
用户最新问题:{query}
请根据上述信息,结合对话历史,准确全面地回答用户最新问题。
答案:该模板:
- 包含所有检索文档作为上下文
- 包含完整的对话历史
- 要求结合对话历史回答最新问题
7. 去重策略
- 去重键:使用文档内容(
page_content)作为唯一标识 - 分数选择:如果同一文档被多个query检索到,保留最佳分数(距离最小)
- 排序:去重后按分数升序排序,最相关的文档在前
7.2.6 优势与应用场景 #
优势:
- 上下文感知:结合对话历史理解用户意图
- 智能查询生成:自动补全关键信息、识别指代
- 多查询检索:提高检索覆盖率和准确性
- 自然对话:支持多轮连续对话
适用场景:
- 客服系统:需要理解上下文的多轮对话
- 产品咨询:需要结合历史信息回答问题
- 技术支持:需要追踪问题背景
- 对话式搜索:需要理解用户意图的连续查询
7.2.7 技术细节 #
- 查询解析:
- 去除编号前缀(如 "1. "、"1、")
- 去除列表标记(如 "- "、"• ")
- 保留非空行作为query
- 降级处理:
- 如果未解析出任何query,使用原始问题
- 历史处理:
- 如果历史为空,显示"(无历史对话)"
该设计通过“历史感知+智能查询生成”策略,在多轮对话场景下提供更准确的检索和答案生成,适用于需要理解上下文和连续对话的 RAG 应用。
8. 行业场景改写(Industry scenario adaptation) #
行业场景适配(Industry scenario adaptation)是面向垂直行业RAG应用时的重要检索前优化手段。其核心思想是:针对不同行业(如教育、医疗、法律等),由于用户的提问习惯、行业术语表达方式、检索需求差异较大,在将用户query送入RAG检索链路之前,引入“行业适配器”对提问进行智能改写或结构化,从而大幅提升匹配的精准度和最终答案的专业性。
典型应用场景举例
教育行业
用户常用叙述性和生活化描述提出数学、物理等题目。适配器需从表面语境中抽取核心的数学关系,消除无关细节,使检索query精确指向题干本质。例如:
“一个建筑工人每天能挖4米深的洞,挖了6天后还差3米就完成了,这个洞总共有多深?”
适配后变为:“已知每天挖4米,挖6天后剩3米,求总长度。”医疗行业
用户多为口语化描述身体症状,适配器需将其转换为标准医疗术语和专业表达:
“我最近总是头疼,还有点恶心,这是什么病?”
改写后:“症状:头痛、恶心,常见可能原因有哪些?”法律行业
查询问题往往带有具体纠纷细节或情绪表述。适配器应抽取法律关系、要素,形成适合法规检索的结构化query:
“我在网上买东西,商家不发货也不退钱,我该怎么办?”
改写后:“网购交易,商家未履约且拒绝退款,涉及的法律条款与处理方式有哪些?”
技术机制与流程
行业场景管理器(IndustryScenarioManager)
负责管理多行业的适配器,每种适配器实现行业独有的query重写策略。支持灵活注册和获取。行业适配器(IndustryScenarioAdapter)
基类定义统一接口,不同行业可灵活继承实现。内部常用LLM+PromptTemplate设计,可根据行业特点设计定制化Prompt。RAG检索前优化
当有用户query和行业类型输入时,行业场景管理器自动将query交给对应的适配器进行改写,再送入后续检索和生成模块,使召回和理解更加精准。
代码样例摘要
以“教育、医疗、法律”三行业为例,重写器和RAG主流程如下——
# 初始化行业场景管理器(可拓展更多行业)
scenario_manager = IndustryScenarioManager()
scenario_manager.register_adapter("教育", EducationScenarioAdapter(llm))
scenario_manager.register_adapter("医疗", MedicalScenarioAdapter(llm))
scenario_manager.register_adapter("法律", LegalScenarioAdapter(llm))
# 查询改写流程
edu_query = "小明每天跑步5圈,跑了7天后还剩2圈,问总共要跑多少圈?"
edu_adapted = scenario_manager.adapt_query(edu_query, "教育")
print(f"教育场景改写后:{edu_adapted}")
medical_query = "我最近总是头疼,有时候还会恶心,这是什么病?"
medical_adapted = scenario_manager.adapt_query(medical_query, "医疗")
print(f"医疗场景改写后:{medical_adapted}")
legal_query = "我在网上买东西,商家不发货也不退钱,我该怎么办?"
legal_adapted = scenario_manager.adapt_query(legal_query, "法律")
print(f"法律场景改写后:{legal_adapted}")技术优势与建议
- 提升行业检索召回率与准确率:有针对性地处理行业术语、隐含要素、表达变体。
- 极大降低后续生成压力:让生成任务聚焦在精要的检索片段上,减少业务歧义。
- 方便扩展和行业微调:只需增加适用于新行业的适配器及prompt模板,即可复用整套RAG流程。
总结:行业场景适配是RAG系统“以术语为桥梁、以意图为导向”走向高质量产业落地的关键策略之一。建议在各类垂直知识库应用中作为标准前置优化组件使用,并结合行业专家和数据持续迭代优化适配prompt与重写逻辑。
8.1 IndustryScenarioAdaptation.py #
# 导入List、Dict、Any、Optional类型用于类型注解
from typing import List, Dict, Any, Optional
# 导入pydantic用于配置数据模型
from pydantic import ConfigDict
# 导入langchain核心模块中的基础检索器基类
from langchain_core.retrievers import BaseRetriever
# 导入langchain核心模块中的文档对象类型
from langchain_core.documents import Document
# 导入langchain核心模块中的基础LLM类型
from langchain_core.language_models import BaseLanguageModel
# 导入langchain核心模块中的提示词模板类
from langchain_core.prompts import PromptTemplate
# 导入自定义的llm对象
from llm import llm
# 导入自定义的获取向量存储的函数
from vector_store import get_vector_store
# 定义行业场景适配器基类:用于实现不同行业的query改写策略
class IndustryScenarioAdapter:
"""行业场景适配器基类:定义不同行业的query改写策略"""
# 初始化适配器,保存llm对象、行业名和改写模板
def __init__(self, llm: BaseLanguageModel, industry_name: str, rewriting_template: str):
self.llm = llm
self.industry_name = industry_name
# 使用PromptTemplate来封装改写模板
self.rewriting_template = PromptTemplate(
input_variables=["query"], template=rewriting_template
)
# 行业相关的query改写方法
def adapt_query(self, query: str) -> str:
"""根据行业特点改写query"""
# 使用模板格式化query
prompt = self.rewriting_template.format(query=query)
# 调用llm生成内容并去除首尾空格
adapted = self.llm.invoke(prompt).content.strip()
# 若llm输出不为空则返回,否则返回原query
return adapted if adapted else query
# 定义教育场景的适配器:用于数学题目表述抽取
class EducationScenarioAdapter(IndustryScenarioAdapter):
"""教育场景适配器:抽取题干,剥离表层信息,保留核心数学关系"""
# 构造方法,传入llm对象并设置教育场景的改写模板
def __init__(self, llm: BaseLanguageModel):
template = (
"请对以下题目进行题干抽取,剥离人物、情境、故事等表层信息,"
"只保留关键的数学关系、物理关系或求解目标。\n\n"
"原题目:{query}\n\n"
"抽取后的题干(只保留核心关系和求解目标):"
)
super().__init__(llm, "教育", template)
# 定义医疗场景适配器:用于口语化表述的专业化转换
class MedicalScenarioAdapter(IndustryScenarioAdapter):
"""医疗场景适配器:将口语化描述转换为专业术语"""
# 构造方法,传入llm对象并设置医疗场景的改写模板
def __init__(self, llm: BaseLanguageModel):
template = (
"请将以下医疗相关的口语化描述转换为专业术语和标准表达,"
"便于在医疗知识库中检索。\n\n"
"用户描述:{query}\n\n"
"专业术语查询:"
)
super().__init__(llm, "医疗", template)
# 定义法律场景适配器:用于法律问题关键词提取
class LegalScenarioAdapter(IndustryScenarioAdapter):
"""法律场景适配器:提取法条关键词和案例要素"""
# 构造方法,传入llm对象并设置法律场景的改写模板
def __init__(self, llm: BaseLanguageModel):
template = (
"请从以下法律问题中提取关键要素:法条类型、法律关系、争议焦点等,"
"生成适合法律知识库检索的query。\n\n"
"法律问题:{query}\n\n"
"检索query(包含法条类型和关键要素):"
)
super().__init__(llm, "法律", template)
# 定义行业场景管理器,用于管理所有行业的适配器
class IndustryScenarioManager:
"""行业场景管理器:管理多个行业的适配器"""
# 初始化,创建空的适配器字典
def __init__(self):
self.adapters: Dict[str, IndustryScenarioAdapter] = {}
# 注册指定行业及其对应的适配器
def register_adapter(self, industry: str, adapter: IndustryScenarioAdapter):
"""注册行业适配器"""
self.adapters[industry] = adapter
# 获取某行业的适配器,没有则返回None
def get_adapter(self, industry: str) -> Optional[IndustryScenarioAdapter]:
"""获取指定行业的适配器"""
return self.adapters.get(industry)
# 对query进行行业相关改写
def adapt_query(self, query: str, industry: str) -> str:
"""根据行业改写query"""
adapter = self.get_adapter(industry)
# 存在适配器则用适配器的方法,没有则原样返回query
if adapter:
return adapter.adapt_query(query)
return query
# 定义行业场景检索器:用于行业相关的向量搜索检索
class IndustryScenarioRetriever(BaseRetriever):
"""行业场景检索器:根据行业特点进行检索"""
# 声明成员变量类型:向量库对象
vector_store: Any
# 行业场景管理器
scenario_manager: IndustryScenarioManager
# 当前所属行业
industry: str
# 默认的检索文档数量
k: int = 4
# 允许任意类型参数的pydantic配置
model_config = ConfigDict(arbitrary_types_allowed=True)
# 获取与query相关的文档(行业适配方案下)
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
"""根据行业特点检索相关文档"""
# 获取检索数量参数,优先用传入值
k = kwargs.get("k", self.k)
# 获取行业参数,优先用传入值
industry = kwargs.get("industry", self.industry)
# 先对query进行行业相关改写
adapted_query = self.scenario_manager.adapt_query(query, industry)
# 用改写后的query在向量库中检索
docs = self.vector_store.similarity_search_with_score(adapted_query, k=k)
# 遍历检索结果,每个结果添加附加元数据
result_docs = []
for doc, distance in docs:
result_docs.append(Document(
page_content=doc.page_content,
metadata={
**doc.metadata,
"score": float(distance),
"industry": industry,
"original_query": query,
"adapted_query": adapted_query
}
))
# 返回封装后的文档对象列表
return result_docs
# 定义行业场景RAG系统:集成query改写和行业检索
class IndustryScenarioRAG:
"""行业场景RAG系统:支持不同行业的query改写"""
# 初始化RAG系统,配置检索器、llm和答案输出模板
def __init__(self, retriever: IndustryScenarioRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str] = None):
self.retriever = retriever
self.llm = llm
# 指定答案生成模板,如未指定则使用默认模板
template = answer_prompt_template or (
"已知以下相关信息:\n\n{context}\n\n"
"用户问题:{query}\n"
"行业场景:{industry}\n"
"改写后的查询:{adapted_query}\n\n"
"请根据上述信息,准确全面地回答用户问题。\n\n答案:"
)
# 初始化PromptTemplate对象
self.answer_prompt_template = PromptTemplate(
input_variables=["context", "query", "industry", "adapted_query"], template=template
)
# 生成答案的主入口
def generate_answer(self, query: str, industry: str, k: int = 4) -> Dict[str, Any]:
"""生成答案"""
# 检索与query与行业相关的文档
docs = self.retriever._get_relevant_documents(query, industry=industry, k=k)
# 提取第一个检索文档的adapted_query,否则使用原query
adapted_query = docs[0].metadata.get("adapted_query", query) if docs else query
# 拼接所有检索文档内容作为上下文
context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
# 用答案模板组织final prompt
prompt = self.answer_prompt_template.format(
context=context,
query=query,
industry=industry,
adapted_query=adapted_query
)
# 利用llm生成答案并去掉首尾空白
answer = self.llm.invoke(prompt).content.strip()
# 返回相关信息组成的结果字典
return {
"query": query,
"industry": industry,
"adapted_query": adapted_query,
"answer": answer,
"retrieved_documents": docs,
"num_documents": len(docs)
}
# =============================
# 以下为初始化与使用示例
# =============================
# 创建行业场景管理器
scenario_manager = IndustryScenarioManager()
# 注册教育、医疗、法律三类行业的适配器
scenario_manager.register_adapter("教育", EducationScenarioAdapter(llm))
scenario_manager.register_adapter("医疗", MedicalScenarioAdapter(llm))
scenario_manager.register_adapter("法律", LegalScenarioAdapter(llm))
# 创建向量存储对象,指定存储目录和集合名称
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="industry_scenario"
)
# 创建行业场景检索器,指定初始行业为教育,检索返回3条
retriever = IndustryScenarioRetriever(
vector_store=vector_store,
scenario_manager=scenario_manager,
industry="教育",
k=3
)
# 创建行业场景RAG系统
rag_system = IndustryScenarioRAG(retriever=retriever, llm=llm)
# 构造教育场景用的示例题目文档列表
education_documents = [
"题目:工人挖桥洞问题。一个工人每天挖3米,挖了5天后还剩2米,问桥洞总长多少米?\n"
"核心关系:已知每天工作量、工作天数、剩余量,求总量。\n"
"解题方法:总量 = 每天工作量 × 工作天数 + 剩余量 = 3 × 5 + 2 = 17米",
"题目:李白买酒问题。李白去买酒,每次买3升,买了4次后还剩1升,问总共需要多少升?\n"
"核心关系:已知每次购买量、购买次数、剩余量,求总量。\n"
"解题方法:总量 = 每次购买量 × 购买次数 + 剩余量 = 3 × 4 + 1 = 13升",
"题目:小明存钱问题。小明每天存5元,存了6天后还剩10元,问总共需要存多少钱?\n"
"核心关系:已知每天存钱数、存钱天数、剩余量,求总量。\n"
"解题方法:总量 = 每天存钱数 × 存钱天数 + 剩余量 = 5 × 6 + 10 = 40元",
"题目:速度时间路程问题。一辆车以60km/h的速度行驶,行驶了3小时后还剩120km,问总路程多少?\n"
"核心关系:已知速度、时间、剩余路程,求总路程。\n"
"解题方法:总路程 = 速度 × 时间 + 剩余路程 = 60 × 3 + 120 = 300km"
]
# 索引教育场景的示例文档到向量库
print("正在初始化知识库...")
vector_store.add_texts(
education_documents,
metadatas=[{"topic": "数学题目", "industry": "教育"}] * len(education_documents)
)
print("知识库初始化完成!\n")
# ========== 示例1:教育场景 ==========
# 打印分隔线
print("="*60)
print("示例1:教育场景 - 题目检索(题干抽取)")
print("="*60)
# 构造测试题目1,并调用RAG系统生成答案
query1 = "一个建筑工人每天能挖4米深的洞,挖了6天后还差3米就完成了,这个洞总共有多深?"
result1 = rag_system.generate_answer(query1, industry="教育", k=3)
print(f"原题目:{result1['query']}")
print(f"抽取后的题干:{result1['adapted_query']}")
print(f"生成的答案:\n{result1['answer']}")
print(f"检索到 {result1['num_documents']} 个相关文档\n")
# 构造测试题目2,并调用RAG系统生成答案
query2 = "诗人去商店买酒,每次买2瓶,买了5次后发现还需要3瓶,问总共要买多少瓶?"
result2 = rag_system.generate_answer(query2, industry="教育", k=3)
print(f"原题目:{result2['query']}")
print(f"抽取后的题干:{result2['adapted_query']}")
print(f"生成的答案:\n{result2['answer']}")
print(f"检索到 {result2['num_documents']} 个相关文档\n")
# ========== 示例2:展示不同行业的适配能力 ==========
# 打印分隔线
print("\n" + "="*60)
print("示例2:不同行业的query改写策略")
print("="*60)
# 教育场景下的自定义query测试
edu_query = "小明每天跑步5圈,跑了7天后还剩2圈,问总共要跑多少圈?"
edu_adapted = scenario_manager.adapt_query(edu_query, "教育")
print(f"教育场景:")
print(f" 原query:{edu_query}")
print(f" 改写后:{edu_adapted}\n")
# 医疗场景下的自定义query测试
medical_query = "我最近总是头疼,有时候还会恶心,这是什么病?"
medical_adapted = scenario_manager.adapt_query(medical_query, "医疗")
print(f"医疗场景:")
print(f" 原query:{medical_query}")
print(f" 改写后:{medical_adapted}\n")
# 法律场景下的自定义query测试
legal_query = "我在网上买东西,商家不发货也不退钱,我该怎么办?"
legal_adapted = scenario_manager.adapt_query(legal_query, "法律")
print(f"法律场景:")
print(f" 原query:{legal_query}")
print(f" 改写后:{legal_adapted}\n")8.2 执行流程 #
8.2.1 核心思想 #
行业场景适配采用“行业特定改写”策略:
- 适配器模式:为不同行业定义专门的查询改写策略
- 行业管理器:统一管理多个行业的适配器
- 查询改写:根据行业特点将用户查询改写为适合检索的形式
- 行业检索:使用改写后的查询在知识库中检索
- 答案生成:基于检索结果和行业信息生成答案
8.2.2 执行流程 #
阶段一:初始化
# 1. 创建行业场景管理器
scenario_manager = IndustryScenarioManager()
# 2. 注册各行业的适配器
scenario_manager.register_adapter("教育", EducationScenarioAdapter(llm))
scenario_manager.register_adapter("医疗", MedicalScenarioAdapter(llm))
scenario_manager.register_adapter("法律", LegalScenarioAdapter(llm))
# 3. 创建向量存储实例
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="industry_scenario"
)
# 4. 创建行业场景检索器
retriever = IndustryScenarioRetriever(
vector_store=vector_store,
scenario_manager=scenario_manager,
industry="教育",
k=3
)
# 5. 创建行业场景RAG系统
rag_system = IndustryScenarioRAG(retriever=retriever, llm=llm)初始化时:
- 创建行业场景管理器
- 为教育、医疗、法律三个行业注册适配器
- 创建向量存储实例
- 创建检索器,配置场景管理器和默认行业
- 创建RAG系统
阶段二:知识库初始化
# 添加教育场景的示例文档到向量库
vector_store.add_texts(
education_documents,
metadatas=[{"topic": "数学题目", "industry": "教育"}] * len(education_documents)
)索引过程:
- 将教育场景的示例文档添加到向量库
- 每个文档附带元数据(topic、industry)
- 文档会被向量化并存储
阶段三:查询处理
query = "一个建筑工人每天能挖4米深的洞,挖了6天后还差3米就完成了,这个洞总共有多深?"
result = rag_system.generate_answer(query, industry="教育", k=3)完整流程:
- 用户提交查询并指定行业
- 行业适配改写:
- 调用
scenario_manager.adapt_query(query, industry) - 获取对应行业的适配器
- 使用适配器的改写模板改写查询
- 调用
- 检索:
- 使用改写后的查询在向量库中检索
- 返回 top-k 个相关文档(带分数)
- 生成答案:
- 整合检索文档作为上下文
- 使用提示词模板构建最终 prompt(包含上下文、原始查询、行业、改写查询)
- LLM 生成答案
- 返回结果:
- 包含查询、行业、改写查询、答案、检索文档等
8.2.3 类图 #
8.2.4 时序图 #
8.2.4.1 完整RAG流程时序图 #
8.2.4.2 行业适配改写详细流程 #
8.2.4.3 不同行业适配器示例流程 #
8.2.5 关键设计要点 #
1. 行业适配流程
用户查询 + 行业标识
↓
获取行业适配器
↓
使用行业特定模板改写查询
↓
使用改写后的查询检索
↓
生成答案 (包含行业信息)2. 适配器模式设计
- 基类:
IndustryScenarioAdapter- 定义通用的
adapt_query()方法 - 使用模板和 LLM 进行查询改写
- 定义通用的
- 子类:具体行业适配器
EducationScenarioAdapter:教育场景MedicalScenarioAdapter:医疗场景LegalScenarioAdapter:法律场景
3. 行业改写策略示例
教育场景:
原始查询: "一个建筑工人每天能挖4米深的洞,挖了6天后还差3米就完成了,这个洞总共有多深?"
改写后: "已知每天工作量、工作天数、剩余量,求总量"医疗场景:
原始查询: "我最近总是头疼,有时候还会恶心,这是什么病?"
改写后: "头痛、恶心症状,可能疾病诊断"法律场景:
原始查询: "我在网上买东西,商家不发货也不退钱,我该怎么办?"
改写后: "网络购物合同纠纷,商家违约,消费者权益保护"4. 行业场景管理器
- 适配器注册:使用字典存储行业到适配器的映射
- 适配器获取:根据行业名称获取对应的适配器
- 统一接口:提供
adapt_query()方法统一处理
管理器结构:
adapters = {
"教育": EducationScenarioAdapter,
"医疗": MedicalScenarioAdapter,
"法律": LegalScenarioAdapter
}5. 元数据设计
检索返回的 Document 对象包含:
score:相似度分数(距离,越小越相似)industry:行业标识original_query:原始用户查询adapted_query:改写后的查询- 继承原始文档的元数据(topic、industry 等)
6. 答案生成模板
默认答案生成模板:
已知以下相关信息:
{context}
用户问题:{query}
行业场景:{industry}
改写后的查询:{adapted_query}
请根据上述信息,准确全面地回答用户问题。
答案:该模板:
- 包含所有检索文档作为上下文
- 明确列出原始查询、行业和改写查询
- 要求准确全面地回答用户问题
7. 优势与应用场景
优势:
- 行业定制:不同行业使用专门的改写策略
- 可扩展性:易于添加新的行业适配器
- 提高精度:行业特定的改写提高检索匹配率
- 灵活配置:支持动态切换行业
适用场景:
- 教育领域:数学题目、物理题目等需要抽取核心关系
- 医疗领域:口语化症状描述转换为专业术语
- 法律领域:法律问题提取关键要素和法条类型
- 多行业系统:需要支持多个行业的统一系统
8. 技术细节
- 适配器注册:
- 使用字典存储,键为行业名称,值为适配器实例
- 支持动态注册新的适配器
- 查询改写:
- 使用行业特定的提示词模板
- 如果改写结果为空,降级为原始查询
- 行业切换:
- 支持在查询时动态指定行业
- 检索器可以设置默认行业
9. 扩展性设计
添加新行业的步骤:
- 创建新的适配器类,继承
IndustryScenarioAdapter - 定义行业特定的改写模板
- 在管理器中注册新适配器
示例:
class FinanceScenarioAdapter(IndustryScenarioAdapter):
def __init__(self, llm: BaseLanguageModel):
template = "请将以下金融问题转换为专业术语..."
super().__init__(llm, "金融", template)
# 注册新适配器
scenario_manager.register_adapter("金融", FinanceScenarioAdapter(llm))该设计通过“适配器模式+行业管理器”策略,在多行业场景下提供更精确的查询改写和检索,适用于需要支持多个行业的 RAG 应用。
9. Text2SQL #
本节介绍如何实现文本到SQL(Text2SQL)任务,即将自然语言的问题转化为对应的SQL查询语句。 Text2SQL技术在数据库问答、智能客服、数据分析等场景中有着广泛的应用。其基本流程通常包括:
- 理解用户意图:分析自然语言中表述的查询需求,例如检索条件、目标表、要查询的字段等。
- 槽位抽取:从用户问题中结构化地提取出相关的实体、属性和条件(如部门、日期范围、数值过滤等),为SQL拼接做准备。
- SQL模板生成:根据提取的信息结合数据库表结构,动态生成对应的标准SQL语句。
- 执行与结果返回:将生成的SQL语句提交到数据库执行,并将查询结果以易懂的方式反馈给用户。
9.1. Text2SQL.py #
# 导入类型注解相关的类型List、Dict、Any、Optional、Tuple
from typing import List, Dict, Any, Optional, Tuple
# 导入pydantic中用于配置数据模型的ConfigDict
from pydantic import ConfigDict
# 导入langchain核心模块中的基础大模型类型
from langchain_core.language_models import BaseLanguageModel
# 导入langchain核心模块中的提示模板类
from langchain_core.prompts import PromptTemplate
# 导入自定义的llm对象
from llm import llm
# 导入json模块和re模块用于数据解析
import json
import re
# 导入sqlite3用于数据库操作
import sqlite3
# 定义槽位提取器类,用于从自然语言中提取结构化条件
class SlotExtractor:
"""槽位提取器:从自然语言中提取结构化条件"""
# 初始化方法,llm为大模型对象,slot_schema为槽位定义
def __init__(self, llm: BaseLanguageModel, slot_schema: Dict[str, Any]):
# 保存传入的大模型对象
self.llm = llm
# 保存槽位定义
self.slot_schema = slot_schema
# 构建可用槽位描述文本
schema_text = "\n".join([f"- {slot}: {desc}" for slot, desc in slot_schema.items()])
# 构建llm提示模板
template = (
"请从以下用户查询中提取结构化条件,以JSON格式返回。\n\n"
"可用槽位:\n{schema}\n\n"
"用户查询:{query}\n\n"
"请提取所有相关槽位,返回JSON格式,例如:"
'{{"position": "Java开发", "experience_years": 3, "max_salary": 20000, "work_mode": "远程"}}\n'
"如果某个槽位未提及,则不要包含在JSON中。只返回JSON,不要其他内容:"
)
# 实例化PromptTemplate
self.extraction_template = PromptTemplate(
input_variables=["query", "schema"], template=template
)
# 从用户输入中提取槽位
def extract_slots(self, query: str) -> Dict[str, Any]:
"""从查询中提取槽位"""
# 构造槽位schema的文本
schema_text = "\n".join([f"- {slot}: {desc}" for slot, desc in self.slot_schema.items()])
# 填充prompt模板
prompt = self.extraction_template.format(query=query, schema=schema_text)
# 使用大模型得到输出内容
response = self.llm.invoke(prompt).content.strip()
try:
# 用正则提取JSON片段
json_match = re.search(r'\{[^}]+\}', response, re.DOTALL)
if json_match:
# 将JSON解析为字典
slots = json.loads(json_match.group())
return slots
except:
# 捕获异常,解析失败则跳过
pass
# 未能解析则返回空字典
return {}
# 定义对话状态管理器类
class DialogueStateManager:
"""对话状态管理器:管理多轮对话中的槽位状态"""
# 初始化,保存当前槽位和对话历史
def __init__(self):
# 初始化当前槽位为空
self.slots: Dict[str, Any] = {}
# 初始化对话历史为空
self.history: List[Dict[str, str]] = []
# 更新槽位状态,新加槽位覆盖旧槽位
def update_slots(self, new_slots: Dict[str, Any], query: str) -> Dict[str, Any]:
"""更新槽位状态,新槽位会覆盖旧槽位"""
# 记录本轮执行前的历史
self.history.append({"query": query, "slots_before": self.slots.copy()})
# 更新槽位字典,新值覆盖老值
self.slots.update(new_slots)
# 记录更新后的槽位状态
self.history[-1]["slots_after"] = self.slots.copy()
# 返回最新的槽位状态副本
return self.slots.copy()
# 获取当前槽位状态
def get_slots(self) -> Dict[str, Any]:
"""获取当前槽位状态"""
return self.slots.copy()
# 清空所有槽位和对话历史
def clear(self):
"""清空状态"""
self.slots = {}
self.history = []
# 定义SQL生成器类,根据槽位生成SQL语句
class SQLGenerator:
"""SQL生成器:根据槽位生成SQL语句(使用参数化查询)"""
# 初始化,指定要查询的表名
def __init__(self, table_name: str = "candidates"):
# 保存表名
self.table_name = table_name
# 根据输入的槽位生成SQL与参数
def generate_sql(self, slots: Dict[str, Any]) -> Tuple[str, List[Any]]:
"""根据槽位生成SQL查询语句和参数,返回(SQL语句, 参数列表)"""
# 条件列表
conditions = []
# 参数列表
params = []
# 若槽位有岗位,根据岗位添加模糊查询
if "position" in slots:
pos = slots['position']
conditions.append("position LIKE ?")
params.append(f"%{pos}%")
# 若槽位有工作经验,添加经验过滤
if "experience_years" in slots:
exp = slots['experience_years']
conditions.append("experience_years >= ?")
params.append(exp)
# 若槽位有最低期望薪资,添加对应条件
if "min_salary" in slots:
min_sal = slots['min_salary']
conditions.append("expected_salary >= ?")
params.append(min_sal)
# 若槽位有最高期望薪资,添加对应条件
if "max_salary" in slots:
max_sal = slots['max_salary']
conditions.append("expected_salary <= ?")
params.append(max_sal)
# 若槽位有工作模式,添加对应条件
if "work_mode" in slots:
mode = slots['work_mode']
conditions.append("work_mode = ?")
params.append(mode)
# 拼接WHERE子句,没有即为1=1(查全部)
where_clause = " AND ".join(conditions) if conditions else "1=1"
# 拼装最终SQL语句
sql = f"SELECT * FROM {self.table_name} WHERE {where_clause}"
# 返回SQL和参数
return sql, params
# 定义SQLite数据库访问类
class SQLiteDatabase:
"""SQLite数据库:使用真实数据库进行查询"""
# 初始化,指定数据库文件及表名
def __init__(self, db_path: str = "candidates.db", table_name: str = "candidates"):
# 保存数据库路径
self.db_path = db_path
# 保存表名
self.table_name = table_name
# 初始化数据库及表结构
self._init_database()
# 初始化数据库和表,插入样例数据
def _init_database(self):
"""初始化数据库,创建表并插入示例数据"""
# 连接数据库
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 创建候选人表
cursor.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
position TEXT NOT NULL,
experience_years INTEGER NOT NULL,
expected_salary INTEGER NOT NULL,
work_mode TEXT NOT NULL
)
""")
# 检查表数据数量
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
count = cursor.fetchone()[0]
# 若表为空则插入样例数据
if count == 0:
sample_data = [
("张三", "Java开发", 5, 18000, "远程"),
("李四", "Java开发", 3, 15000, "远程"),
("王五", "Java开发", 4, 22000, "现场"),
("赵六", "Python开发", 3, 19000, "远程"),
("钱七", "Java开发", 2, 12000, "远程"),
]
cursor.executemany(
f"INSERT INTO {self.table_name} (name, position, experience_years, expected_salary, work_mode) VALUES (?, ?, ?, ?, ?)",
sample_data
)
# 提交并关闭连接
conn.commit()
conn.close()
# 执行SQL查询,返回字典列表,支持参数化查询
def execute_sql(self, sql: str, params: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
"""执行SQL查询,返回结果列表(支持参数化查询)"""
# 连接数据库
conn = sqlite3.connect(self.db_path)
# 设置行工厂为Row对象,实现结果为字典输出
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
# 执行SQL,带参数则传参数
if params:
cursor.execute(sql, params)
else:
cursor.execute(sql)
# 获取所有行数据
rows = cursor.fetchall()
# 每一行转为字典,组合为列表
results = [dict(row) for row in rows]
return results
except Exception as e:
# 打印错误信息,如有错误返回空列表
print(f"SQL执行错误: {e}")
return []
finally:
# 关闭数据库连接
conn.close()
# 定义Text2SQL检索增强生成(RAG)系统
class Text2SQLRAG:
"""Text2SQL RAG系统:将自然语言转换为SQL查询"""
# 初始化方法,配置llm对象、槽位定义、表名、数据库位置
def __init__(self, llm: BaseLanguageModel, slot_schema: Dict[str, Any], table_name: str = "candidates", db_path: str = "candidates.db"):
# 保存llm对象
self.llm = llm
# 创建槽位提取器
self.slot_extractor = SlotExtractor(llm, slot_schema)
# 创建对话状态管理器
self.state_manager = DialogueStateManager()
# 创建SQL生成器
self.sql_generator = SQLGenerator(table_name)
# 创建SQLite数据库对象
self.database = SQLiteDatabase(db_path, table_name)
# 定义大模型答案生成模板
self.answer_template = PromptTemplate(
input_variables=["query", "sql", "results"],
template=(
"用户查询:{query}\n\n"
"生成的SQL:{sql}\n\n"
"查询结果:{results}\n\n"
"请根据查询结果,用自然语言回答用户的问题。\n\n答案:"
)
)
# 处理一次自然语言用户查询,返回结构化结果
def query(self, user_query: str) -> Dict[str, Any]:
"""处理用户查询"""
# 槽位提取
new_slots = self.slot_extractor.extract_slots(user_query)
# 用新槽位更新对话状态
updated_slots = self.state_manager.update_slots(new_slots, user_query)
# 根据当前槽位生成SQL语句和参数
sql, params = self.sql_generator.generate_sql(updated_slots)
# 执行SQL查询,获取结果
results = self.database.execute_sql(sql, params)
# 将结果格式化为简明文本
results_text = "\n".join([f"- {r['name']}: {r['position']}, {r['experience_years']}年经验, 期望薪资{r['expected_salary']}元, {r['work_mode']}" for r in results])
if not results_text:
results_text = "未找到符合条件的候选人"
# 使用大模型生成问答输出内容
prompt = self.answer_template.format(
query=user_query,
sql=sql,
results=results_text
)
answer = self.llm.invoke(prompt).content.strip()
# 返回完整结构化结果
return {
"query": user_query,
"extracted_slots": new_slots,
"current_slots": updated_slots,
"sql": sql,
"results": results,
"answer": answer
}
# 清空当前对话状态
def clear_state(self):
"""清空对话状态"""
self.state_manager.clear()
# 定义槽位模式,用于文本提取与SQL拼接
slot_schema = {
"position": "岗位名称,如Java开发、Python开发等",
"experience_years": "工作经验年限(数字)",
"min_salary": "最低期望薪资(数字,单位:元)",
"max_salary": "最高期望薪资(数字,单位:元)",
"work_mode": "工作方式,如远程、现场、混合等"
}
# 实例化Text2SQL RAG系统对象
rag_system = Text2SQLRAG(llm, slot_schema, table_name="candidates")
# ========== 示例:多轮对话场景 ==========
# 打印分隔线和案例标题
print("="*60)
print("示例:多轮对话 - 逐步补充条件")
print("="*60)
# 第一轮对话:只给经验和岗位
print("\n【第一轮】")
result1 = rag_system.query("请帮我查找有三年以上Java开发经验的候选人")
print(f"用户查询:{result1['query']}")
print(f"提取的槽位:{result1['extracted_slots']}")
print(f"生成的SQL:{result1['sql']}")
result1_count = len(result1['results'])
print(f"查询结果数量:{result1_count}")
print(f"AI回答:{result1['answer']}\n")
# 第二轮对话:补充薪资条件
print("\n【第二轮】")
result2 = rag_system.query("期望薪资不超过2万元")
print(f"用户查询:{result2['query']}")
print(f"提取的槽位:{result2['extracted_slots']}")
print(f"当前所有槽位:{result2['current_slots']}")
print(f"生成的SQL:{result2['sql']}")
print(f"查询结果数量:{len(result2['results'])}")
print(f"AI回答:{result2['answer']}\n")
# 第三轮对话:补充工作方式
print("\n【第三轮】")
result3 = rag_system.query("必须能远程办公")
print(f"用户查询:{result3['query']}")
print(f"提取的槽位:{result3['extracted_slots']}")
print(f"当前所有槽位:{result3['current_slots']}")
print(f"生成的SQL:{result3['sql']}")
result3_count = len(result3['results'])
print(f"查询结果数量:{result3_count}")
print(f"AI回答:{result3['answer']}\n")
# ========== 示例:多条件组合查询 ==========
print("\n" + "="*60)
print("示例:多条件组合查询")
print("="*60)
# 清空会话,开始新一轮多条件对话
rag_system.clear_state()
# 第一轮:只设薪资上限
print("\n【第一轮】")
result1 = rag_system.query("期望薪资不超过2万")
print(f"用户查询:{result1['query']}")
print(f"提取的槽位:{result1['extracted_slots']}")
print(f"生成的SQL:{result1['sql']}")
print(f"查询结果数量:{len(result1['results'])}")
print(f"AI回答:{result1['answer']}\n")
# 第二轮:补充薪资下限
print("\n【第二轮】")
result2 = rag_system.query("有没有薪资高于1.5万的?")
print(f"用户查询:{result2['query']}")
print(f"提取的槽位:{result2['extracted_slots']}")
print(f"当前所有槽位:{result2['current_slots']}")
print(f"生成的SQL:{result2['sql']}")
print(f"查询结果数量:{len(result2['results'])}")
print(f"AI回答:{result2['answer']}\n")9.2 执行过程 #
9.2.1 核心思想 #
Text2SQL 采用“槽位提取+SQL生成”策略:
- 槽位提取:从自然语言中提取结构化条件(槽位)
- 对话状态管理:管理多轮对话中的槽位状态,支持逐步补充条件
- SQL生成:根据槽位生成SQL查询语句(使用参数化查询防止SQL注入)
- 数据库查询:使用SQLite数据库执行查询
- 答案生成:基于查询结果生成自然语言答案
9.2.2 执行流程 #
阶段一:初始化
# 1. 定义槽位模式
slot_schema = {
"position": "岗位名称,如Java开发、Python开发等",
"experience_years": "工作经验年限(数字)",
"min_salary": "最低期望薪资(数字,单位:元)",
"max_salary": "最高期望薪资(数字,单位:元)",
"work_mode": "工作方式,如远程、现场、混合等"
}
# 2. 创建Text2SQL RAG系统
rag_system = Text2SQLRAG(llm, slot_schema, table_name="candidates")初始化时:
- 定义槽位模式,描述每个槽位的含义
- 创建Text2SQL RAG系统,内部会创建:
- 槽位提取器(SlotExtractor)
- 对话状态管理器(DialogueStateManager)
- SQL生成器(SQLGenerator)
- SQLite数据库对象(SQLiteDatabase)
阶段二:数据库初始化
# SQLiteDatabase初始化时自动执行
database = SQLiteDatabase(db_path="candidates.db", table_name="candidates")数据库初始化:
- 自动创建数据库文件和表结构
- 如果表为空,自动插入示例数据(5个候选人记录)
阶段三:多轮对话处理
# 第一轮对话
result1 = rag_system.query("请帮我查找有三年以上Java开发经验的候选人")
# 第二轮对话(补充条件)
result2 = rag_system.query("期望薪资不超过2万元")
# 第三轮对话(继续补充条件)
result3 = rag_system.query("必须能远程办公")完整流程(每轮对话):
- 用户提交自然语言查询
- 槽位提取:
- 调用
slot_extractor.extract_slots(query) - 使用LLM从查询中提取结构化槽位
- 返回JSON格式的槽位字典
- 调用
- 状态更新:
- 调用
state_manager.update_slots(new_slots, query) - 将新槽位合并到当前槽位状态
- 记录对话历史
- 调用
- SQL生成:
- 调用
sql_generator.generate_sql(updated_slots) - 根据槽位生成SQL语句和参数列表(参数化查询)
- 调用
- 数据库查询:
- 调用
database.execute_sql(sql, params) - 执行SQL查询,返回结果列表
- 调用
- 结果格式化:
- 将查询结果格式化为文本
- 答案生成:
- 使用提示词模板构建最终prompt
- LLM生成自然语言答案
- 返回结果:
- 包含查询、提取的槽位、当前槽位、SQL、结果、答案等
9.2.3 类图 #
9.2.4 时序图 #
9.2.4.1 完整多轮对话流程时序图 #
9.2.4.2 槽位提取详细流程 #
9.2.4.3 SQL生成与数据库查询详细流程 #
9.2.4.4 多轮对话状态管理流程 #
9.2.5 关键设计要点 #
1. Text2SQL流程
自然语言查询
↓
槽位提取 (LLM)
↓
状态更新 (合并槽位)
↓
SQL生成 (参数化查询)
↓
数据库查询 (SQLite)
↓
结果格式化
↓
答案生成 (LLM)2. 槽位提取示例
用户查询: "请帮我查找有三年以上Java开发经验的候选人"
提取的槽位: {
"position": "Java开发",
"experience_years": 3
}多轮对话槽位累积:
第一轮: {"position": "Java开发", "experience_years": 3}
第二轮: + {"max_salary": 20000}
第三轮: + {"work_mode": "远程"}
最终槽位: {
"position": "Java开发",
"experience_years": 3,
"max_salary": 20000,
"work_mode": "远程"
}3. SQL生成策略
- 参数化查询:使用
?占位符,防止SQL注入 - 条件构建:根据槽位动态构建WHERE条件
- 参数列表:与SQL语句分离,安全传递参数
SQL生成示例:
-- 槽位: {"position": "Java开发", "experience_years": 3, "max_salary": 20000}
-- 生成的SQL:
SELECT * FROM candidates
WHERE position LIKE ?
AND experience_years >= ?
AND expected_salary <= ?
-- 参数列表:
["%Java开发%", 3, 20000]4. 对话状态管理
- 槽位累积:新槽位会合并到现有槽位中
- 历史记录:记录每轮对话前后的槽位状态
- 状态持久化:支持多轮对话的状态保持
- 可清空:支持清空状态重新开始
历史记录结构:
history = [
{
"query": "请帮我查找有三年以上Java开发经验的候选人",
"slots_before": {},
"slots_after": {"position": "Java开发", "experience_years": 3}
},
{
"query": "期望薪资不超过2万元",
"slots_before": {"position": "Java开发", "experience_years": 3},
"slots_after": {"position": "Java开发", "experience_years": 3, "max_salary": 20000}
}
]5. 数据库设计
表结构:
CREATE TABLE candidates (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
position TEXT NOT NULL,
experience_years INTEGER NOT NULL,
expected_salary INTEGER NOT NULL,
work_mode TEXT NOT NULL
)示例数据:
- 张三:Java开发,5年经验,18000元,远程
- 李四:Java开发,3年经验,15000元,远程
- 王五:Java开发,4年经验,22000元,现场
- 赵六:Python开发,3年经验,19000元,远程
- 钱七:Java开发,2年经验,12000元,远程
6. 答案生成模板
默认答案生成模板:
用户查询:{query}
生成的SQL:{sql}
查询结果:{results}
请根据查询结果,用自然语言回答用户的问题。
答案:该模板:
- 包含原始查询
- 显示生成的SQL(便于调试)
- 包含查询结果
- 要求用自然语言回答
7. 优势与应用场景
优势:
- 自然语言交互:用户可以用自然语言查询数据库
- 多轮对话:支持逐步补充条件,无需一次性提供所有信息
- 安全性:使用参数化查询,防止SQL注入
- 可扩展性:易于添加新的槽位类型
适用场景:
- 数据库查询系统:将自然语言转换为SQL查询
- 智能客服:支持多轮对话的查询系统
- 数据检索:非技术人员查询结构化数据
- 业务系统:HR系统、招聘系统等
8. 技术细节
- 槽位提取:
- 使用LLM提取结构化信息
- 使用正则表达式提取JSON
- 解析失败时返回空字典
- SQL生成:
- 使用参数化查询(
?占位符) - 动态构建WHERE条件
- 无条件时使用
1=1(返回所有记录)
- 使用参数化查询(
- 数据库操作:
- 使用SQLite的Row工厂,返回字典格式
- 支持参数化查询执行
- 异常处理:捕获SQL执行错误
9. 多轮对话示例
场景:逐步补充条件
第一轮: "请帮我查找有三年以上Java开发经验的候选人"
→ SQL: WHERE position LIKE '%Java开发%' AND experience_years >= 3
→ 结果: 3个候选人
第二轮: "期望薪资不超过2万元"
→ SQL: WHERE position LIKE '%Java开发%' AND experience_years >= 3 AND expected_salary <= 20000
→ 结果: 2个候选人(筛选掉22000元的)
第三轮: "必须能远程办公"
→ SQL: WHERE position LIKE '%Java开发%' AND experience_years >= 3 AND expected_salary <= 20000 AND work_mode = '远程'
→ 结果: 1个候选人(最终匹配)该设计通过“槽位提取+SQL生成+多轮对话”策略,在数据库查询场景下提供自然语言交互能力,适用于需要将自然语言转换为SQL查询的RAG应用。