导航菜单

  • 1.langchain.intro
  • 2.langchain.chat_models
  • 3.langchain.prompts
  • 4.langchain.example_selectors
  • 5.output_parsers
  • 6.Runnable
  • 7.memory
  • 8.document_loaders
  • 9.text_splitters
  • 10.embeddings
  • 11.tool
  • 12.retrievers
  • 13.optimize
  • 14.项目介绍
  • 15.启动HTTP
  • 16.数据与模型
  • 17.权限管理
  • 18.知识库管理
  • 19.设置
  • 20.文档管理
  • 21.聊天
  • 22.API文档
  • 23.RAG优化
  • 24.索引时优化
  • 25.检索前优化
  • 26.检索后优化
  • 27.系统优化
  • 28.GraphRAG
  • 29.图
  • 30.为什么选择图数据库
  • 31.什么是 Neo4j
  • 32.安装和连接 Neo4j
  • 33.Neo4j核心概念
  • 34.Cypher基础
  • 35.模式匹配
  • 36.数据CRUD操作
  • 37.GraphRAG
  • 38.查询和过滤
  • 39.结果处理和聚合
  • 40.语句组合
  • 41.子查询
  • 42.模式和约束
  • 43.日期时间处理
  • 44.Cypher内置函数
  • 45.Python操作Neo4j
  • 46.neo4j
  • 47.py2neo
  • 48.Streamlit
  • 49.Pandas
  • 50.graphRAG
  • 51.deepdoc
  • 52.deepdoc
  • 53.deepdoc
  • 55.deepdoc
  • 54.deepdoc
  • Pillow
  • 1. 伪文档生成法(Query-to-Document)
    • 1.1 QueryToDocument.py
    • 1.2 执行流程
      • 1.2.1 核心思想
      • 1.2.2 执行流程
      • 1.2.3 类图
      • 1.2.4 时序图
        • 1.2.4.1 完整RAG流程时序图
        • 1.2.4.2 伪文档生成详细流程
      • 1.2.5 关键设计要点
      • 1.2.6 与其他方法的对比
      • 1.2.7 优势与应用场景
      • 1.2.8 提示词模板设计
  • 2. 假设文档向量化(Assume Document Vectorization)
    • 2.1 AssumeDocumentVectorization.py
    • 2.2 执行流程
      • 2.2.1 核心思想
      • 2.2.2 执行流程
      • 2.2.3 类图
      • 2.2.4 时序图
        • 2.2.4.1 完整RAG流程时序图
        • 2.2.4.2 多角度假设文档生成详细流程
        • 2.2.4.3 向量平均与检索详细流程
      • 2.2.5 关键设计要点
      • 2.2.6 与其他方法的对比
      • 2.2.7. 优势与应用场景
      • 2.2.8. 技术细节
  • 3. 问题分解策略(Sub-Question Decomposition)
    • 3.1 SubQuestionDecomposition.py
    • 3.2 执行流程
      • 3.2.1 核心思想
      • 3.2.2 执行流程
    • 3.3 类图
    • 3.4 时序图
      • 3.4.1 完整RAG流程时序图
      • 3.4.2 子问题分解详细流程
      • 3.4.3 子问题检索与去重详细流程
    • 3.5 关键设计要点
    • 3.6. 与其他方法的对比
    • 3.7. 优势与应用场景
    • 3.8. 技术细节
  • 4. 多角度查询重写(Query Rewriting)
    • 4.1 QueryRewriting.py
    • 4.2 执行过程
      • 4.2.1 核心思想
      • 4.2.2 执行流程
      • 4.2.3 类图
      • 4.2.4 时序图
        • 4.2.4.1 完整RAG流程时序图
        • 4.2.4.2 查询重写详细流程
        • 4.2.4.3 多版本检索与去重详细流程
      • 4.2.5 关键设计要点
      • 4.2.6. 与子问题分解的对比
      • 4.2.7. 优势与应用场景
      • 4.2.8. 技术细节
      • 4.2.9. 查询重写的要求
  • 5. 抽象化查询转换(Take a Step Back)
    • 5.1 TakeAStepBack.py
    • 5.2 执行过程
      • 5.2.1 核心思想
      • 5.2.2 执行流程
      • 5.2.3 类图
      • 5.2.4 时序图
        • 5.2.4.1 完整RAG流程时序图
        • 5.2.4.2 抽象化转换详细流程
        • 5.2.4.3 双重检索与合并详细流程
      • 5.2.5 关键设计要点
      • 5.2.6 与其他方法的对比
      • 5.2.7 优势与应用场景
      • 5.2.8 技术细节
      • 5.2.9 检索统计
  • 6. 检索-生成一体化(Retrieval-generation integration)
    • 6.1 RetrievalGenerationIntegration.py
    • 6.2 执行过程
      • 6.2.1 核心思想
      • 6.2.2 执行流程
      • 6.2.3 类图
      • 6.2.4 时序图
        • 6.2.4.1 完整RAG流程时序图
        • 6.2.4.2 意图识别详细流程
        • 6.2.4.3 查询重写详细流程
        • 6.2.4.4 知识库路由与检索详细流程
      • 6.2.5 关键设计要点
      • 6.2.6 优势与应用场景
      • 6.2.7 技术细节
  • 7. 上下文对话(Contextual Dialogue)
    • 7.1 ContextualDialogue.py
    • 7.2 执行过程
      • 7.2.1 核心思想
      • 7.2.2 执行流程
      • 7.2.3 类图
      • 7.2.4 时序图
        • 7.2.4.1 完整多轮对话流程时序图
        • 7.2.4.2 智能查询生成详细流程
        • 7.2.4.3 多查询检索与合并详细流程
      • 7.2.5 关键设计要点
      • 7.2.6 优势与应用场景
      • 7.2.7 技术细节
  • 8. 行业场景改写(Industry scenario adaptation)
    • 8.1 IndustryScenarioAdaptation.py
    • 8.2 执行流程
      • 8.2.1 核心思想
      • 8.2.2 执行流程
      • 8.2.3 类图
      • 8.2.4 时序图
        • 8.2.4.1 完整RAG流程时序图
        • 8.2.4.2 行业适配改写详细流程
        • 8.2.4.3 不同行业适配器示例流程
      • 8.2.5 关键设计要点
  • 9. Text2SQL
    • 9.1. Text2SQL.py
    • 9.2 执行过程
      • 9.2.1 核心思想
      • 9.2.2 执行流程
      • 9.2.3 类图
      • 9.2.4 时序图
        • 9.2.4.1 完整多轮对话流程时序图
        • 9.2.4.2 槽位提取详细流程
        • 9.2.4.3 SQL生成与数据库查询详细流程
        • 9.2.4.4 多轮对话状态管理流程
      • 9.2.5 关键设计要点

1. 伪文档生成法(Query-to-Document) #

伪文档生成法(Query-to-Document)是一种在信息检索流程中提升检索结果相关性的重要技术。其主要思路是,用户输入查询后,先由大语言模型(LLM)基于该查询生成一段伪文档。这个伪文档可能包括对原始查询的进一步解释、相关背景知识、潜在答案片段,或者是更丰富的同义词与关键词补充。

伪文档不仅弥补了原始短查询可能缺失的信息,而且通过包含更多上下文,有助于覆盖更多相关文档,提高召回率。随后,将原始查询与伪文档拼接形成增强查询(augmented query),再送入向量化和检索流程。这样,检索系统能更好地理解和匹配用户意图,尤其在知识库规模大、查询本身语意不充分的场景下,效果尤为明显。

简单来说,伪文档生成法的关键步骤包括:

  • 调用LLM,根据用户查询生成伪文档(包含所需知识、扩展表述、补充背景等)。
  • 将伪文档与原始查询拼接形成增强型检索输入。
  • 对增强后的查询进行向量化或关键词检索,提升原始检索效果。

应用场景:

  • 用户输入模糊、信息不足的查询时,通过自动补全信息实现更优检索。
  • 面向多领域或大规模知识库时,增加召回率与相关性。
  • RAG(Retrieval-Augmented Generation)等检索增强生成任务。

小结:伪文档生成法是实现智能检索系统、提升用户查询体验的核心技术之一,在RAG等方案中有着非常广泛的应用和实用价值。

1.1 QueryToDocument.py #

# 引入类型注解相关的库
from typing import List, Dict, Any, Optional
# 引入基础检索器类
from langchain_core.retrievers import BaseRetriever
# 引入文档对象
from langchain_core.documents import Document
# 引入语言模型基类
from langchain_core.language_models import BaseLanguageModel
# 引入提示词模板类
from langchain_core.prompts import PromptTemplate
# 引入自定义llm对象
from llm import llm
# 引入自定义embedding对象
from embeddings import embeddings
# 引入获取向量库的函数
from vector_store import get_vector_store

# 定义QueryToDocumentRetriever检索器,继承自BaseRetriever
class QueryToDocumentRetriever(BaseRetriever):
    # 定义向量库属性
    vector_store: Any
    # 定义LLM属性
    llm: BaseLanguageModel
    # 定义embedding属性
    embeddings: Any
    # k 默认检索文档数量
    k: int = 4
    # 可选的自定义伪文档提示词模板字符串
    pseudo_document_prompt_template_str: Optional[str] = None

    # 获取(或生成)伪文档的prompt模板
    def _get_prompt_template(self) -> PromptTemplate:
        # 如果定义了自定义模板则使用
        if self.pseudo_document_prompt_template_str:
            return PromptTemplate(
                input_variables=["query"],
                template=self.pseudo_document_prompt_template_str
            )
        # 否则使用内置模板
        return PromptTemplate(
            input_variables=["query"],
            template="""请根据以下用户查询,生成一段伪文档,需包含关键信息、相关术语及必要背景,并自然流畅呈现。\n用户查询:{query}\n伪文档:"""
        )

    # 根据查询,通过LLM生成伪文档
    def generate_pseudo_document(self, query: str) -> str:
        # 获取提示词模板并填充
        prompt = self._get_prompt_template().format(query=query)
        # 用LLM生成伪文档
        response = self.llm.invoke(prompt)
        # 返回伪文档内容,去除首尾空格
        return response.content.strip()

    # 将原查询和伪文档拼接为增强查询
    def enhance_query(self, query: str, pseudo_document: str) -> str:
        return f"{query}\n\n{pseudo_document}"

    # 检索相关文档
    def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
        # 获取需检索的文档数k
        k = kwargs.get("k", self.k)
        # 生成伪文档
        pseudo_document = self.generate_pseudo_document(query)
        # 构建增强后的查询
        enhanced_query = self.enhance_query(query, pseudo_document)
        # 用增强查询向量库检索,获取top-k文档(包含分数)
        docs = self.vector_store.similarity_search_with_score(enhanced_query, k=k)
        # 构建Document对象并附加相关元信息
        return [
            Document(
                page_content=doc.page_content,
                metadata={
                    **doc.metadata,
                    "score": float(distance),
                    "retrieval_method": "query_to_document",
                    "original_query": query,
                    "pseudo_document_length": len(pseudo_document)
                }
            ) for doc, distance in docs
        ]


# 定义RAG系统类,实现拼接、生成最终答案
class QueryToDocumentRAG:
    # 构造方法,传入检索器、LLM和可选回答模板
    def __init__(
        self,
        retriever: QueryToDocumentRetriever,
        llm: BaseLanguageModel,
        answer_prompt_template: Optional[str] = None
    ):
        self.retriever = retriever
        self.llm = llm
        # 初始化答案生成的PromptTemplate
        self.answer_prompt_template = PromptTemplate(
            input_variables=["context", "query"],
            template=answer_prompt_template or
            "已知相关信息:\n{context}\n请基于上述信息回答用户问题,如信息不足请指出。\n用户问题:{query}\n答案:"
        )

    # 生成最终答案的方法
    def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
        # 检索相关文档
        docs = self.retriever._get_relevant_documents(query, k=k)
        # 整合文档内容用于llm生成答案
        context = "\n\n".join([doc.page_content for doc in docs])
        # 用prompt模板拼接最终提示词
        prompt = self.answer_prompt_template.format(context=context, query=query)
        # 调用llm生成答案
        response = self.llm.invoke(prompt)
        # 返回包含问题、答案、检索文档和数量的字典
        return {
            "query": query,
            "answer": response.content.strip(),
            "retrieved_documents": docs,
            "num_documents": len(docs)
        }

# 初始化向量库
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="query_to_document"
)
# 初始化QueryToDocumentRetriever对象
retriever = QueryToDocumentRetriever(
    vector_store=vector_store,
    llm=llm,
    embeddings=embeddings,
    k=4
)
# 初始化RAG系统对象
rag_system = QueryToDocumentRAG(
    retriever=retriever,
    llm=llm
)

# 构造示例文档数据
documents = [
    "人工智能(AI)是计算机科学的分支,旨在使系统完成需人类智能的任务。机器学习让计算机从数据学习,无需明确编程。深度学习通过神经网络模拟人脑。自然语言处理(NLP)让计算机理解人类语言。",
    "区块链是一种分布式账本技术,用密码学保障安全和不可篡改。比特币是区块链的首次大规模应用,解决双重支付难题。以太坊支持智能合约,便于去中心化应用开发。智能合约自动执行,无需中介。",
    "量子计算利用量子力学的新型计算方式,极具潜力。量子比特可处于0和1的叠加态。量子纠缠是其关键特性。量子计算或将应用于密码学、药物发现与优化领域。"
]

# 向向量库中批量添加文档及对应元数据
vector_store.add_texts(
    documents,
    metadatas=[
        {"topic": "人工智能", "category": "科技", "doc_id": "ai_1"},
        {"topic": "区块链", "category": "科技", "doc_id": "blockchain_1"},
        {"topic": "量子计算", "category": "科技", "doc_id": "quantum_1"},
    ]
)

# 设定一个示例用户查询
query = "什么是机器学习?它和深度学习有什么关系?"
# 执行完整RAG流程,检索3个相关文档
result = rag_system.generate_answer(query, k=3)

# 打印流程和最终输出结果的分隔线
print("\n" + "="*60)
print("最终结果")
print("="*60)
# 打印用户查询内容
print(f"\n用户查询: {result['query']}")
# 打印生成的答案
print(f"\n生成的答案:\n{result['answer']}")
# 打印检索到的文档数量
print(f"\n检索到的文档数量: {result['num_documents']}")
# 打印检索到的每个文档及相关信息
print(f"\n检索到的文档详情:")
for i, doc in enumerate(result['retrieved_documents'], 1):
    print(f"\n文档 {i} (分数: {doc.metadata.get('score', 'N/A'):.4f}):")
    print(f"内容: {doc.page_content[:200]}...")
    print(f"元数据: {doc.metadata}")

1.2 执行流程 #

1.2.1 核心思想 #

Query-to-Document 采用“查询扩展”策略:

  • 根据用户查询,使用 LLM 生成一个伪文档(pseudo document)
  • 伪文档包含关键信息、相关术语和必要背景
  • 将原始查询和伪文档拼接成增强查询
  • 使用增强查询在向量库中进行检索
  • 返回检索结果并生成最终答案

1.2.2 执行流程 #

阶段一:初始化

# 1. 获取向量存储实例
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="query_to_document"
)

# 2. 创建QueryToDocumentRetriever检索器
retriever = QueryToDocumentRetriever(
    vector_store=vector_store,
    llm=llm,
    embeddings=embeddings,
    k=4
)

# 3. 创建RAG系统
rag_system = QueryToDocumentRAG(
    retriever=retriever,
    llm=llm
)

初始化时:

  • 创建向量存储实例
  • 创建检索器,配置 LLM、embeddings 和默认 k 值
  • 创建 RAG 系统,用于完整流程

阶段二:文档索引

# 添加文档到向量库
vector_store.add_texts(
    documents,
    metadatas=[...]
)

索引过程:

  • 将示例文档添加到向量库
  • 每个文档附带元数据(topic、category、doc_id)
  • 文档会被向量化并存储

阶段三:查询处理

query = "什么是机器学习?它和深度学习有什么关系?"
result = rag_system.generate_answer(query, k=3)

完整流程:

  1. 用户提交查询
  2. 生成伪文档:
    • 调用 generate_pseudo_document(query)
    • 使用提示词模板生成伪文档
    • LLM 生成包含关键信息、术语和背景的文本
  3. 增强查询:
    • 调用 enhance_query(query, pseudo_document)
    • 将原始查询和伪文档拼接
  4. 检索文档:
    • 使用增强查询在向量库中检索
    • 返回 top-k 个相关文档(带分数)
  5. 生成答案:
    • 整合检索到的文档内容作为上下文
    • 使用提示词模板构建最终 prompt
    • LLM 生成答案
  6. 返回结果:
    • 包含查询、答案、检索文档和数量

1.2.3 类图 #

classDiagram class QueryToDocumentRetriever { -vector_store: Any -llm: BaseLanguageModel -embeddings: Any -k: int -pseudo_document_prompt_template_str: Optional[str] +_get_prompt_template() PromptTemplate +generate_pseudo_document(query: str) str +enhance_query(query: str, pseudo_document: str) str +_get_relevant_documents(query: str, **kwargs) List[Document] } class QueryToDocumentRAG { -retriever: QueryToDocumentRetriever -llm: BaseLanguageModel -answer_prompt_template: PromptTemplate +__init__(retriever, llm, answer_prompt_template) +generate_answer(query: str, k: int) Dict[str, Any] } class BaseRetriever { <<abstract>> +invoke(query: str) List[Document] } class PromptTemplate { +format(**kwargs) str } class Document { +page_content: str +metadata: Dict } class VectorStore { <<interface>> +add_texts(texts: List[str], metadatas: List[Dict]) +similarity_search_with_score(query: str, k: int) List[Tuple[Document, float]] } class BaseLanguageModel { <<interface>> +invoke(prompt: str) AIMessage } QueryToDocumentRetriever --|> BaseRetriever QueryToDocumentRAG --> QueryToDocumentRetriever QueryToDocumentRAG --> BaseLanguageModel QueryToDocumentRetriever --> BaseLanguageModel QueryToDocumentRetriever --> PromptTemplate QueryToDocumentRetriever --> VectorStore QueryToDocumentRetriever ..> Document : creates VectorStore ..> Document : returns

1.2.4 时序图 #

1.2.4.1 完整RAG流程时序图 #
sequenceDiagram participant User as 用户 participant RAG as QueryToDocumentRAG participant Retriever as QueryToDocumentRetriever participant PromptTemplate as PromptTemplate participant LLM as BaseLanguageModel participant VectorStore as VectorStore User->>RAG: generate_answer("什么是机器学习?...", k=3) RAG->>Retriever: _get_relevant_documents(query, k=3) Note over Retriever: 步骤1: 生成伪文档 Retriever->>Retriever: _get_prompt_template() Retriever->>PromptTemplate: format(query=用户查询) PromptTemplate-->>Retriever: 返回完整prompt Retriever->>LLM: invoke(prompt) Note over LLM: 生成伪文档<br/>(包含关键信息、术语、背景) LLM-->>Retriever: 返回伪文档内容 Note over Retriever: 步骤2: 增强查询 Retriever->>Retriever: enhance_query(query, pseudo_document) Note over Retriever: 拼接原始查询和伪文档 Note over Retriever: 步骤3: 检索文档 Retriever->>VectorStore: similarity_search_with_score(enhanced_query, k=3) Note over VectorStore: 使用增强查询进行相似度检索 VectorStore-->>Retriever: [(doc1, score1), (doc2, score2), (doc3, score3)] Retriever->>Retriever: 构建Document对象<br/>(附加元数据) Retriever-->>RAG: 返回检索文档列表 Note over RAG: 步骤4: 生成答案 RAG->>RAG: 整合文档内容为上下文 RAG->>PromptTemplate: format(context=上下文, query=原始查询) PromptTemplate-->>RAG: 返回完整prompt RAG->>LLM: invoke(prompt) Note over LLM: 基于上下文生成答案 LLM-->>RAG: 返回答案内容 RAG-->>User: 返回结果字典<br/>(query, answer, documents, num_documents)
1.2.4.2 伪文档生成详细流程 #
sequenceDiagram participant Retriever as QueryToDocumentRetriever participant Template as PromptTemplate participant LLM as BaseLanguageModel Note over Retriever: generate_pseudo_document(query) Retriever->>Retriever: _get_prompt_template() alt 有自定义模板 Retriever->>Template: 使用自定义模板 else 使用默认模板 Retriever->>Template: 使用默认模板<br/>"请根据以下用户查询,生成一段伪文档..." end Retriever->>Template: format(query=用户查询) Template-->>Retriever: 返回完整prompt Retriever->>LLM: invoke(prompt) Note over LLM: 生成伪文档<br/>包含:<br/>- 关键信息<br/>- 相关术语<br/>- 必要背景<br/>- 自然流畅的文本 LLM-->>Retriever: 返回伪文档内容 Retriever->>Retriever: strip() 去除首尾空格 Retriever-->>Retriever: 返回伪文档字符串

1.2.5 关键设计要点 #

1. 查询增强流程

用户查询
    ↓
生成伪文档 (LLM)
    ↓
增强查询 = 原始查询 + 伪文档
    ↓
向量检索 (使用增强查询)
    ↓
返回相关文档
    ↓
生成最终答案

2. 伪文档的作用

伪文档包含:

  • 关键信息:查询相关的核心概念
  • 相关术语:专业词汇和同义词
  • 必要背景:上下文信息
  • 自然流畅:连贯的文本表达

示例:

用户查询: "什么是机器学习?"

伪文档可能生成:
"机器学习是人工智能的一个分支,它使计算机能够从数据中学习模式,
而无需明确编程。机器学习涉及算法、统计模型和数据分析技术,
广泛应用于图像识别、自然语言处理、推荐系统等领域。"

3. 增强查询的优势

  • 语义扩展:伪文档提供更多相关语义信息
  • 术语匹配:包含同义词和相关术语,提高匹配率
  • 上下文丰富:提供背景信息,帮助理解查询意图
  • 检索精度:增强查询能更准确地匹配相关文档

4. 元数据设计

检索返回的 Document 对象包含:

  • score:相似度分数(距离,越小越相似)
  • retrieval_method: "query_to_document":检索方法标识
  • original_query:原始用户查询
  • pseudo_document_length:伪文档长度
  • 继承原始文档的元数据(topic、category、doc_id 等)

1.2.6 与其他方法的对比 #

特性 传统检索 Query-to-Document
查询处理 直接使用原始查询 生成伪文档并增强查询
语义扩展 无 通过伪文档扩展
术语匹配 依赖查询中的术语 伪文档包含相关术语
计算成本 低 中等(需要LLM生成)
适用场景 简单查询 复杂、模糊查询

1.2.7 优势与应用场景 #

优势:

  • 提高检索精度:增强查询包含更多语义信息
  • 处理模糊查询:伪文档帮助理解查询意图
  • 术语扩展:自动包含相关术语和同义词
  • 上下文理解:提供背景信息,改善匹配

适用场景:

  • 复杂查询:需要语义扩展的查询
  • 模糊查询:用户表达不够精确
  • 专业领域:需要术语扩展的场景
  • 多语言场景:需要跨语言语义理解

1.2.8 提示词模板设计 #

默认伪文档生成模板:

请根据以下用户查询,生成一段伪文档,需包含关键信息、
相关术语及必要背景,并自然流畅呈现。
用户查询:{query}
伪文档:

默认答案生成模板:

已知相关信息:
{context}
请基于上述信息回答用户问题,如信息不足请指出。
用户问题:{query}
答案:

该设计通过查询扩展提升检索质量,适用于需要语义扩展和上下文理解的 RAG 应用。

2. 假设文档向量化(Assume Document Vectorization) #

在RAG(Retrieval-Augmented Generation)流程中,检索质量直接影响最终答案的准确性。传统做法通常是直接用用户查询来做向量化检索,但由于自然语言查询本身的简洁性和主观性,往往不能覆盖用户潜在的多角度信息需求,因此导致召回的文档有限或者语义相关性不足。假设文档向量化法正是对此问题的优化。

其核心流程如下:

  1. 假设文档生成
    给定用户原始查询,利用LLM(大语言模型)“扩写”出多个假设文档。这些假设文档旨在“假装”是合理相关的答案片段,从多个视角、更细粒度或更全面的表述对原始查询进行“扩展”。这样不仅弥补查询语义的不全,也能覆盖更多可能的信息需求。

  2. 独立向量化
    将原始查询和所有假设文档分别输入嵌入模型,得到各自的向量表示。每个假设文档的向量都是在不同上下文下对原查询的“信息补全”。

  3. 平均向量合成
    计算所有这些向量的算术平均值,把它作为增强后的“集成查询向量”。均值在理论上可视为多角度信息的“中心向量”,这样可提高检索相关文档的概率。

  4. 向量检索
    使用该平均向量去向量库中检索,能够显著提升与真实高相关文档的召回率,尤其对复杂问题、语义表达多变的问题尤为有效。

这种方案不仅简单(只需增加一次假设文档生成和向量平均),无需模型微调,而且对标准向量数据库和检索管道兼容性好。

典型应用场景举例:

  • 面对高度开放性问题(如“如何提高新能源电池效率?”),假设文档向量化可以自动从材料、制造工艺、充放电策略等多角度激活潜在有效的语义空间;
  • 复杂信息需求聚合型检索(如“机器学习与深度学习的异同?”)时,通过多种说法的假设文档更容易接近教材、论文等结构化知识内容。

2.1 AssumeDocumentVectorization.py #

# 假设文档向量化(Assume Document Vectorization)
"""
通过LLM辅助生成多角度假设文档,对原始查询和假设文档向量化后求平均向量,用于更精准的知识库检索与RAG增强。
"""

# 导入类型注解相关依赖
from typing import List, Dict, Any
# 导入基础检索器类
from langchain_core.retrievers import BaseRetriever
# 导入文档对象
from langchain_core.documents import Document
# 导入语言模型基类
from langchain_core.language_models import BaseLanguageModel
# 导入提示词模板
from langchain_core.prompts import PromptTemplate
# 导入自定义llm对象
from llm import llm
# 导入自定义embedding对象
from embeddings import embeddings
# 导入获取向量库的函数
from vector_store import get_vector_store
# 导入numpy用于向量计算
import numpy as np

# 定义假设文档向量化检索器,继承自BaseRetriever
class AssumeDocumentVectorizationRetriever(BaseRetriever):
    # 定义向量库属性
    vector_store: Any
    # 定义LLM属性
    llm: BaseLanguageModel
    # 定义embedding属性
    embeddings: Any
    # 默认检索文档数量
    k: int = 4
    # 生成假设文档的数量
    num_hypothetical_documents: int = 3

    # 获取不同角度的Prompt模板
    def _get_prompt_templates(self) -> List[PromptTemplate]:
        # 定义角度列表
        perspectives = [
            "学术研究角度", "实际应用角度", "基础概念角度"
        ]
        # 针对不同角度生成PromptTemplate对象
        return [
            PromptTemplate(
                input_variables=["query", "perspective"],
                template=f"请从{perspectives[i]},针对如下问题生成回答:{{query}}\n详细叙述。"
            ) for i in range(self.num_hypothetical_documents)
        ]

    # 通过LLM生成多角度假设文档
    def generate_hypothetical_documents(self, query: str) -> List[str]:
        # 存储生成的文档内容
        docs = []
        # 遍历每个Prompt模板
        for t in self._get_prompt_templates():
            # 获取当前角度描述
            p = t.template.split("从")[1].split(",")[0] if "从" in t.template else ""
            # 填充Prompt
            prompt = t.format(query=query, perspective=p)
            # 用LLM生成内容并去除首尾空白
            content = self.llm.invoke(prompt).content.strip()
            # 添加到文档列表
            docs.append(content)
        # 返回生成的假设文档列表
        return docs

    # 将查询和假设文档批量向量化
    def vectorize_documents(self, query: str, docs: List[str]) -> List[List[float]]:
        # 首先对原始查询进行向量化
        vectors = [self.embeddings.embed_query(query)]
        # 对每个假设文档进行向量化
        for doc in docs:
            vectors.append(self.embeddings.embed_query(doc))
        # 返回得到的全部向量
        return vectors

    # 计算所有向量的平均向量
    def calculate_average_vector(self, vectors: List[List[float]]) -> List[float]:
        # 使用numpy计算均值
        return np.mean(np.array(vectors), axis=0).tolist()

    # 用指定向量进行向量库检索
    def retrieve_with_vector(self, query_vector: List[float], k: int) -> List[tuple]:
        # 直接用底层Chroma接口进行向量检索
        r = self.vector_store._collection.query(query_embeddings=[query_vector], n_results=k)
        # 获取检索到的文档内容
        docs = r.get("documents", [[]])[0]
        # 获取距离值
        dists = r.get("distances", [[]])[0]
        # 获取元数据
        metas = r.get("metadatas", [[]])[0] if r.get("metadatas", [[]]) else [{}]*len(docs)
        # 将文档内容、距离、元数据打包
        return list(zip(docs, dists, metas))

    # 供RAG主流程调用的相关文档检索方法
    def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
        # 获取需检索的文档数k
        k = kwargs.get("k", self.k)
        # 生成多角度假设文档文本
        hypos = self.generate_hypothetical_documents(query)
        # 对原查询和假设文档全部向量化
        vectors = self.vectorize_documents(query, hypos)
        # 计算向量平均作为查询向量
        avg_vec = self.calculate_average_vector(vectors)
        # 使用平均向量进行检索
        results = self.retrieve_with_vector(avg_vec, k)
        # 构建Document对象并附加分数元信息
        return [
            Document(
                page_content=txt,
                metadata={**meta, "score": float(dist)}
            )
            for txt, dist, meta in results
        ]

# 定义假设文档向量化RAG系统
class AssumeDocumentVectorizationRAG:
    # 构造方法,传入检索器和LLM
    def __init__(self, retriever: AssumeDocumentVectorizationRetriever, llm: BaseLanguageModel):
        # 存储检索器
        self.retriever = retriever
        # 存储llm
        self.llm = llm
        # 定义答案生成Prompt模板
        self.answer_prompt_template = PromptTemplate(
            input_variables=["context", "query"],
            template="""已知信息:{context}\n请回答:{query}\n答案:"""
        )

    # 生成最终答案的方法
    def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
        # 检索相关文档
        docs = self.retriever._get_relevant_documents(query, k=k)
        # 整合检索到的文档内容
        context = "\n".join([doc.page_content for doc in docs])
        # 用Prompt模板生成回答用的提示词
        prompt = self.answer_prompt_template.format(context=context, query=query)
        # 调用LLM生成答案
        answer = self.llm.invoke(prompt).content.strip()
        # 返回结果字典
        return {
            "query": query,
            "answer": answer,
            "retrieved_documents": docs,
            "num_documents": len(docs)
        }

# 初始化向量库实例
vector_store = get_vector_store(persist_directory="chroma_db", collection_name="assume_document_vectorization")
# 初始化假设文档向量化检索器
retriever = AssumeDocumentVectorizationRetriever(
    vector_store=vector_store, llm=llm, embeddings=embeddings, k=4, num_hypothetical_documents=3
)
# 初始化RAG主流程类
rag_system = AssumeDocumentVectorizationRAG(retriever=retriever, llm=llm)

# 构造示例文档数据
documents = [
    "人工智能(AI)是计算机科学的分支,致力于让计算机完成类似人类智能的任务。机器学习让计算机从数据中学习,无需显式编程。深度学习用神经网络模拟人脑。NLP使计算机理解和生成自然语言。",
    "区块链是一种分布式账本技术,以密码学保障安全与不可篡改。比特币是首个区块链加密货币,解决了双重支付难题。以太坊支持智能合约,可构建去中心化应用。智能合约自动执行,无需中介。",
    "量子计算利用量子力学进行信息处理,具巨大潜力。量子比特可叠加多态,量子纠缠是其核心特性。量子计算或将在密码学、药物等领域带来变革。",
]
# 向向量库中添加文档及元数据
vector_store.add_texts(
    documents,
    metadatas=[
        {"topic": "人工智能", "category": "科技"},
        {"topic": "区块链", "category": "科技"},
        {"topic": "量子计算", "category": "科技"},
    ]
)

# 设定用户查询问题
query = "什么是机器学习?它和深度学习有什么关系?"
# 运行RAG流程,检索3个相关文档并生成答案
result = rag_system.generate_answer(query, k=3)
# 打印用户查询
print("\n用户查询:", result['query'])
# 打印生成的答案
print("\n生成的答案:\n", result['answer'])
# 打印检索到的文档数量
print("\n检索到的文档数量:", result['num_documents'])
# 打印检索到的每个文档的部分内容及元数据
for i, doc in enumerate(result['retrieved_documents'], 1):
    print(f"\n文档{i}: {doc.page_content[:80]}...")
    print(f"元数据: {doc.metadata}")

2.2 执行流程 #

2.2.1 核心思想 #

假设文档向量化采用“多角度假设 + 向量平均”策略:

  • 根据用户查询,使用 LLM 从多个角度生成假设文档
  • 对原始查询和所有假设文档进行向量化
  • 计算所有向量的平均向量作为查询向量
  • 使用平均向量在向量库中进行检索
  • 返回检索结果并生成最终答案

2.2.2 执行流程 #

阶段一:初始化

# 1. 获取向量存储实例
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="assume_document_vectorization"
)

# 2. 创建假设文档向量化检索器
retriever = AssumeDocumentVectorizationRetriever(
    vector_store=vector_store,
    llm=llm,
    embeddings=embeddings,
    k=4,
    num_hypothetical_documents=3  # 生成3个假设文档
)

# 3. 创建RAG系统
rag_system = AssumeDocumentVectorizationRAG(
    retriever=retriever,
    llm=llm
)

初始化时:

  • 创建向量存储实例
  • 创建检索器,配置 LLM、embeddings、默认 k 值和假设文档数量
  • 创建 RAG 系统,用于完整流程

阶段二:文档索引

# 添加文档到向量库
vector_store.add_texts(
    documents,
    metadatas=[...]
)

索引过程:

  • 将示例文档添加到向量库
  • 每个文档附带元数据(topic、category)
  • 文档会被向量化并存储

阶段三:查询处理

query = "什么是机器学习?它和深度学习有什么关系?"
result = rag_system.generate_answer(query, k=3)

完整流程:

  1. 用户提交查询
  2. 生成多角度假设文档:
    • 调用 generate_hypothetical_documents(query)
    • 从3个角度生成假设文档(学术研究、实际应用、基础概念)
    • 每个角度使用不同的提示词模板
  3. 向量化处理:
    • 调用 vectorize_documents(query, hypos)
    • 对原始查询进行向量化
    • 对每个假设文档进行向量化
  4. 计算平均向量:
    • 调用 calculate_average_vector(vectors)
    • 使用 numpy 计算所有向量的平均值
  5. 向量检索:
    • 调用 retrieve_with_vector(avg_vec, k)
    • 使用平均向量在向量库中检索
    • 返回 top-k 个相关文档(带分数)
  6. 生成答案:
    • 整合检索到的文档内容作为上下文
    • 使用提示词模板构建最终 prompt
    • LLM 生成答案
  7. 返回结果:
    • 包含查询、答案、检索文档和数量

2.2.3 类图 #

classDiagram class AssumeDocumentVectorizationRetriever { -vector_store: Any -llm: BaseLanguageModel -embeddings: Any -k: int -num_hypothetical_documents: int +_get_prompt_templates() List[PromptTemplate] +generate_hypothetical_documents(query: str) List[str] +vectorize_documents(query: str, docs: List[str]) List[List[float]] +calculate_average_vector(vectors: List[List[float]]) List[float] +retrieve_with_vector(query_vector: List[float], k: int) List[tuple] +_get_relevant_documents(query: str, **kwargs) List[Document] } class AssumeDocumentVectorizationRAG { -retriever: AssumeDocumentVectorizationRetriever -llm: BaseLanguageModel -answer_prompt_template: PromptTemplate +__init__(retriever, llm) +generate_answer(query: str, k: int) Dict[str, Any] } class BaseRetriever { <<abstract>> +invoke(query: str) List[Document] } class PromptTemplate { +format(**kwargs) str } class Document { +page_content: str +metadata: Dict } class VectorStore { <<interface>> +add_texts(texts: List[str], metadatas: List[Dict]) +_collection: ChromaCollection } class BaseLanguageModel { <<interface>> +invoke(prompt: str) AIMessage } class Embeddings { <<interface>> +embed_query(text: str) List[float] } class NumPy { +mean(array, axis) ndarray } AssumeDocumentVectorizationRetriever --|> BaseRetriever AssumeDocumentVectorizationRAG --> AssumeDocumentVectorizationRetriever AssumeDocumentVectorizationRAG --> BaseLanguageModel AssumeDocumentVectorizationRetriever --> BaseLanguageModel AssumeDocumentVectorizationRetriever --> PromptTemplate AssumeDocumentVectorizationRetriever --> VectorStore AssumeDocumentVectorizationRetriever --> Embeddings AssumeDocumentVectorizationRetriever --> NumPy AssumeDocumentVectorizationRetriever ..> Document : creates VectorStore ..> Document : returns

2.2.4 时序图 #

2.2.4.1 完整RAG流程时序图 #
sequenceDiagram participant User as 用户 participant RAG as AssumeDocumentVectorizationRAG participant Retriever as AssumeDocumentVectorizationRetriever participant PromptTemplates as PromptTemplate列表 participant LLM as BaseLanguageModel participant Embeddings as Embeddings模型 participant NumPy as NumPy participant VectorStore as VectorStore User->>RAG: generate_answer("什么是机器学习?...", k=3) RAG->>Retriever: _get_relevant_documents(query, k=3) Note over Retriever: 步骤1: 生成多角度假设文档 Retriever->>Retriever: _get_prompt_templates() Retriever->>PromptTemplates: 创建3个角度的模板<br/>(学术研究、实际应用、基础概念) loop 对每个角度模板 Retriever->>PromptTemplates: format(query=用户查询, perspective=角度) PromptTemplates-->>Retriever: 返回完整prompt Retriever->>LLM: invoke(prompt) Note over LLM: 从该角度生成假设文档 LLM-->>Retriever: 返回假设文档内容 end Note over Retriever: 步骤2: 向量化处理 Retriever->>Retriever: vectorize_documents(query, hypos) Retriever->>Embeddings: embed_query(原始查询) Embeddings-->>Retriever: 返回查询向量 loop 对每个假设文档 Retriever->>Embeddings: embed_query(假设文档) Embeddings-->>Retriever: 返回文档向量 end Note over Retriever: 步骤3: 计算平均向量 Retriever->>Retriever: calculate_average_vector(vectors) Retriever->>NumPy: mean(vectors, axis=0) Note over NumPy: 计算所有向量的平均值 NumPy-->>Retriever: 返回平均向量 Note over Retriever: 步骤4: 向量检索 Retriever->>Retriever: retrieve_with_vector(avg_vec, k=3) Retriever->>VectorStore: _collection.query(query_embeddings=[avg_vec], n_results=3) Note over VectorStore: 使用平均向量进行相似度检索 VectorStore-->>Retriever: 返回(documents, distances, metadatas) Retriever->>Retriever: 构建Document对象<br/>(附加分数元数据) Retriever-->>RAG: 返回检索文档列表 Note over RAG: 步骤5: 生成答案 RAG->>RAG: 整合文档内容为上下文 RAG->>PromptTemplate: format(context=上下文, query=原始查询) PromptTemplate-->>RAG: 返回完整prompt RAG->>LLM: invoke(prompt) Note over LLM: 基于上下文生成答案 LLM-->>RAG: 返回答案内容 RAG-->>User: 返回结果字典<br/>(query, answer, documents, num_documents)
2.2.4.2 多角度假设文档生成详细流程 #
sequenceDiagram participant Retriever as AssumeDocumentVectorizationRetriever participant Templates as PromptTemplate列表 participant LLM as BaseLanguageModel Note over Retriever: generate_hypothetical_documents(query) Retriever->>Retriever: _get_prompt_templates() Note over Retriever: 创建3个角度的模板 Retriever->>Templates: 模板1: "请从学术研究角度..." Retriever->>Templates: 模板2: "请从实际应用角度..." Retriever->>Templates: 模板3: "请从基础概念角度..." loop 遍历每个模板 Retriever->>Retriever: 提取角度描述 Retriever->>Templates: format(query=查询, perspective=角度) Templates-->>Retriever: 返回完整prompt Retriever->>LLM: invoke(prompt) Note over LLM: 从该角度生成详细回答<br/>作为假设文档 LLM-->>Retriever: 返回假设文档内容 Retriever->>Retriever: strip() 去除首尾空格 Retriever->>Retriever: 添加到文档列表 end Retriever-->>Retriever: 返回假设文档列表<br/>[doc1, doc2, doc3]
2.2.4.3 向量平均与检索详细流程 #
sequenceDiagram participant Retriever as AssumeDocumentVectorizationRetriever participant Embeddings as Embeddings模型 participant NumPy as NumPy participant VectorStore as VectorStore Note over Retriever: vectorize_documents(query, hypos) Retriever->>Embeddings: embed_query(原始查询) Embeddings-->>Retriever: 返回查询向量 [v1] loop 对每个假设文档 Retriever->>Embeddings: embed_query(假设文档i) Embeddings-->>Retriever: 返回文档向量 [vi] end Note over Retriever: 现在有4个向量:<br/>[查询向量, 假设文档1, 假设文档2, 假设文档3] Retriever->>Retriever: calculate_average_vector(vectors) Retriever->>NumPy: mean(vectors, axis=0) Note over NumPy: 计算4个向量的平均值<br/>得到平均查询向量 NumPy-->>Retriever: 返回平均向量 [avg_vec] Retriever->>Retriever: retrieve_with_vector(avg_vec, k=3) Retriever->>VectorStore: _collection.query(query_embeddings=[avg_vec], n_results=3) Note over VectorStore: 使用平均向量进行相似度检索<br/>返回最相关的3个文档 VectorStore-->>Retriever: 返回(documents, distances, metadatas) Retriever->>Retriever: 构建Document对象列表 Retriever-->>Retriever: 返回检索结果

2.2.5 关键设计要点 #

1. 多角度假设文档生成流程

用户查询
    ↓
从3个角度生成假设文档:
  - 学术研究角度
  - 实际应用角度  
  - 基础概念角度
    ↓
得到3个假设文档
    ↓
向量化: [查询向量, 假设文档1向量, 假设文档2向量, 假设文档3向量]
    ↓
计算平均向量
    ↓
使用平均向量检索
    ↓
返回相关文档

2. 向量平均的优势

  • 语义融合:融合原始查询和多个角度的语义信息
  • 提高精度:平均向量能更好地代表查询意图
  • 多角度覆盖:涵盖不同角度的语义表达
  • 鲁棒性:减少单一角度可能带来的偏差

3. 角度设计

默认三个角度:

  • 学术研究角度:理论、研究、学术视角
  • 实际应用角度:实践、应用、案例视角
  • 基础概念角度:概念、定义、基础视角

示例:

用户查询: "什么是机器学习?"

学术研究角度假设文档:
"机器学习是人工智能领域的重要研究方向,涉及统计学、
算法理论和计算复杂性等学科。研究者通过设计算法使计算
机能够从数据中发现模式,建立预测模型..."

实际应用角度假设文档:
"机器学习在现实生活中广泛应用,如推荐系统、图像识别、
语音助手等。企业利用机器学习技术分析用户行为,提供
个性化服务,提升用户体验和业务效率..."

基础概念角度假设文档:
"机器学习是一种让计算机从数据中学习的方法,无需明确
编程。它通过训练数据建立模型,能够对新数据进行预测
或分类。主要包括监督学习、无监督学习和强化学习..."

4. 向量计算过程

数学表示:

设原始查询向量为 $\mathbf{q}$,假设文档向量为 $\mathbf{h}_1, \mathbf{h}_2, \ldots, \mathbf{h}_n$,则平均向量为:

$$ v_{\text{avg}} = \frac{1}{n+1}(q + \sum_{i=1}^{n} h_i) $$

其中 $n$ 为假设文档数量(默认3)。

2.2.6 与其他方法的对比 #

特性 Query-to-Document Assume Document Vectorization
生成内容 单个伪文档 多个角度假设文档
向量处理 直接使用增强查询 计算平均向量
语义融合 文本拼接 向量平均
角度覆盖 单一视角 多角度视角
计算成本 中等 较高(多文档生成+向量计算)
适用场景 简单扩展 复杂多角度查询

2.2.7. 优势与应用场景 #

优势:

  • 多角度语义融合:从不同角度理解查询意图
  • 提高检索精度:平均向量能更准确地匹配相关文档
  • 鲁棒性强:减少单一角度的偏差
  • 语义丰富:假设文档提供更多上下文信息

适用场景:

  • 复杂查询:需要多角度理解的查询
  • 专业领域:需要从不同视角分析的问题
  • 知识检索:需要全面覆盖相关知识的场景
  • 研究场景:需要学术和应用双重视角

2.2.8. 技术细节 #

  • 向量维度:所有向量必须具有相同的维度
  • 平均计算:使用 numpy 的 mean() 函数,axis=0 表示按列平均
  • 检索接口:直接使用 Chroma 的底层 _collection.query() 方法
  • 结果处理:将文档、距离、元数据打包为 Document 对象

该设计通过多角度假设文档和向量平均,在复杂查询场景下提供更准确的语义匹配,适用于需要多角度理解的 RAG 应用。

3. 问题分解策略(Sub-Question Decomposition) #

核心思想与流程
对于复杂的用户查询,直接用整体检索往往效果有限,因为原始问题结构复杂、包含若干子意图和层层递进的知识点。问题分解策略旨在用LLM先自动将复杂问题拆解(decompose)为多个更简单但相关的子问题,每个子问题都聚焦某一个点或层面。然后,分别用每个子问题对知识库进行检索,获得各自相关的文本块。最后,将各子问题检索得到的文档合并、去重,并输入LLM进行最终答案生成,确保全面覆盖原问题的所有细节。

整个流程如下图所示:

               ┌──────────┐
               │ 复杂查询 │
               └────┬─────┘
                    │(LLM分解)
        ┌───────────┼────────────┐
        │           │            │
 ┌──────▼─────┐┌────▼────┐┌─────▼─────┐
 │ 子问题1    ││ 子问题2 ││ 子问题3   │ ...
 └────┬───────┘└───┬─────┘└────┬──────┘
      │检索         │检索        │检索
 ┌────▼──┐     ┌────▼──┐    ┌───▼────┐
 │ 文档A │     │ 文档B │    │ 文档C  │ ...
 └─────┬─┘     └──────┬┘    └─┬──────┘
       └────┬─────┬───┴────┬──┘
            │合并+去重     │
            ▼
      ┌────────────┐
      │ 检索上下文 │
      └────┬───────┘
           │
    ┌──────▼─────┐
    │  LLM生成答案 │
    └──────┬─────┘
           │
     ┌─────▼──────┐
     │   最终答案  │
     └────────────┘

RAG流程对比说明

  • 传统RAG:用户原始大问题→embedding→检索→生成答案;对于多子意图大问题,检索效果易有限。
  • 分解优化RAG:用户原始大问题→LLM分解为多个子问题→每个子问题分别embedding检索→合并所有文本→生成答案。这样可以保持对每个知识点的覆盖与更细粒度的检索,提升答案完整性和准确率。

适用场景

  • 查询包含并列/递进的若干小问:如“请阐述A与B的区别以及各自应用”。
  • 复杂链式推理问题:如“请解释X的原理,并列举三个实际场景”。
  • 希望保证大问题的每个方面都被充分覆盖和回答。

关键技术点

  1. 问题拆解:调用LLM,给定复杂的问题,输出按顺序排列的子问题清单(如返回一个子问题列表)。
  2. 子问题检索:对每个子问题分别用embedding在向量库中相似度检索。
  3. 去重合并:将所有检索到的文本去重合并,防止重复内容。
  4. 回答生成:LLM以问题分解列表和全部检索材料为上下文,综合生成高质量长答案。

代码实现要点

  1. 封装SubQuestionDecompositionRetriever类,实现包含问题拆解、子问题批量检索与结果归并等流程的方法。
  2. 提供自动“问题-子问题生成”的Prompt设计、调用及结果解析。
  3. 检索时为每个文档标注其对应的子问题来源,便于后续分析和复现。

3.1 SubQuestionDecomposition.py #

# 导入类型注解所需的相关类型
from typing import List, Dict, Any, Optional
# 导入pydantic配置工具
from pydantic import ConfigDict
# 导入基础检索器基类
from langchain_core.retrievers import BaseRetriever
# 导入文档类型
from langchain_core.documents import Document
# 导入LLM基础类型
from langchain_core.language_models import BaseLanguageModel
# 导入提示词模板
from langchain_core.prompts import PromptTemplate
# 导入自定义llm对象
from llm import llm
# 导入获取向量库函数
from vector_store import get_vector_store
# 导入正则模块
import re

# 定义子问题分解检索器类,继承自BaseRetriever
class SubQuestionDecompositionRetriever(BaseRetriever):
    # 向量库对象
    vector_store: Any
    # 大语言模型对象
    llm: BaseLanguageModel
    # 检索每个子问题的文档数量
    k: int = 4
    # 子问题最少数量
    min_sub_questions: int = 3
    # 子问题最多数量
    max_sub_questions: int = 5
    # 可选的分解prompt模板字符串
    decomposition_prompt_template_str: Optional[str] = None
    # 配置模型参数,允许任意类型
    model_config = ConfigDict(arbitrary_types_allowed=True)

    # 获取问题分解prompt模板的方法
    def _get_decomposition_prompt_template(self) -> PromptTemplate:
        # 获取或默认生成用于子问题分解的Prompt模板字符串
        template = self.decomposition_prompt_template_str or (
            "请将以下复杂问题拆解为3-5个相关的子问题。\n"
            "每个子问题应该:\n"
            "1. 独立且具体\n2. 能够帮助回答原始问题\n3. 覆盖原始问题的不同方面\n\n"
            "原始问题:{query}\n请列出子问题(每行一个,以问号结尾):"
        )
        # 返回PromptTemplate对象
        return PromptTemplate(input_variables=["query"], template=template)

    # 调用LLM将查询分解为若干子问题
    def decompose_question(self, query: str) -> List[str]:
        # 构建分解prompt
        prompt = self._get_decomposition_prompt_template().format(query=query)
        # 调用LLM进行内容生成
        resp = self.llm.invoke(prompt).content.strip()
        sub_questions = []
        # 遍历LLM返回的每一行,解析子问题文本
        for line in resp.split('\n'):
            # 去除前缀编号等符号
            line = re.sub(r'^\d+[\.、]\s*|^[-•]\s*', '', line.strip())
            # 只保留含中文或英文问号的行
            if line and ('?' in line or '?' in line):
                sub_questions.append(line)
        # 限制子问题数量不超过最大值
        sub_questions = sub_questions[:self.max_sub_questions]
        # 若数量少于最小,回退为原始问题
        if len(sub_questions) < self.min_sub_questions:
            sub_questions = [query] if not sub_questions else sub_questions
        # 返回子问题列表
        return sub_questions

    # 针对子问题列表依次检索相似文档
    def retrieve_for_sub_questions(self, sub_questions: List[str], k: int) -> List[Document]:
        all_docs = []
        # 遍历每个子问题及下标
        for i, sub_q in enumerate(sub_questions, 1):
            # 对当前子问题进行向量相似性检索
            docs = self.vector_store.similarity_search_with_score(sub_q, k=k)
            # 将检索结果统一封装为Document对象,并增加子问题相关信息
            for doc, distance in docs:
                all_docs.append(Document(
                    page_content=doc.page_content,
                    metadata={**doc.metadata, "score": float(distance), "sub_question": sub_q, "sub_question_index": i}
                ))
        # 返回所有结果列表
        return all_docs

    # 对检索到的文档内容进行去重与归并
    def deduplicate_and_merge(self, all_docs: List[Document]) -> List[Document]:
        unique = {}
        # 遍历所有文档,使用内容做去重键,保留分数更小者
        for doc in all_docs:
            c = doc.page_content
            s = doc.metadata.get("score", float('inf'))
            if c not in unique or s < unique[c].metadata.get("score", float('inf')):
                unique[c] = doc
        # 对去重后的文档按分数升序排序
        docs = sorted(unique.values(), key=lambda x: x.metadata.get("score", float('inf')))
        # 返回排序后的文档列表
        return docs

    # 主入口:根据查询获取相关文档流程
    def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
        # 获取本次检索的k值(可动态传递)
        k = kwargs.get("k", self.k)
        # 子问题分解
        sqs = self.decompose_question(query)
        # 针对子问题检索
        docs = self.retrieve_for_sub_questions(sqs, k)
        # 去重合并并返回
        return self.deduplicate_and_merge(docs)

# 定义子问题分解RAG主流程类
class SubQuestionDecompositionRAG:
    # 初始化,注入检索器和LLM,支持自定义答案生成Prompt
    def __init__(self, retriever: SubQuestionDecompositionRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str]=None):
        self.retriever = retriever
        self.llm = llm
        # 设置答案生成Prompt模板,支持自定义或默认
        template = answer_prompt_template or (
            "已知以下相关信息:\n\n{context}\n\n原始问题:{query}\n\n"
            "该问题被分解为以下子问题:\n{sub_questions}\n\n"
            "请根据上述信息,全面回答原始问题。确保覆盖所有子问题的答案。\n\n答案:"
        )
        # 构建最终答案生成的PromptTemplate对象
        self.answer_prompt_template = PromptTemplate(
            input_variables=["context", "query", "sub_questions"], template=template
        )

    # 输入查询及可选k,返回答案及检索相关信息
    def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
        # 检索相关文档
        docs = self.retriever._get_relevant_documents(query, k=k)
        # 汇总所有子问题(去重)
        sub_questions = list(set([doc.metadata.get("sub_question","") for doc in docs if doc.metadata.get("sub_question")]))
        # 子问题格式化
        sub_questions_text = "\n".join([f"{i+1}. {sq}" for i, sq in enumerate(sub_questions)])
        # 合并上下文内容
        context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
        # 构建答案生成Prompt
        prompt = self.answer_prompt_template.format(context=context, query=query, sub_questions=sub_questions_text)
        # 调用LLM生成答案
        answer = self.llm.invoke(prompt).content.strip()
        # 整理并返回完整流程结果
        return {
            "query": query,
            "answer": answer,
            "sub_questions": sub_questions,
            "retrieved_documents": docs,
            "num_documents": len(docs)
        }

# 初始化向量库对象
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="sub_question_decomposition"
)
# 初始化子问题分解检索器
retriever = SubQuestionDecompositionRetriever(
    vector_store=vector_store,
    llm=llm,
    k=3,
    min_sub_questions=3,
    max_sub_questions=5
)
# 初始化RAG系统
rag_system = SubQuestionDecompositionRAG(
    retriever=retriever,
    llm=llm
)
# 定义示例文档内容列表
documents = [
    "人工智能(AI)是计算机科学的一个分支,旨在创建能够执行通常需要人类智能的任务的系统。"
    "机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习,而无需明确编程。"
    "深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。"
    "自然语言处理(NLP)是AI的另一个重要领域,专注于使计算机能够理解和生成人类语言。",
    "区块链技术是一种分布式账本技术,通过密码学方法确保数据的安全性和不可篡改性。"
    "比特币是第一个成功应用区块链技术的加密货币,它解决了数字货币的双重支付问题。"
    "以太坊是一个支持智能合约的区块链平台,允许开发者在其上构建去中心化应用(DApps)。"
    "智能合约是自动执行的合约,其条款直接写入代码中,无需第三方中介。",
    "量子计算是一种利用量子力学现象进行计算的新兴技术,具有巨大的计算潜力。"
    "量子比特(qubit)是量子计算的基本单位,与经典比特不同,它可以同时处于0和1的叠加态。"
    "量子纠缠是量子计算的关键特性,允许量子比特之间建立特殊的关联关系。"
    "量子计算在密码学、药物发现和优化问题等领域具有潜在的应用前景。",
]
# 向向量库中添加文档及元数据
vector_store.add_texts(
    documents,
    metadatas=[
        {"topic": "人工智能", "category": "科技"},
        {"topic": "区块链", "category": "科技"},
        {"topic": "量子计算", "category": "科技"},
    ]
)
# 指定一个复杂问题作为用户查询
query = "人工智能和机器学习的关系是什么?深度学习在其中的作用如何?它们在实际应用中有哪些典型案例?"
# 运行RAG流程并获取结果
result = rag_system.generate_answer(query, k=3)
# 打印分隔线
print("="*60)
print("最终结果")
print("="*60)
# 打印用户查询
print(f"\n用户查询: {result['query']}")
# 打印子问题列表及数量
print(f"\n分解的子问题 ({len(result['sub_questions'])} 个):")
for i, sq in enumerate(result['sub_questions'], 1):
    print(f"  {i}. {sq}")
# 打印生成的最终答案
print(f"\n生成的答案:\n{result['answer']}")
# 打印检索到的文档数量
print(f"\n检索到的文档数量: {result['num_documents']}")
# 打印部分检索到的文档的细节信息
print(f"\n检索到的文档详情:")
for i, doc in enumerate(result['retrieved_documents'][:3], 1):
    print(f"\n文档 {i} (分数: {doc.metadata.get('score', 'N/A'):.4f}):")
    print(f"  来源子问题: {doc.metadata.get('sub_question', 'N/A')[:60]}...")
    print(f"  内容: {doc.page_content[:150]}...")

3.2 执行流程 #

3.2.1 核心思想 #

子问题分解采用“分而治之”策略:

  • 将复杂问题拆解为多个独立的子问题
  • 对每个子问题分别进行向量检索
  • 合并所有子问题的检索结果
  • 对结果去重并按分数排序
  • 基于所有子问题的检索结果生成最终答案

3.2.2 执行流程 #

阶段一:初始化

# 1. 获取向量存储实例
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="sub_question_decomposition"
)

# 2. 创建子问题分解检索器
retriever = SubQuestionDecompositionRetriever(
    vector_store=vector_store,
    llm=llm,
    k=3,                      # 每个子问题检索3个文档
    min_sub_questions=3,       # 最少3个子问题
    max_sub_questions=5       # 最多5个子问题
)

# 3. 创建RAG系统
rag_system = SubQuestionDecompositionRAG(
    retriever=retriever,
    llm=llm
)

初始化时:

  • 创建向量存储实例
  • 创建检索器,配置 LLM、向量库、k 值和子问题数量范围
  • 创建 RAG 系统,用于完整流程

阶段二:文档索引

# 添加文档到向量库
vector_store.add_texts(
    documents,
    metadatas=[...]
)

索引过程:

  • 将示例文档添加到向量库
  • 每个文档附带元数据(topic、category)
  • 文档会被向量化并存储

阶段三:查询处理

query = "人工智能和机器学习的关系是什么?深度学习在其中的作用如何?它们在实际应用中有哪些典型案例?"
result = rag_system.generate_answer(query, k=3)

完整流程:

  1. 用户提交复杂查询
  2. 子问题分解:
    • 调用 decompose_question(query)
    • 使用 LLM 将复杂问题拆解为 3-5 个子问题
    • 解析并验证子问题数量
  3. 子问题检索:
    • 调用 retrieve_for_sub_questions(sqs, k)
    • 对每个子问题分别进行向量检索
    • 为每个检索结果添加子问题元数据
  4. 去重合并:
    • 调用 deduplicate_and_merge(all_docs)
    • 按文档内容去重,保留最佳分数
    • 按分数排序
  5. 生成答案:
    • 整合所有检索文档作为上下文
    • 提取所有子问题并格式化
    • 使用提示词模板构建最终 prompt
    • LLM 生成答案
  6. 返回结果:
    • 包含查询、答案、子问题列表、检索文档和数量

3.3 类图 #

classDiagram class SubQuestionDecompositionRetriever { -vector_store: Any -llm: BaseLanguageModel -k: int -min_sub_questions: int -max_sub_questions: int -decomposition_prompt_template_str: Optional[str] +_get_decomposition_prompt_template() PromptTemplate +decompose_question(query: str) List[str] +retrieve_for_sub_questions(sub_questions: List[str], k: int) List[Document] +deduplicate_and_merge(all_docs: List[Document]) List[Document] +_get_relevant_documents(query: str, **kwargs) List[Document] } class SubQuestionDecompositionRAG { -retriever: SubQuestionDecompositionRetriever -llm: BaseLanguageModel -answer_prompt_template: PromptTemplate +__init__(retriever, llm, answer_prompt_template) +generate_answer(query: str, k: int) Dict[str, Any] } class BaseRetriever { <<abstract>> +invoke(query: str) List[Document] } class PromptTemplate { +format(**kwargs) str } class Document { +page_content: str +metadata: Dict } class VectorStore { <<interface>> +add_texts(texts: List[str], metadatas: List[Dict]) +similarity_search_with_score(query: str, k: int) List[Tuple[Document, float]] } class BaseLanguageModel { <<interface>> +invoke(prompt: str) AIMessage } SubQuestionDecompositionRetriever --|> BaseRetriever SubQuestionDecompositionRAG --> SubQuestionDecompositionRetriever SubQuestionDecompositionRAG --> BaseLanguageModel SubQuestionDecompositionRetriever --> BaseLanguageModel SubQuestionDecompositionRetriever --> PromptTemplate SubQuestionDecompositionRetriever --> VectorStore SubQuestionDecompositionRetriever ..> Document : creates VectorStore ..> Document : returns

3.4 时序图 #

3.4.1 完整RAG流程时序图 #

sequenceDiagram participant User as 用户 participant RAG as SubQuestionDecompositionRAG participant Retriever as SubQuestionDecompositionRetriever participant PromptTemplate as PromptTemplate participant LLM as BaseLanguageModel participant VectorStore as VectorStore User->>RAG: generate_answer(复杂查询, k=3) RAG->>Retriever: _get_relevant_documents(query, k=3) Note over Retriever: 步骤1: 子问题分解 Retriever->>Retriever: _get_decomposition_prompt_template() Retriever->>PromptTemplate: format(query=复杂查询) PromptTemplate-->>Retriever: 返回完整prompt Retriever->>LLM: invoke(prompt) Note over LLM: 将复杂问题拆解为<br/>3-5个子问题 LLM-->>Retriever: 返回子问题文本 Retriever->>Retriever: 解析子问题<br/>(去除编号,提取问号行) Retriever->>Retriever: 验证子问题数量<br/>(3-5个) Retriever->>Retriever: 返回子问题列表 Note over Retriever: 步骤2: 子问题检索 Retriever->>Retriever: retrieve_for_sub_questions(sqs, k=3) loop 对每个子问题 Retriever->>VectorStore: similarity_search_with_score(子问题, k=3) Note over VectorStore: 对当前子问题进行<br/>向量相似度检索 VectorStore-->>Retriever: 返回3个相关文档(带分数) loop 对每个检索结果 Retriever->>Retriever: 构建Document对象<br/>(添加sub_question元数据) end end Note over Retriever: 步骤3: 去重合并 Retriever->>Retriever: deduplicate_and_merge(all_docs) Note over Retriever: 按内容去重<br/>保留最佳分数<br/>按分数排序 Retriever-->>RAG: 返回去重后的文档列表 Note over RAG: 步骤4: 生成答案 RAG->>RAG: 提取所有子问题并格式化 RAG->>RAG: 整合文档内容为上下文 RAG->>PromptTemplate: format(context, query, sub_questions) PromptTemplate-->>RAG: 返回完整prompt RAG->>LLM: invoke(prompt) Note over LLM: 基于所有子问题的<br/>检索结果生成答案 LLM-->>RAG: 返回答案内容 RAG-->>User: 返回结果字典<br/>(query, answer, sub_questions, documents)

3.4.2 子问题分解详细流程 #

sequenceDiagram participant Retriever as SubQuestionDecompositionRetriever participant Template as PromptTemplate participant LLM as BaseLanguageModel Note over Retriever: decompose_question(query) Retriever->>Retriever: _get_decomposition_prompt_template() Note over Retriever: 获取或创建分解模板<br/>"请将以下复杂问题拆解为3-5个相关的子问题..." Retriever->>Template: format(query=复杂查询) Template-->>Retriever: 返回完整prompt Retriever->>LLM: invoke(prompt) Note over LLM: 生成子问题列表<br/>(每行一个,以问号结尾) LLM-->>Retriever: 返回子问题文本 Retriever->>Retriever: 按行分割响应文本 loop 遍历每一行 Retriever->>Retriever: 去除前缀编号<br/>(如"1. "、"1、"、"- "等) Retriever->>Retriever: 检查是否包含问号(?或?) alt 包含问号 Retriever->>Retriever: 添加到子问题列表 end end Retriever->>Retriever: 限制数量为max_sub_questions alt 子问题数量 < min_sub_questions Retriever->>Retriever: 回退为原始查询 end Retriever-->>Retriever: 返回子问题列表<br/>[sq1, sq2, sq3, ...]

3.4.3 子问题检索与去重详细流程 #

sequenceDiagram participant Retriever as SubQuestionDecompositionRetriever participant VectorStore as VectorStore Note over Retriever: retrieve_for_sub_questions(sqs, k=3) loop 对每个子问题 (sq1, sq2, sq3, ...) Retriever->>VectorStore: similarity_search_with_score(子问题i, k=3) Note over VectorStore: 对子问题i进行<br/>向量相似度检索 VectorStore-->>Retriever: 返回3个文档(带分数) loop 对每个检索结果 Retriever->>Retriever: 构建Document对象<br/>添加元数据:<br/>- score: 分数<br/>- sub_question: 子问题<br/>- sub_question_index: 索引 Retriever->>Retriever: 添加到all_docs列表 end end Note over Retriever: 现在all_docs包含所有子问题的检索结果<br/>(可能有重复文档) Retriever->>Retriever: deduplicate_and_merge(all_docs) Note over Retriever: 去重逻辑 loop 遍历所有文档 Retriever->>Retriever: 获取文档内容作为键 alt 文档首次出现 Retriever->>Retriever: 添加到unique字典 else 文档已存在 alt 当前分数 < 已有分数 Retriever->>Retriever: 更新为当前文档<br/>(保留更好的分数) end end end Retriever->>Retriever: 按分数升序排序 Retriever-->>Retriever: 返回去重并排序后的文档列表

3.5 关键设计要点 #

1. 子问题分解流程

复杂查询
    ↓
LLM分解为子问题 (3-5个)
    ↓
解析并验证子问题
    ↓
对每个子问题分别检索
    ↓
合并所有检索结果
    ↓
去重并排序
    ↓
返回最终文档列表

2. 子问题分解示例

原始查询: "人工智能和机器学习的关系是什么?深度学习在其中的作用如何?它们在实际应用中有哪些典型案例?"

分解后的子问题:
1. 人工智能和机器学习的关系是什么?
2. 深度学习在人工智能中的作用如何?
3. 机器学习和深度学习在实际应用中有哪些典型案例?
4. 人工智能、机器学习和深度学习之间的层次关系是什么?
5. 这些技术在实际应用中的具体案例有哪些?

3. 去重策略

  • 去重键:使用文档内容(page_content)作为唯一标识
  • 分数选择:如果同一文档被多个子问题检索到,保留最佳分数(距离最小)
  • 排序:去重后按分数升序排序,最相关的文档在前

示例:

子问题1检索到: [docA(score=0.2), docB(score=0.3), docC(score=0.4)]
子问题2检索到: [docA(score=0.1), docD(score=0.3), docE(score=0.5)]

去重后: [docA(score=0.1), docB(score=0.3), docC(score=0.4), docD(score=0.3), docE(score=0.5)]

4. 元数据设计

检索返回的 Document 对象包含:

  • score:相似度分数(距离,越小越相似)
  • sub_question:来源子问题
  • sub_question_index:子问题索引(1, 2, 3...)
  • 继承原始文档的元数据(topic、category 等)

5. 答案生成模板

默认答案生成模板:

已知以下相关信息:

{context}

原始问题:{query}

该问题被分解为以下子问题:
{sub_questions}

请根据上述信息,全面回答原始问题。确保覆盖所有子问题的答案。

答案:

该模板:

  • 包含所有检索文档作为上下文
  • 明确列出所有子问题
  • 要求覆盖所有子问题的答案

3.6. 与其他方法的对比 #

特性 传统检索 子问题分解
查询处理 直接使用原始查询 拆解为多个子问题
检索方式 单次检索 多次检索(每个子问题)
结果处理 直接返回 去重合并
适用场景 简单查询 复杂多部分查询
覆盖度 可能遗漏 更全面覆盖

3.7. 优势与应用场景 #

优势:

  • 全面覆盖:通过子问题确保覆盖查询的各个方面
  • 提高精度:每个子问题独立检索,更精准匹配
  • 处理复杂查询:适合多部分、多角度的复杂问题
  • 结果去重:自动处理重复文档,保留最佳分数

适用场景:

  • 复杂问题:包含多个子问题的复合查询
  • 多角度查询:需要从不同角度回答的问题
  • 综合分析:需要综合多个方面信息的查询
  • 研究场景:需要全面覆盖相关知识的场景

3.8. 技术细节 #

  • 子问题数量控制:
    • 最少 min_sub_questions 个(默认3)
    • 最多 max_sub_questions 个(默认5)
    • 如果少于最小值,回退为原始查询
  • 子问题解析:
    • 去除编号前缀(如 "1. "、"1、"、"- ")
    • 只保留包含问号的行
  • 去重算法:
    • 使用字典以内容为键
    • 保留最佳分数(距离最小)

该设计通过“分而治之”策略,在复杂查询场景下提供更全面的检索覆盖,适用于需要多角度、多部分回答的 RAG 应用。

4. 多角度查询重写(Query Rewriting) #

传统的RAG(Retrieval-Augmented Generation,检索增强生成)流程,直接用用户的自然语言查询去做embedding检索,其效果很依赖用户的表述方式:同一个查询如果表达方式不同,底层向量检索的命中内容可能有很大区别,导致召回信息不完整、答案片面甚至遗漏关键信息。尤其是

  • 用户表达简略,未穷举出所有意图的措辞
  • 查询存在同义/多义,或有不同专业表述(例如「人工智能」VS「AI」VS「智能系统」)
  • 用户问题本身较模糊或抽象

为此,可以通过查询重写(Query Rewriting)技术,将用户原始问题自动改写为多个表述不同、风格多样的查询版本,从而用多角度、丰富的语义视角去覆盖底层知识库的潜在高相关片段,大幅提升检索的召回率和结果多样性,实现信息“组团”式召回与补全。

步骤说明与关键技术点如下:

  1. 多版本查询生成

    • 调用LLM,输入一个用户查询,让大模型帮我们生成3-5个不同措辞、风格或视角的查询表达。
    • 要求每个版本都“忠于核心意图”但具备表述差异,可以参考如下Prompt引导:

      请将以下查询改写为3-5个不同表达方式的查询版本。
      每个查询版本应该:
      1. 保持原始查询的核心意图
      2. 使用不同的措辞和表达方式
      3. 可以从不同角度(正式、通俗、专业、简洁等)表达
      4. 确保所有版本都能检索到相关信息
      
      原始查询:{query}
      请列出查询版本(每行一个):
  2. 多版本批量检索

    • 针对上述每一个改写后的查询,分别进行embedding后在向量库中检索K个高相关文档。
    • 可为每条检索结果标注其对应的“查询版本”来源,便于后续分析。
  3. 去重合并(Deduplicate & Merge)

    • 很多时候,不同的查询版本可能检索到内容大致相同的文本,因此需要对结果按内容hash或embedding聚类做去重,只保留分数最高最相关的若干条。
  4. 答案生成

    • 可以汇总所有检索到的文档内容,附带各自的来源查询版本,统一拼接成上下文,再用LLM生成高质量响应答案。

优势总结

  • 极大提升了召回率和检索多样性,最大化覆盖用户表达中的隐含意图和边界情况
  • 对于表达不精准、领域同义词多、知识表述丰富的任务有巨大效果(如医学、法务、科技百科等)
  • 还能为后续调优改写prompt与分析检索“盲区”提供基础

典型适用场景:

  • 用户写的问题比较通用/宽泛/有歧义时
  • 知识库表述丰富,内容多样,传统直检检索经常漏掉重要信息时
  • 希望提高长尾召回、提升答案全面性时

方法小结:

  • Query Rewriting ≈ 多视角“扩展式”检索 ≈ 类似于“查询增强(Query Expansion)”+“多表达聚合”
  • 自动化结合大模型和向量检索,可插拔适配各种RAG/检索框架

4.1 QueryRewriting.py #

# 导入类型注解所需的相关类型
from typing import List, Dict, Any, Optional
# 导入pydantic配置工具
from pydantic import ConfigDict
# 导入基础检索器基类
from langchain_core.retrievers import BaseRetriever
# 导入文档类型
from langchain_core.documents import Document
# 导入LLM基础类型
from langchain_core.language_models import BaseLanguageModel
# 导入提示词模板
from langchain_core.prompts import PromptTemplate
# 导入自定义llm对象
from llm import llm
# 导入获取向量库函数
from vector_store import get_vector_store
# 导入正则模块
import re

# 定义查询重写检索器类,继承自BaseRetriever
class QueryRewritingRetriever(BaseRetriever):
    # 向量库对象
    vector_store: Any
    # 大语言模型对象
    llm: BaseLanguageModel
    # 检索每个查询版本的文档数量
    k: int = 4
    # 查询版本最少数量
    min_versions: int = 3
    # 查询版本最多数量
    max_versions: int = 5
    # 可选的查询重写prompt模板字符串
    rewriting_prompt_template_str: Optional[str] = None
    # 配置模型参数,允许任意类型
    model_config = ConfigDict(arbitrary_types_allowed=True)

    # 获取查询重写prompt模板的方法
    def _get_rewriting_prompt_template(self) -> PromptTemplate:
        # 获取或默认生成用于查询重写的Prompt模板字符串
        template = self.rewriting_prompt_template_str or (
            "请将以下查询改写为3-5个不同表达方式的查询版本。\n"
            "每个查询版本应该:\n"
            "1. 保持原始查询的核心意图\n"
            "2. 使用不同的措辞和表达方式\n"
            "3. 可以从不同角度(正式、通俗、专业、简洁等)表达\n"
            "4. 确保所有版本都能检索到相关信息\n\n"
            "原始查询:{query}\n请列出查询版本(每行一个):"
        )
        # 返回PromptTemplate对象
        return PromptTemplate(input_variables=["query"], template=template)

    # 调用LLM生成多个不同表达方式的查询版本
    def rewrite_query(self, query: str) -> List[str]:
        # 构建重写prompt
        prompt = self._get_rewriting_prompt_template().format(query=query)
        # 调用LLM进行内容生成
        resp = self.llm.invoke(prompt).content.strip()
        rewritten_queries = []
        # 遍历LLM返回的每一行,解析查询版本文本
        for line in resp.split('\n'):
            # 去除前缀编号等符号
            line = re.sub(r'^\d+[\.、]\s*|^[-•]\s*', '', line.strip())
            # 保留非空行作为查询版本
            if line:
                rewritten_queries.append(line)
        # 限制查询版本数量不超过最大值
        rewritten_queries = rewritten_queries[:self.max_versions]
        # 若数量少于最小,至少包含原始查询
        if len(rewritten_queries) < self.min_versions:
            if query not in rewritten_queries:
                rewritten_queries.insert(0, query)
        # 返回查询版本列表
        return rewritten_queries

    # 针对查询版本列表依次检索相似文档
    def retrieve_for_queries(self, queries: List[str], k: int) -> List[Document]:
        all_docs = []
        # 遍历每个查询版本及下标
        for i, query_version in enumerate(queries, 1):
            # 对当前查询版本进行向量相似性检索
            docs = self.vector_store.similarity_search_with_score(query_version, k=k)
            # 将检索结果统一封装为Document对象,并增加查询版本相关信息
            for doc, distance in docs:
                all_docs.append(Document(
                    page_content=doc.page_content,
                    metadata={**doc.metadata, "score": float(distance), "query_version": query_version, "version_index": i}
                ))
        # 返回所有结果列表
        return all_docs

    # 对检索到的文档内容进行去重与归并
    def deduplicate_and_merge(self, all_docs: List[Document]) -> List[Document]:
        unique = {}
        # 遍历所有文档,使用内容做去重键,保留分数更小者
        for doc in all_docs:
            c = doc.page_content
            s = doc.metadata.get("score", float('inf'))
            if c not in unique or s < unique[c].metadata.get("score", float('inf')):
                unique[c] = doc
        # 对去重后的文档按分数升序排序
        docs = sorted(unique.values(), key=lambda x: x.metadata.get("score", float('inf')))
        # 返回排序后的文档列表
        return docs

    # 主入口:根据查询获取相关文档流程
    def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
        # 获取本次检索的k值(可动态传递)
        k = kwargs.get("k", self.k)
        # 查询重写:生成多个查询版本
        query_versions = self.rewrite_query(query)
        # 针对查询版本检索
        docs = self.retrieve_for_queries(query_versions, k)
        # 去重合并并返回
        return self.deduplicate_and_merge(docs)

# 定义查询重写RAG主流程类
class QueryRewritingRAG:
    # 初始化,注入检索器和LLM,支持自定义答案生成Prompt
    def __init__(self, retriever: QueryRewritingRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str]=None):
        self.retriever = retriever
        self.llm = llm
        # 设置答案生成Prompt模板,支持自定义或默认
        template = answer_prompt_template or (
            "已知以下相关信息:\n\n{context}\n\n"
            "原始查询:{query}\n\n"
            "该查询被重写为以下版本:\n{query_versions}\n\n"
            "请根据上述信息,准确全面地回答原始查询。\n\n答案:"
        )
        # 构建最终答案生成的PromptTemplate对象
        self.answer_prompt_template = PromptTemplate(
            input_variables=["context", "query", "query_versions"], template=template
        )

    # 输入查询及可选k,返回答案及检索相关信息
    def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
        # 检索相关文档
        docs = self.retriever._get_relevant_documents(query, k=k)
        # 汇总所有查询版本(去重)
        query_versions = list(set([doc.metadata.get("query_version","") for doc in docs if doc.metadata.get("query_version")]))
        # 查询版本格式化
        query_versions_text = "\n".join([f"{i+1}. {qv}" for i, qv in enumerate(query_versions)])
        # 合并上下文内容
        context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
        # 构建答案生成Prompt
        prompt = self.answer_prompt_template.format(context=context, query=query, query_versions=query_versions_text)
        # 调用LLM生成答案
        answer = self.llm.invoke(prompt).content.strip()
        # 整理并返回完整流程结果
        return {
            "query": query,
            "answer": answer,
            "query_versions": query_versions,
            "retrieved_documents": docs,
            "num_documents": len(docs)
        }

# 初始化向量库对象
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="query_rewriting"
)
# 初始化查询重写检索器
retriever = QueryRewritingRetriever(
    vector_store=vector_store,
    llm=llm,
    k=3,
    min_versions=3,
    max_versions=5
)
# 初始化RAG系统
rag_system = QueryRewritingRAG(
    retriever=retriever,
    llm=llm
)
# 定义示例文档内容列表
documents = [
    "人工智能(AI)是计算机科学的一个分支,旨在创建能够执行通常需要人类智能的任务的系统。"
    "机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习,而无需明确编程。"
    "深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。"
    "自然语言处理(NLP)是AI的另一个重要领域,专注于使计算机能够理解和生成人类语言。",
    "区块链技术是一种分布式账本技术,通过密码学方法确保数据的安全性和不可篡改性。"
    "比特币是第一个成功应用区块链技术的加密货币,它解决了数字货币的双重支付问题。"
    "以太坊是一个支持智能合约的区块链平台,允许开发者在其上构建去中心化应用(DApps)。"
    "智能合约是自动执行的合约,其条款直接写入代码中,无需第三方中介。",
    "量子计算是一种利用量子力学现象进行计算的新兴技术,具有巨大的计算潜力。"
    "量子比特(qubit)是量子计算的基本单位,与经典比特不同,它可以同时处于0和1的叠加态。"
    "量子纠缠是量子计算的关键特性,允许量子比特之间建立特殊的关联关系。"
    "量子计算在密码学、药物发现和优化问题等领域具有潜在的应用前景。",
]
# 向向量库中添加文档及元数据
vector_store.add_texts(
    documents,
    metadatas=[
        {"topic": "人工智能", "category": "科技"},
        {"topic": "区块链", "category": "科技"},
        {"topic": "量子计算", "category": "科技"},
    ]
)
# 指定一个查询作为用户输入
query = "机器学习是什么?"
# 运行RAG流程并获取结果
result = rag_system.generate_answer(query, k=3)
# 打印分隔线
print("="*60)
print("最终结果")
print("="*60)
# 打印用户查询
print(f"\n用户查询: {result['query']}")
# 打印查询版本列表及数量
print(f"\n生成的查询版本 ({len(result['query_versions'])} 个):")
for i, qv in enumerate(result['query_versions'], 1):
    print(f"  {i}. {qv}")
# 打印生成的最终答案
print(f"\n生成的答案:\n{result['answer']}")
# 打印检索到的文档数量
print(f"\n检索到的文档数量: {result['num_documents']}")
# 打印部分检索到的文档的细节信息
print(f"\n检索到的文档详情:")
for i, doc in enumerate(result['retrieved_documents'][:3], 1):
    print(f"\n文档 {i} (分数: {doc.metadata.get('score', 'N/A'):.4f}):")
    print(f"  来源查询版本: {doc.metadata.get('query_version', 'N/A')[:60]}...")
    print(f"  内容: {doc.page_content[:150]}...")

4.2 执行过程 #

4.2.1 核心思想 #

查询重写采用“多版本查询”策略:

  • 将用户查询改写为多个不同表达方式的查询版本
  • 每个版本保持核心意图,但使用不同措辞和角度
  • 对每个查询版本分别进行向量检索
  • 合并所有查询版本的检索结果
  • 对结果去重并按分数排序
  • 基于所有查询版本的检索结果生成最终答案

4.2.2 执行流程 #

阶段一:初始化

# 1. 获取向量存储实例
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="query_rewriting"
)

# 2. 创建查询重写检索器
retriever = QueryRewritingRetriever(
    vector_store=vector_store,
    llm=llm,
    k=3,                      # 每个查询版本检索3个文档
    min_versions=3,           # 最少3个查询版本
    max_versions=5            # 最多5个查询版本
)

# 3. 创建RAG系统
rag_system = QueryRewritingRAG(
    retriever=retriever,
    llm=llm
)

初始化时:

  • 创建向量存储实例
  • 创建检索器,配置 LLM、向量库、k 值和查询版本数量范围
  • 创建 RAG 系统,用于完整流程

阶段二:文档索引

# 添加文档到向量库
vector_store.add_texts(
    documents,
    metadatas=[...]
)

索引过程:

  • 将示例文档添加到向量库
  • 每个文档附带元数据(topic、category)
  • 文档会被向量化并存储

阶段三:查询处理

query = "机器学习是什么?"
result = rag_system.generate_answer(query, k=3)

完整流程:

  1. 用户提交查询
  2. 查询重写:
    • 调用 rewrite_query(query)
    • 使用 LLM 将查询改写为 3-5 个不同表达方式的版本
    • 解析并验证查询版本数量
  3. 多版本检索:
    • 调用 retrieve_for_queries(query_versions, k)
    • 对每个查询版本分别进行向量检索
    • 为每个检索结果添加查询版本元数据
  4. 去重合并:
    • 调用 deduplicate_and_merge(all_docs)
    • 按文档内容去重,保留最佳分数
    • 按分数排序
  5. 生成答案:
    • 整合所有检索文档作为上下文
    • 提取所有查询版本并格式化
    • 使用提示词模板构建最终 prompt
    • LLM 生成答案
  6. 返回结果:
    • 包含查询、答案、查询版本列表、检索文档和数量

4.2.3 类图 #

classDiagram class QueryRewritingRetriever { -vector_store: Any -llm: BaseLanguageModel -k: int -min_versions: int -max_versions: int -rewriting_prompt_template_str: Optional[str] +_get_rewriting_prompt_template() PromptTemplate +rewrite_query(query: str) List[str] +retrieve_for_queries(queries: List[str], k: int) List[Document] +deduplicate_and_merge(all_docs: List[Document]) List[Document] +_get_relevant_documents(query: str, **kwargs) List[Document] } class QueryRewritingRAG { -retriever: QueryRewritingRetriever -llm: BaseLanguageModel -answer_prompt_template: PromptTemplate +__init__(retriever, llm, answer_prompt_template) +generate_answer(query: str, k: int) Dict[str, Any] } class BaseRetriever { <<abstract>> +invoke(query: str) List[Document] } class PromptTemplate { +format(**kwargs) str } class Document { +page_content: str +metadata: Dict } class VectorStore { <<interface>> +add_texts(texts: List[str], metadatas: List[Dict]) +similarity_search_with_score(query: str, k: int) List[Tuple[Document, float]] } class BaseLanguageModel { <<interface>> +invoke(prompt: str) AIMessage } QueryRewritingRetriever --|> BaseRetriever QueryRewritingRAG --> QueryRewritingRetriever QueryRewritingRAG --> BaseLanguageModel QueryRewritingRetriever --> BaseLanguageModel QueryRewritingRetriever --> PromptTemplate QueryRewritingRetriever --> VectorStore QueryRewritingRetriever ..> Document : creates VectorStore ..> Document : returns

4.2.4 时序图 #

4.2.4.1 完整RAG流程时序图 #
sequenceDiagram participant User as 用户 participant RAG as QueryRewritingRAG participant Retriever as QueryRewritingRetriever participant PromptTemplate as PromptTemplate participant LLM as BaseLanguageModel participant VectorStore as VectorStore User->>RAG: generate_answer("机器学习是什么?", k=3) RAG->>Retriever: _get_relevant_documents(query, k=3) Note over Retriever: 步骤1: 查询重写 Retriever->>Retriever: _get_rewriting_prompt_template() Retriever->>PromptTemplate: format(query=用户查询) PromptTemplate-->>Retriever: 返回完整prompt Retriever->>LLM: invoke(prompt) Note over LLM: 将查询改写为<br/>3-5个不同表达方式的版本 LLM-->>Retriever: 返回查询版本文本 Retriever->>Retriever: 解析查询版本<br/>(去除编号,保留非空行) Retriever->>Retriever: 验证查询版本数量<br/>(3-5个,不足则添加原始查询) Retriever->>Retriever: 返回查询版本列表 Note over Retriever: 步骤2: 多版本检索 Retriever->>Retriever: retrieve_for_queries(query_versions, k=3) loop 对每个查询版本 Retriever->>VectorStore: similarity_search_with_score(查询版本, k=3) Note over VectorStore: 对当前查询版本进行<br/>向量相似度检索 VectorStore-->>Retriever: 返回3个相关文档(带分数) loop 对每个检索结果 Retriever->>Retriever: 构建Document对象<br/>(添加query_version元数据) end end Note over Retriever: 步骤3: 去重合并 Retriever->>Retriever: deduplicate_and_merge(all_docs) Note over Retriever: 按内容去重<br/>保留最佳分数<br/>按分数排序 Retriever-->>RAG: 返回去重后的文档列表 Note over RAG: 步骤4: 生成答案 RAG->>RAG: 提取所有查询版本并格式化 RAG->>RAG: 整合文档内容为上下文 RAG->>PromptTemplate: format(context, query, query_versions) PromptTemplate-->>RAG: 返回完整prompt RAG->>LLM: invoke(prompt) Note over LLM: 基于所有查询版本的<br/>检索结果生成答案 LLM-->>RAG: 返回答案内容 RAG-->>User: 返回结果字典<br/>(query, answer, query_versions, documents)
4.2.4.2 查询重写详细流程 #
sequenceDiagram participant Retriever as QueryRewritingRetriever participant Template as PromptTemplate participant LLM as BaseLanguageModel Note over Retriever: rewrite_query(query) Retriever->>Retriever: _get_rewriting_prompt_template() Note over Retriever: 获取或创建重写模板<br/>"请将以下查询改写为3-5个不同表达方式的查询版本..." Retriever->>Template: format(query=用户查询) Template-->>Retriever: 返回完整prompt Retriever->>LLM: invoke(prompt) Note over LLM: 生成多个查询版本<br/>(不同措辞、角度、表达方式) LLM-->>Retriever: 返回查询版本文本 Retriever->>Retriever: 按行分割响应文本 loop 遍历每一行 Retriever->>Retriever: 去除前缀编号<br/>(如"1. "、"1、"、"- "等) Retriever->>Retriever: 检查是否非空 alt 非空行 Retriever->>Retriever: 添加到查询版本列表 end end Retriever->>Retriever: 限制数量为max_versions alt 查询版本数量 < min_versions alt 原始查询不在列表中 Retriever->>Retriever: 在开头插入原始查询 end end Retriever-->>Retriever: 返回查询版本列表<br/>[qv1, qv2, qv3, ...]
4.2.4.3 多版本检索与去重详细流程 #
sequenceDiagram participant Retriever as QueryRewritingRetriever participant VectorStore as VectorStore Note over Retriever: retrieve_for_queries(query_versions, k=3) loop 对每个查询版本 (qv1, qv2, qv3, ...) Retriever->>VectorStore: similarity_search_with_score(查询版本i, k=3) Note over VectorStore: 对查询版本i进行<br/>向量相似度检索 VectorStore-->>Retriever: 返回3个文档(带分数) loop 对每个检索结果 Retriever->>Retriever: 构建Document对象<br/>添加元数据:<br/>- score: 分数<br/>- query_version: 查询版本<br/>- version_index: 版本索引 Retriever->>Retriever: 添加到all_docs列表 end end Note over Retriever: 现在all_docs包含所有查询版本的检索结果<br/>(可能有重复文档) Retriever->>Retriever: deduplicate_and_merge(all_docs) Note over Retriever: 去重逻辑 loop 遍历所有文档 Retriever->>Retriever: 获取文档内容作为键 alt 文档首次出现 Retriever->>Retriever: 添加到unique字典 else 文档已存在 alt 当前分数 < 已有分数 Retriever->>Retriever: 更新为当前文档<br/>(保留更好的分数) end end end Retriever->>Retriever: 按分数升序排序 Retriever-->>Retriever: 返回去重并排序后的文档列表

4.2.5 关键设计要点 #

1. 查询重写流程

用户查询
    ↓
LLM改写为多个版本 (3-5个)
    ↓
解析并验证查询版本
    ↓
对每个查询版本分别检索
    ↓
合并所有检索结果
    ↓
去重并排序
    ↓
返回最终文档列表

2. 查询重写示例

原始查询: "机器学习是什么?"

重写后的查询版本:
1. 机器学习的定义是什么?
2. 什么是机器学习技术?
3. 机器学习的概念和原理
4. 如何理解机器学习?
5. 机器学习的基本含义

不同表达方式:

  • 正式表达:"机器学习的定义是什么?"
  • 通俗表达:"什么是机器学习技术?"
  • 专业表达:"机器学习的概念和原理"
  • 简洁表达:"如何理解机器学习?"
  • 基础表达:"机器学习的基本含义"

3. 去重策略

  • 去重键:使用文档内容(page_content)作为唯一标识
  • 分数选择:如果同一文档被多个查询版本检索到,保留最佳分数(距离最小)
  • 排序:去重后按分数升序排序,最相关的文档在前

示例:

查询版本1检索到: [docA(score=0.2), docB(score=0.3), docC(score=0.4)]
查询版本2检索到: [docA(score=0.15), docD(score=0.3), docE(score=0.5)]
查询版本3检索到: [docB(score=0.25), docF(score=0.35), docG(score=0.45)]

去重后: [docA(score=0.15), docB(score=0.25), docC(score=0.4), docD(score=0.3), docE(score=0.5), docF(score=0.35), docG(score=0.45)]

4. 元数据设计

检索返回的 Document 对象包含:

  • score:相似度分数(距离,越小越相似)
  • query_version:来源查询版本
  • version_index:查询版本索引(1, 2, 3...)
  • 继承原始文档的元数据(topic、category 等)

5. 答案生成模板

默认答案生成模板:

已知以下相关信息:

{context}

原始查询:{query}

该查询被重写为以下版本:
{query_versions}

请根据上述信息,准确全面地回答原始查询。

答案:

该模板:

  • 包含所有检索文档作为上下文
  • 明确列出所有查询版本
  • 要求准确全面地回答原始查询

4.2.6. 与子问题分解的对比 #

特性 子问题分解 查询重写
处理方式 拆解为子问题 改写为不同表达方式
核心意图 覆盖不同方面 保持核心意图不变
表达方式 不同角度的问题 不同措辞和角度
适用场景 复杂多部分查询 单一查询的不同表达
检索策略 每个子问题独立检索 每个版本独立检索

4.2.7. 优势与应用场景 #

优势:

  • 提高召回率:不同表达方式可能匹配到不同文档
  • 处理表达多样性:适应不同的查询表达习惯
  • 增强鲁棒性:减少单一表达方式可能带来的遗漏
  • 结果去重:自动处理重复文档,保留最佳分数

适用场景:

  • 用户表达多样性:不同用户可能用不同方式表达相同意图
  • 术语变体:同一概念的不同术语表达
  • 语言风格差异:正式、通俗、专业等不同风格
  • 跨语言场景:需要处理不同语言表达方式

4.2.8. 技术细节 #

  • 查询版本数量控制:
    • 最少 min_versions 个(默认3)
    • 最多 max_versions 个(默认5)
    • 如果少于最小值,确保包含原始查询
  • 查询版本解析:
    • 去除编号前缀(如 "1. "、"1、"、"- ")
    • 保留所有非空行作为查询版本
  • 去重算法:
    • 使用字典以内容为键
    • 保留最佳分数(距离最小)

4.2.9. 查询重写的要求 #

根据提示词模板,每个查询版本应该:

  1. 保持原始查询的核心意图
  2. 使用不同的措辞和表达方式
  3. 可以从不同角度(正式、通俗、专业、简洁等)表达
  4. 确保所有版本都能检索到相关信息

该设计通过“多版本查询”策略,在单一查询场景下提供更全面的检索覆盖,适用于需要处理表达多样性的 RAG 应用。

5. 抽象化查询转换(Take a Step Back) #

抽象化查询转换(Take a Step Back)是一种通过引导大语言模型(LLM)将用户提出的具体问题转化为更高层次、更加抽象且通用的问题,从而实现“广谱”信息检索的技术手段。这种方法的核心思想是“退一步”,不拘泥于问题的细节表述,而是挖掘出用户背后更广泛的核心诉求,进而覆盖更多相关内容,避免仅因提问限制而遗漏潜在高相关知识。

主要动机与应用场景

  1. 用户问题常常包含大量细节或上下文假设,造成检索空间受限。
  2. 知识库中很多内容用更抽象或泛化的方式表达,直接用具体问题检索可能无法命中这些内容。
  3. 支持“举一反三”型场景,用户关切的并不仅仅是问题本身,更关心相关原理、通用方法、背景知识。

比如,用户问:

  • “如何使用Python的scikit-learn库训练一个支持向量机模型来分类鸢尾花数据集?” 抽象化后可转换为:
  • “如何利用机器学习算法进行分类任务?”
  • “支持向量机模型在分类问题中的应用方法是什么?”

这样,可以检索到更全面的泛化知识、典型流程和相关概念,丰富后续答案的广度与深度。

技术流程拆解

  1. 抽象化转换

    • 针对用户原始查询,构造Prompt引导大模型去“去细节、保主旨”,生成一个更高层次、覆盖更广的抽象化问题。
    • 常用Prompt示例:

      请将以下具体问题转化为更高层次的抽象问题。
      抽象化问题应该:
      1. 去除具体细节,保留核心概念和意图
      2. 使用更通用的术语和表达方式
      3. 能够匹配更广泛的相关文档
      4. 保持与原始问题的语义关联
      
      具体问题:{query}
      
      抽象化问题:
    • 得到抽象查询后,通常与原始查询一同组合用于检索。
  2. 双路混合检索(Abstract + Concrete)

    • 对抽象化查询检索较多候选文档(广覆盖)。
    • 对原始具体查询检索精确相关文档(高匹配)。
    • 合并两路检索结果,内容去重、按相关性打分筛选。
  3. 答案生成/合成

    • 整理上述所有检索到的文档内容,合成rich context。
    • LLM生成答案时,参考原始查询+抽象查询,提升准确性与覆盖面。

方法优势与适用性

  • 提升召回能力: 避免“细节漏检”,显著增加知识检索广度。
  • 支持模糊查询: 用户问题不清晰或缺失背景时,自动补全潜在上位主题。
  • 增强泛化表达: 能发现教材/论文/百科等资料通用讲法,补足典型套路与场景案例。
  • 极适用于: 教育问答、技术原理解读、流程方法通用性归纳、案例“举一反三”等领域。

方法小结

  • Take a Step Back(抽象化查询) = “拿掉细节、提升视角、扩展召回”。
  • 本质是引入更泛化的表达,使知识检索“拉宽一圈”,与具体查询并用,最大化信息检索和答案生成的全面性及鲁棒性。

5.1 TakeAStepBack.py #

# 导入类型注解所需的相关类型
from typing import List, Dict, Any, Optional
# 导入pydantic配置工具
from pydantic import ConfigDict
# 导入基础检索器基类
from langchain_core.retrievers import BaseRetriever
# 导入文档类型
from langchain_core.documents import Document
# 导入LLM基础类型
from langchain_core.language_models import BaseLanguageModel
# 导入提示词模板
from langchain_core.prompts import PromptTemplate
# 导入自定义llm对象
from llm import llm
# 导入获取向量库函数
from vector_store import get_vector_store

# 定义抽象化查询转换检索器类,继承自BaseRetriever
class TakeAStepBackRetriever(BaseRetriever):
    # 向量库对象
    vector_store: Any
    # 大语言模型对象
    llm: BaseLanguageModel
    # 使用抽象化查询进行广泛检索的结果数量
    abstract_k: int = 5
    # 使用原始查询进行精确检索的结果数量
    concrete_k: int = 3
    # 可选的抽象化转换prompt模板字符串
    abstraction_prompt_template_str: Optional[str] = None
    # 配置模型参数,允许任意类型
    model_config = ConfigDict(arbitrary_types_allowed=True)

    # 获取抽象化转换prompt模板的方法
    def _get_abstraction_prompt_template(self) -> PromptTemplate:
        # 获取或默认生成用于抽象化转换的Prompt模板字符串
        template = self.abstraction_prompt_template_str or (
            "请将以下具体问题转化为更高层次的抽象问题。\n"
            "抽象化问题应该:\n"
            "1. 去除具体细节,保留核心概念和意图\n"
            "2. 使用更通用的术语和表达方式\n"
            "3. 能够匹配更广泛的相关文档\n"
            "4. 保持与原始问题的语义关联\n\n"
            "具体问题:{query}\n\n"
            "抽象化问题:"
        )
        # 返回PromptTemplate对象
        return PromptTemplate(input_variables=["query"], template=template)

    # 调用LLM将具体问题转化为抽象问题
    def abstract_query(self, query: str) -> str:
        # 构建抽象化prompt
        prompt = self._get_abstraction_prompt_template().format(query=query)
        # 调用LLM进行内容生成
        abstract_query = self.llm.invoke(prompt).content.strip()
        # 返回抽象化查询
        return abstract_query

    # 使用抽象化查询进行广泛检索
    def retrieve_with_abstract_query(self, abstract_query: str, k: Optional[int] = None) -> List[Document]:
        # 如果未提供k参数,则使用实例默认值
        k = k if k is not None else self.abstract_k
        # 对抽象化查询进行向量相似性检索,获取更多结果
        docs = self.vector_store.similarity_search_with_score(abstract_query, k=k)
        # 将检索结果统一封装为Document对象,并标记为抽象化检索结果
        abstract_docs = []
        for doc, distance in docs:
            abstract_docs.append(Document(
                page_content=doc.page_content,
                metadata={**doc.metadata, "score": float(distance), "retrieval_type": "abstract", "query": abstract_query}
            ))
        # 返回抽象化检索结果
        return abstract_docs

    # 使用原始查询进行精确检索
    def retrieve_with_concrete_query(self, concrete_query: str, k: Optional[int] = None) -> List[Document]:
        # 如果未提供k参数,则使用实例默认值
        k = k if k is not None else self.concrete_k
        # 对原始查询进行向量相似性检索,获取精确结果
        docs = self.vector_store.similarity_search_with_score(concrete_query, k=k)
        # 将检索结果统一封装为Document对象,并标记为精确检索结果
        concrete_docs = []
        for doc, distance in docs:
            concrete_docs.append(Document(
                page_content=doc.page_content,
                metadata={**doc.metadata, "score": float(distance), "retrieval_type": "concrete", "query": concrete_query}
            ))
        # 返回精确检索结果
        return concrete_docs

    # 合并检索结果并去重
    def merge_and_deduplicate(self, abstract_docs: List[Document], concrete_docs: List[Document]) -> List[Document]:
        # 合并所有检索结果
        all_docs = abstract_docs + concrete_docs
        unique = {}
        # 遍历所有文档,使用内容做去重键,保留分数更小者
        for doc in all_docs:
            c = doc.page_content
            s = doc.metadata.get("score", float('inf'))
            # 如果内容未见过,或当前分数更好,则更新
            if c not in unique or s < unique[c].metadata.get("score", float('inf')):
                unique[c] = doc
        # 对去重后的文档按分数升序排序
        docs = sorted(unique.values(), key=lambda x: x.metadata.get("score", float('inf')))
        # 返回排序后的文档列表
        return docs

    # 主入口:根据查询获取相关文档流程
    def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
        # 获取可选的检索数量参数(如果提供则覆盖默认值)
        abstract_k = kwargs.get("abstract_k", self.abstract_k)
        concrete_k = kwargs.get("concrete_k", self.concrete_k)
        # 获取可选的抽象化查询(如果提供则直接使用,避免重复生成)
        abstract_query = kwargs.get("abstract_query")
        if abstract_query is None:
            # 抽象化转换:将具体问题转化为抽象问题
            abstract_query = self.abstract_query(query)

        # 混合检索:使用抽象化查询进行广泛检索,使用原始查询进行精确检索
        # 直接传递检索数量参数,无需临时修改实例属性
        abstract_docs = self.retrieve_with_abstract_query(abstract_query, k=abstract_k)
        concrete_docs = self.retrieve_with_concrete_query(query, k=concrete_k)

        # 结果合并:合并检索结果并去重
        final_docs = self.merge_and_deduplicate(abstract_docs, concrete_docs)

        # 返回最终文档列表
        return final_docs

# 定义抽象化查询转换RAG主流程类
class TakeAStepBackRAG:
    # 初始化,注入检索器和LLM,支持自定义答案生成Prompt
    def __init__(self, retriever: TakeAStepBackRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str]=None):
        self.retriever = retriever
        self.llm = llm
        # 设置答案生成Prompt模板,支持自定义或默认
        template = answer_prompt_template or (
            "已知以下相关信息:\n\n{context}\n\n"
            "原始查询:{query}\n"
            "抽象化查询:{abstract_query}\n\n"
            "请根据上述信息,准确全面地回答原始查询。\n\n答案:"
        )
        # 构建最终答案生成的PromptTemplate对象
        self.answer_prompt_template = PromptTemplate(
            input_variables=["context", "query", "abstract_query"], template=template
        )

    # 输入查询及可选检索数量,返回答案及检索相关信息
    def generate_answer(self, query: str, abstract_k: int = 5, concrete_k: int = 3) -> Dict[str, Any]:
        # 先生成抽象化查询
        abstract_query = self.retriever.abstract_query(query)

        # 检索相关文档(包含抽象化和精确检索),传入抽象化查询避免重复生成
        docs = self.retriever._get_relevant_documents(query, abstract_k=abstract_k, concrete_k=concrete_k, abstract_query=abstract_query)

        # 合并上下文内容
        context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])

        # 构建答案生成Prompt
        prompt = self.answer_prompt_template.format(context=context, query=query, abstract_query=abstract_query)

        # 调用LLM生成答案
        answer = self.llm.invoke(prompt).content.strip()

        # 统计检索类型
        abstract_count = sum(1 for doc in docs if doc.metadata.get("retrieval_type") == "abstract")
        concrete_count = sum(1 for doc in docs if doc.metadata.get("retrieval_type") == "concrete")

        # 整理并返回完整流程结果
        return {
            "query": query,
            "abstract_query": abstract_query,
            "answer": answer,
            "retrieved_documents": docs,
            "num_documents": len(docs),
            "abstract_count": abstract_count,
            "concrete_count": concrete_count
        }

# 初始化向量库对象
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="take_a_step_back"
)
# 初始化抽象化查询转换检索器
retriever = TakeAStepBackRetriever(
    vector_store=vector_store,
    llm=llm,
    abstract_k=5,
    concrete_k=3
)
# 初始化RAG系统
rag_system = TakeAStepBackRAG(
    retriever=retriever,
    llm=llm
)
# 定义示例文档内容列表
documents = [
    "人工智能(AI)是计算机科学的一个分支,旨在创建能够执行通常需要人类智能的任务的系统。"
    "机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习,而无需明确编程。"
    "深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。"
    "自然语言处理(NLP)是AI的另一个重要领域,专注于使计算机能够理解和生成人类语言。",
    "区块链技术是一种分布式账本技术,通过密码学方法确保数据的安全性和不可篡改性。"
    "比特币是第一个成功应用区块链技术的加密货币,它解决了数字货币的双重支付问题。"
    "以太坊是一个支持智能合约的区块链平台,允许开发者在其上构建去中心化应用(DApps)。"
    "智能合约是自动执行的合约,其条款直接写入代码中,无需第三方中介。",
    "量子计算是一种利用量子力学现象进行计算的新兴技术,具有巨大的计算潜力。"
    "量子比特(qubit)是量子计算的基本单位,与经典比特不同,它可以同时处于0和1的叠加态。"
    "量子纠缠是量子计算的关键特性,允许量子比特之间建立特殊的关联关系。"
    "量子计算在密码学、药物发现和优化问题等领域具有潜在的应用前景。",
]
# 向向量库中添加文档及元数据
vector_store.add_texts(
    documents,
    metadatas=[
        {"topic": "人工智能", "category": "科技"},
        {"topic": "区块链", "category": "科技"},
        {"topic": "量子计算", "category": "科技"},
    ]
)
# 指定一个具体查询作为用户输入
query = "如何使用Python的scikit-learn库训练一个支持向量机模型来分类鸢尾花数据集?"
# 运行RAG流程并获取结果
result = rag_system.generate_answer(query, abstract_k=5, concrete_k=3)
# 打印分隔线
print("="*60)
print("最终结果")
print("="*60)
# 打印用户查询
print(f"\n原始查询: {result['query']}")
# 打印抽象化查询
print(f"\n抽象化查询: {result['abstract_query']}")
# 打印生成的最终答案
print(f"\n生成的答案:\n{result['answer']}")
# 打印检索统计信息
print(f"\n检索统计:")
print(f"  总文档数: {result['num_documents']}")
print(f"  抽象化检索: {result['abstract_count']} 个")
print(f"  精确检索: {result['concrete_count']} 个")
# 打印部分检索到的文档的细节信息
print(f"\n检索到的文档详情:")
for i, doc in enumerate(result['retrieved_documents'][:5], 1):
    retrieval_type = doc.metadata.get("retrieval_type", "unknown")
    print(f"\n文档 {i} (分数: {doc.metadata.get('score', 'N/A'):.4f}, 类型: {retrieval_type}):")
    print(f"  内容: {doc.page_content[:150]}...")

5.2 执行过程 #

5.2.1 核心思想 #

Take a Step Back 采用“抽象+精确”的双重检索策略:

  • 将具体问题转化为更高层次的抽象问题
  • 使用抽象化查询进行广泛检索(获取更多相关文档)
  • 使用原始查询进行精确检索(获取精确匹配的文档)
  • 合并两种检索结果并去重
  • 基于合并结果生成最终答案

5.2.2 执行流程 #

阶段一:初始化

# 1. 获取向量存储实例
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="take_a_step_back"
)

# 2. 创建Take a Step Back检索器
retriever = TakeAStepBackRetriever(
    vector_store=vector_store,
    llm=llm,
    abstract_k=5,    # 抽象化检索返回5个文档
    concrete_k=3     # 精确检索返回3个文档
)

# 3. 创建RAG系统
rag_system = TakeAStepBackRAG(
    retriever=retriever,
    llm=llm
)

初始化时:

  • 创建向量存储实例
  • 创建检索器,配置 LLM、向量库、抽象化检索数量(abstract_k)和精确检索数量(concrete_k)
  • 创建 RAG 系统,用于完整流程

阶段二:文档索引

# 添加文档到向量库
vector_store.add_texts(
    documents,
    metadatas=[...]
)

索引过程:

  • 将示例文档添加到向量库
  • 每个文档附带元数据(topic、category)
  • 文档会被向量化并存储

阶段三:查询处理

query = "如何使用Python的scikit-learn库训练一个支持向量机模型来分类鸢尾花数据集?"
result = rag_system.generate_answer(query, abstract_k=5, concrete_k=3)

完整流程:

  1. 用户提交具体查询
  2. 抽象化转换:
    • 调用 abstract_query(query)
    • 使用 LLM 将具体问题转化为抽象问题
    • 去除具体细节,保留核心概念
  3. 双重检索:
    • 使用抽象化查询进行广泛检索(abstract_k=5)
    • 使用原始查询进行精确检索(concrete_k=3)
  4. 结果合并:
    • 调用 merge_and_deduplicate(abstract_docs, concrete_docs)
    • 合并两种检索结果
    • 按文档内容去重,保留最佳分数
    • 按分数排序
  5. 生成答案:
    • 整合所有检索文档作为上下文
    • 使用提示词模板构建最终 prompt(包含原始查询和抽象化查询)
    • LLM 生成答案
  6. 返回结果:
    • 包含查询、抽象化查询、答案、检索文档、统计信息等

5.2.3 类图 #

classDiagram class TakeAStepBackRetriever { -vector_store: Any -llm: BaseLanguageModel -abstract_k: int -concrete_k: int -abstraction_prompt_template_str: Optional[str] +_get_abstraction_prompt_template() PromptTemplate +abstract_query(query: str) str +retrieve_with_abstract_query(abstract_query: str) List[Document] +retrieve_with_concrete_query(concrete_query: str) List[Document] +merge_and_deduplicate(abstract_docs, concrete_docs) List[Document] +_get_relevant_documents(query: str, **kwargs) List[Document] } class TakeAStepBackRAG { -retriever: TakeAStepBackRetriever -llm: BaseLanguageModel -answer_prompt_template: PromptTemplate +__init__(retriever, llm, answer_prompt_template) +generate_answer(query: str, abstract_k: int, concrete_k: int) Dict[str, Any] } class BaseRetriever { <<abstract>> +invoke(query: str) List[Document] } class PromptTemplate { +format(**kwargs) str } class Document { +page_content: str +metadata: Dict } class VectorStore { <<interface>> +add_texts(texts: List[str], metadatas: List[Dict]) +similarity_search_with_score(query: str, k: int) List[Tuple[Document, float]] } class BaseLanguageModel { <<interface>> +invoke(prompt: str) AIMessage } TakeAStepBackRetriever --|> BaseRetriever TakeAStepBackRAG --> TakeAStepBackRetriever TakeAStepBackRAG --> BaseLanguageModel TakeAStepBackRetriever --> BaseLanguageModel TakeAStepBackRetriever --> PromptTemplate TakeAStepBackRetriever --> VectorStore TakeAStepBackRetriever ..> Document : creates VectorStore ..> Document : returns

5.2.4 时序图 #

5.2.4.1 完整RAG流程时序图 #
sequenceDiagram participant User as 用户 participant RAG as TakeAStepBackRAG participant Retriever as TakeAStepBackRetriever participant PromptTemplate as PromptTemplate participant LLM as BaseLanguageModel participant VectorStore as VectorStore User->>RAG: generate_answer(具体查询, abstract_k=5, concrete_k=3) Note over RAG: 步骤1: 生成抽象化查询 RAG->>Retriever: abstract_query(query) Retriever->>Retriever: _get_abstraction_prompt_template() Retriever->>PromptTemplate: format(query=具体查询) PromptTemplate-->>Retriever: 返回完整prompt Retriever->>LLM: invoke(prompt) Note over LLM: 将具体问题转化为<br/>抽象问题<br/>(去除细节,保留核心概念) LLM-->>Retriever: 返回抽象化查询 Retriever-->>RAG: 返回抽象化查询 Note over RAG: 步骤2: 双重检索 RAG->>Retriever: _get_relevant_documents(query, abstract_k=5, concrete_k=3, abstract_query) Note over Retriever: 并行执行两种检索 par 抽象化检索 Retriever->>Retriever: retrieve_with_abstract_query(abstract_query) Retriever->>VectorStore: similarity_search_with_score(抽象化查询, k=5) Note over VectorStore: 广泛检索<br/>匹配更广泛的相关文档 VectorStore-->>Retriever: 返回5个文档(带分数) Retriever->>Retriever: 构建Document对象<br/>(标记retrieval_type="abstract") and 精确检索 Retriever->>Retriever: retrieve_with_concrete_query(原始查询) Retriever->>VectorStore: similarity_search_with_score(原始查询, k=3) Note over VectorStore: 精确检索<br/>匹配具体查询的文档 VectorStore-->>Retriever: 返回3个文档(带分数) Retriever->>Retriever: 构建Document对象<br/>(标记retrieval_type="concrete") end Note over Retriever: 步骤3: 合并去重 Retriever->>Retriever: merge_and_deduplicate(abstract_docs, concrete_docs) Note over Retriever: 合并结果<br/>按内容去重<br/>保留最佳分数<br/>按分数排序 Retriever-->>RAG: 返回合并后的文档列表 Note over RAG: 步骤4: 生成答案 RAG->>RAG: 整合文档内容为上下文 RAG->>RAG: 统计检索类型数量 RAG->>PromptTemplate: format(context, query, abstract_query) PromptTemplate-->>RAG: 返回完整prompt RAG->>LLM: invoke(prompt) Note over LLM: 基于双重检索结果生成答案 LLM-->>RAG: 返回答案内容 RAG-->>User: 返回结果字典<br/>(query, abstract_query, answer, documents, statistics)
5.2.4.2 抽象化转换详细流程 #
sequenceDiagram participant Retriever as TakeAStepBackRetriever participant Template as PromptTemplate participant LLM as BaseLanguageModel Note over Retriever: abstract_query(query) Retriever->>Retriever: _get_abstraction_prompt_template() Note over Retriever: 获取或创建抽象化模板<br/>"请将以下具体问题转化为更高层次的抽象问题..." Retriever->>Template: format(query=具体查询) Template-->>Retriever: 返回完整prompt Retriever->>LLM: invoke(prompt) Note over LLM: 生成抽象化问题<br/>要求:<br/>1. 去除具体细节<br/>2. 保留核心概念<br/>3. 使用通用术语<br/>4. 匹配更广泛文档 LLM-->>Retriever: 返回抽象化查询 Retriever->>Retriever: strip() 去除首尾空格 Retriever-->>Retriever: 返回抽象化查询字符串
5.2.4.3 双重检索与合并详细流程 #
sequenceDiagram participant Retriever as TakeAStepBackRetriever participant VectorStore as VectorStore Note over Retriever: 双重检索流程 par 抽象化检索 (abstract_k=5) Retriever->>VectorStore: similarity_search_with_score(抽象化查询, k=5) Note over VectorStore: 使用抽象化查询<br/>进行广泛检索<br/>匹配相关概念和原理 VectorStore-->>Retriever: 返回5个文档<br/>[(doc1, score1), ..., (doc5, score5)] loop 对每个检索结果 Retriever->>Retriever: 构建Document对象<br/>添加元数据:<br/>- score: 分数<br/>- retrieval_type: "abstract"<br/>- query: 抽象化查询 end and 精确检索 (concrete_k=3) Retriever->>VectorStore: similarity_search_with_score(原始查询, k=3) Note over VectorStore: 使用原始查询<br/>进行精确检索<br/>匹配具体实现和细节 VectorStore-->>Retriever: 返回3个文档<br/>[(docA, scoreA), (docB, scoreB), (docC, scoreC)] loop 对每个检索结果 Retriever->>Retriever: 构建Document对象<br/>添加元数据:<br/>- score: 分数<br/>- retrieval_type: "concrete"<br/>- query: 原始查询 end end Note over Retriever: 现在有abstract_docs和concrete_docs<br/>(可能有重复文档) Retriever->>Retriever: merge_and_deduplicate(abstract_docs, concrete_docs) Note over Retriever: 合并和去重逻辑 Retriever->>Retriever: 合并所有文档: all_docs = abstract_docs + concrete_docs loop 遍历所有文档 Retriever->>Retriever: 获取文档内容作为键 alt 文档首次出现 Retriever->>Retriever: 添加到unique字典 else 文档已存在 alt 当前分数 < 已有分数 Retriever->>Retriever: 更新为当前文档<br/>(保留更好的分数) end end end Retriever->>Retriever: 按分数升序排序 Retriever-->>Retriever: 返回去重并排序后的文档列表

5.2.5 关键设计要点 #

1. 抽象化转换流程

具体查询
    ↓
LLM抽象化转换
    ↓
抽象化查询 (去除细节,保留核心概念)
    ↓
双重检索:
  - 抽象化查询 → 广泛检索 (abstract_k=5)
  - 原始查询 → 精确检索 (concrete_k=3)
    ↓
合并结果并去重
    ↓
返回最终文档列表

2. 抽象化转换示例

原始查询: "如何使用Python的scikit-learn库训练一个支持向量机模型来分类鸢尾花数据集?"

抽象化查询: "如何使用机器学习算法进行分类任务?"
或
抽象化查询: "机器学习分类模型的训练方法"

抽象化要求:

  • 去除具体细节:Python、scikit-learn、支持向量机、鸢尾花数据集
  • 保留核心概念:机器学习、分类、训练
  • 使用通用术语:算法、模型、方法
  • 匹配更广泛文档:能匹配到机器学习相关的所有文档

3. 双重检索策略

  • 抽象化检索(广泛检索):
    • 使用抽象化查询
    • 返回更多文档(abstract_k=5)
    • 匹配相关概念和原理
    • 提供背景知识和理论基础
  • 精确检索(精确匹配):
    • 使用原始查询
    • 返回精确文档(concrete_k=3)
    • 匹配具体实现和细节
    • 提供直接相关的答案

4. 去重策略

  • 去重键:使用文档内容(page_content)作为唯一标识
  • 分数选择:如果同一文档被两种检索方式都检索到,保留最佳分数(距离最小)
  • 排序:去重后按分数升序排序,最相关的文档在前

示例:

抽象化检索结果: [docA(score=0.2), docB(score=0.3), docC(score=0.4), docD(score=0.5), docE(score=0.6)]
精确检索结果: [docA(score=0.15), docF(score=0.25), docG(score=0.35)]

合并去重后: [docA(score=0.15), docF(score=0.25), docB(score=0.3), docG(score=0.35), docC(score=0.4), docD(score=0.5), docE(score=0.6)]

5. 元数据设计

检索返回的 Document 对象包含:

  • score:相似度分数(距离,越小越相似)
  • retrieval_type:检索类型("abstract" 或 "concrete")
  • query:使用的查询(抽象化查询或原始查询)
  • 继承原始文档的元数据(topic、category 等)

6. 答案生成模板

默认答案生成模板:

已知以下相关信息:

{context}

原始查询:{query}
抽象化查询:{abstract_query}

请根据上述信息,准确全面地回答原始查询。

答案:

该模板:

  • 包含所有检索文档作为上下文
  • 明确列出原始查询和抽象化查询
  • 要求准确全面地回答原始查询

5.2.6 与其他方法的对比 #

特性 传统检索 Take a Step Back
查询处理 直接使用原始查询 抽象化转换 + 原始查询
检索方式 单次检索 双重检索(广泛+精确)
覆盖范围 单一匹配 广泛+精确双重覆盖
适用场景 简单查询 具体但需要背景知识的查询
优势 简单直接 兼顾广泛性和精确性

5.2.7 优势与应用场景 #

优势:

  • 广泛覆盖:抽象化检索提供背景知识和理论基础
  • 精确匹配:精确检索提供直接相关的答案
  • 平衡策略:兼顾广泛性和精确性
  • 结果去重:自动处理重复文档,保留最佳分数

适用场景:

  • 具体技术问题:需要具体实现和背景知识
  • 学习场景:需要理论背景和实际应用
  • 研究场景:需要概念理解和具体案例
  • 复杂查询:需要多层次信息的查询

5.2.8 技术细节 #

  • 检索数量控制:
    • 抽象化检索:abstract_k(默认5)
    • 精确检索:concrete_k(默认3)
    • 可通过参数动态调整
  • 抽象化查询缓存:
    • 在 _get_relevant_documents 中支持传入 abstract_query
    • 避免重复生成抽象化查询
  • 去重算法:
    • 使用字典以内容为键
    • 保留最佳分数(距离最小)

5.2.9 检索统计 #

返回结果包含统计信息:

  • abstract_count:来自抽象化检索的文档数量
  • concrete_count:来自精确检索的文档数量
  • num_documents:总文档数量(去重后)

该设计通过“抽象+精确”的双重检索策略,在具体查询场景下提供更全面的检索覆盖,适用于需要背景知识和具体答案的 RAG 应用。

6. 检索-生成一体化(Retrieval-generation integration) #

传统RAG(Retrieval-Augmented Generation)体系通常将“检索”与“生成”划分为独立模块:先对查询进行信息检索,再将检索到的文档输入大语言模型生成答案。这种串联模式虽然结构清晰,但也存在以下潜在不足:

  • 检索与生成的分界较为生硬,检索阶段难以充分感知生成任务的具体要求。
  • 生成模型无法灵活动态地影响检索策略与范围。
  • 在复杂问题和多轮对话场景下,检索结果的相关性和准确性难以保障。

为了解决上述问题,近年来兴起了“检索-生成一体化”思想——即通过深度融合检索与生成过程,实现二者的协同优化。其核心目标在于让生成模型动态、主动参与到文档检索阶段,例如:

  • 让大模型根据当前生成任务多轮拟定检索子查询,分阶段检索并综合更多相关信息。
  • 检索结果与生成阶段通过奖励机制、指令调优(instruction tuning)、强化学习等方式联合训练,增强全流程对齐能力。
  • 构建端到端可微分检索生成链路,实现前馈式联合优化。

典型的一体化实现包括:

  1. Query Routing: 生成模型自动识别查询意图,有针对性地选择检索策略(如结构化、非结构化、混合型文档库)。
  2. 多轮增量式检索-生成:生成模型先粗略生成答案→识别缺失信息→动态追加检索→迭代完善答案。
  3. 检索-生成共设计(Co-design):如指令调优的融合检索生成Agent,将信息查询和答案生成步骤交错建模。

这一趋势使得RAG系统更加智能、灵活,能更好地处理开放域、复杂任务与场景,如多跳推理、事实查证、科技问答及对话系统。

6.1 RetrievalGenerationIntegration.py #

# 导入List、Dict、Any、Optional类型用于类型注解
from typing import List, Dict, Any, Optional

# 导入pydantic用于配置数据模型
from pydantic import ConfigDict

# 导入基础检索器基类
from langchain_core.retrievers import BaseRetriever

# 导入文档对象类型
from langchain_core.documents import Document

# 导入基础LLM类型
from langchain_core.language_models import BaseLanguageModel

# 导入提示词模板类
from langchain_core.prompts import PromptTemplate

# 导入自定义llm对象
from llm import llm

# 导入获取向量存储的函数
from vector_store import get_vector_store

# 定义意图分类器类
class IntentClassifier:
    """意图分类器:识别用户问题的类型"""

    # 构造函数,初始化llm对象和意图类别列表
    def __init__(self, llm: BaseLanguageModel, intent_categories: List[str]):
        self.llm = llm
        self.intent_categories = intent_categories
        # 定义意图分类任务的prompt模板
        self.classification_template = PromptTemplate(
            input_variables=["query", "categories"],
            template="请判断以下用户查询属于哪个类别。\n"
            "可用类别:{categories}\n\n"
            "用户查询:{query}\n\n"
            "请只返回类别名称,不要返回其他内容:"
        )

    # 对单条用户查询进行意图分类
    def classify(self, query: str) -> str:
        """对查询进行意图分类"""
        # 用中文顿号拼接所有意图类别
        categories_str = "、".join(self.intent_categories)
        # 填充prompt生成具体任务
        prompt = self.classification_template.format(query=query, categories=categories_str)
        # 调用llm得到结果
        response = self.llm.invoke(prompt).content.strip()
        # 检查返回的类别是否在目标类别中
        for category in self.intent_categories:
            if category in response:
                return category
        # 默认返回第一个类别
        return self.intent_categories[0]

# 定义查询重写器类
class QueryRewriter:
    """查询重写器:将口语化问题转换为标准术语"""

    # 构造,接收llm对象和意图到重写策略的映射
    def __init__(self, llm: BaseLanguageModel, rewriting_strategies: Dict[str, str]):
        self.llm = llm
        self.rewriting_strategies = rewriting_strategies

    # 根据意图重写查询内容
    def rewrite(self, query: str, intent: str) -> str:
        """根据意图重写查询"""
        # 获取指定意图的重写提示词,没有则用默认
        strategy = self.rewriting_strategies.get(intent, "请将以下查询转换为标准术语进行检索:{query}")

        # 根据策略格式化生成prompt
        if "{query}" in strategy:
            prompt = strategy.format(query=query)
        else:
            prompt = f"{strategy}\n\n用户查询:{query}\n\n标准术语查询:"

        # 使用llm调用重写查询
        rewritten = self.llm.invoke(prompt).content.strip()
        # 对重写为空的情况做降级处理
        return rewritten if rewritten else query

# 定义知识库管理器类
class KnowledgeBaseManager:
    """知识库管理器:管理多个意图对应的知识库"""

    # 构造函数,初始化知识库存储路径和知识库字典
    def __init__(self, persist_directory: str):
        self.persist_directory = persist_directory
        self.knowledge_bases: Dict[str, Any] = {}

    # 获取指定意图和集合名的知识库实例(不存在则自动创建)
    def get_knowledge_base(self, intent: str, collection_name: str) -> Any:
        """获取或创建指定意图的知识库"""
        # 如果知识库不存在则创建
        if intent not in self.knowledge_bases:
            self.knowledge_bases[intent] = get_vector_store(
                persist_directory=self.persist_directory,
                collection_name=collection_name
            )
        return self.knowledge_bases[intent]

    # 向指定知识库添加文档列表及元数据
    def add_documents(self, intent: str, collection_name: str, texts: List[str], metadatas: List[Dict] = None):
        """向指定意图的知识库添加文档"""
        kb = self.get_knowledge_base(intent, collection_name)
        kb.add_texts(texts, metadatas=metadatas or [{}] * len(texts))

# 定义检索生成一体化检索器
class RetrievalGenerationIntegrationRetriever(BaseRetriever):
    """检索-生成一体化检索器"""

    # 定义知识库管理器类型
    kb_manager: KnowledgeBaseManager
    # 定义意图分类器类型
    intent_classifier: IntentClassifier
    # 定义查询重写器类型
    query_rewriter: QueryRewriter
    # 意图到集合名的映射表
    intent_to_collection: Dict[str, str]
    # 默认检索文档数量
    k: int = 4
    # pydantic模型配置,允许任意类型字段
    model_config = ConfigDict(arbitrary_types_allowed=True)

    # 获取与查询相关的文档
    def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
        """一体化检索流程"""
        # 优先使用外部传入的k,否则用默认
        k = kwargs.get("k", self.k)

        # 步骤1: 意图识别
        intent = self.intent_classifier.classify(query)

        # 步骤2: 查询重写
        rewritten_query = self.query_rewriter.rewrite(query, intent)

        # 步骤3: 检索对应知识库
        collection_name = self.intent_to_collection.get(intent, "default")
        kb = self.kb_manager.get_knowledge_base(intent, collection_name)
        docs = kb.similarity_search_with_score(rewritten_query, k=k)

        # 步骤4: 文档结果组装成Document对象,加入相关元数据
        result_docs = []
        for doc, distance in docs:
            result_docs.append(Document(
                page_content=doc.page_content,
                metadata={
                    **doc.metadata,
                    "score": float(distance),
                    "intent": intent,
                    "original_query": query,
                    "rewritten_query": rewritten_query
                }
            ))

        return result_docs

# 定义检索生成一体化RAG系统
class RetrievalGenerationIntegrationRAG:
    """检索-生成一体化RAG系统"""

    # 构造方法,初始化retriever与llm,并配置答案生成prompt
    def __init__(self, retriever: RetrievalGenerationIntegrationRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str] = None):
        self.retriever = retriever
        self.llm = llm
        # 若无自定义模板则使用默认
        template = answer_prompt_template or (
            "已知以下相关信息:\n\n{context}\n\n"
            "原始查询:{query}\n"
            "识别意图:{intent}\n"
            "重写查询:{rewritten_query}\n\n"
            "请根据上述信息,准确全面地回答原始查询。\n\n答案:"
        )
        # 初始化生成答案所需的prompt模板
        self.answer_prompt_template = PromptTemplate(
            input_variables=["context", "query", "intent", "rewritten_query"], template=template
        )

    # 对用户查询生成最终答案
    def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
        """生成答案"""
        # 步骤1: 检索相关文档
        docs = self.retriever._get_relevant_documents(query, k=k)

        # 步骤2: 提取相关元数据
        intent = docs[0].metadata.get("intent", "unknown") if docs else "unknown"
        rewritten_query = docs[0].metadata.get("rewritten_query", query) if docs else query

        # 步骤3: 组装检索到的文档内容作为上下文
        context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])

        # 步骤4: 构建答案生成prompt
        prompt = self.answer_prompt_template.format(
            context=context,
            query=query,
            intent=intent,
            rewritten_query=rewritten_query
        )
        # 步骤5: 调用llm生成答案
        answer = self.llm.invoke(prompt).content.strip()

        # 汇总所有输出信息
        return {
            "query": query,
            "intent": intent,
            "rewritten_query": rewritten_query,
            "answer": answer,
            "retrieved_documents": docs,
            "num_documents": len(docs)
        }

# 初始化各种组件配置参数

# 定义意图类别(如人事政策场景)
intent_categories = ["假期政策", "调岗流程", "薪资福利", "其他"]

# 定义针对不同意图的查询重写策略
rewriting_strategies = {
    "假期政策": "请从以下查询中提取假期类型(如年假、病假、产假、婚假等)的标准术语,然后生成检索查询。\n用户查询:{query}\n标准术语查询:",
    "调岗流程": "请从以下查询中提取调岗相关的标准术语和流程关键词,然后生成检索查询。\n用户查询:{query}\n标准术语查询:",
    "薪资福利": "请从以下查询中提取薪资福利相关的标准术语,然后生成检索查询。\n用户查询:{query}\n标准术语查询:",
    "其他": "请将以下查询转换为标准术语进行检索:{query}"
}

# 定义意图到向量集合名的映射
intent_to_collection = {
    "假期政策": "leave_policy",
    "调岗流程": "transfer_process",
    "薪资福利": "salary_benefits",
    "其他": "default"
}

# 创建意图分类器对象
intent_classifier = IntentClassifier(llm, intent_categories)

# 创建查询重写器对象
query_rewriter = QueryRewriter(llm, rewriting_strategies)

# 创建知识库管理器对象
kb_manager = KnowledgeBaseManager("chroma_db")

# 按依赖注入创建一体化检索器
retriever = RetrievalGenerationIntegrationRetriever(
    kb_manager=kb_manager,
    intent_classifier=intent_classifier,
    query_rewriter=query_rewriter,
    intent_to_collection=intent_to_collection,
    k=4
)

# 创建一体化RAG主系统
rag_system = RetrievalGenerationIntegrationRAG(retriever=retriever, llm=llm)

# 示例:准备“假期政策”知识库文档
leave_policy_docs = [
    "年假政策:员工工作满一年后可享受年假,年假天数根据工作年限递增。",
    "病假政策:员工因病需要请假时,需提供医院证明,病假期间工资按相关规定发放。",
    "产假政策:女性员工生育可享受产假,产假期间工资照发,具体天数根据国家规定执行。",
    "婚假政策:员工结婚可申请婚假,婚假天数根据公司规定执行,通常为3-7天。"
]

# 示例:准备“调岗流程”知识库文档
transfer_process_docs = [
    "调岗申请流程:员工需填写调岗申请表,经部门主管审批后提交人力资源部。",
    "调岗审批流程:人力资源部审核调岗申请,评估员工能力和岗位匹配度,做出审批决定。",
    "调岗执行流程:调岗获批后,人力资源部通知相关部门,安排员工到新岗位报到。"
]

# 示例:准备“薪资福利”知识库文档
salary_benefits_docs = [
    "薪资结构:员工薪资由基本工资、绩效奖金、津贴等部分组成。",
    "福利待遇:公司为员工提供五险一金、带薪年假、节日福利等。",
    "绩效奖金:根据员工工作表现和公司业绩,发放季度和年度绩效奖金。"
]

# 各意图知识库分别添加相关文档及元数据
print("正在初始化知识库...")
kb_manager.add_documents("假期政策", "leave_policy", leave_policy_docs, 
                         [{"topic": "假期政策", "type": "年假"}] * 1 +
                         [{"topic": "假期政策", "type": "病假"}] * 1 +
                         [{"topic": "假期政策", "type": "产假"}] * 1 +
                         [{"topic": "假期政策", "type": "婚假"}] * 1)

kb_manager.add_documents("调岗流程", "transfer_process", transfer_process_docs,
                         [{"topic": "调岗流程"}] * 3)

kb_manager.add_documents("薪资福利", "salary_benefits", salary_benefits_docs,
                         [{"topic": "薪资福利"}] * 3)

print("知识库初始化完成!\n")

# 口语化用户查询示例列表
queries = [
    "我下个月要结婚,能请几天假?",  # 应该识别为"假期政策",重写为"婚假"
    "家里人生病了,公司有相关假期吗?",  # 应该识别为"假期政策",重写为"病假"
    "我想换个部门工作,需要走什么流程?"  # 应该识别为"调岗流程"
]

# 测试每个查询,打印RAG系统的完整回答和检索信息
for query in queries:
    print("="*60)
    print(f"用户查询: {query}")
    print("="*60)

    result = rag_system.generate_answer(query, k=3)

    print(f"\n识别意图: {result['intent']}")
    print(f"重写查询: {result['rewritten_query']}")
    print(f"\n生成答案:\n{result['answer']}")
    print(f"\n检索到 {result['num_documents']} 个文档")
    print("\n" + "-"*60 + "\n")

6.2 执行过程 #

6.2.1 核心思想 #

检索-生成一体化采用“意图驱动”的多知识库策略:

  • 意图识别:识别用户问题的类型(如假期政策、调岗流程、薪资福利等)
  • 查询重写:根据意图将口语化问题转换为标准术语
  • 知识库路由:根据意图选择对应的知识库
  • 检索:在对应的知识库中检索相关文档
  • 生成答案:基于检索结果生成最终答案

6.2.2 执行流程 #

阶段一:初始化

# 1. 定义意图类别
intent_categories = ["假期政策", "调岗流程", "薪资福利", "其他"]

# 2. 定义查询重写策略
rewriting_strategies = {
    "假期政策": "...",
    "调岗流程": "...",
    "薪资福利": "...",
    "其他": "..."
}

# 3. 定义意图到集合名的映射
intent_to_collection = {
    "假期政策": "leave_policy",
    "调岗流程": "transfer_process",
    "薪资福利": "salary_benefits",
    "其他": "default"
}

# 4. 创建各个组件
intent_classifier = IntentClassifier(llm, intent_categories)
query_rewriter = QueryRewriter(llm, rewriting_strategies)
kb_manager = KnowledgeBaseManager("chroma_db")

# 5. 创建一体化检索器
retriever = RetrievalGenerationIntegrationRetriever(...)

# 6. 创建RAG系统
rag_system = RetrievalGenerationIntegrationRAG(retriever=retriever, llm=llm)

初始化时:

  • 定义意图类别和重写策略
  • 创建意图分类器、查询重写器、知识库管理器
  • 创建一体化检索器和RAG系统

阶段二:知识库初始化

# 为不同意图的知识库添加文档
kb_manager.add_documents("假期政策", "leave_policy", leave_policy_docs, ...)
kb_manager.add_documents("调岗流程", "transfer_process", transfer_process_docs, ...)
kb_manager.add_documents("薪资福利", "salary_benefits", salary_benefits_docs, ...)

知识库初始化:

  • 为每个意图创建独立的知识库
  • 添加对应的文档和元数据
  • 每个知识库使用不同的集合名

阶段三:查询处理

query = "我下个月要结婚,能请几天假?"
result = rag_system.generate_answer(query, k=3)

完整流程:

  1. 用户提交口语化查询
  2. 意图识别:
    • 调用 intent_classifier.classify(query)
    • 使用 LLM 识别查询的意图类别
  3. 查询重写:
    • 调用 query_rewriter.rewrite(query, intent)
    • 根据意图使用对应的重写策略
    • 将口语化问题转换为标准术语
  4. 知识库路由:
    • 根据意图获取对应的集合名
    • 获取或创建对应的知识库实例
  5. 检索:
    • 在对应的知识库中使用重写后的查询进行检索
    • 返回 top-k 个相关文档
  6. 生成答案:
    • 整合检索文档作为上下文
    • 使用提示词模板构建最终 prompt(包含原始查询、意图、重写查询)
    • LLM 生成答案
  7. 返回结果:
    • 包含查询、意图、重写查询、答案、检索文档等

6.2.3 类图 #

classDiagram class IntentClassifier { -llm: BaseLanguageModel -intent_categories: List[str] -classification_template: PromptTemplate +__init__(llm, intent_categories) +classify(query: str) str } class QueryRewriter { -llm: BaseLanguageModel -rewriting_strategies: Dict[str, str] +__init__(llm, rewriting_strategies) +rewrite(query: str, intent: str) str } class KnowledgeBaseManager { -persist_directory: str -knowledge_bases: Dict[str, Any] +__init__(persist_directory) +get_knowledge_base(intent: str, collection_name: str) Any +add_documents(intent: str, collection_name: str, texts: List[str], metadatas: List[Dict]) } class RetrievalGenerationIntegrationRetriever { -kb_manager: KnowledgeBaseManager -intent_classifier: IntentClassifier -query_rewriter: QueryRewriter -intent_to_collection: Dict[str, str] -k: int +_get_relevant_documents(query: str, **kwargs) List[Document] } class RetrievalGenerationIntegrationRAG { -retriever: RetrievalGenerationIntegrationRetriever -llm: BaseLanguageModel -answer_prompt_template: PromptTemplate +__init__(retriever, llm, answer_prompt_template) +generate_answer(query: str, k: int) Dict[str, Any] } class BaseRetriever { <<abstract>> +invoke(query: str) List[Document] } class PromptTemplate { +format(**kwargs) str } class Document { +page_content: str +metadata: Dict } class VectorStore { <<interface>> +add_texts(texts: List[str], metadatas: List[Dict]) +similarity_search_with_score(query: str, k: int) List[Tuple[Document, float]] } class BaseLanguageModel { <<interface>> +invoke(prompt: str) AIMessage } RetrievalGenerationIntegrationRetriever --|> BaseRetriever RetrievalGenerationIntegrationRAG --> RetrievalGenerationIntegrationRetriever RetrievalGenerationIntegrationRAG --> BaseLanguageModel RetrievalGenerationIntegrationRetriever --> IntentClassifier RetrievalGenerationIntegrationRetriever --> QueryRewriter RetrievalGenerationIntegrationRetriever --> KnowledgeBaseManager IntentClassifier --> BaseLanguageModel QueryRewriter --> BaseLanguageModel KnowledgeBaseManager --> VectorStore RetrievalGenerationIntegrationRetriever ..> Document : creates VectorStore ..> Document : returns

6.2.4 时序图 #

6.2.4.1 完整RAG流程时序图 #
sequenceDiagram participant User as 用户 participant RAG as RetrievalGenerationIntegrationRAG participant Retriever as RetrievalGenerationIntegrationRetriever participant IntentClassifier as IntentClassifier participant QueryRewriter as QueryRewriter participant KBManager as KnowledgeBaseManager participant VectorStore as VectorStore participant LLM as BaseLanguageModel User->>RAG: generate_answer("我下个月要结婚,能请几天假?", k=3) RAG->>Retriever: _get_relevant_documents(query, k=3) Note over Retriever: 步骤1: 意图识别 Retriever->>IntentClassifier: classify(query) IntentClassifier->>IntentClassifier: 构建分类prompt<br/>"请判断以下用户查询属于哪个类别..." IntentClassifier->>LLM: invoke(prompt) Note over LLM: 识别查询意图<br/>(假期政策、调岗流程、薪资福利、其他) LLM-->>IntentClassifier: 返回意图类别 IntentClassifier-->>Retriever: 返回意图("假期政策") Note over Retriever: 步骤2: 查询重写 Retriever->>QueryRewriter: rewrite(query, intent) QueryRewriter->>QueryRewriter: 获取意图对应的重写策略 QueryRewriter->>QueryRewriter: 构建重写prompt<br/>"请从以下查询中提取假期类型..." QueryRewriter->>LLM: invoke(prompt) Note over LLM: 将口语化查询转换为标准术语<br/>"我下个月要结婚,能请几天假?"<br/>→ "婚假" LLM-->>QueryRewriter: 返回重写后的查询 QueryRewriter-->>Retriever: 返回重写查询("婚假") Note over Retriever: 步骤3: 知识库路由 Retriever->>Retriever: 根据意图获取集合名<br/>intent_to_collection["假期政策"] = "leave_policy" Retriever->>KBManager: get_knowledge_base("假期政策", "leave_policy") KBManager->>KBManager: 检查知识库是否存在 alt 知识库不存在 KBManager->>VectorStore: get_vector_store("leave_policy") VectorStore-->>KBManager: 返回新知识库实例 KBManager->>KBManager: 存储到knowledge_bases字典 end KBManager-->>Retriever: 返回知识库实例 Note over Retriever: 步骤4: 检索 Retriever->>VectorStore: similarity_search_with_score("婚假", k=3) Note over VectorStore: 在"假期政策"知识库中<br/>使用重写后的查询检索 VectorStore-->>Retriever: 返回3个相关文档(带分数) Retriever->>Retriever: 构建Document对象<br/>添加元数据:<br/>- intent: 意图<br/>- original_query: 原始查询<br/>- rewritten_query: 重写查询 Retriever-->>RAG: 返回检索文档列表 Note over RAG: 步骤5: 生成答案 RAG->>RAG: 提取意图和重写查询 RAG->>RAG: 整合文档内容为上下文 RAG->>PromptTemplate: format(context, query, intent, rewritten_query) PromptTemplate-->>RAG: 返回完整prompt RAG->>LLM: invoke(prompt) Note over LLM: 基于检索结果生成答案 LLM-->>RAG: 返回答案内容 RAG-->>User: 返回结果字典<br/>(query, intent, rewritten_query, answer, documents)
6.2.4.2 意图识别详细流程 #
sequenceDiagram participant Retriever as RetrievalGenerationIntegrationRetriever participant Classifier as IntentClassifier participant Template as PromptTemplate participant LLM as BaseLanguageModel Note over Retriever: classify(query) Classifier->>Classifier: 构建分类prompt模板 Note over Classifier: "请判断以下用户查询属于哪个类别。<br/>可用类别:{categories}<br/>用户查询:{query}" Classifier->>Classifier: 拼接意图类别字符串<br/>"假期政策、调岗流程、薪资福利、其他" Classifier->>Template: format(query=用户查询, categories=类别字符串) Template-->>Classifier: 返回完整prompt Classifier->>LLM: invoke(prompt) Note over LLM: 识别查询意图<br/>返回类别名称 LLM-->>Classifier: 返回意图类别文本 Classifier->>Classifier: 检查返回文本是否包含目标类别 loop 遍历所有意图类别 alt 返回文本包含该类别 Classifier-->>Retriever: 返回该意图类别 end end alt 未匹配到任何类别 Classifier-->>Retriever: 返回第一个类别(默认) end
6.2.4.3 查询重写详细流程 #
sequenceDiagram participant Retriever as RetrievalGenerationIntegrationRetriever participant Rewriter as QueryRewriter participant LLM as BaseLanguageModel Note over Retriever: rewrite(query, intent) Retriever->>Rewriter: rewrite("我下个月要结婚,能请几天假?", "假期政策") Rewriter->>Rewriter: 获取意图对应的重写策略<br/>rewriting_strategies["假期政策"] Note over Rewriter: 策略: "请从以下查询中提取假期类型<br/>(如年假、病假、产假、婚假等)<br/>的标准术语,然后生成检索查询。" alt 策略包含{query}占位符 Rewriter->>Rewriter: format(query=用户查询) else 策略不包含占位符 Rewriter->>Rewriter: 拼接策略和查询 end Rewriter->>LLM: invoke(prompt) Note over LLM: 将口语化查询转换为标准术语<br/>"我下个月要结婚,能请几天假?"<br/>→ "婚假" LLM-->>Rewriter: 返回重写后的查询 alt 重写结果为空 Rewriter-->>Retriever: 返回原始查询(降级处理) else 重写结果非空 Rewriter-->>Retriever: 返回重写后的查询 end
6.2.4.4 知识库路由与检索详细流程 #
sequenceDiagram participant Retriever as RetrievalGenerationIntegrationRetriever participant KBManager as KnowledgeBaseManager participant VectorStore as VectorStore Note over Retriever: 知识库路由与检索 Retriever->>Retriever: 根据意图获取集合名<br/>intent_to_collection["假期政策"] = "leave_policy" Retriever->>KBManager: get_knowledge_base("假期政策", "leave_policy") KBManager->>KBManager: 检查knowledge_bases字典<br/>是否存在"假期政策"知识库 alt 知识库不存在 KBManager->>VectorStore: get_vector_store(<br/>persist_directory="chroma_db",<br/>collection_name="leave_policy") Note over VectorStore: 创建新的向量存储实例<br/>使用指定的集合名 VectorStore-->>KBManager: 返回知识库实例 KBManager->>KBManager: 存储到knowledge_bases["假期政策"] else 知识库已存在 KBManager->>KBManager: 直接从knowledge_bases获取 end KBManager-->>Retriever: 返回知识库实例 Retriever->>VectorStore: similarity_search_with_score("婚假", k=3) Note over VectorStore: 在"leave_policy"集合中<br/>使用重写后的查询检索 VectorStore-->>Retriever: 返回3个相关文档(带分数) Retriever->>Retriever: 构建Document对象<br/>添加元数据:<br/>- intent: "假期政策"<br/>- original_query: "我下个月要结婚,能请几天假?"<br/>- rewritten_query: "婚假" Retriever-->>Retriever: 返回检索文档列表

6.2.5 关键设计要点 #

1. 一体化流程

用户查询 (口语化)
    ↓
意图识别 → 意图类别
    ↓
查询重写 → 标准术语查询
    ↓
知识库路由 → 选择对应知识库
    ↓
检索 → 在对应知识库中检索
    ↓
生成答案 → 基于检索结果生成

2. 意图识别示例

用户查询: "我下个月要结婚,能请几天假?"
    ↓
意图识别: "假期政策"
    ↓
查询重写: "婚假"
    ↓
知识库路由: "leave_policy"
    ↓
检索结果: 婚假政策相关文档

3. 多知识库架构

  • 每个意图对应一个独立的知识库
  • 使用不同的集合名(collection_name)区分
  • 知识库管理器统一管理
  • 按需创建,避免重复初始化

知识库结构:

chroma_db/
  ├── leave_policy/          (假期政策知识库)
  ├── transfer_process/      (调岗流程知识库)
  ├── salary_benefits/       (薪资福利知识库)
  └── default/               (其他知识库)

4. 查询重写策略

不同意图使用不同的重写策略:

  • 假期政策:提取假期类型(年假、病假、产假、婚假等)
  • 调岗流程:提取调岗相关的标准术语和流程关键词
  • 薪资福利:提取薪资福利相关的标准术语
  • 其他:通用转换策略

示例:

原始查询: "我下个月要结婚,能请几天假?"
意图: "假期政策"
重写策略: "请从以下查询中提取假期类型..."
重写结果: "婚假"

5. 元数据设计

检索返回的 Document 对象包含:

  • score:相似度分数(距离,越小越相似)
  • intent:识别的意图类别
  • original_query:原始用户查询
  • rewritten_query:重写后的查询
  • 继承原始文档的元数据(topic、type 等)

6. 答案生成模板

默认答案生成模板:

已知以下相关信息:

{context}

原始查询:{query}
识别意图:{intent}
重写查询:{rewritten_query}

请根据上述信息,准确全面地回答原始查询。

答案:

该模板:

  • 包含所有检索文档作为上下文
  • 明确列出原始查询、意图和重写查询
  • 要求准确全面地回答原始查询

6.2.6 优势与应用场景 #

优势:

  • 意图驱动:根据意图选择对应的知识库,提高检索精度
  • 查询优化:将口语化问题转换为标准术语,提高匹配率
  • 知识库隔离:不同意图使用独立知识库,避免干扰
  • 可扩展性:易于添加新的意图类别和知识库

适用场景:

  • 企业知识库:HR政策、流程文档等
  • 多领域问答:不同领域使用不同知识库
  • 专业场景:需要精确匹配特定领域知识
  • 口语化查询:需要将自然语言转换为标准术语

6.2.7 技术细节 #

  • 意图识别:
    • 使用 LLM 进行分类
    • 检查返回文本是否包含目标类别
    • 默认返回第一个类别
  • 查询重写:
    • 根据意图使用对应的重写策略
    • 支持自定义策略模板
    • 重写失败时降级为原始查询
  • 知识库管理:
    • 使用字典缓存已创建的知识库
    • 按需创建,避免重复初始化
    • 支持动态添加文档

该设计通过“意图驱动+多知识库”策略,在特定领域场景下提供更精确的检索和答案生成,适用于需要意图识别和知识库路由的 RAG 应用。

7. 上下文对话(Contextual Dialogue) #

在许多真实场景中,用户的查询并不是一次性的,而是涉及多轮、上下文相关的对话。例如,用户第一次询问“我手机坏了怎么办?”,接下来又问“保修期还没过,怎么处理?”——后者的理解依赖于前文信息。传统检索流程对每次输入都“独立处理”,难以充分利用上下文已知的信息来精准理解当前问题、缩小检索范围或者消歧。

上下文对话(Contextual Dialogue)系统的主要任务,是结合对话历史,智能解析和补全当前用户的意图。例如:

  1. 指代消解
    当用户问“它还能返厂维修吗?”时,系统应自动分析出“它”指代前文提到的具体产品,并据此生成准确的检索或回答。

  2. 缺省信息补全
    如果用户前面已经说明产品型号、事件等,后面的问句出现省略(如直接问“换电池多少钱?”),系统可自动补充上下文条件,避免重复追问。

  3. 多轮追问/比较
    对于“那与隔壁品牌的比,价格贵吗?”、“两者维修周期哪个短?”等对话,需要系统抓取对比对象,理解推理链路,综合分析。

为提升RAG系统在多轮对话场景下的表现,常见技术实践包括:

  • 对话历史记录与回溯:为每位用户维护一份完整的会话轮次记录,并将历史摘要与当前输入共同作为检索和生成的条件。
  • 指代消解与实体跟踪:使用NLP技术自动判断“他/它/他们/这种/这里/那个”等指代词,结合历史找到确切指向。
  • 上下文增强的检索重写:对当前用户输入,通过LLM或专用重写组件,将含糊/不全的提问扩展为明确、完整的检索查询。
  • 多轮信息整合生成:生成阶段基于综合的历史+检索片段,回答不仅准确,还符合连续对话语境。

举例说明:

  • 用户A:

    1. “我买了你们的Alpha智能手表,但好像坏了。”
    2. “是去年买的。”
    3. “保修期还有多久?”
      上述第3问,需要算法自动整合用户第1轮的产品信息和第2轮的购买时间,形成完整的条件(如“2023年购买的Alpha智能手表保修期还有多久”),再进行知识库检索和答案生成。
  • 用户B:

    1. “我在看Beta蓝牙耳机和Gamma智能音箱,想了解一下售后政策。”
    2. “主要是保修时间。”
    3. “哪个产品保修期更长?”
      第3问属于“对比型”,系统要抓出比较对象并准确回答。

应用难点与要点:

  • 上下文解析的准确性直接影响检索质量,特别是在指代、歧义、信息不全时。
  • 在RAG流程中,应让检索步骤和生成步骤都能访问到“上下文融合”的综合查询。
  • 实际工程中,LangChain等框架提供了类似“ConversationMemory”、“ChatPromptTemplate”等用于上下文对话处理的工具。

综上,支持上下文对话是RAG从“检索-生成”到“智能陪伴体”的关键一步,是企业知识助手、客户服务机器人等场景不可或缺的技术模块。下面将展示一个典型的多轮上下文QA对话的代码样例。

7.1 ContextualDialogue.py #

# 导入List、Dict、Any、Optional类型用于类型注解
from typing import List, Dict, Any, Optional

# 导入pydantic用于配置数据模型
from pydantic import ConfigDict

# 导入基础检索器基类
from langchain_core.retrievers import BaseRetriever

# 导入文档对象类型
from langchain_core.documents import Document

# 导入基础LLM类型
from langchain_core.language_models import BaseLanguageModel

# 导入提示词模板类
from langchain_core.prompts import PromptTemplate

# 导入自定义llm对象
from llm import llm

# 导入获取向量存储的函数
from vector_store import get_vector_store

# 导入正则模块用于解析query字符串
import re

# 定义对话历史记录管理器类
class DialogueHistory:
    """对话历史管理器:管理多轮对话记录"""

    # 初始化函数,创建空的历史列表
    def __init__(self):
        self.history: List[Dict[str, str]] = []

    # 添加一轮对话,参数为角色和内容
    def add_turn(self, role: str, content: str):
        """添加一轮对话"""
        self.history.append({"role": role, "content": content})

    # 获取对话历史记录的文本串,格式为"角色:内容"
    def get_history_text(self) -> str:
        """获取对话历史的文本表示"""
        lines = []
        # 遍历每一轮对话,拼接字符串
        for turn in self.history:
            lines.append(f"{turn['role']}:{turn['content']}")
        # 用换行符拼接所有轮次
        return "\n".join(lines)

    # 清空对话历史
    def clear(self):
        """清空对话历史"""
        self.history = []

# 定义Query归纳Agent类
class QueryGenerationAgent:
    """Agent:分析对话历史和最新问题,生成检索query列表"""

    # 初始化函数,配置llm和模板
    def __init__(self, llm: BaseLanguageModel, query_generation_template: Optional[str] = None):
        # 保存llm对象
        self.llm = llm
        # 若未指定模板则用默认模板
        template = query_generation_template or (
            "请根据以下多轮对话内容和用户最新问题,智能归纳出1到多条适合知识库检索的query。\n"
            "要求:\n"
            "1. 结合历史对话补全关键信息(如产品名称、时间等)\n"
            "2. 识别模糊指代(如'它们'、'这个'等)的具体对象\n"
            "3. 对于对比型问题,为每个对象生成独立的query\n"
            "4. 对于多意图问题,为每个意图生成独立的query\n"
            "5. 将反问型问题转换为正面询问\n"
            "6. 将条件型问题转化为检索参数\n\n"
            "历史对话:\n{history}\n\n"
            "用户最新问题:{latest_question}\n\n"
            "请输出所有应检索的query,每行一条,不要编号:"
        )
        # 创建提示词模板对象
        self.query_generation_template = PromptTemplate(
            input_variables=["history", "latest_question"], template=template
        )

    # 根据历史和最新提问生成query
    def generate_queries(self, history: str, latest_question: str) -> List[str]:
        """根据对话历史和最新问题生成检索query列表"""
        # 填充模板得到prompt
        prompt = self.query_generation_template.format(history=history, latest_question=latest_question)
        # 用llm生成结果,去除首尾空白
        response = self.llm.invoke(prompt).content.strip()

        # 解析生成的query字符串列表
        queries = []
        # 按行处理每个生成的查询
        for line in response.split('\n'):
            # 去除两端空白字符
            line = line.strip()
            # 移除行首编号(如"1. "、"1、"等)
            line = re.sub(r'^\d+[\.、]\s*', '', line)
            # 移除列表标记(如"- "、"• "等)
            line = re.sub(r'^[-•]\s*', '', line)
            # 若不为空则加入列表
            if line:
                queries.append(line)

        # 如果没解析出结果,则返回原始问题
        if not queries:
            queries = [latest_question]

        return queries

# 定义上下文对话检索器类
class ContextualDialogueRetriever(BaseRetriever):
    """上下文对话检索器:结合对话历史进行检索"""

    # 定义成员:向量数据库对象
    vector_store: Any
    # 定义成员:Query归纳Agent
    agent: QueryGenerationAgent
    # 定义成员:默认返回文档数量
    k: int = 4
    # 配置pydantic模型,允许任意类型参数
    model_config = ConfigDict(arbitrary_types_allowed=True)

    # 根据输入问题和历史,返回相关文档列表
    def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
        """根据查询和对话历史检索相关文档"""
        # 获取k值,若未给则用默认
        k = kwargs.get("k", self.k)
        # 获取对话历史文本
        history = kwargs.get("history", "")

        # 若有历史,则生成queries,否则直接用query
        if history:
            queries = self.agent.generate_queries(history, query)
        else:
            queries = [query]

        # 保存所有检索到的文档
        all_docs = []
        # 遍历每个query进行向量检索
        for q in queries:
            # 用向量库检索相似文档和分数
            docs = self.vector_store.similarity_search_with_score(q, k=k)
            # 遍历每一个文档和分数,组装为Document对象并记录meta信息
            for doc, distance in docs:
                all_docs.append(Document(
                    page_content=doc.page_content,
                    metadata={
                        **doc.metadata,
                        "score": float(distance),
                        "generated_query": q,
                        "original_query": query
                    }
                ))

        # 去重:同一文档内容只保留分数最低(最相关)的那条
        unique_docs = {}
        for doc in all_docs:
            content = doc.page_content
            score = doc.metadata.get("score", float('inf'))
            # 如果文档没见过,或更优分数,更新
            if content not in unique_docs or score < unique_docs[content].metadata.get("score", float('inf')):
                unique_docs[content] = doc

        # 按分数升序排序,分数越低越相关
        result_docs = sorted(unique_docs.values(), key=lambda x: x.metadata.get("score", float('inf')))

        return result_docs

# 定义上下文对话RAG主系统类
class ContextualDialogueRAG:
    """上下文对话RAG系统:支持多轮对话的智能问答"""

    # 初始化函数,注入检索器和llm,并可选注入答案模板
    def __init__(self, retriever: ContextualDialogueRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str] = None):
        self.retriever = retriever
        self.llm = llm
        # 创建对话历史管理对象
        self.dialogue_history = DialogueHistory()

        # 设置回答prompt模板
        template = answer_prompt_template or (
            "已知以下相关信息:\n\n{context}\n\n"
            "对话历史:\n{history}\n\n"
            "用户最新问题:{query}\n\n"
            "请根据上述信息,结合对话历史,准确全面地回答用户最新问题。\n\n答案:"
        )
        # 创建模板对象
        self.answer_prompt_template = PromptTemplate(
            input_variables=["context", "history", "query"], template=template
        )

    # 一轮对话主逻辑
    def chat(self, user_query: str, k: int = 4) -> Dict[str, Any]:
        """处理一轮对话"""
        # 获取历史对话文本
        history_text = self.dialogue_history.get_history_text()
        # 检索相关文档(结合历史)
        docs = self.retriever._get_relevant_documents(user_query, k=k, history=history_text)
        # 从文档里提取所有本轮生成的query(去重)
        generated_queries = list(set([doc.metadata.get("generated_query", "") for doc in docs if doc.metadata.get("generated_query")]))
        # 构建上下文字符串,每个文档单独展示
        context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
        # 用模板生成prompt,注意若历史为空显示“无历史对话”
        prompt = self.answer_prompt_template.format(
            context=context,
            history=history_text if history_text else "(无历史对话)",
            query=user_query
        )
        # 用llm得到答案,去尾部空白
        answer = self.llm.invoke(prompt).content.strip()
        # 对话历史添加用户提问
        self.dialogue_history.add_turn("用户", user_query)
        # 对话历史添加AI回答
        self.dialogue_history.add_turn("AI", answer)
        # 返回本轮的全部信息
        return {
            "user_query": user_query,
            "answer": answer,
            "generated_queries": generated_queries,
            "retrieved_documents": docs,
            "num_documents": len(docs)
        }

    # 清空历史
    def clear_history(self):
        """清空对话历史"""
        self.dialogue_history.clear()

# ---------------- 以下为组件初始化和样例 ----------------

# 创建Query归纳Agent
agent = QueryGenerationAgent(llm)

# 创建向量数据库对象
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="contextual_dialogue"
)

# 创建上下文对话检索器
retriever = ContextualDialogueRetriever(
    vector_store=vector_store,
    agent=agent,
    k=3
)

# 创建上下文对话RAG系统
rag_system = ContextualDialogueRAG(retriever=retriever, llm=llm)

# 示例产品知识库文档
documents = [
    "Alpha智能手表:2022年购买的产品,保修期为2年,支持无线充电,颜色有黑色和白色。",
    "Beta蓝牙耳机:保修期为1年,支持无线充电,颜色有黑色、白色和红色。",
    "Gamma智能音箱:保修期为3年,不支持无线充电,颜色有黑色和白色。",
    "Delta笔记本电脑:保修期为2年,不支持无线充电,颜色有黑色和银色。",
    "Epsilon平板电脑:保修期为1年,支持无线充电,颜色有黑色、白色和金色。",
    "Zeta无线充电器:保修期为1年,颜色有黑色和白色,支持多种设备充电。"
]

# 索引文档到向量库
print("正在初始化知识库...")
vector_store.add_texts(
    documents,
    metadatas=[{"topic": "产品信息"}] * len(documents)
)
print("知识库初始化完成!\n")

# ------------------- 示例1:上下文依赖对话 -------------------
print("="*60)
print("示例1:上下文依赖型对话")
print("="*60)
# 清空历史
rag_system.clear_history()

# 第一轮---智能手表问题
result1 = rag_system.chat("我最近买了个智能手表。")
print(f"用户:{result1['user_query']}")
print(f"AI:{result1['answer']}\n")

# 第二轮---补充产品关键信息
result2 = rag_system.chat("是Alpha智能手表,去年买的。")
print(f"用户:{result2['user_query']}")
print(f"AI:{result2['answer']}\n")

# 第三轮---上下文依赖查询,提及保修期
result3 = rag_system.chat("保修期还有多久?")
print(f"用户:{result3['user_query']}")
print(f"生成的检索query:{result3['generated_queries']}")
print(f"AI:{result3['answer']}\n")

# ------------------- 示例2:对比型对话 -------------------
print("\n" + "="*60)
print("示例2:对比型对话")
print("="*60)
# 清空历史
rag_system.clear_history()

result1 = rag_system.chat("我在看Beta蓝牙耳机和Gamma智能音箱,想了解一下售后政策。")
print(f"用户:{result1['user_query']}")
print(f"AI:{result1['answer']}\n")

result2 = rag_system.chat("主要是保修时间。")
print(f"用户:{result2['user_query']}")
print(f"AI:{result2['answer']}\n")

result3 = rag_system.chat("哪个产品保修期更长?")
print(f"用户:{result3['user_query']}")
print(f"生成的检索query:{result3['generated_queries']}")
print(f"AI:{result3['answer']}\n")

# ------------------- 示例3:模糊指代对话 -------------------
print("\n" + "="*60)
print("示例3:模糊指代型对话")
print("="*60)
# 清空历史
rag_system.clear_history()

result1 = rag_system.chat("我买了你们家的笔记本电脑和平板电脑。")
print(f"用户:{result1['user_query']}")
print(f"AI:{result1['answer']}\n")

result2 = rag_system.chat("是Delta笔记本电脑和Epsilon平板电脑。")
print(f"用户:{result2['user_query']}")
print(f"AI:{result2['answer']}\n")

result3 = rag_system.chat("它们都支持无线充电吗?")
print(f"用户:{result3['user_query']}")
print(f"生成的检索query:{result3['generated_queries']}")
print(f"AI:{result3['answer']}\n")

7.2 执行过程 #

7.2.1 核心思想 #

上下文对话采用“历史感知+智能查询生成”策略:

  • 对话历史管理:维护多轮对话记录
  • 智能查询生成:根据对话历史和最新问题,使用 Agent 生成适合检索的 query 列表
  • 多查询检索:对每个生成的 query 分别进行检索
  • 结果合并:合并所有检索结果并去重
  • 上下文生成:基于检索结果和对话历史生成答案

7.2.2 执行流程 #

阶段一:初始化

# 1. 创建Query归纳Agent
agent = QueryGenerationAgent(llm)

# 2. 创建向量存储实例
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="contextual_dialogue"
)

# 3. 创建上下文对话检索器
retriever = ContextualDialogueRetriever(
    vector_store=vector_store,
    agent=agent,
    k=3
)

# 4. 创建上下文对话RAG系统
rag_system = ContextualDialogueRAG(retriever=retriever, llm=llm)

初始化时:

  • 创建 Query 归纳 Agent,用于生成检索 query
  • 创建向量存储实例
  • 创建检索器,配置 Agent 和默认 k 值
  • 创建 RAG 系统,内部包含对话历史管理器

阶段二:知识库初始化

# 添加文档到向量库
vector_store.add_texts(
    documents,
    metadatas=[...]
)

索引过程:

  • 将示例文档添加到向量库
  • 每个文档附带元数据(topic)
  • 文档会被向量化并存储

阶段三:多轮对话处理

# 第一轮对话
result1 = rag_system.chat("我最近买了个智能手表。")

# 第二轮对话
result2 = rag_system.chat("是Alpha智能手表,去年买的。")

# 第三轮对话(上下文依赖)
result3 = rag_system.chat("保修期还有多久?")

完整流程(每轮对话):

  1. 用户提交查询
  2. 获取对话历史:
    • 从对话历史管理器获取历史文本
  3. 智能查询生成:
    • 如果有历史,调用 agent.generate_queries(history, query)
    • 使用 LLM 根据历史和最新问题生成检索 query 列表
    • 如果没有历史,直接使用原始查询
  4. 多查询检索:
    • 对每个生成的 query 分别进行向量检索
    • 为每个检索结果添加元数据(generated_query、original_query)
  5. 结果合并:
    • 按文档内容去重,保留最佳分数
    • 按分数排序
  6. 生成答案:
    • 整合检索文档作为上下文
    • 使用提示词模板构建最终 prompt(包含上下文、历史、查询)
    • LLM 生成答案
  7. 更新历史:
    • 将用户查询添加到历史
    • 将 AI 回答添加到历史
  8. 返回结果:
    • 包含用户查询、答案、生成的 query 列表、检索文档等

7.2.3 类图 #

classDiagram class DialogueHistory { -history: List[Dict[str, str]] +__init__() +add_turn(role: str, content: str) +get_history_text() str +clear() } class QueryGenerationAgent { -llm: BaseLanguageModel -query_generation_template: PromptTemplate +__init__(llm, query_generation_template) +generate_queries(history: str, latest_question: str) List[str] } class ContextualDialogueRetriever { -vector_store: Any -agent: QueryGenerationAgent -k: int +_get_relevant_documents(query: str, **kwargs) List[Document] } class ContextualDialogueRAG { -retriever: ContextualDialogueRetriever -llm: BaseLanguageModel -dialogue_history: DialogueHistory -answer_prompt_template: PromptTemplate +__init__(retriever, llm, answer_prompt_template) +chat(user_query: str, k: int) Dict[str, Any] +clear_history() } class BaseRetriever { <<abstract>> +invoke(query: str) List[Document] } class PromptTemplate { +format(**kwargs) str } class Document { +page_content: str +metadata: Dict } class VectorStore { <<interface>> +add_texts(texts: List[str], metadatas: List[Dict]) +similarity_search_with_score(query: str, k: int) List[Tuple[Document, float]] } class BaseLanguageModel { <<interface>> +invoke(prompt: str) AIMessage } ContextualDialogueRetriever --|> BaseRetriever ContextualDialogueRAG --> ContextualDialogueRetriever ContextualDialogueRAG --> BaseLanguageModel ContextualDialogueRAG --> DialogueHistory ContextualDialogueRetriever --> QueryGenerationAgent ContextualDialogueRetriever --> VectorStore QueryGenerationAgent --> BaseLanguageModel QueryGenerationAgent --> PromptTemplate ContextualDialogueRetriever ..> Document : creates VectorStore ..> Document : returns

7.2.4 时序图 #

7.2.4.1 完整多轮对话流程时序图 #
sequenceDiagram participant User as 用户 participant RAG as ContextualDialogueRAG participant History as DialogueHistory participant Retriever as ContextualDialogueRetriever participant Agent as QueryGenerationAgent participant VectorStore as VectorStore participant LLM as BaseLanguageModel Note over User,RAG: 第一轮对话 User->>RAG: chat("我最近买了个智能手表。") RAG->>History: get_history_text() History-->>RAG: 返回空历史(第一轮) RAG->>Retriever: _get_relevant_documents(query, k=3, history="") Note over Retriever: 无历史,直接使用原始查询 Retriever->>Retriever: queries = [原始查询] Retriever->>VectorStore: similarity_search_with_score(原始查询, k=3) VectorStore-->>Retriever: 返回3个文档(带分数) Retriever->>Retriever: 构建Document对象<br/>添加元数据 Retriever-->>RAG: 返回检索文档列表 RAG->>RAG: 整合文档内容为上下文 RAG->>PromptTemplate: format(context, history="(无历史对话)", query) PromptTemplate-->>RAG: 返回完整prompt RAG->>LLM: invoke(prompt) LLM-->>RAG: 返回答案 RAG->>History: add_turn("用户", "我最近买了个智能手表。") RAG->>History: add_turn("AI", 答案) RAG-->>User: 返回结果 Note over User,RAG: 第二轮对话 User->>RAG: chat("是Alpha智能手表,去年买的。") RAG->>History: get_history_text() History-->>RAG: 返回历史文本<br/>"用户:我最近买了个智能手表。\nAI:..." RAG->>Retriever: _get_relevant_documents(query, k=3, history=历史文本) Note over Retriever: 有历史,生成智能查询 Retriever->>Agent: generate_queries(历史, 最新问题) Agent->>Agent: 构建查询生成prompt Agent->>LLM: invoke(prompt) Note over LLM: 根据历史和最新问题<br/>生成检索query列表 LLM-->>Agent: 返回query列表文本 Agent->>Agent: 解析query列表<br/>(去除编号、列表标记) Agent-->>Retriever: 返回query列表 loop 对每个生成的query Retriever->>VectorStore: similarity_search_with_score(query, k=3) VectorStore-->>Retriever: 返回3个文档(带分数) Retriever->>Retriever: 构建Document对象<br/>添加元数据 end Retriever->>Retriever: 去重合并<br/>按分数排序 Retriever-->>RAG: 返回合并后的文档列表 RAG->>RAG: 整合文档内容为上下文 RAG->>PromptTemplate: format(context, history=历史文本, query) PromptTemplate-->>RAG: 返回完整prompt RAG->>LLM: invoke(prompt) LLM-->>RAG: 返回答案 RAG->>History: add_turn("用户", "是Alpha智能手表,去年买的。") RAG->>History: add_turn("AI", 答案) RAG-->>User: 返回结果
7.2.4.2 智能查询生成详细流程 #
sequenceDiagram participant Retriever as ContextualDialogueRetriever participant Agent as QueryGenerationAgent participant Template as PromptTemplate participant LLM as BaseLanguageModel Note over Retriever: generate_queries(history, latest_question) Retriever->>Agent: generate_queries(历史文本, 最新问题) Agent->>Template: format(history=历史, latest_question=最新问题) Note over Template: 模板要求:<br/>1. 结合历史补全关键信息<br/>2. 识别模糊指代<br/>3. 对比型问题生成多个query<br/>4. 多意图问题生成多个query<br/>5. 反问转正面询问<br/>6. 条件型问题转检索参数 Template-->>Agent: 返回完整prompt Agent->>LLM: invoke(prompt) Note over LLM: 生成检索query列表<br/>每行一条,不要编号 LLM-->>Agent: 返回query列表文本 Agent->>Agent: 按行分割响应文本 loop 遍历每一行 Agent->>Agent: 去除两端空白 Agent->>Agent: 移除行首编号<br/>(如"1. "、"1、"等) Agent->>Agent: 移除列表标记<br/>(如"- "、"• "等) alt 行非空 Agent->>Agent: 添加到query列表 end end alt 未解析出任何query Agent-->>Retriever: 返回[原始问题](降级处理) else 解析出query列表 Agent-->>Retriever: 返回query列表<br/>[query1, query2, ...] end
7.2.4.3 多查询检索与合并详细流程 #
sequenceDiagram participant Retriever as ContextualDialogueRetriever participant VectorStore as VectorStore Note over Retriever: 多查询检索与合并 Retriever->>Retriever: 获取生成的query列表<br/>[query1, query2, query3, ...] loop 对每个生成的query Retriever->>VectorStore: similarity_search_with_score(query, k=3) Note over VectorStore: 对当前query进行<br/>向量相似度检索 VectorStore-->>Retriever: 返回3个文档(带分数) loop 对每个检索结果 Retriever->>Retriever: 构建Document对象<br/>添加元数据:<br/>- score: 分数<br/>- generated_query: 生成的query<br/>- original_query: 原始查询 Retriever->>Retriever: 添加到all_docs列表 end end Note over Retriever: 现在all_docs包含所有query的检索结果<br/>(可能有重复文档) Retriever->>Retriever: 去重逻辑 Note over Retriever: 使用文档内容作为键 loop 遍历所有文档 Retriever->>Retriever: 获取文档内容作为键 alt 文档首次出现 Retriever->>Retriever: 添加到unique_docs字典 else 文档已存在 alt 当前分数 < 已有分数 Retriever->>Retriever: 更新为当前文档<br/>(保留更好的分数) end end end Retriever->>Retriever: 按分数升序排序 Retriever-->>Retriever: 返回去重并排序后的文档列表

7.2.5 关键设计要点 #

1. 上下文对话流程

用户查询
    ↓
获取对话历史
    ↓
智能查询生成 (Agent)
    ↓
多查询检索
    ↓
结果合并去重
    ↓
生成答案 (结合历史)
    ↓
更新对话历史

2. 智能查询生成示例

场景1:上下文依赖

历史对话:
用户:我最近买了个智能手表。
AI:...

最新问题: "保修期还有多久?"

生成的query: ["Alpha智能手表保修期"]

场景2:对比型问题

历史对话:
用户:我在看Beta蓝牙耳机和Gamma智能音箱,想了解一下售后政策。
AI:...

最新问题: "哪个产品保修期更长?"

生成的query: ["Beta蓝牙耳机保修期", "Gamma智能音箱保修期"]

场景3:模糊指代

历史对话:
用户:我买了你们家的笔记本电脑和平板电脑。
用户:是Delta笔记本电脑和Epsilon平板电脑。
AI:...

最新问题: "它们都支持无线充电吗?"

生成的query: ["Delta笔记本电脑无线充电", "Epsilon平板电脑无线充电"]

3. 对话历史管理

  • 历史格式:每轮对话以"角色:内容"格式存储
  • 历史文本:用于查询生成和答案生成
  • 自动更新:每轮对话后自动添加用户查询和AI回答
  • 可清空:支持清空历史重新开始对话

历史示例:

用户:我最近买了个智能手表。
AI:很高兴听到您购买了智能手表...
用户:是Alpha智能手表,去年买的。
AI:Alpha智能手表是一款不错的产品...

4. 查询生成要求

根据提示词模板,Agent需要:

  1. 结合历史补全关键信息(如产品名称、时间等)
  2. 识别模糊指代(如"它们"、"这个"等)的具体对象
  3. 对于对比型问题,为每个对象生成独立的query
  4. 对于多意图问题,为每个意图生成独立的query
  5. 将反问型问题转换为正面询问
  6. 将条件型问题转化为检索参数

5. 元数据设计

检索返回的 Document 对象包含:

  • score:相似度分数(距离,越小越相似)
  • generated_query:生成的检索query
  • original_query:原始用户查询
  • 继承原始文档的元数据(topic 等)

6. 答案生成模板

默认答案生成模板:

已知以下相关信息:

{context}

对话历史:
{history}

用户最新问题:{query}

请根据上述信息,结合对话历史,准确全面地回答用户最新问题。

答案:

该模板:

  • 包含所有检索文档作为上下文
  • 包含完整的对话历史
  • 要求结合对话历史回答最新问题

7. 去重策略

  • 去重键:使用文档内容(page_content)作为唯一标识
  • 分数选择:如果同一文档被多个query检索到,保留最佳分数(距离最小)
  • 排序:去重后按分数升序排序,最相关的文档在前

7.2.6 优势与应用场景 #

优势:

  • 上下文感知:结合对话历史理解用户意图
  • 智能查询生成:自动补全关键信息、识别指代
  • 多查询检索:提高检索覆盖率和准确性
  • 自然对话:支持多轮连续对话

适用场景:

  • 客服系统:需要理解上下文的多轮对话
  • 产品咨询:需要结合历史信息回答问题
  • 技术支持:需要追踪问题背景
  • 对话式搜索:需要理解用户意图的连续查询

7.2.7 技术细节 #

  • 查询解析:
    • 去除编号前缀(如 "1. "、"1、")
    • 去除列表标记(如 "- "、"• ")
    • 保留非空行作为query
  • 降级处理:
    • 如果未解析出任何query,使用原始问题
  • 历史处理:
    • 如果历史为空,显示"(无历史对话)"

该设计通过“历史感知+智能查询生成”策略,在多轮对话场景下提供更准确的检索和答案生成,适用于需要理解上下文和连续对话的 RAG 应用。

8. 行业场景改写(Industry scenario adaptation) #

行业场景适配(Industry scenario adaptation)是面向垂直行业RAG应用时的重要检索前优化手段。其核心思想是:针对不同行业(如教育、医疗、法律等),由于用户的提问习惯、行业术语表达方式、检索需求差异较大,在将用户query送入RAG检索链路之前,引入“行业适配器”对提问进行智能改写或结构化,从而大幅提升匹配的精准度和最终答案的专业性。

典型应用场景举例

  • 教育行业
    用户常用叙述性和生活化描述提出数学、物理等题目。适配器需从表面语境中抽取核心的数学关系,消除无关细节,使检索query精确指向题干本质。例如:
    “一个建筑工人每天能挖4米深的洞,挖了6天后还差3米就完成了,这个洞总共有多深?”
    适配后变为:“已知每天挖4米,挖6天后剩3米,求总长度。”

  • 医疗行业
    用户多为口语化描述身体症状,适配器需将其转换为标准医疗术语和专业表达:
    “我最近总是头疼,还有点恶心,这是什么病?”
    改写后:“症状:头痛、恶心,常见可能原因有哪些?”

  • 法律行业
    查询问题往往带有具体纠纷细节或情绪表述。适配器应抽取法律关系、要素,形成适合法规检索的结构化query:
    “我在网上买东西,商家不发货也不退钱,我该怎么办?”
    改写后:“网购交易,商家未履约且拒绝退款,涉及的法律条款与处理方式有哪些?”

技术机制与流程

  1. 行业场景管理器(IndustryScenarioManager)
    负责管理多行业的适配器,每种适配器实现行业独有的query重写策略。支持灵活注册和获取。

  2. 行业适配器(IndustryScenarioAdapter)
    基类定义统一接口,不同行业可灵活继承实现。内部常用LLM+PromptTemplate设计,可根据行业特点设计定制化Prompt。

  3. RAG检索前优化
    当有用户query和行业类型输入时,行业场景管理器自动将query交给对应的适配器进行改写,再送入后续检索和生成模块,使召回和理解更加精准。

代码样例摘要

以“教育、医疗、法律”三行业为例,重写器和RAG主流程如下——

# 初始化行业场景管理器(可拓展更多行业)
scenario_manager = IndustryScenarioManager()
scenario_manager.register_adapter("教育", EducationScenarioAdapter(llm))
scenario_manager.register_adapter("医疗", MedicalScenarioAdapter(llm))
scenario_manager.register_adapter("法律", LegalScenarioAdapter(llm))

# 查询改写流程
edu_query = "小明每天跑步5圈,跑了7天后还剩2圈,问总共要跑多少圈?"
edu_adapted = scenario_manager.adapt_query(edu_query, "教育")
print(f"教育场景改写后:{edu_adapted}")

medical_query = "我最近总是头疼,有时候还会恶心,这是什么病?"
medical_adapted = scenario_manager.adapt_query(medical_query, "医疗")
print(f"医疗场景改写后:{medical_adapted}")

legal_query = "我在网上买东西,商家不发货也不退钱,我该怎么办?"
legal_adapted = scenario_manager.adapt_query(legal_query, "法律")
print(f"法律场景改写后:{legal_adapted}")

技术优势与建议

  • 提升行业检索召回率与准确率:有针对性地处理行业术语、隐含要素、表达变体。
  • 极大降低后续生成压力:让生成任务聚焦在精要的检索片段上,减少业务歧义。
  • 方便扩展和行业微调:只需增加适用于新行业的适配器及prompt模板,即可复用整套RAG流程。

总结:行业场景适配是RAG系统“以术语为桥梁、以意图为导向”走向高质量产业落地的关键策略之一。建议在各类垂直知识库应用中作为标准前置优化组件使用,并结合行业专家和数据持续迭代优化适配prompt与重写逻辑。

8.1 IndustryScenarioAdaptation.py #

# 导入List、Dict、Any、Optional类型用于类型注解
from typing import List, Dict, Any, Optional

# 导入pydantic用于配置数据模型
from pydantic import ConfigDict

# 导入langchain核心模块中的基础检索器基类
from langchain_core.retrievers import BaseRetriever

# 导入langchain核心模块中的文档对象类型
from langchain_core.documents import Document

# 导入langchain核心模块中的基础LLM类型
from langchain_core.language_models import BaseLanguageModel

# 导入langchain核心模块中的提示词模板类
from langchain_core.prompts import PromptTemplate

# 导入自定义的llm对象
from llm import llm

# 导入自定义的获取向量存储的函数
from vector_store import get_vector_store

# 定义行业场景适配器基类:用于实现不同行业的query改写策略
class IndustryScenarioAdapter:
    """行业场景适配器基类:定义不同行业的query改写策略"""

    # 初始化适配器,保存llm对象、行业名和改写模板
    def __init__(self, llm: BaseLanguageModel, industry_name: str, rewriting_template: str):
        self.llm = llm
        self.industry_name = industry_name
        # 使用PromptTemplate来封装改写模板
        self.rewriting_template = PromptTemplate(
            input_variables=["query"], template=rewriting_template
        )

    # 行业相关的query改写方法
    def adapt_query(self, query: str) -> str:
        """根据行业特点改写query"""
        # 使用模板格式化query
        prompt = self.rewriting_template.format(query=query)
        # 调用llm生成内容并去除首尾空格
        adapted = self.llm.invoke(prompt).content.strip()
        # 若llm输出不为空则返回,否则返回原query
        return adapted if adapted else query

# 定义教育场景的适配器:用于数学题目表述抽取
class EducationScenarioAdapter(IndustryScenarioAdapter):
    """教育场景适配器:抽取题干,剥离表层信息,保留核心数学关系"""

    # 构造方法,传入llm对象并设置教育场景的改写模板
    def __init__(self, llm: BaseLanguageModel):
        template = (
            "请对以下题目进行题干抽取,剥离人物、情境、故事等表层信息,"
            "只保留关键的数学关系、物理关系或求解目标。\n\n"
            "原题目:{query}\n\n"
            "抽取后的题干(只保留核心关系和求解目标):"
        )
        super().__init__(llm, "教育", template)

# 定义医疗场景适配器:用于口语化表述的专业化转换
class MedicalScenarioAdapter(IndustryScenarioAdapter):
    """医疗场景适配器:将口语化描述转换为专业术语"""

    # 构造方法,传入llm对象并设置医疗场景的改写模板
    def __init__(self, llm: BaseLanguageModel):
        template = (
            "请将以下医疗相关的口语化描述转换为专业术语和标准表达,"
            "便于在医疗知识库中检索。\n\n"
            "用户描述:{query}\n\n"
            "专业术语查询:"
        )
        super().__init__(llm, "医疗", template)

# 定义法律场景适配器:用于法律问题关键词提取
class LegalScenarioAdapter(IndustryScenarioAdapter):
    """法律场景适配器:提取法条关键词和案例要素"""

    # 构造方法,传入llm对象并设置法律场景的改写模板
    def __init__(self, llm: BaseLanguageModel):
        template = (
            "请从以下法律问题中提取关键要素:法条类型、法律关系、争议焦点等,"
            "生成适合法律知识库检索的query。\n\n"
            "法律问题:{query}\n\n"
            "检索query(包含法条类型和关键要素):"
        )
        super().__init__(llm, "法律", template)

# 定义行业场景管理器,用于管理所有行业的适配器
class IndustryScenarioManager:
    """行业场景管理器:管理多个行业的适配器"""

    # 初始化,创建空的适配器字典
    def __init__(self):
        self.adapters: Dict[str, IndustryScenarioAdapter] = {}

    # 注册指定行业及其对应的适配器
    def register_adapter(self, industry: str, adapter: IndustryScenarioAdapter):
        """注册行业适配器"""
        self.adapters[industry] = adapter

    # 获取某行业的适配器,没有则返回None
    def get_adapter(self, industry: str) -> Optional[IndustryScenarioAdapter]:
        """获取指定行业的适配器"""
        return self.adapters.get(industry)

    # 对query进行行业相关改写
    def adapt_query(self, query: str, industry: str) -> str:
        """根据行业改写query"""
        adapter = self.get_adapter(industry)
        # 存在适配器则用适配器的方法,没有则原样返回query
        if adapter:
            return adapter.adapt_query(query)
        return query

# 定义行业场景检索器:用于行业相关的向量搜索检索
class IndustryScenarioRetriever(BaseRetriever):
    """行业场景检索器:根据行业特点进行检索"""

    # 声明成员变量类型:向量库对象
    vector_store: Any
    # 行业场景管理器
    scenario_manager: IndustryScenarioManager
    # 当前所属行业
    industry: str
    # 默认的检索文档数量
    k: int = 4
    # 允许任意类型参数的pydantic配置
    model_config = ConfigDict(arbitrary_types_allowed=True)

    # 获取与query相关的文档(行业适配方案下)
    def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
        """根据行业特点检索相关文档"""
        # 获取检索数量参数,优先用传入值
        k = kwargs.get("k", self.k)
        # 获取行业参数,优先用传入值
        industry = kwargs.get("industry", self.industry)

        # 先对query进行行业相关改写
        adapted_query = self.scenario_manager.adapt_query(query, industry)

        # 用改写后的query在向量库中检索
        docs = self.vector_store.similarity_search_with_score(adapted_query, k=k)

        # 遍历检索结果,每个结果添加附加元数据
        result_docs = []
        for doc, distance in docs:
            result_docs.append(Document(
                page_content=doc.page_content,
                metadata={
                    **doc.metadata,
                    "score": float(distance),
                    "industry": industry,
                    "original_query": query,
                    "adapted_query": adapted_query
                }
            ))

        # 返回封装后的文档对象列表
        return result_docs

# 定义行业场景RAG系统:集成query改写和行业检索
class IndustryScenarioRAG:
    """行业场景RAG系统:支持不同行业的query改写"""

    # 初始化RAG系统,配置检索器、llm和答案输出模板
    def __init__(self, retriever: IndustryScenarioRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str] = None):
        self.retriever = retriever
        self.llm = llm

        # 指定答案生成模板,如未指定则使用默认模板
        template = answer_prompt_template or (
            "已知以下相关信息:\n\n{context}\n\n"
            "用户问题:{query}\n"
            "行业场景:{industry}\n"
            "改写后的查询:{adapted_query}\n\n"
            "请根据上述信息,准确全面地回答用户问题。\n\n答案:"
        )
        # 初始化PromptTemplate对象
        self.answer_prompt_template = PromptTemplate(
            input_variables=["context", "query", "industry", "adapted_query"], template=template
        )

    # 生成答案的主入口
    def generate_answer(self, query: str, industry: str, k: int = 4) -> Dict[str, Any]:
        """生成答案"""
        # 检索与query与行业相关的文档
        docs = self.retriever._get_relevant_documents(query, industry=industry, k=k)

        # 提取第一个检索文档的adapted_query,否则使用原query
        adapted_query = docs[0].metadata.get("adapted_query", query) if docs else query

        # 拼接所有检索文档内容作为上下文
        context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])

        # 用答案模板组织final prompt
        prompt = self.answer_prompt_template.format(
            context=context,
            query=query,
            industry=industry,
            adapted_query=adapted_query
        )
        # 利用llm生成答案并去掉首尾空白
        answer = self.llm.invoke(prompt).content.strip()

        # 返回相关信息组成的结果字典
        return {
            "query": query,
            "industry": industry,
            "adapted_query": adapted_query,
            "answer": answer,
            "retrieved_documents": docs,
            "num_documents": len(docs)
        }

# =============================
# 以下为初始化与使用示例
# =============================

# 创建行业场景管理器
scenario_manager = IndustryScenarioManager()

# 注册教育、医疗、法律三类行业的适配器
scenario_manager.register_adapter("教育", EducationScenarioAdapter(llm))
scenario_manager.register_adapter("医疗", MedicalScenarioAdapter(llm))
scenario_manager.register_adapter("法律", LegalScenarioAdapter(llm))

# 创建向量存储对象,指定存储目录和集合名称
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="industry_scenario"
)

# 创建行业场景检索器,指定初始行业为教育,检索返回3条
retriever = IndustryScenarioRetriever(
    vector_store=vector_store,
    scenario_manager=scenario_manager,
    industry="教育",
    k=3
)

# 创建行业场景RAG系统
rag_system = IndustryScenarioRAG(retriever=retriever, llm=llm)

# 构造教育场景用的示例题目文档列表
education_documents = [
    "题目:工人挖桥洞问题。一个工人每天挖3米,挖了5天后还剩2米,问桥洞总长多少米?\n"
    "核心关系:已知每天工作量、工作天数、剩余量,求总量。\n"
    "解题方法:总量 = 每天工作量 × 工作天数 + 剩余量 = 3 × 5 + 2 = 17米",

    "题目:李白买酒问题。李白去买酒,每次买3升,买了4次后还剩1升,问总共需要多少升?\n"
    "核心关系:已知每次购买量、购买次数、剩余量,求总量。\n"
    "解题方法:总量 = 每次购买量 × 购买次数 + 剩余量 = 3 × 4 + 1 = 13升",

    "题目:小明存钱问题。小明每天存5元,存了6天后还剩10元,问总共需要存多少钱?\n"
    "核心关系:已知每天存钱数、存钱天数、剩余量,求总量。\n"
    "解题方法:总量 = 每天存钱数 × 存钱天数 + 剩余量 = 5 × 6 + 10 = 40元",

    "题目:速度时间路程问题。一辆车以60km/h的速度行驶,行驶了3小时后还剩120km,问总路程多少?\n"
    "核心关系:已知速度、时间、剩余路程,求总路程。\n"
    "解题方法:总路程 = 速度 × 时间 + 剩余路程 = 60 × 3 + 120 = 300km"
]

# 索引教育场景的示例文档到向量库
print("正在初始化知识库...")
vector_store.add_texts(
    education_documents,
    metadatas=[{"topic": "数学题目", "industry": "教育"}] * len(education_documents)
)
print("知识库初始化完成!\n")

# ========== 示例1:教育场景 ==========

# 打印分隔线
print("="*60)
print("示例1:教育场景 - 题目检索(题干抽取)")
print("="*60)

# 构造测试题目1,并调用RAG系统生成答案
query1 = "一个建筑工人每天能挖4米深的洞,挖了6天后还差3米就完成了,这个洞总共有多深?"
result1 = rag_system.generate_answer(query1, industry="教育", k=3)
print(f"原题目:{result1['query']}")
print(f"抽取后的题干:{result1['adapted_query']}")
print(f"生成的答案:\n{result1['answer']}")
print(f"检索到 {result1['num_documents']} 个相关文档\n")

# 构造测试题目2,并调用RAG系统生成答案
query2 = "诗人去商店买酒,每次买2瓶,买了5次后发现还需要3瓶,问总共要买多少瓶?"
result2 = rag_system.generate_answer(query2, industry="教育", k=3)
print(f"原题目:{result2['query']}")
print(f"抽取后的题干:{result2['adapted_query']}")
print(f"生成的答案:\n{result2['answer']}")
print(f"检索到 {result2['num_documents']} 个相关文档\n")

# ========== 示例2:展示不同行业的适配能力 ==========

# 打印分隔线
print("\n" + "="*60)
print("示例2:不同行业的query改写策略")
print("="*60)

# 教育场景下的自定义query测试
edu_query = "小明每天跑步5圈,跑了7天后还剩2圈,问总共要跑多少圈?"
edu_adapted = scenario_manager.adapt_query(edu_query, "教育")
print(f"教育场景:")
print(f"  原query:{edu_query}")
print(f"  改写后:{edu_adapted}\n")

# 医疗场景下的自定义query测试
medical_query = "我最近总是头疼,有时候还会恶心,这是什么病?"
medical_adapted = scenario_manager.adapt_query(medical_query, "医疗")
print(f"医疗场景:")
print(f"  原query:{medical_query}")
print(f"  改写后:{medical_adapted}\n")

# 法律场景下的自定义query测试
legal_query = "我在网上买东西,商家不发货也不退钱,我该怎么办?"
legal_adapted = scenario_manager.adapt_query(legal_query, "法律")
print(f"法律场景:")
print(f"  原query:{legal_query}")
print(f"  改写后:{legal_adapted}\n")

8.2 执行流程 #

8.2.1 核心思想 #

行业场景适配采用“行业特定改写”策略:

  • 适配器模式:为不同行业定义专门的查询改写策略
  • 行业管理器:统一管理多个行业的适配器
  • 查询改写:根据行业特点将用户查询改写为适合检索的形式
  • 行业检索:使用改写后的查询在知识库中检索
  • 答案生成:基于检索结果和行业信息生成答案

8.2.2 执行流程 #

阶段一:初始化

# 1. 创建行业场景管理器
scenario_manager = IndustryScenarioManager()

# 2. 注册各行业的适配器
scenario_manager.register_adapter("教育", EducationScenarioAdapter(llm))
scenario_manager.register_adapter("医疗", MedicalScenarioAdapter(llm))
scenario_manager.register_adapter("法律", LegalScenarioAdapter(llm))

# 3. 创建向量存储实例
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="industry_scenario"
)

# 4. 创建行业场景检索器
retriever = IndustryScenarioRetriever(
    vector_store=vector_store,
    scenario_manager=scenario_manager,
    industry="教育",
    k=3
)

# 5. 创建行业场景RAG系统
rag_system = IndustryScenarioRAG(retriever=retriever, llm=llm)

初始化时:

  • 创建行业场景管理器
  • 为教育、医疗、法律三个行业注册适配器
  • 创建向量存储实例
  • 创建检索器,配置场景管理器和默认行业
  • 创建RAG系统

阶段二:知识库初始化

# 添加教育场景的示例文档到向量库
vector_store.add_texts(
    education_documents,
    metadatas=[{"topic": "数学题目", "industry": "教育"}] * len(education_documents)
)

索引过程:

  • 将教育场景的示例文档添加到向量库
  • 每个文档附带元数据(topic、industry)
  • 文档会被向量化并存储

阶段三:查询处理

query = "一个建筑工人每天能挖4米深的洞,挖了6天后还差3米就完成了,这个洞总共有多深?"
result = rag_system.generate_answer(query, industry="教育", k=3)

完整流程:

  1. 用户提交查询并指定行业
  2. 行业适配改写:
    • 调用 scenario_manager.adapt_query(query, industry)
    • 获取对应行业的适配器
    • 使用适配器的改写模板改写查询
  3. 检索:
    • 使用改写后的查询在向量库中检索
    • 返回 top-k 个相关文档(带分数)
  4. 生成答案:
    • 整合检索文档作为上下文
    • 使用提示词模板构建最终 prompt(包含上下文、原始查询、行业、改写查询)
    • LLM 生成答案
  5. 返回结果:
    • 包含查询、行业、改写查询、答案、检索文档等

8.2.3 类图 #

classDiagram class IndustryScenarioAdapter { <<abstract>> -llm: BaseLanguageModel -industry_name: str -rewriting_template: PromptTemplate +__init__(llm, industry_name, rewriting_template) +adapt_query(query: str) str } class EducationScenarioAdapter { +__init__(llm) } class MedicalScenarioAdapter { +__init__(llm) } class LegalScenarioAdapter { +__init__(llm) } class IndustryScenarioManager { -adapters: Dict[str, IndustryScenarioAdapter] +__init__() +register_adapter(industry: str, adapter: IndustryScenarioAdapter) +get_adapter(industry: str) Optional[IndustryScenarioAdapter] +adapt_query(query: str, industry: str) str } class IndustryScenarioRetriever { -vector_store: Any -scenario_manager: IndustryScenarioManager -industry: str -k: int +_get_relevant_documents(query: str, **kwargs) List[Document] } class IndustryScenarioRAG { -retriever: IndustryScenarioRetriever -llm: BaseLanguageModel -answer_prompt_template: PromptTemplate +__init__(retriever, llm, answer_prompt_template) +generate_answer(query: str, industry: str, k: int) Dict[str, Any] } class BaseRetriever { <<abstract>> +invoke(query: str) List[Document] } class PromptTemplate { +format(**kwargs) str } class Document { +page_content: str +metadata: Dict } class VectorStore { <<interface>> +add_texts(texts: List[str], metadatas: List[Dict]) +similarity_search_with_score(query: str, k: int) List[Tuple[Document, float]] } class BaseLanguageModel { <<interface>> +invoke(prompt: str) AIMessage } EducationScenarioAdapter --|> IndustryScenarioAdapter MedicalScenarioAdapter --|> IndustryScenarioAdapter LegalScenarioAdapter --|> IndustryScenarioAdapter IndustryScenarioRetriever --|> BaseRetriever IndustryScenarioRAG --> IndustryScenarioRetriever IndustryScenarioRAG --> BaseLanguageModel IndustryScenarioRetriever --> IndustryScenarioManager IndustryScenarioManager --> IndustryScenarioAdapter IndustryScenarioAdapter --> BaseLanguageModel IndustryScenarioAdapter --> PromptTemplate IndustryScenarioRetriever --> VectorStore IndustryScenarioRetriever ..> Document : creates VectorStore ..> Document : returns

8.2.4 时序图 #

8.2.4.1 完整RAG流程时序图 #
sequenceDiagram participant User as 用户 participant RAG as IndustryScenarioRAG participant Retriever as IndustryScenarioRetriever participant Manager as IndustryScenarioManager participant Adapter as IndustryScenarioAdapter participant VectorStore as VectorStore participant LLM as BaseLanguageModel User->>RAG: generate_answer("一个建筑工人每天能挖4米...", industry="教育", k=3) RAG->>Retriever: _get_relevant_documents(query, industry="教育", k=3) Note over Retriever: 步骤1: 行业适配改写 Retriever->>Manager: adapt_query(query, industry="教育") Manager->>Manager: get_adapter("教育") Manager-->>Retriever: 返回EducationScenarioAdapter Retriever->>Adapter: adapt_query(query) Adapter->>Adapter: 使用教育场景改写模板<br/>"请对以下题目进行题干抽取..." Adapter->>PromptTemplate: format(query=原始查询) PromptTemplate-->>Adapter: 返回完整prompt Adapter->>LLM: invoke(prompt) Note over LLM: 抽取题干,剥离表层信息<br/>保留核心数学关系 LLM-->>Adapter: 返回改写后的查询 Adapter-->>Retriever: 返回改写查询 Note over Retriever: 步骤2: 检索 Retriever->>VectorStore: similarity_search_with_score(改写查询, k=3) Note over VectorStore: 使用改写后的查询<br/>在知识库中检索 VectorStore-->>Retriever: 返回3个相关文档(带分数) Retriever->>Retriever: 构建Document对象<br/>添加元数据:<br/>- industry: 行业<br/>- original_query: 原始查询<br/>- adapted_query: 改写查询 Retriever-->>RAG: 返回检索文档列表 Note over RAG: 步骤3: 生成答案 RAG->>RAG: 提取改写查询 RAG->>RAG: 整合文档内容为上下文 RAG->>PromptTemplate: format(context, query, industry, adapted_query) PromptTemplate-->>RAG: 返回完整prompt RAG->>LLM: invoke(prompt) Note over LLM: 基于检索结果和行业信息生成答案 LLM-->>RAG: 返回答案内容 RAG-->>User: 返回结果字典<br/>(query, industry, adapted_query, answer, documents)
8.2.4.2 行业适配改写详细流程 #
sequenceDiagram participant Retriever as IndustryScenarioRetriever participant Manager as IndustryScenarioManager participant Adapter as IndustryScenarioAdapter participant Template as PromptTemplate participant LLM as BaseLanguageModel Note over Retriever: adapt_query(query, industry) Retriever->>Manager: adapt_query("一个建筑工人每天能挖4米...", "教育") Manager->>Manager: get_adapter("教育") Manager->>Manager: 从adapters字典获取适配器 alt 适配器存在 Manager-->>Retriever: 返回EducationScenarioAdapter Retriever->>Adapter: adapt_query(query) Adapter->>Adapter: 获取教育场景改写模板 Note over Adapter: 模板: "请对以下题目进行题干抽取,<br/>剥离人物、情境、故事等表层信息,<br/>只保留关键的数学关系..." Adapter->>Template: format(query=原始查询) Template-->>Adapter: 返回完整prompt Adapter->>LLM: invoke(prompt) Note over LLM: 根据行业特点改写查询<br/>教育场景: 抽取题干,保留核心关系 LLM-->>Adapter: 返回改写后的查询 alt 改写结果为空 Adapter-->>Retriever: 返回原始查询(降级处理) else 改写结果非空 Adapter-->>Retriever: 返回改写后的查询 end else 适配器不存在 Manager-->>Retriever: 返回原始查询(无适配器) end
8.2.4.3 不同行业适配器示例流程 #
sequenceDiagram participant Manager as IndustryScenarioManager participant EduAdapter as EducationScenarioAdapter participant MedAdapter as MedicalScenarioAdapter participant LegAdapter as LegalScenarioAdapter participant LLM as BaseLanguageModel Note over Manager: 教育场景适配 Manager->>EduAdapter: adapt_query("小明每天跑步5圈,跑了7天后还剩2圈...") EduAdapter->>LLM: invoke(教育场景改写模板) Note over LLM: 抽取题干,剥离表层信息<br/>保留核心数学关系 LLM-->>EduAdapter: "已知每天数量、天数、剩余量,求总量" EduAdapter-->>Manager: 返回改写查询 Note over Manager: 医疗场景适配 Manager->>MedAdapter: adapt_query("我最近总是头疼,有时候还会恶心...") MedAdapter->>LLM: invoke(医疗场景改写模板) Note over LLM: 将口语化描述转换为专业术语 LLM-->>MedAdapter: "头痛、恶心症状,可能疾病诊断" MedAdapter-->>Manager: 返回改写查询 Note over Manager: 法律场景适配 Manager->>LegAdapter: adapt_query("我在网上买东西,商家不发货也不退钱...") LegAdapter->>LLM: invoke(法律场景改写模板) Note over LLM: 提取法条类型、法律关系、争议焦点 LLM-->>LegAdapter: "网络购物合同纠纷,商家违约,消费者权益保护" LegAdapter-->>Manager: 返回改写查询

8.2.5 关键设计要点 #

1. 行业适配流程

用户查询 + 行业标识
    ↓
获取行业适配器
    ↓
使用行业特定模板改写查询
    ↓
使用改写后的查询检索
    ↓
生成答案 (包含行业信息)

2. 适配器模式设计

  • 基类:IndustryScenarioAdapter
    • 定义通用的 adapt_query() 方法
    • 使用模板和 LLM 进行查询改写
  • 子类:具体行业适配器
    • EducationScenarioAdapter:教育场景
    • MedicalScenarioAdapter:医疗场景
    • LegalScenarioAdapter:法律场景

3. 行业改写策略示例

教育场景:

原始查询: "一个建筑工人每天能挖4米深的洞,挖了6天后还差3米就完成了,这个洞总共有多深?"

改写后: "已知每天工作量、工作天数、剩余量,求总量"

医疗场景:

原始查询: "我最近总是头疼,有时候还会恶心,这是什么病?"

改写后: "头痛、恶心症状,可能疾病诊断"

法律场景:

原始查询: "我在网上买东西,商家不发货也不退钱,我该怎么办?"

改写后: "网络购物合同纠纷,商家违约,消费者权益保护"

4. 行业场景管理器

  • 适配器注册:使用字典存储行业到适配器的映射
  • 适配器获取:根据行业名称获取对应的适配器
  • 统一接口:提供 adapt_query() 方法统一处理

管理器结构:

adapters = {
    "教育": EducationScenarioAdapter,
    "医疗": MedicalScenarioAdapter,
    "法律": LegalScenarioAdapter
}

5. 元数据设计

检索返回的 Document 对象包含:

  • score:相似度分数(距离,越小越相似)
  • industry:行业标识
  • original_query:原始用户查询
  • adapted_query:改写后的查询
  • 继承原始文档的元数据(topic、industry 等)

6. 答案生成模板

默认答案生成模板:

已知以下相关信息:

{context}

用户问题:{query}
行业场景:{industry}
改写后的查询:{adapted_query}

请根据上述信息,准确全面地回答用户问题。

答案:

该模板:

  • 包含所有检索文档作为上下文
  • 明确列出原始查询、行业和改写查询
  • 要求准确全面地回答用户问题

7. 优势与应用场景

优势:

  • 行业定制:不同行业使用专门的改写策略
  • 可扩展性:易于添加新的行业适配器
  • 提高精度:行业特定的改写提高检索匹配率
  • 灵活配置:支持动态切换行业

适用场景:

  • 教育领域:数学题目、物理题目等需要抽取核心关系
  • 医疗领域:口语化症状描述转换为专业术语
  • 法律领域:法律问题提取关键要素和法条类型
  • 多行业系统:需要支持多个行业的统一系统

8. 技术细节

  • 适配器注册:
    • 使用字典存储,键为行业名称,值为适配器实例
    • 支持动态注册新的适配器
  • 查询改写:
    • 使用行业特定的提示词模板
    • 如果改写结果为空,降级为原始查询
  • 行业切换:
    • 支持在查询时动态指定行业
    • 检索器可以设置默认行业

9. 扩展性设计

添加新行业的步骤:

  1. 创建新的适配器类,继承 IndustryScenarioAdapter
  2. 定义行业特定的改写模板
  3. 在管理器中注册新适配器

示例:

class FinanceScenarioAdapter(IndustryScenarioAdapter):
    def __init__(self, llm: BaseLanguageModel):
        template = "请将以下金融问题转换为专业术语..."
        super().__init__(llm, "金融", template)

# 注册新适配器
scenario_manager.register_adapter("金融", FinanceScenarioAdapter(llm))

该设计通过“适配器模式+行业管理器”策略,在多行业场景下提供更精确的查询改写和检索,适用于需要支持多个行业的 RAG 应用。

9. Text2SQL #

本节介绍如何实现文本到SQL(Text2SQL)任务,即将自然语言的问题转化为对应的SQL查询语句。 Text2SQL技术在数据库问答、智能客服、数据分析等场景中有着广泛的应用。其基本流程通常包括:

  1. 理解用户意图:分析自然语言中表述的查询需求,例如检索条件、目标表、要查询的字段等。
  2. 槽位抽取:从用户问题中结构化地提取出相关的实体、属性和条件(如部门、日期范围、数值过滤等),为SQL拼接做准备。
  3. SQL模板生成:根据提取的信息结合数据库表结构,动态生成对应的标准SQL语句。
  4. 执行与结果返回:将生成的SQL语句提交到数据库执行,并将查询结果以易懂的方式反馈给用户。

9.1. Text2SQL.py #

# 导入类型注解相关的类型List、Dict、Any、Optional、Tuple
from typing import List, Dict, Any, Optional, Tuple

# 导入pydantic中用于配置数据模型的ConfigDict
from pydantic import ConfigDict

# 导入langchain核心模块中的基础大模型类型
from langchain_core.language_models import BaseLanguageModel

# 导入langchain核心模块中的提示模板类
from langchain_core.prompts import PromptTemplate

# 导入自定义的llm对象
from llm import llm

# 导入json模块和re模块用于数据解析
import json
import re
# 导入sqlite3用于数据库操作
import sqlite3

# 定义槽位提取器类,用于从自然语言中提取结构化条件
class SlotExtractor:
    """槽位提取器:从自然语言中提取结构化条件"""
    # 初始化方法,llm为大模型对象,slot_schema为槽位定义
    def __init__(self, llm: BaseLanguageModel, slot_schema: Dict[str, Any]):
        # 保存传入的大模型对象
        self.llm = llm
        # 保存槽位定义
        self.slot_schema = slot_schema
        # 构建可用槽位描述文本
        schema_text = "\n".join([f"- {slot}: {desc}" for slot, desc in slot_schema.items()])
        # 构建llm提示模板
        template = (
            "请从以下用户查询中提取结构化条件,以JSON格式返回。\n\n"
            "可用槽位:\n{schema}\n\n"
            "用户查询:{query}\n\n"
            "请提取所有相关槽位,返回JSON格式,例如:"
            '{{"position": "Java开发", "experience_years": 3, "max_salary": 20000, "work_mode": "远程"}}\n'
            "如果某个槽位未提及,则不要包含在JSON中。只返回JSON,不要其他内容:"
        )
        # 实例化PromptTemplate
        self.extraction_template = PromptTemplate(
            input_variables=["query", "schema"], template=template
        )

    # 从用户输入中提取槽位
    def extract_slots(self, query: str) -> Dict[str, Any]:
        """从查询中提取槽位"""
        # 构造槽位schema的文本
        schema_text = "\n".join([f"- {slot}: {desc}" for slot, desc in self.slot_schema.items()])
        # 填充prompt模板
        prompt = self.extraction_template.format(query=query, schema=schema_text)
        # 使用大模型得到输出内容
        response = self.llm.invoke(prompt).content.strip()

        try:
            # 用正则提取JSON片段
            json_match = re.search(r'\{[^}]+\}', response, re.DOTALL)
            if json_match:
                # 将JSON解析为字典
                slots = json.loads(json_match.group())
                return slots
        except:
            # 捕获异常,解析失败则跳过
            pass

        # 未能解析则返回空字典
        return {}

# 定义对话状态管理器类
class DialogueStateManager:
    """对话状态管理器:管理多轮对话中的槽位状态"""

    # 初始化,保存当前槽位和对话历史
    def __init__(self):
        # 初始化当前槽位为空
        self.slots: Dict[str, Any] = {}
        # 初始化对话历史为空
        self.history: List[Dict[str, str]] = []

    # 更新槽位状态,新加槽位覆盖旧槽位
    def update_slots(self, new_slots: Dict[str, Any], query: str) -> Dict[str, Any]:
        """更新槽位状态,新槽位会覆盖旧槽位"""
        # 记录本轮执行前的历史
        self.history.append({"query": query, "slots_before": self.slots.copy()})

        # 更新槽位字典,新值覆盖老值
        self.slots.update(new_slots)

        # 记录更新后的槽位状态
        self.history[-1]["slots_after"] = self.slots.copy()

        # 返回最新的槽位状态副本
        return self.slots.copy()

    # 获取当前槽位状态
    def get_slots(self) -> Dict[str, Any]:
        """获取当前槽位状态"""
        return self.slots.copy()

    # 清空所有槽位和对话历史
    def clear(self):
        """清空状态"""
        self.slots = {}
        self.history = []

# 定义SQL生成器类,根据槽位生成SQL语句
class SQLGenerator:
    """SQL生成器:根据槽位生成SQL语句(使用参数化查询)"""

    # 初始化,指定要查询的表名
    def __init__(self, table_name: str = "candidates"):
        # 保存表名
        self.table_name = table_name

    # 根据输入的槽位生成SQL与参数
    def generate_sql(self, slots: Dict[str, Any]) -> Tuple[str, List[Any]]:
        """根据槽位生成SQL查询语句和参数,返回(SQL语句, 参数列表)"""
        # 条件列表
        conditions = []
        # 参数列表
        params = []

        # 若槽位有岗位,根据岗位添加模糊查询
        if "position" in slots:
            pos = slots['position']
            conditions.append("position LIKE ?")
            params.append(f"%{pos}%")

        # 若槽位有工作经验,添加经验过滤
        if "experience_years" in slots:
            exp = slots['experience_years']
            conditions.append("experience_years >= ?")
            params.append(exp)

        # 若槽位有最低期望薪资,添加对应条件
        if "min_salary" in slots:
            min_sal = slots['min_salary']
            conditions.append("expected_salary >= ?")
            params.append(min_sal)

        # 若槽位有最高期望薪资,添加对应条件
        if "max_salary" in slots:
            max_sal = slots['max_salary']
            conditions.append("expected_salary <= ?")
            params.append(max_sal)

        # 若槽位有工作模式,添加对应条件
        if "work_mode" in slots:
            mode = slots['work_mode']
            conditions.append("work_mode = ?")
            params.append(mode)

        # 拼接WHERE子句,没有即为1=1(查全部)
        where_clause = " AND ".join(conditions) if conditions else "1=1"
        # 拼装最终SQL语句
        sql = f"SELECT * FROM {self.table_name} WHERE {where_clause}"

        # 返回SQL和参数
        return sql, params

# 定义SQLite数据库访问类
class SQLiteDatabase:
    """SQLite数据库:使用真实数据库进行查询"""

    # 初始化,指定数据库文件及表名
    def __init__(self, db_path: str = "candidates.db", table_name: str = "candidates"):
        # 保存数据库路径
        self.db_path = db_path
        # 保存表名
        self.table_name = table_name
        # 初始化数据库及表结构
        self._init_database()

    # 初始化数据库和表,插入样例数据
    def _init_database(self):
        """初始化数据库,创建表并插入示例数据"""
        # 连接数据库
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        # 创建候选人表
        cursor.execute(f"""
            CREATE TABLE IF NOT EXISTS {self.table_name} (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                name TEXT NOT NULL,
                position TEXT NOT NULL,
                experience_years INTEGER NOT NULL,
                expected_salary INTEGER NOT NULL,
                work_mode TEXT NOT NULL
            )
        """)

        # 检查表数据数量
        cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
        count = cursor.fetchone()[0]

        # 若表为空则插入样例数据
        if count == 0:
            sample_data = [
                ("张三", "Java开发", 5, 18000, "远程"),
                ("李四", "Java开发", 3, 15000, "远程"),
                ("王五", "Java开发", 4, 22000, "现场"),
                ("赵六", "Python开发", 3, 19000, "远程"),
                ("钱七", "Java开发", 2, 12000, "远程"),
            ]
            cursor.executemany(
                f"INSERT INTO {self.table_name} (name, position, experience_years, expected_salary, work_mode) VALUES (?, ?, ?, ?, ?)",
                sample_data
            )

        # 提交并关闭连接
        conn.commit()
        conn.close()

    # 执行SQL查询,返回字典列表,支持参数化查询
    def execute_sql(self, sql: str, params: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
        """执行SQL查询,返回结果列表(支持参数化查询)"""
        # 连接数据库
        conn = sqlite3.connect(self.db_path)
        # 设置行工厂为Row对象,实现结果为字典输出
        conn.row_factory = sqlite3.Row
        cursor = conn.cursor()

        try:
            # 执行SQL,带参数则传参数
            if params:
                cursor.execute(sql, params)
            else:
                cursor.execute(sql)
            # 获取所有行数据
            rows = cursor.fetchall()
            # 每一行转为字典,组合为列表
            results = [dict(row) for row in rows]
            return results
        except Exception as e:
            # 打印错误信息,如有错误返回空列表
            print(f"SQL执行错误: {e}")
            return []
        finally:
            # 关闭数据库连接
            conn.close()

# 定义Text2SQL检索增强生成(RAG)系统
class Text2SQLRAG:
    """Text2SQL RAG系统:将自然语言转换为SQL查询"""

    # 初始化方法,配置llm对象、槽位定义、表名、数据库位置
    def __init__(self, llm: BaseLanguageModel, slot_schema: Dict[str, Any], table_name: str = "candidates", db_path: str = "candidates.db"):
        # 保存llm对象
        self.llm = llm
        # 创建槽位提取器
        self.slot_extractor = SlotExtractor(llm, slot_schema)
        # 创建对话状态管理器
        self.state_manager = DialogueStateManager()
        # 创建SQL生成器
        self.sql_generator = SQLGenerator(table_name)
        # 创建SQLite数据库对象
        self.database = SQLiteDatabase(db_path, table_name)

        # 定义大模型答案生成模板
        self.answer_template = PromptTemplate(
            input_variables=["query", "sql", "results"],
            template=(
                "用户查询:{query}\n\n"
                "生成的SQL:{sql}\n\n"
                "查询结果:{results}\n\n"
                "请根据查询结果,用自然语言回答用户的问题。\n\n答案:"
            )
        )

    # 处理一次自然语言用户查询,返回结构化结果
    def query(self, user_query: str) -> Dict[str, Any]:
        """处理用户查询"""
        # 槽位提取
        new_slots = self.slot_extractor.extract_slots(user_query)

        # 用新槽位更新对话状态
        updated_slots = self.state_manager.update_slots(new_slots, user_query)

        # 根据当前槽位生成SQL语句和参数
        sql, params = self.sql_generator.generate_sql(updated_slots)

        # 执行SQL查询,获取结果
        results = self.database.execute_sql(sql, params)

        # 将结果格式化为简明文本
        results_text = "\n".join([f"- {r['name']}: {r['position']}, {r['experience_years']}年经验, 期望薪资{r['expected_salary']}元, {r['work_mode']}" for r in results])
        if not results_text:
            results_text = "未找到符合条件的候选人"

        # 使用大模型生成问答输出内容
        prompt = self.answer_template.format(
            query=user_query,
            sql=sql,
            results=results_text
        )
        answer = self.llm.invoke(prompt).content.strip()

        # 返回完整结构化结果
        return {
            "query": user_query,
            "extracted_slots": new_slots,
            "current_slots": updated_slots,
            "sql": sql,
            "results": results,
            "answer": answer
        }

    # 清空当前对话状态
    def clear_state(self):
        """清空对话状态"""
        self.state_manager.clear()

# 定义槽位模式,用于文本提取与SQL拼接
slot_schema = {
    "position": "岗位名称,如Java开发、Python开发等",
    "experience_years": "工作经验年限(数字)",
    "min_salary": "最低期望薪资(数字,单位:元)",
    "max_salary": "最高期望薪资(数字,单位:元)",
    "work_mode": "工作方式,如远程、现场、混合等"
}

# 实例化Text2SQL RAG系统对象
rag_system = Text2SQLRAG(llm, slot_schema, table_name="candidates")

# ========== 示例:多轮对话场景 ==========
# 打印分隔线和案例标题
print("="*60)
print("示例:多轮对话 - 逐步补充条件")
print("="*60)

# 第一轮对话:只给经验和岗位
print("\n【第一轮】")
result1 = rag_system.query("请帮我查找有三年以上Java开发经验的候选人")
print(f"用户查询:{result1['query']}")
print(f"提取的槽位:{result1['extracted_slots']}")
print(f"生成的SQL:{result1['sql']}")
result1_count = len(result1['results'])
print(f"查询结果数量:{result1_count}")
print(f"AI回答:{result1['answer']}\n")

# 第二轮对话:补充薪资条件
print("\n【第二轮】")
result2 = rag_system.query("期望薪资不超过2万元")
print(f"用户查询:{result2['query']}")
print(f"提取的槽位:{result2['extracted_slots']}")
print(f"当前所有槽位:{result2['current_slots']}")
print(f"生成的SQL:{result2['sql']}")
print(f"查询结果数量:{len(result2['results'])}")
print(f"AI回答:{result2['answer']}\n")

# 第三轮对话:补充工作方式
print("\n【第三轮】")
result3 = rag_system.query("必须能远程办公")
print(f"用户查询:{result3['query']}")
print(f"提取的槽位:{result3['extracted_slots']}")
print(f"当前所有槽位:{result3['current_slots']}")
print(f"生成的SQL:{result3['sql']}")
result3_count = len(result3['results'])
print(f"查询结果数量:{result3_count}")
print(f"AI回答:{result3['answer']}\n")

# ========== 示例:多条件组合查询 ==========
print("\n" + "="*60)
print("示例:多条件组合查询")
print("="*60)

# 清空会话,开始新一轮多条件对话
rag_system.clear_state()

# 第一轮:只设薪资上限
print("\n【第一轮】")
result1 = rag_system.query("期望薪资不超过2万")
print(f"用户查询:{result1['query']}")
print(f"提取的槽位:{result1['extracted_slots']}")
print(f"生成的SQL:{result1['sql']}")
print(f"查询结果数量:{len(result1['results'])}")
print(f"AI回答:{result1['answer']}\n")

# 第二轮:补充薪资下限
print("\n【第二轮】")
result2 = rag_system.query("有没有薪资高于1.5万的?")
print(f"用户查询:{result2['query']}")
print(f"提取的槽位:{result2['extracted_slots']}")
print(f"当前所有槽位:{result2['current_slots']}")
print(f"生成的SQL:{result2['sql']}")
print(f"查询结果数量:{len(result2['results'])}")
print(f"AI回答:{result2['answer']}\n")

9.2 执行过程 #

9.2.1 核心思想 #

Text2SQL 采用“槽位提取+SQL生成”策略:

  • 槽位提取:从自然语言中提取结构化条件(槽位)
  • 对话状态管理:管理多轮对话中的槽位状态,支持逐步补充条件
  • SQL生成:根据槽位生成SQL查询语句(使用参数化查询防止SQL注入)
  • 数据库查询:使用SQLite数据库执行查询
  • 答案生成:基于查询结果生成自然语言答案

9.2.2 执行流程 #

阶段一:初始化

# 1. 定义槽位模式
slot_schema = {
    "position": "岗位名称,如Java开发、Python开发等",
    "experience_years": "工作经验年限(数字)",
    "min_salary": "最低期望薪资(数字,单位:元)",
    "max_salary": "最高期望薪资(数字,单位:元)",
    "work_mode": "工作方式,如远程、现场、混合等"
}

# 2. 创建Text2SQL RAG系统
rag_system = Text2SQLRAG(llm, slot_schema, table_name="candidates")

初始化时:

  • 定义槽位模式,描述每个槽位的含义
  • 创建Text2SQL RAG系统,内部会创建:
    • 槽位提取器(SlotExtractor)
    • 对话状态管理器(DialogueStateManager)
    • SQL生成器(SQLGenerator)
    • SQLite数据库对象(SQLiteDatabase)

阶段二:数据库初始化

# SQLiteDatabase初始化时自动执行
database = SQLiteDatabase(db_path="candidates.db", table_name="candidates")

数据库初始化:

  • 自动创建数据库文件和表结构
  • 如果表为空,自动插入示例数据(5个候选人记录)

阶段三:多轮对话处理

# 第一轮对话
result1 = rag_system.query("请帮我查找有三年以上Java开发经验的候选人")

# 第二轮对话(补充条件)
result2 = rag_system.query("期望薪资不超过2万元")

# 第三轮对话(继续补充条件)
result3 = rag_system.query("必须能远程办公")

完整流程(每轮对话):

  1. 用户提交自然语言查询
  2. 槽位提取:
    • 调用 slot_extractor.extract_slots(query)
    • 使用LLM从查询中提取结构化槽位
    • 返回JSON格式的槽位字典
  3. 状态更新:
    • 调用 state_manager.update_slots(new_slots, query)
    • 将新槽位合并到当前槽位状态
    • 记录对话历史
  4. SQL生成:
    • 调用 sql_generator.generate_sql(updated_slots)
    • 根据槽位生成SQL语句和参数列表(参数化查询)
  5. 数据库查询:
    • 调用 database.execute_sql(sql, params)
    • 执行SQL查询,返回结果列表
  6. 结果格式化:
    • 将查询结果格式化为文本
  7. 答案生成:
    • 使用提示词模板构建最终prompt
    • LLM生成自然语言答案
  8. 返回结果:
    • 包含查询、提取的槽位、当前槽位、SQL、结果、答案等

9.2.3 类图 #

classDiagram class SlotExtractor { -llm: BaseLanguageModel -slot_schema: Dict[str, Any] -extraction_template: PromptTemplate +__init__(llm, slot_schema) +extract_slots(query: str) Dict[str, Any] } class DialogueStateManager { -slots: Dict[str, Any] -history: List[Dict[str, str]] +__init__() +update_slots(new_slots: Dict[str, Any], query: str) Dict[str, Any] +get_slots() Dict[str, Any] +clear() } class SQLGenerator { -table_name: str +__init__(table_name) +generate_sql(slots: Dict[str, Any]) Tuple[str, List[Any]] } class SQLiteDatabase { -db_path: str -table_name: str +__init__(db_path, table_name) -_init_database() +execute_sql(sql: str, params: Optional[List[Any]]) List[Dict[str, Any]] } class Text2SQLRAG { -llm: BaseLanguageModel -slot_extractor: SlotExtractor -state_manager: DialogueStateManager -sql_generator: SQLGenerator -database: SQLiteDatabase -answer_template: PromptTemplate +__init__(llm, slot_schema, table_name, db_path) +query(user_query: str) Dict[str, Any] +clear_state() } class PromptTemplate { +format(**kwargs) str } class BaseLanguageModel { <<interface>> +invoke(prompt: str) AIMessage } Text2SQLRAG --> SlotExtractor Text2SQLRAG --> DialogueStateManager Text2SQLRAG --> SQLGenerator Text2SQLRAG --> SQLiteDatabase Text2SQLRAG --> BaseLanguageModel SlotExtractor --> BaseLanguageModel SlotExtractor --> PromptTemplate Text2SQLRAG --> PromptTemplate

9.2.4 时序图 #

9.2.4.1 完整多轮对话流程时序图 #
sequenceDiagram participant User as 用户 participant RAG as Text2SQLRAG participant Extractor as SlotExtractor participant StateManager as DialogueStateManager participant SQLGen as SQLGenerator participant Database as SQLiteDatabase participant LLM as BaseLanguageModel Note over User,RAG: 第一轮对话 User->>RAG: query("请帮我查找有三年以上Java开发经验的候选人") Note over RAG: 步骤1: 槽位提取 RAG->>Extractor: extract_slots(query) Extractor->>Extractor: 构建槽位提取prompt Extractor->>LLM: invoke(prompt) Note over LLM: 从查询中提取结构化槽位<br/>返回JSON格式 LLM-->>Extractor: 返回JSON文本 Extractor->>Extractor: 解析JSON<br/>(使用正则提取) Extractor-->>RAG: 返回槽位字典<br/>{"position": "Java开发", "experience_years": 3} Note over RAG: 步骤2: 状态更新 RAG->>StateManager: update_slots(new_slots, query) StateManager->>StateManager: 记录历史(slots_before) StateManager->>StateManager: 更新槽位(合并新槽位) StateManager->>StateManager: 记录历史(slots_after) StateManager-->>RAG: 返回更新后的槽位 Note over RAG: 步骤3: SQL生成 RAG->>SQLGen: generate_sql(updated_slots) SQLGen->>SQLGen: 根据槽位构建WHERE条件<br/>使用参数化查询 SQLGen-->>RAG: 返回(SQL语句, 参数列表) Note over RAG: 步骤4: 数据库查询 RAG->>Database: execute_sql(sql, params) Database->>Database: 连接SQLite数据库 Database->>Database: 执行参数化查询 Database-->>RAG: 返回查询结果列表 Note over RAG: 步骤5: 答案生成 RAG->>RAG: 格式化查询结果为文本 RAG->>PromptTemplate: format(query, sql, results) PromptTemplate-->>RAG: 返回完整prompt RAG->>LLM: invoke(prompt) Note over LLM: 基于查询结果生成自然语言答案 LLM-->>RAG: 返回答案内容 RAG-->>User: 返回结果字典 Note over User,RAG: 第二轮对话(补充条件) User->>RAG: query("期望薪资不超过2万元") RAG->>Extractor: extract_slots(query) Extractor->>LLM: invoke(prompt) LLM-->>Extractor: 返回槽位<br/>{"max_salary": 20000} Extractor-->>RAG: 返回新槽位 RAG->>StateManager: update_slots(new_slots, query) Note over StateManager: 合并槽位:<br/>保留第一轮的槽位<br/>+ 添加第二轮的槽位 StateManager-->>RAG: 返回合并后的槽位 RAG->>SQLGen: generate_sql(updated_slots) Note over SQLGen: 生成包含所有条件的SQL<br/>position AND experience_years AND max_salary SQLGen-->>RAG: 返回SQL和参数 RAG->>Database: execute_sql(sql, params) Database-->>RAG: 返回筛选后的结果 RAG->>LLM: invoke(prompt) LLM-->>RAG: 返回答案 RAG-->>User: 返回结果
9.2.4.2 槽位提取详细流程 #
sequenceDiagram participant RAG as Text2SQLRAG participant Extractor as SlotExtractor participant Template as PromptTemplate participant LLM as BaseLanguageModel Note over RAG: extract_slots(query) RAG->>Extractor: extract_slots("请帮我查找有三年以上Java开发经验的候选人") Extractor->>Extractor: 构建槽位schema文本<br/>"- position: 岗位名称...\n- experience_years: 工作经验年限..." Extractor->>Template: format(query=用户查询, schema=槽位schema) Template-->>Extractor: 返回完整prompt Extractor->>LLM: invoke(prompt) Note over LLM: 提取结构化槽位<br/>返回JSON格式<br/>{"position": "Java开发", "experience_years": 3} LLM-->>Extractor: 返回JSON文本 Extractor->>Extractor: 使用正则提取JSON片段<br/>re.search(r'\{[^}]+\}', response) alt JSON解析成功 Extractor->>Extractor: json.loads(json_match.group()) Extractor-->>RAG: 返回槽位字典 else JSON解析失败 Extractor-->>RAG: 返回空字典(降级处理) end
9.2.4.3 SQL生成与数据库查询详细流程 #
sequenceDiagram participant RAG as Text2SQLRAG participant SQLGen as SQLGenerator participant Database as SQLiteDatabase participant SQLite as SQLite数据库 Note over RAG: generate_sql(updated_slots) RAG->>SQLGen: generate_sql({"position": "Java开发", "experience_years": 3, "max_salary": 20000}) SQLGen->>SQLGen: 遍历槽位,构建条件列表 alt 有position槽位 SQLGen->>SQLGen: conditions.append("position LIKE ?")<br/>params.append("%Java开发%") end alt 有experience_years槽位 SQLGen->>SQLGen: conditions.append("experience_years >= ?")<br/>params.append(3) end alt 有max_salary槽位 SQLGen->>SQLGen: conditions.append("expected_salary <= ?")<br/>params.append(20000) end SQLGen->>SQLGen: 拼接WHERE子句<br/>"position LIKE ? AND experience_years >= ? AND expected_salary <= ?" SQLGen->>SQLGen: 拼装完整SQL<br/>"SELECT * FROM candidates WHERE ..." SQLGen-->>RAG: 返回(SQL语句, 参数列表) RAG->>Database: execute_sql(sql, params) Database->>SQLite: connect(db_path) Database->>SQLite: 设置row_factory为Row对象 Database->>SQLite: execute(sql, params) Note over SQLite: 执行参数化查询<br/>防止SQL注入 SQLite-->>Database: 返回查询结果行 Database->>Database: 将Row对象转换为字典列表 Database-->>RAG: 返回结果列表<br/>[{"id": 1, "name": "张三", ...}, ...]
9.2.4.4 多轮对话状态管理流程 #
sequenceDiagram participant RAG as Text2SQLRAG participant StateManager as DialogueStateManager Note over RAG: 第一轮对话 RAG->>StateManager: update_slots({"position": "Java开发", "experience_years": 3}, query1) StateManager->>StateManager: 记录历史<br/>history.append({"query": query1, "slots_before": {}}) StateManager->>StateManager: 更新槽位<br/>slots.update(new_slots) StateManager->>StateManager: 记录历史<br/>history[-1]["slots_after"] = slots StateManager-->>RAG: 返回{"position": "Java开发", "experience_years": 3} Note over RAG: 第二轮对话(补充条件) RAG->>StateManager: update_slots({"max_salary": 20000}, query2) StateManager->>StateManager: 记录历史<br/>history.append({"query": query2, "slots_before": {"position": "Java开发", "experience_years": 3}}) StateManager->>StateManager: 更新槽位<br/>slots.update({"max_salary": 20000}) Note over StateManager: 合并后的槽位:<br/>{"position": "Java开发", "experience_years": 3, "max_salary": 20000} StateManager->>StateManager: 记录历史<br/>history[-1]["slots_after"] = slots StateManager-->>RAG: 返回合并后的槽位 Note over RAG: 第三轮对话(继续补充) RAG->>StateManager: update_slots({"work_mode": "远程"}, query3) StateManager->>StateManager: 更新槽位(合并) Note over StateManager: 最终槽位:<br/>{"position": "Java开发", "experience_years": 3, "max_salary": 20000, "work_mode": "远程"} StateManager-->>RAG: 返回最终槽位

9.2.5 关键设计要点 #

1. Text2SQL流程

自然语言查询
    ↓
槽位提取 (LLM)
    ↓
状态更新 (合并槽位)
    ↓
SQL生成 (参数化查询)
    ↓
数据库查询 (SQLite)
    ↓
结果格式化
    ↓
答案生成 (LLM)

2. 槽位提取示例

用户查询: "请帮我查找有三年以上Java开发经验的候选人"

提取的槽位: {
    "position": "Java开发",
    "experience_years": 3
}

多轮对话槽位累积:

第一轮: {"position": "Java开发", "experience_years": 3}
第二轮: + {"max_salary": 20000}
第三轮: + {"work_mode": "远程"}

最终槽位: {
    "position": "Java开发",
    "experience_years": 3,
    "max_salary": 20000,
    "work_mode": "远程"
}

3. SQL生成策略

  • 参数化查询:使用 ? 占位符,防止SQL注入
  • 条件构建:根据槽位动态构建WHERE条件
  • 参数列表:与SQL语句分离,安全传递参数

SQL生成示例:

-- 槽位: {"position": "Java开发", "experience_years": 3, "max_salary": 20000}

-- 生成的SQL:
SELECT * FROM candidates 
WHERE position LIKE ? 
  AND experience_years >= ? 
  AND expected_salary <= ?

-- 参数列表:
["%Java开发%", 3, 20000]

4. 对话状态管理

  • 槽位累积:新槽位会合并到现有槽位中
  • 历史记录:记录每轮对话前后的槽位状态
  • 状态持久化:支持多轮对话的状态保持
  • 可清空:支持清空状态重新开始

历史记录结构:

history = [
    {
        "query": "请帮我查找有三年以上Java开发经验的候选人",
        "slots_before": {},
        "slots_after": {"position": "Java开发", "experience_years": 3}
    },
    {
        "query": "期望薪资不超过2万元",
        "slots_before": {"position": "Java开发", "experience_years": 3},
        "slots_after": {"position": "Java开发", "experience_years": 3, "max_salary": 20000}
    }
]

5. 数据库设计

表结构:

CREATE TABLE candidates (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    name TEXT NOT NULL,
    position TEXT NOT NULL,
    experience_years INTEGER NOT NULL,
    expected_salary INTEGER NOT NULL,
    work_mode TEXT NOT NULL
)

示例数据:

  • 张三:Java开发,5年经验,18000元,远程
  • 李四:Java开发,3年经验,15000元,远程
  • 王五:Java开发,4年经验,22000元,现场
  • 赵六:Python开发,3年经验,19000元,远程
  • 钱七:Java开发,2年经验,12000元,远程

6. 答案生成模板

默认答案生成模板:

用户查询:{query}

生成的SQL:{sql}

查询结果:{results}

请根据查询结果,用自然语言回答用户的问题。

答案:

该模板:

  • 包含原始查询
  • 显示生成的SQL(便于调试)
  • 包含查询结果
  • 要求用自然语言回答

7. 优势与应用场景

优势:

  • 自然语言交互:用户可以用自然语言查询数据库
  • 多轮对话:支持逐步补充条件,无需一次性提供所有信息
  • 安全性:使用参数化查询,防止SQL注入
  • 可扩展性:易于添加新的槽位类型

适用场景:

  • 数据库查询系统:将自然语言转换为SQL查询
  • 智能客服:支持多轮对话的查询系统
  • 数据检索:非技术人员查询结构化数据
  • 业务系统:HR系统、招聘系统等

8. 技术细节

  • 槽位提取:
    • 使用LLM提取结构化信息
    • 使用正则表达式提取JSON
    • 解析失败时返回空字典
  • SQL生成:
    • 使用参数化查询(? 占位符)
    • 动态构建WHERE条件
    • 无条件时使用 1=1(返回所有记录)
  • 数据库操作:
    • 使用SQLite的Row工厂,返回字典格式
    • 支持参数化查询执行
    • 异常处理:捕获SQL执行错误

9. 多轮对话示例

场景:逐步补充条件

第一轮: "请帮我查找有三年以上Java开发经验的候选人"
  → SQL: WHERE position LIKE '%Java开发%' AND experience_years >= 3
  → 结果: 3个候选人

第二轮: "期望薪资不超过2万元"
  → SQL: WHERE position LIKE '%Java开发%' AND experience_years >= 3 AND expected_salary <= 20000
  → 结果: 2个候选人(筛选掉22000元的)

第三轮: "必须能远程办公"
  → SQL: WHERE position LIKE '%Java开发%' AND experience_years >= 3 AND expected_salary <= 20000 AND work_mode = '远程'
  → 结果: 1个候选人(最终匹配)

该设计通过“槽位提取+SQL生成+多轮对话”策略,在数据库查询场景下提供自然语言交互能力,适用于需要将自然语言转换为SQL查询的RAG应用。

← 上一节 24.索引时优化 下一节 26.检索后优化 →

访问验证

请输入访问令牌

Token不正确,请重新输入