导航菜单

  • 1.langchain.intro
  • 2.langchain.chat_models
  • 3.langchain.prompts
  • 4.langchain.example_selectors
  • 5.output_parsers
  • 6.Runnable
  • 7.memory
  • 8.document_loaders
  • 9.text_splitters
  • 10.embeddings
  • 11.tool
  • 12.retrievers
  • 13.optimize
  • 14.项目介绍
  • 15.启动HTTP
  • 16.数据与模型
  • 17.权限管理
  • 18.知识库管理
  • 19.设置
  • 20.文档管理
  • 21.聊天
  • 22.API文档
  • 23.RAG优化
  • 24.索引时优化
  • 25.检索前优化
  • 26.检索后优化
  • 27.系统优化
  • 28.GraphRAG
  • 29.图
  • 30.为什么选择图数据库
  • 31.什么是 Neo4j
  • 32.安装和连接 Neo4j
  • 33.Neo4j核心概念
  • 34.Cypher基础
  • 35.模式匹配
  • 36.数据CRUD操作
  • 37.GraphRAG
  • 38.查询和过滤
  • 39.结果处理和聚合
  • 40.语句组合
  • 41.子查询
  • 42.模式和约束
  • 43.日期时间处理
  • 44.Cypher内置函数
  • 45.Python操作Neo4j
  • 46.neo4j
  • 47.py2neo
  • 48.Streamlit
  • 49.Pandas
  • 50.graphRAG
  • 51.deepdoc
  • 52.deepdoc
  • 53.deepdoc
  • 55.deepdoc
  • 54.deepdoc
  • Pillow
  • 1.上下文压缩(Contextual Compression)
  • 2.混合检索(Hybrid Retrieval)
  • 3.文档重排序(Document Reranking)

1.上下文压缩(Contextual Compression) #

“上下文压缩(Contextual Compression)”是近年来在检索增强生成(RAG)系统中非常重要的一种优化技术。其核心思想是:不直接把检索到的完整原始文档全部喂给大模型,而是先对每份检索文档基于用户当前的查询进行“压缩”,即仅提取与本次查询密切相关的关键信息,去除无关内容,再将这些浓缩后的文本提供给大模型进行问答、生成等操作。

这样做有如下显著优势:

  • 大幅提升上下文利用率:可用token有限,压缩后可纳入更多相关知识点,有效提升大模型检索后的知识覆盖广度。
  • 噪声过滤和内容聚焦:自动去除大段无用背景和题外信息,使LLM更聚焦于“当前问题相关”部分,减少“幻觉”比例。
  • 提升推理推断能力:短小精炼的上下文有助于LLM更好地串联信息,提升复杂/细粒度问答的准确率。
  • 节省推理算力和上下文窗口:同样token预算下,能压缩、浓缩进更多知识,提升准确率和系统成本效益。

常见实现方式包括:

  1. LLM驱动的内容压缩 —— 用prompt指令让大模型对每篇文档只提取与用户问题直接相关的语句、论述,并自动丢弃冗余内容。
  2. 粗过滤+细压缩结合 —— 先用向量召回/关键词召回选出相关文档,再对这些文档一一做针对性精细压缩。
  3. 分层级压缩 —— 对超长文档可多次“递进压缩”,逐层减小内容规模。
  4. 压缩结果的结构化标注 —— 可将压缩部分和原文长度、压缩率、所属文档编号等信息同步传给大模型,便于溯源和追踪。

RAG典型工作流程如下:

  1. 用户提出查询(Query)。
  2. 检索模块召回若干候选文档。
  3. 压缩模块对每份文档基于query进行内容提取,仅保留相关核心片段。
  4. 将所有压缩后的内容拼接成最终上下文,输入大模型进行统一“阅读理解”与作答。
  5. 最终输出答案,并可溯源相关内容证据。

下方提供了一个标准Contextual Compression工作流程的实现。包括了自定义的上下文压缩器、融合检索与压缩的检索器,以及典型的文件初始化、检索、压缩和问答过程演示。这为构建“高质量、低噪声、强推理能力”的RAG系统打下了坚实基础。

适用场景举例:

  • 多轮长对话延续,历史可以用压缩方法保留关键信息供模型记忆。
  • 法律、医学、学术等领域长文案场景,帮助大模型高效阅读和准确引用长文档答案。
  • 通用RAG问答系统的查询-检索-生成链路,大幅提升端到端系统的答案质量。

ContextualCompression.py

# 导入typing模块的List、Dict、Any、Optional用于类型注解
from typing import List, Dict, Any, Optional

# 导入pydantic的ConfigDict用于模型参数配置
from pydantic import ConfigDict

# 导入langchain_core.retrievers的BaseRetriever作为检索器基类
from langchain_core.retrievers import BaseRetriever

# 导入langchain_core.documents的Document类用于封装文档对象
from langchain_core.documents import Document

# 导入langchain_core.language_models的BaseLanguageModel作为大语言模型基类
from langchain_core.language_models import BaseLanguageModel

# 导入langchain_core.prompts的PromptTemplate用于构建提示模板
from langchain_core.prompts import PromptTemplate

# 导入自定义llm对象(大语言模型)
from llm import llm

# 导入自定义的get_vector_store函数,用于获取向量数据库
from vector_store import get_vector_store

# 定义上下文压缩器,根据用户查询对文档进行压缩,仅保留相关内容
class ContextualCompressor:
    # 类描述:上下文压缩器,根据查询压缩文档内容,只保留与查询相关部分
    def __init__(self, llm: BaseLanguageModel):
        # 初始化llm
        self.llm = llm

        # 构建文档压缩用的提示模板
        self.compression_template = PromptTemplate(
            input_variables=["query", "document"],
            template=(
                "请根据用户查询,从以下文档中提取与查询相关的关键信息,"
                "去除不相关的内容,保留核心信息。\n\n"
                "用户查询:{query}\n\n"
                "原始文档:\n{document}\n\n"
                "请只返回压缩后的相关内容,去除所有不相关的信息。"
                "如果文档中没有相关信息,请返回\"无相关信息\"。\n\n"
                "压缩后的内容:"
            )
        )

    # 方法:执行文档压缩
    def compress(self, query: str, document: str) -> str:
        """
        压缩文档内容,根据查询只保留相关内容
        """
        # 格式化提示词,生成新的prompt
        prompt = self.compression_template.format(query=query, document=document)
        # 调用llm执行内容压缩,获取压缩后内容并去掉首尾空白
        compressed = self.llm.invoke(prompt).content.strip()

        # 如果结果为空或包含“无相关信息”,返回空字符串
        if not compressed or "无相关信息" in compressed:
            return ""

        # 返回压缩后的有效内容
        return compressed

# 定义带上下文压缩功能的检索器
class ContextualCompressionRetriever(BaseRetriever):
    # 类描述:先检索文档,再根据查询压缩文档内容
    vector_store: Any
    compressor: ContextualCompressor
    k: int = 4
    model_config = ConfigDict(arbitrary_types_allowed=True)

    # 方法:根据查询检索和压缩相关文档
    def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
        # 获取参数k,默认为成员变量k
        k = kwargs.get("k", self.k)

        # 第一步:使用向量数据库进行相似度检索,选出候选文档(数量为top-k的两倍)
        candidate_docs = self.vector_store.similarity_search_with_score(query, k=k * 2)

        # 第二步:对每个候选文档内容进行压缩处理
        compressed_docs = []
        for doc, distance in candidate_docs:
            # 对文档内容进行压缩,仅保留与查询相关的部分
            compressed_content = self.compressor.compress(query, doc.page_content)

            # 只保存压缩后有内容的文档对象
            if compressed_content:
                compressed_docs.append(
                    Document(
                        page_content=compressed_content,
                        metadata={
                            **doc.metadata,
                            "score": float(distance),
                            "original_length": len(doc.page_content),
                            "compressed_length": len(compressed_content),
                            "compression_ratio": len(compressed_content) / len(doc.page_content) if doc.page_content else 0,
                            "retrieval_method": "contextual_compression"
                        }
                    )
                )

        # 第三步:按照原始检索分数排序,选出top-k个文档返回
        compressed_docs.sort(key=lambda x: x.metadata.get("score", float('inf')))
        return compressed_docs[:k]

# 定义上下文压缩RAG系统,综合检索和上下文压缩能力
class ContextualCompressionRAG:
    # 类描述:上下文压缩RAG系统,结合检索和压缩流程
    def __init__(self, retriever: ContextualCompressionRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str] = None):
        # 初始化检索器和大语言模型
        self.retriever = retriever
        self.llm = llm

        # 若未指定自定义答案模板,则使用默认模板
        template = answer_prompt_template or (
            "已知以下相关信息(经过上下文压缩,只保留与查询相关的关键信息):\n\n{context}\n\n"
            "用户查询:{query}\n\n"
            "请根据上述信息,准确全面地回答用户问题。\n\n答案:"
        )
        # 构建大语言模型输出答案用的提示模板
        self.answer_prompt_template = PromptTemplate(
            input_variables=["context", "query"], template=template
        )

    # 方法:执行检索-压缩-问答链,生成最终答案
    def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
        # 检索相关文档并压缩
        docs = self.retriever._get_relevant_documents(query, k=k)

        # 将压缩后的文档内容拼接成最终上下文
        context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])

        # 根据上下文与查询,调用llm生成答案
        prompt = self.answer_prompt_template.format(context=context, query=query)
        answer = self.llm.invoke(prompt).content.strip()

        # 计算全部原文和压缩后内容长度
        total_original_length = sum([doc.metadata.get("original_length", 0) for doc in docs])
        total_compressed_length = sum([doc.metadata.get("compressed_length", 0) for doc in docs])
        # 平均压缩比
        avg_compression_ratio = total_compressed_length / total_original_length if total_original_length > 0 else 0

        # 返回问答结果、检索文档及压缩统计信息
        return {
            "query": query,
            "answer": answer,
            "retrieved_documents": docs,
            "num_documents": len(docs),
            "compression_stats": {
                "total_original_length": total_original_length,
                "total_compressed_length": total_compressed_length,
                "avg_compression_ratio": avg_compression_ratio,
                "space_saved": total_original_length - total_compressed_length
            }
        }

# 创建向量库实例,指定持久化路径和集合名
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="contextual_compression"
)

# 创建上下文压缩器实例
compressor = ContextualCompressor(llm)

# 创建带上下文压缩功能的检索器实例
compression_retriever = ContextualCompressionRetriever(
    vector_store=vector_store,
    compressor=compressor,
    k=4
)

# 创建上下文压缩RAG系统实例
rag_system = ContextualCompressionRAG(retriever=compression_retriever, llm=llm)

# 定义示例文档(内容较长且包含多个主题)
documents = [
    "人工智能(AI)是计算机科学的一个分支,旨在创建能够执行通常需要人类智能的任务的系统。"
    "机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习,而无需明确编程。"
    "深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。"
    "神经网络由多个层组成,每层包含多个神经元,通过反向传播算法进行训练。"
    "卷积神经网络(CNN)特别适合处理图像数据,而循环神经网络(RNN)适合处理序列数据。"
    "Transformer架构是NLP领域的重要突破,它使用自注意力机制来处理序列数据。",
    "自然语言处理(NLP)是人工智能的一个分支,专注于使计算机能够理解、解释和生成人类语言。"
    "BERT和GPT是Transformer架构的两个重要应用,分别用于理解任务和生成任务。"
    "BERT通过双向编码器理解上下文,GPT通过自回归生成器生成文本。"
    "词嵌入技术将词汇转换为向量表示,Word2Vec、GloVe和FastText是常见的词嵌入方法。"
    "注意力机制允许模型关注输入序列的不同部分,提高了模型的表达能力。",
    "计算机视觉是人工智能的另一个重要分支,专注于使计算机能够理解和分析视觉信息。"
    "图像分类、目标检测和图像分割是计算机视觉的三个主要任务。"
    "ResNet、YOLO和U-Net分别是这三个任务的代表性模型。"
    "数据增强技术可以提高模型的泛化能力,包括旋转、缩放、裁剪等操作。"
    "迁移学习允许将在一个任务上训练的模型应用到另一个相关任务上。",
]

# 打印初始化信息
print("正在初始化知识库...")

# 向向量库批量添加文档及元数据
vector_store.add_texts(
    documents,
    metadatas=[{"topic": "人工智能"}, {"topic": "自然语言处理"}, {"topic": "计算机视觉"}]
)

# 打印完成信息
print("知识库初始化完成!\n")

# 打印分隔线和当前示例说明
print("="*60)
print("上下文压缩示例")
print("="*60)

# 定义示例查询
query = "什么是深度学习?它和神经网络有什么关系?"

# 执行上下文压缩和答案生成流程
result = rag_system.generate_answer(query, k=3)

# 打印最终问答结果
print(f"\n用户查询:{result['query']}")
print(f"\n生成的答案:\n{result['answer']}")
print(f"\n检索到的文档数量:{result['num_documents']}")
print(f"\n压缩统计信息:")
print(f"  原始总长度:{result['compression_stats']['total_original_length']} 字符")
print(f"  压缩后总长度:{result['compression_stats']['total_compressed_length']} 字符")
print(f"  平均压缩比:{result['compression_stats']['avg_compression_ratio']:.2%}")
print(f"  节省空间:{result['compression_stats']['space_saved']} 字符")

# 打印每个压缩后文档的详细信息
print(f"\n压缩后的文档详情:")
for i, doc in enumerate(result['retrieved_documents'], 1):
    print(f"\n文档 {i}:")
    print(f"  原始长度:{doc.metadata.get('original_length', 'N/A')} 字符")
    print(f"  压缩后长度:{doc.metadata.get('compressed_length', 'N/A')} 字符")
    print(f"  压缩比:{doc.metadata.get('compression_ratio', 0):.2%}")
    print(f"  压缩后内容:{doc.page_content[:200]}...")

2.混合检索(Hybrid Retrieval) #

混合检索(Hybrid Retrieval)是一种结合稀疏检索(如关键词/符号检索)与密集检索(如语义向量检索)优势的检索方式,常用于提升信息检索系统的准确性与召回率。在大模型RAG问答系统中,通常将稀疏检索和密集检索的结果进行加权融合后,选取得分最高的文档返还给大模型生成答案。

常见做法包括:

  • 稀疏检索(如BM25):通过关键词匹配查找相关文档,适合捕获与查询词高度重合的内容。
  • 密集检索(如向量检索):将查询和文档编码为向量,计算其余弦相似度,能够发现语义相关但词面不一致的文档。
  • 加权融合:可指定稀疏分数、密集分数的权重,进行归一化加权合并,得到最终的融合分数并排序。

这样,混合检索能够兼顾传统信息检索和语义检索的优势,大幅提升复杂查询下的召回和精度,是现代RAG系统中非常推荐的一种检索方案。

HybridRetrieval.py

# 导入List、Dict、Any、Optional类型用于类型注解
from typing import List, Dict, Any, Optional

# 导入pydantic库中的ConfigDict,用于Pydantic模型配置
from pydantic import ConfigDict

# 导入langchain核心模块的基础检索器基类
from langchain_core.retrievers import BaseRetriever

# 导入langchain核心模块的文档对象类型
from langchain_core.documents import Document

# 导入langchain核心模块的基础LLM类型
from langchain_core.language_models import BaseLanguageModel

# 导入langchain核心模块的提示词模板类
from langchain_core.prompts import PromptTemplate

# 导入自定义llm对象
from llm import llm

# 导入自定义获取向量存储的函数
from vector_store import get_vector_store

# 导入自定义嵌入对象
from embeddings import embeddings

# 导入collections模块中的Counter用于词频统计
from collections import Counter

# 导入math模块用于数学计算
import math

# 定义BM25稀疏检索器
class BM25Retriever:
    """BM25稀疏检索器:基于关键词匹配的检索"""

    # 初始化BM25检索器,传入文档列表
    def __init__(self, documents: List[str]):
        self.documents = documents
        # 对每个文档进行分词
        self.tokenized_docs = [self._tokenize(doc) for doc in documents]
        # 构建BM25索引
        self.bm25_index = self._build_bm25_index()

    # 文本分词函数
    def _tokenize(self, text: str) -> List[str]:
        """简单分词(实际应用中应使用jieba等专业分词工具)"""
        # 按正则表达式分隔出单词字符,全部转小写
        import re
        tokens = re.findall(r'\w+', text.lower())
        return tokens

    # 构建BM25索引,包括文档频率、平均文档长度、文档数
    def _build_bm25_index(self):
        """构建BM25索引"""
        # 统计文档频率(DF)
        doc_freq = Counter()
        for doc_tokens in self.tokenized_docs:
            unique_tokens = set(doc_tokens)
            for token in unique_tokens:
                doc_freq[token] += 1
        # 计算所有文档长度
        doc_lengths = [len(doc) for doc in self.tokenized_docs]
        # 计算平均文档长度
        avg_doc_length = sum(doc_lengths) / len(doc_lengths) if doc_lengths else 0
        # 返回字典结构包含所有必要统计量
        return {
            "doc_freq": doc_freq,
            "avg_doc_length": avg_doc_length,
            "total_docs": len(self.documents)
        }

    # 计算逆文档频率IDF
    def _calculate_idf(self, term: str) -> float:
        """计算逆文档频率(IDF)"""
        df = self.bm25_index["doc_freq"].get(term, 0)
        if df == 0:
            return 0
        total_docs = self.bm25_index["total_docs"]
        # 按照BM25标准公式计算IDF
        return math.log((total_docs - df + 0.5) / (df + 0.5) + 1.0)

    # 计算BM25分数(query与某文档之间的得分)
    def _calculate_bm25_score(self, query_tokens: List[str], doc_tokens: List[str], doc_idx: int) -> float:
        """计算BM25分数"""
        k1 = 1.5  # BM25参数k1
        b = 0.75  # BM25参数b
        score = 0.0

        # 获取文档长度
        doc_length = len(doc_tokens)
        # 获取语料平均文档长度
        avg_doc_length = self.bm25_index["avg_doc_length"]
        # 统计题词频
        term_freq = Counter(doc_tokens)

        # 对每个query中的词,累加BM25得分
        for term in query_tokens:
            if term not in term_freq:
                continue
            # 计算IDF
            idf = self._calculate_idf(term)
            # 计算词频相关系数TF部分
            f = term_freq[term]
            numerator = f * (k1 + 1)
            denominator = f + k1 * (1 - b + b * (doc_length / avg_doc_length))
            tf = numerator / denominator if denominator != 0 else 0
            # 把该词的BM25贡献值加入总分
            score += idf * tf

        return score

    # 外部接口,执行检索
    def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
        """执行BM25检索"""
        # 对输入问题分词
        query_tokens = self._tokenize(query)

        # 依次对所有文档计算BM25分数
        scores = []
        for idx, doc_tokens in enumerate(self.tokenized_docs):
            score = self._calculate_bm25_score(query_tokens, doc_tokens, idx)
            # 只保留分数大于0的文档
            if score > 0:
                scores.append({
                    "document": self.documents[idx],
                    "score": score,
                    "index": idx
                })
        # 按BM25分数降序排
        scores.sort(key=lambda x: x["score"], reverse=True)
        # 取前k条返回
        return scores[:k]

# 定义密集检索器(向量检索)
class DenseRetriever:
    """密集检索器:基于向量相似度的检索"""

    # 初始化传入向量库对象
    def __init__(self, vector_store: Any):
        self.vector_store = vector_store

    # 外部接口,执行向量检索
    def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
        """执行密集向量检索"""
        # 使用向量库的similarity_search_with_score方法做相关性检索
        docs = self.vector_store.similarity_search_with_score(query, k=k)
        results = []
        # 将检索结果组装为字典形式,包含内容、得分和元信息
        for doc, distance in docs:
            results.append({
                "document": doc.page_content,
                "score": float(distance),
                "metadata": doc.metadata
            })
        return results

# 定义混合检索器
class HybridRetriever(BaseRetriever):
    """混合检索器:结合稀疏检索和密集检索"""

    # 向量库对象
    vector_store: Any
    # BM25检索器
    bm25_retriever: BM25Retriever
    # 密集检索器
    dense_retriever: DenseRetriever
    # 密集检索权重(默认0.7)
    dense_weight: float = 0.7
    # 稀疏检索权重(默认0.3)
    sparse_weight: float = 0.3
    # 导数排名算法的k值(默认60)
    rank_k: int = 60
    # 检索文档数量
    k: int = 4
    # 配置模型参数,允许任意类型
    model_config = ConfigDict(arbitrary_types_allowed=True)

    # 排名变换——根据排名算导数分数,用于后续融合排序
    def _calculate_rank_score(self, rank: int) -> float:
        """
        计算加权导数排名分数
        公式:分数 = 1 / (k + rank)
        其中:rank为排名位置(从1开始),k为经验常数(通常取60)
        """
        return 1.0 / (self.rank_k + rank)

    # 融合稀疏检索和密集检索结果
    def _fuse_results(self, sparse_results: List[Dict], dense_results: List[Dict]) -> List[Document]:
        """
        融合稀疏检索和密集检索的结果
        算法:
        1. 对每个检索方法的结果,使用导数排名算法计算分数
        2. 按权重加权融合
        3. 按融合分数排序
        """
        # 建立一个文档内容字符串到其稀疏/密集分数映射
        doc_scores: Dict[str, Dict[str, float]] = {}

        # 遍历稀疏检索结果,标记其排名分数
        for rank, result in enumerate(sparse_results, 1):
            doc = result["document"]
            rank_score = self._calculate_rank_score(rank)
            if doc not in doc_scores:
                doc_scores[doc] = {"sparse_score": 0.0, "dense_score": 0.0}
            doc_scores[doc]["sparse_score"] = rank_score

        # 遍历密集检索结果,标记其排名分数
        for rank, result in enumerate(dense_results, 1):
            doc = result["document"]
            rank_score = self._calculate_rank_score(rank)
            if doc not in doc_scores:
                doc_scores[doc] = {"sparse_score": 0.0, "dense_score": 0.0}
            doc_scores[doc]["dense_score"] = rank_score

        # 计算融合分数并组装结果
        fused_results = []
        for doc, scores in doc_scores.items():
            # 融合分数 = 稀疏分数 * 稀疏权重 + 密集分数 * 密集权重
            fused_score = scores["sparse_score"] * self.sparse_weight + scores["dense_score"] * self.dense_weight
            fused_results.append({
                "document": doc,
                "fused_score": fused_score,
                "sparse_score": scores["sparse_score"],
                "dense_score": scores["dense_score"]
            })

        # 按融合分数排序
        fused_results.sort(key=lambda x: x["fused_score"], reverse=True)

        # 每个结果转换成Document对象并附带元数据
        result_docs = []
        for result in fused_results:
            result_docs.append(Document(
                page_content=result["document"],
                metadata={
                    "fused_score": result["fused_score"],
                    "sparse_score": result["sparse_score"],
                    "dense_score": result["dense_score"],
                    "retrieval_method": "hybrid"
                }
            ))

        return result_docs

    # 检索入口,返回最相关的k个文档
    def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
        """执行混合检索"""
        # 若外部未传k,则用默认的k
        k = kwargs.get("k", self.k)

        # 兼容融合,检索稀疏和密集通路各返回2k个(冗余用于融合)
        sparse_results = self.bm25_retriever.retrieve(query, k=k * 2)
        dense_results = self.dense_retriever.retrieve(query, k=k * 2)

        # 融合排序
        fused_docs = self._fuse_results(sparse_results, dense_results)

        # 返回Top-k
        return fused_docs[:k]

# 定义混合Retrieval-Augmented Generation系统
class HybridRetrievalRAG:
    """混合检索RAG系统:结合稀疏检索和密集检索"""

    # 初始化传入混合检索器、LLM对象、可选自定义模板
    def __init__(self, retriever: HybridRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str] = None):
        self.retriever = retriever
        self.llm = llm
        # 默认答题模板
        template = answer_prompt_template or (
            "已知以下相关信息(通过混合检索技术获取):\n\n{context}\n\n"
            "用户查询:{query}\n\n"
            "请根据上述信息,准确全面地回答用户问题。\n\n答案:"
        )
        # 用PromptTemplate封装模板
        self.answer_prompt_template = PromptTemplate(
            input_variables=["context", "query"], template=template
        )

    # 外部接口,根据query生成答案
    def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
        """生成答案"""
        # 执行混合检索
        docs = self.retriever._get_relevant_documents(query, k=k)
        # 拼接检索到的文档作为上下文
        context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
        # 用模板生成prompt
        prompt = self.answer_prompt_template.format(context=context, query=query)
        # 调用llm生成答案
        answer = self.llm.invoke(prompt).content.strip()
        # 返回包含查询、答案、检索到的文档等信息的字典
        return {
            "query": query,
            "answer": answer,
            "retrieved_documents": docs,
            "num_documents": len(docs)
        }

# 初始化组件部分

# 创建向量库实例,设置存储目录和集合名
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="hybrid_retrieval"
)

# 示例文档列表,包括AI、区块链、量子计算相关内容
documents = [
    "人工智能(AI)是计算机科学的一个分支,旨在创建能够执行通常需要人类智能的任务的系统。"
    "机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习,而无需明确编程。"
    "深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。",
    "区块链技术是一种分布式账本技术,通过密码学方法确保数据的安全性和不可篡改性。"
    "比特币是第一个成功应用区块链技术的加密货币,它解决了数字货币的双重支付问题。"
    "以太坊是一个支持智能合约的区块链平台,允许开发者在其上构建去中心化应用。",
    "量子计算是一种利用量子力学现象进行计算的新兴技术,具有巨大的计算潜力。"
    "量子比特(qubit)是量子计算的基本单位,与经典比特不同,它可以同时处于0和1的叠加态。"
    "量子纠缠是量子计算的关键特性,允许量子比特之间建立特殊的关联关系。",
]

# 文档添加入向量库,并指定元数据
print("正在初始化知识库...")
vector_store.add_texts(
    documents,
    metadatas=[{"topic": "人工智能"}, {"topic": "区块链"}, {"topic": "量子计算"}]
)

# 创建BM25稀疏检索器
bm25_retriever = BM25Retriever(documents)

# 创建密集检索器
dense_retriever = DenseRetriever(vector_store)

# 创建混合检索器,并指定各类检索、权重及参数
hybrid_retriever = HybridRetriever(
    vector_store=vector_store,
    bm25_retriever=bm25_retriever,
    dense_retriever=dense_retriever,
    dense_weight=0.7,
    sparse_weight=0.3,
    rank_k=60,
    k=3
)

# 创建混合检索RAG系统
rag_system = HybridRetrievalRAG(retriever=hybrid_retriever, llm=llm)

# 打印知识库初始化完毕提示
print("知识库初始化完成!\n")

# 分割线,演示示例流程
print("="*60)
print("混合检索示例")
print("="*60)

# 示例查询
query = "什么是机器学习?"

# 执行混合检索并生成答案
result = rag_system.generate_answer(query, k=3)

# 打印用户查询内容
print(f"\n用户查询:{result['query']}")
# 打印最终生成的答案
print(f"\n生成的答案:\n{result['answer']}")
# 打印检索到的文档数量
print(f"\n检索到的文档数量:{result['num_documents']}")
# 打印检索到的文档详情,包括融合分数
print(f"\n检索到的文档详情(包含融合分数):")
for i, doc in enumerate(result['retrieved_documents'], 1):
    print(f"\n文档 {i}:")
    print(f"  融合分数:{doc.metadata.get('fused_score', 'N/A'):.6f}")
    print(f"  稀疏分数:{doc.metadata.get('sparse_score', 'N/A'):.6f}")
    print(f"  密集分数:{doc.metadata.get('dense_score', 'N/A'):.6f}")
    print(f"  内容:{doc.page_content[:150]}...")

3.文档重排序(Document Reranking) #

文档重排序(Document Reranking)是一种在检索流程中提升最终结果相关性的强有力方法。其基本思想是:

  1. 先用普通检索算法(如向量检索、关键词检索)选出一批初步候选文档;
  2. 然后利用更强的相关性判别模型(如Cross-Encoder、基于大模型的匹配评分)对每对“查询-文档”进行相关性精细打分,重新排序,最终取分数最高的文档作为RAG答案生成的依据。

核心过程

  • 候选召回(Recall): 首先使用高效的向量检索,从知识库中召回一批与查询初步相关的文档,通常数量较多(如10-20条)。
  • 重排序(Rerank): 对每一个“查询-文档”对,输入到更复杂的模型(如跨编码器Cross-Encoder,或用LLM直接相关性评分提示词)中,让模型给出0-1之间的相关性分数。
  • 选择最优(Select Top-K): 按分数从高到低排序,只保留最相关的前几条(如Top-3/Top-4)进入大模型生成环节。

实现方式 常见的重排序技术包括:

  • Cross-Encoder类模型: 输入“[查询][文档]”对,模型联合判断相关性,效果优于单独编码/余弦计算。
  • 大语言模型提示词方式: 用Prompt提示LLM扮演判分员角色,让模型直接打相关性分(见下方代码中的PromptTemplate与llm.invoke用法)。
  • BERT/RoBERTa/SimCSE等模型微调: 针对“查询-文档相关性判别”任务训练得到的判别模型。

优势

  • 能有效过滤掉表面相似但实际不相关的文档(如只包含部分关键词但含义不符的干扰文档)。
  • 极大改善“向量检索+生成”RAG中的答案准确度、可控性和鲁棒性。
  • 通常只需重排序最相关的少量文档,计算代价可控。

应用场景

  • 高精度问答/助理型RAG系统
  • 法律、医疗、金融等对精准文档匹配强依赖的领域
  • 想要摆脱“伪召回”“语义漂移”困扰的知识检索生成场景

DocumentReranking.py

# 导入List、Dict、Any、Optional这些类型,用于类型注解
from typing import List, Dict, Any, Optional

# 导入ConfigDict,Pydantic配置专用
from pydantic import ConfigDict

# 导入基础的检索器基类
from langchain_core.retrievers import BaseRetriever

# 导入Document文档类
from langchain_core.documents import Document

# 导入基础LLM类型
from langchain_core.language_models import BaseLanguageModel

# 导入PromptTemplate提示模版类
from langchain_core.prompts import PromptTemplate

# 导入自定义llm对象
from llm import llm

# 导入自定义的get_vector_store函数
from vector_store import get_vector_store

# 导入自定义的embeddings对象
from embeddings import embeddings

# 定义文档重排序器类,基于Cross-Encoder结构实现
class DocumentReranker:
    """文档重排序器:使用Cross-Encoder架构评估查询-文档相关性"""

    # 初始化方法,指定llm和最大长度
    def __init__(self, llm: BaseLanguageModel, max_length: int = 512):
        # 存储大语言模型对象
        self.llm = llm
        # 存储最大文本长度
        self.max_length = max_length

        # 配置相关性分数提示模版
        self.relevance_template = PromptTemplate(
            input_variables=["query", "document"],
            template=(
                "请评估以下查询与文档的相关性,给出0-1之间的分数(1表示完全相关,0表示完全不相关)。\n\n"
                "查询:{query}\n\n"
                "文档:{document}\n\n"
                "请只返回一个0-1之间的浮点数分数,不要其他内容:"
            )
        )

    # 文本预处理辅助方法
    def _preprocess_text(self, text: str) -> str:
        """文本预处理:标准化处理并控制长度"""
        # 移除多余空白
        text = " ".join(text.split())

        # 限制最大字符数并加省略号
        if len(text) > self.max_length:
            text = text[:self.max_length] + "..."

        return text

    # 构建 [query, document] 对辅助方法
    def _build_query_document_pairs(self, query: str, documents: List[str]) -> List[List[str]]:
        """构建查询-文档对:将用户查询与每个候选文档组成匹配对"""
        # 预处理查询
        processed_query = self._preprocess_text(query)
        # 预处理所有文档
        processed_docs = [self._preprocess_text(doc) for doc in documents]

        pairs = []
        # 遍历每篇文档,构成一对
        for doc in processed_docs:
            # 加入[query, document] 格式的对
            pairs.append([processed_query, doc])

        return pairs

    # 计算相关性分数
    def _calculate_relevance_score(self, query: str, document: str) -> float:
        """计算查询-文档对的相关性分数(使用LLM模拟Cross-Encoder)"""
        # 生成提示
        prompt = self.relevance_template.format(query=query, document=document)
        # 调用llm获取分数文本
        response = self.llm.invoke(prompt).content.strip()

        # 尝试用正则提取分数字符串
        try:
            import re
            # 匹配0-1之间的浮点数
            score_match = re.search(r'0?\.\d+|1\.0?|0\.0?', response)
            if score_match:
                score = float(score_match.group())
                # 限定在0-1范围之间
                score = max(0.0, min(1.0, score))
                return score
        except:
            pass

        # 无法解析则返回默认分数0.5
        return 0.5

    # 主重排序函数
    def rerank(self, query: str, documents: List[str], top_k: Optional[int] = None) -> List[Dict[str, Any]]:
        """
        执行文档重排序
        技术流程:
        1. 数据处理阶段:文档对构建、文本预处理、长度控制
        2. 模型推理阶段:分词处理、特征提取、相关性计算
        3. 结果排序阶段:分数归一化、文档排序、质量筛选
        """
        # 步骤1: 构建查询-文档对
        query_doc_pairs = self._build_query_document_pairs(query, documents)

        # 步骤2: 计算每个对的相关性分数
        scored_docs = []
        for i, (processed_query, processed_doc) in enumerate(query_doc_pairs):
            score = self._calculate_relevance_score(processed_query, processed_doc)
            scored_docs.append({
                "document": documents[i],
                "score": score,
                "original_index": i
            })

        # 步骤3: 根据分数降序排序
        scored_docs.sort(key=lambda x: x["score"], reverse=True)

        # 如果指定了top_k,只保留前k个
        if top_k is not None:
            scored_docs = scored_docs[:top_k]

        return scored_docs

# 定义带重排序的检索器类
class RerankingRetriever(BaseRetriever):
    """带重排序的检索器:先进行向量检索,再进行重排序"""

    # 声明向量库对象
    vector_store: Any
    # 声明重排序器对象
    reranker: DocumentReranker
    # 初始化时检索更多文档以便后续重排序
    initial_k: int = 10
    # 最终返回文档数量
    k: int = 4
    # pydantic的配置,允许任意类型字段
    model_config = ConfigDict(arbitrary_types_allowed=True)

    # 检索和重排序的主方法
    def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
        """执行检索和重排序"""
        # 获取最终返回数量k(允许通过kwargs临时修改)
        k = kwargs.get("k", self.k)
        # 获取初始候选数量initial_k
        initial_k = kwargs.get("initial_k", self.initial_k)

        # 第一步:向量召回,得到多个候选文档
        candidate_docs = self.vector_store.similarity_search_with_score(query, k=initial_k)

        # 提取原文档内容
        documents = [doc.page_content for doc, _ in candidate_docs]

        # 第二步:对所有候选做重排序
        reranked_results = self.reranker.rerank(query, documents, top_k=k)

        # 转换回Document对象,同时加入重排序得分等元信息
        result_docs = []
        for result in reranked_results:
            # 找到原始文档对象
            original_doc = None
            for doc, _ in candidate_docs:
                if doc.page_content == result["document"]:
                    original_doc = doc
                    break

            # 新建Document对象,合并元信息
            result_docs.append(Document(
                page_content=result["document"],
                metadata={
                    **(original_doc.metadata if original_doc else {}),
                    "rerank_score": result["score"],
                    "original_index": result["original_index"],
                    "retrieval_method": "reranking"
                }
            ))

        return result_docs

# 定义整体的文档重排序RAG系统类
class DocumentRerankingRAG:
    """文档重排序RAG系统:结合向量检索和Cross-Encoder重排序"""

    # 初始化,传入重排序检索器和llm
    def __init__(self, retriever: RerankingRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str] = None):
        # 存储检索器
        self.retriever = retriever
        # 存储llm
        self.llm = llm

        # 配置自定义或默认的答案生成提示模版
        template = answer_prompt_template or (
            "已知以下相关信息(通过文档重排序技术筛选出的最相关文档):\n\n{context}\n\n"
            "用户查询:{query}\n\n"
            "注意:系统使用了Cross-Encoder重排序技术,能够更准确地评估查询与文档之间的相关性,"
            "有效避免相似度陷阱问题。\n\n"
            "请根据上述信息,准确全面地回答用户问题。\n\n答案:"
        )
        self.answer_prompt_template = PromptTemplate(
            input_variables=["context", "query"], template=template
        )

    # 生成答案主函数
    def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
        """生成答案"""
        # 先检索与重排序
        docs = self.retriever._get_relevant_documents(query, k=k)

        # 构造上下文,插入每篇文档的内容及分数
        context = "\n\n".join([
            f"文档 {i+1}(重排序分数:{doc.metadata.get('rerank_score', 'N/A'):.4f}):\n{doc.page_content}"
            for i, doc in enumerate(docs)
        ])

        # 构造最终的提示
        prompt = self.answer_prompt_template.format(context=context, query=query)
        # 调用llm获取答案
        answer = self.llm.invoke(prompt).content.strip()

        # 返回带文档和答案的结构
        return {
            "query": query,
            "answer": answer,
            "retrieved_documents": docs,
            "num_documents": len(docs)
        }

# ===================== 初始化各组件 =======================

# 创建向量库对象
vector_store = get_vector_store(
    persist_directory="chroma_db",
    collection_name="document_reranking"
)

# 创建重排序器对象
reranker = DocumentReranker(llm, max_length=512)

# 创建带重排序的检索器对象
reranking_retriever = RerankingRetriever(
    vector_store=vector_store,
    reranker=reranker,
    initial_k=10,  # 初始检索10篇文档
    k=4  # 重排序后输出4篇文档
)

# 实例化文档重排序RAG系统
rag_system = DocumentRerankingRAG(retriever=reranking_retriever, llm=llm)

# ================== 示例文档及知识库初始化 ===================

# 示例文档列表
documents = [
    "人工智能(AI)是计算机科学的一个分支,旨在创建能够执行通常需要人类智能的任务的系统。"
    "机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习,而无需明确编程。"
    "深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。",
    "区块链技术是一种分布式账本技术,通过密码学方法确保数据的安全性和不可篡改性。"
    "比特币是第一个成功应用区块链技术的加密货币,它解决了数字货币的双重支付问题。"
    "以太坊是一个支持智能合约的区块链平台,允许开发者在其上构建去中心化应用。",
    "量子计算是一种利用量子力学现象进行计算的新兴技术,具有巨大的计算潜力。"
    "量子比特(qubit)是量子计算的基本单位,与经典比特不同,它可以同时处于0和1的叠加态。"
    "量子纠缠是量子计算的关键特性,允许量子比特之间建立特殊的关联关系。",
    "自然语言处理(NLP)是人工智能的一个分支,专注于使计算机能够理解、解释和生成人类语言。"
    "Transformer架构是NLP领域的重要突破,它使用自注意力机制来处理序列数据。"
    "BERT和GPT是Transformer架构的两个重要应用,分别用于理解任务和生成任务。",
    "计算机视觉是人工智能的另一个重要分支,专注于使计算机能够理解和分析视觉信息。"
    "卷积神经网络(CNN)是计算机视觉的核心技术,特别适合处理图像数据。"
    "图像分类、目标检测和图像分割是计算机视觉的三个主要任务。",
]

# 打印初始化提示
print("正在初始化知识库...")
# 向向量库中添加文档及元信息
vector_store.add_texts(
    documents,
    metadatas=[{"topic": "人工智能"}, {"topic": "区块链"}, {"topic": "量子计算"},
               {"topic": "自然语言处理"}, {"topic": "计算机视觉"}]
)
print("知识库初始化完成!\n")

# =============== 测试整体流程 ================

# 输出分隔符及说明
print("="*60)
print("文档重排序示例")
print("="*60)

# 配置示例查询
query = "什么是机器学习?它和深度学习有什么关系?"

# 执行检索、重排序、答案生成
result = rag_system.generate_answer(query, k=4)

# 打印最终结果及详细推理
print(f"\n用户查询:{result['query']}")
print(f"\n生成的答案:\n{result['answer']}")
print(f"\n重排序后选择的文档数量:{result['num_documents']}")
print(f"\n重排序后的文档详情(按相关性分数排序):")
for i, doc in enumerate(result['retrieved_documents'], 1):
    print(f"\n文档 {i}(原始索引:{doc.metadata.get('original_index', 'N/A')}):")
    print(f"  重排序分数:{doc.metadata.get('rerank_score', 'N/A'):.4f}")
    print(f"  内容:{doc.page_content[:150]}...")

← 上一节 25.检索前优化 下一节 27.系统优化 →

访问验证

请输入访问令牌

Token不正确,请重新输入