1.上下文压缩(Contextual Compression) #
“上下文压缩(Contextual Compression)”是近年来在检索增强生成(RAG)系统中非常重要的一种优化技术。其核心思想是:不直接把检索到的完整原始文档全部喂给大模型,而是先对每份检索文档基于用户当前的查询进行“压缩”,即仅提取与本次查询密切相关的关键信息,去除无关内容,再将这些浓缩后的文本提供给大模型进行问答、生成等操作。
这样做有如下显著优势:
- 大幅提升上下文利用率:可用token有限,压缩后可纳入更多相关知识点,有效提升大模型检索后的知识覆盖广度。
- 噪声过滤和内容聚焦:自动去除大段无用背景和题外信息,使LLM更聚焦于“当前问题相关”部分,减少“幻觉”比例。
- 提升推理推断能力:短小精炼的上下文有助于LLM更好地串联信息,提升复杂/细粒度问答的准确率。
- 节省推理算力和上下文窗口:同样token预算下,能压缩、浓缩进更多知识,提升准确率和系统成本效益。
常见实现方式包括:
- LLM驱动的内容压缩 —— 用prompt指令让大模型对每篇文档只提取与用户问题直接相关的语句、论述,并自动丢弃冗余内容。
- 粗过滤+细压缩结合 —— 先用向量召回/关键词召回选出相关文档,再对这些文档一一做针对性精细压缩。
- 分层级压缩 —— 对超长文档可多次“递进压缩”,逐层减小内容规模。
- 压缩结果的结构化标注 —— 可将压缩部分和原文长度、压缩率、所属文档编号等信息同步传给大模型,便于溯源和追踪。
RAG典型工作流程如下:
- 用户提出查询(Query)。
- 检索模块召回若干候选文档。
- 压缩模块对每份文档基于query进行内容提取,仅保留相关核心片段。
- 将所有压缩后的内容拼接成最终上下文,输入大模型进行统一“阅读理解”与作答。
- 最终输出答案,并可溯源相关内容证据。
下方提供了一个标准Contextual Compression工作流程的实现。包括了自定义的上下文压缩器、融合检索与压缩的检索器,以及典型的文件初始化、检索、压缩和问答过程演示。这为构建“高质量、低噪声、强推理能力”的RAG系统打下了坚实基础。
适用场景举例:
- 多轮长对话延续,历史可以用压缩方法保留关键信息供模型记忆。
- 法律、医学、学术等领域长文案场景,帮助大模型高效阅读和准确引用长文档答案。
- 通用RAG问答系统的查询-检索-生成链路,大幅提升端到端系统的答案质量。
ContextualCompression.py
# 导入typing模块的List、Dict、Any、Optional用于类型注解
from typing import List, Dict, Any, Optional
# 导入pydantic的ConfigDict用于模型参数配置
from pydantic import ConfigDict
# 导入langchain_core.retrievers的BaseRetriever作为检索器基类
from langchain_core.retrievers import BaseRetriever
# 导入langchain_core.documents的Document类用于封装文档对象
from langchain_core.documents import Document
# 导入langchain_core.language_models的BaseLanguageModel作为大语言模型基类
from langchain_core.language_models import BaseLanguageModel
# 导入langchain_core.prompts的PromptTemplate用于构建提示模板
from langchain_core.prompts import PromptTemplate
# 导入自定义llm对象(大语言模型)
from llm import llm
# 导入自定义的get_vector_store函数,用于获取向量数据库
from vector_store import get_vector_store
# 定义上下文压缩器,根据用户查询对文档进行压缩,仅保留相关内容
class ContextualCompressor:
# 类描述:上下文压缩器,根据查询压缩文档内容,只保留与查询相关部分
def __init__(self, llm: BaseLanguageModel):
# 初始化llm
self.llm = llm
# 构建文档压缩用的提示模板
self.compression_template = PromptTemplate(
input_variables=["query", "document"],
template=(
"请根据用户查询,从以下文档中提取与查询相关的关键信息,"
"去除不相关的内容,保留核心信息。\n\n"
"用户查询:{query}\n\n"
"原始文档:\n{document}\n\n"
"请只返回压缩后的相关内容,去除所有不相关的信息。"
"如果文档中没有相关信息,请返回\"无相关信息\"。\n\n"
"压缩后的内容:"
)
)
# 方法:执行文档压缩
def compress(self, query: str, document: str) -> str:
"""
压缩文档内容,根据查询只保留相关内容
"""
# 格式化提示词,生成新的prompt
prompt = self.compression_template.format(query=query, document=document)
# 调用llm执行内容压缩,获取压缩后内容并去掉首尾空白
compressed = self.llm.invoke(prompt).content.strip()
# 如果结果为空或包含“无相关信息”,返回空字符串
if not compressed or "无相关信息" in compressed:
return ""
# 返回压缩后的有效内容
return compressed
# 定义带上下文压缩功能的检索器
class ContextualCompressionRetriever(BaseRetriever):
# 类描述:先检索文档,再根据查询压缩文档内容
vector_store: Any
compressor: ContextualCompressor
k: int = 4
model_config = ConfigDict(arbitrary_types_allowed=True)
# 方法:根据查询检索和压缩相关文档
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
# 获取参数k,默认为成员变量k
k = kwargs.get("k", self.k)
# 第一步:使用向量数据库进行相似度检索,选出候选文档(数量为top-k的两倍)
candidate_docs = self.vector_store.similarity_search_with_score(query, k=k * 2)
# 第二步:对每个候选文档内容进行压缩处理
compressed_docs = []
for doc, distance in candidate_docs:
# 对文档内容进行压缩,仅保留与查询相关的部分
compressed_content = self.compressor.compress(query, doc.page_content)
# 只保存压缩后有内容的文档对象
if compressed_content:
compressed_docs.append(
Document(
page_content=compressed_content,
metadata={
**doc.metadata,
"score": float(distance),
"original_length": len(doc.page_content),
"compressed_length": len(compressed_content),
"compression_ratio": len(compressed_content) / len(doc.page_content) if doc.page_content else 0,
"retrieval_method": "contextual_compression"
}
)
)
# 第三步:按照原始检索分数排序,选出top-k个文档返回
compressed_docs.sort(key=lambda x: x.metadata.get("score", float('inf')))
return compressed_docs[:k]
# 定义上下文压缩RAG系统,综合检索和上下文压缩能力
class ContextualCompressionRAG:
# 类描述:上下文压缩RAG系统,结合检索和压缩流程
def __init__(self, retriever: ContextualCompressionRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str] = None):
# 初始化检索器和大语言模型
self.retriever = retriever
self.llm = llm
# 若未指定自定义答案模板,则使用默认模板
template = answer_prompt_template or (
"已知以下相关信息(经过上下文压缩,只保留与查询相关的关键信息):\n\n{context}\n\n"
"用户查询:{query}\n\n"
"请根据上述信息,准确全面地回答用户问题。\n\n答案:"
)
# 构建大语言模型输出答案用的提示模板
self.answer_prompt_template = PromptTemplate(
input_variables=["context", "query"], template=template
)
# 方法:执行检索-压缩-问答链,生成最终答案
def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
# 检索相关文档并压缩
docs = self.retriever._get_relevant_documents(query, k=k)
# 将压缩后的文档内容拼接成最终上下文
context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
# 根据上下文与查询,调用llm生成答案
prompt = self.answer_prompt_template.format(context=context, query=query)
answer = self.llm.invoke(prompt).content.strip()
# 计算全部原文和压缩后内容长度
total_original_length = sum([doc.metadata.get("original_length", 0) for doc in docs])
total_compressed_length = sum([doc.metadata.get("compressed_length", 0) for doc in docs])
# 平均压缩比
avg_compression_ratio = total_compressed_length / total_original_length if total_original_length > 0 else 0
# 返回问答结果、检索文档及压缩统计信息
return {
"query": query,
"answer": answer,
"retrieved_documents": docs,
"num_documents": len(docs),
"compression_stats": {
"total_original_length": total_original_length,
"total_compressed_length": total_compressed_length,
"avg_compression_ratio": avg_compression_ratio,
"space_saved": total_original_length - total_compressed_length
}
}
# 创建向量库实例,指定持久化路径和集合名
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="contextual_compression"
)
# 创建上下文压缩器实例
compressor = ContextualCompressor(llm)
# 创建带上下文压缩功能的检索器实例
compression_retriever = ContextualCompressionRetriever(
vector_store=vector_store,
compressor=compressor,
k=4
)
# 创建上下文压缩RAG系统实例
rag_system = ContextualCompressionRAG(retriever=compression_retriever, llm=llm)
# 定义示例文档(内容较长且包含多个主题)
documents = [
"人工智能(AI)是计算机科学的一个分支,旨在创建能够执行通常需要人类智能的任务的系统。"
"机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习,而无需明确编程。"
"深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。"
"神经网络由多个层组成,每层包含多个神经元,通过反向传播算法进行训练。"
"卷积神经网络(CNN)特别适合处理图像数据,而循环神经网络(RNN)适合处理序列数据。"
"Transformer架构是NLP领域的重要突破,它使用自注意力机制来处理序列数据。",
"自然语言处理(NLP)是人工智能的一个分支,专注于使计算机能够理解、解释和生成人类语言。"
"BERT和GPT是Transformer架构的两个重要应用,分别用于理解任务和生成任务。"
"BERT通过双向编码器理解上下文,GPT通过自回归生成器生成文本。"
"词嵌入技术将词汇转换为向量表示,Word2Vec、GloVe和FastText是常见的词嵌入方法。"
"注意力机制允许模型关注输入序列的不同部分,提高了模型的表达能力。",
"计算机视觉是人工智能的另一个重要分支,专注于使计算机能够理解和分析视觉信息。"
"图像分类、目标检测和图像分割是计算机视觉的三个主要任务。"
"ResNet、YOLO和U-Net分别是这三个任务的代表性模型。"
"数据增强技术可以提高模型的泛化能力,包括旋转、缩放、裁剪等操作。"
"迁移学习允许将在一个任务上训练的模型应用到另一个相关任务上。",
]
# 打印初始化信息
print("正在初始化知识库...")
# 向向量库批量添加文档及元数据
vector_store.add_texts(
documents,
metadatas=[{"topic": "人工智能"}, {"topic": "自然语言处理"}, {"topic": "计算机视觉"}]
)
# 打印完成信息
print("知识库初始化完成!\n")
# 打印分隔线和当前示例说明
print("="*60)
print("上下文压缩示例")
print("="*60)
# 定义示例查询
query = "什么是深度学习?它和神经网络有什么关系?"
# 执行上下文压缩和答案生成流程
result = rag_system.generate_answer(query, k=3)
# 打印最终问答结果
print(f"\n用户查询:{result['query']}")
print(f"\n生成的答案:\n{result['answer']}")
print(f"\n检索到的文档数量:{result['num_documents']}")
print(f"\n压缩统计信息:")
print(f" 原始总长度:{result['compression_stats']['total_original_length']} 字符")
print(f" 压缩后总长度:{result['compression_stats']['total_compressed_length']} 字符")
print(f" 平均压缩比:{result['compression_stats']['avg_compression_ratio']:.2%}")
print(f" 节省空间:{result['compression_stats']['space_saved']} 字符")
# 打印每个压缩后文档的详细信息
print(f"\n压缩后的文档详情:")
for i, doc in enumerate(result['retrieved_documents'], 1):
print(f"\n文档 {i}:")
print(f" 原始长度:{doc.metadata.get('original_length', 'N/A')} 字符")
print(f" 压缩后长度:{doc.metadata.get('compressed_length', 'N/A')} 字符")
print(f" 压缩比:{doc.metadata.get('compression_ratio', 0):.2%}")
print(f" 压缩后内容:{doc.page_content[:200]}...")
2.混合检索(Hybrid Retrieval) #
混合检索(Hybrid Retrieval)是一种结合稀疏检索(如关键词/符号检索)与密集检索(如语义向量检索)优势的检索方式,常用于提升信息检索系统的准确性与召回率。在大模型RAG问答系统中,通常将稀疏检索和密集检索的结果进行加权融合后,选取得分最高的文档返还给大模型生成答案。
常见做法包括:
- 稀疏检索(如BM25):通过关键词匹配查找相关文档,适合捕获与查询词高度重合的内容。
- 密集检索(如向量检索):将查询和文档编码为向量,计算其余弦相似度,能够发现语义相关但词面不一致的文档。
- 加权融合:可指定稀疏分数、密集分数的权重,进行归一化加权合并,得到最终的融合分数并排序。
这样,混合检索能够兼顾传统信息检索和语义检索的优势,大幅提升复杂查询下的召回和精度,是现代RAG系统中非常推荐的一种检索方案。
HybridRetrieval.py
# 导入List、Dict、Any、Optional类型用于类型注解
from typing import List, Dict, Any, Optional
# 导入pydantic库中的ConfigDict,用于Pydantic模型配置
from pydantic import ConfigDict
# 导入langchain核心模块的基础检索器基类
from langchain_core.retrievers import BaseRetriever
# 导入langchain核心模块的文档对象类型
from langchain_core.documents import Document
# 导入langchain核心模块的基础LLM类型
from langchain_core.language_models import BaseLanguageModel
# 导入langchain核心模块的提示词模板类
from langchain_core.prompts import PromptTemplate
# 导入自定义llm对象
from llm import llm
# 导入自定义获取向量存储的函数
from vector_store import get_vector_store
# 导入自定义嵌入对象
from embeddings import embeddings
# 导入collections模块中的Counter用于词频统计
from collections import Counter
# 导入math模块用于数学计算
import math
# 定义BM25稀疏检索器
class BM25Retriever:
"""BM25稀疏检索器:基于关键词匹配的检索"""
# 初始化BM25检索器,传入文档列表
def __init__(self, documents: List[str]):
self.documents = documents
# 对每个文档进行分词
self.tokenized_docs = [self._tokenize(doc) for doc in documents]
# 构建BM25索引
self.bm25_index = self._build_bm25_index()
# 文本分词函数
def _tokenize(self, text: str) -> List[str]:
"""简单分词(实际应用中应使用jieba等专业分词工具)"""
# 按正则表达式分隔出单词字符,全部转小写
import re
tokens = re.findall(r'\w+', text.lower())
return tokens
# 构建BM25索引,包括文档频率、平均文档长度、文档数
def _build_bm25_index(self):
"""构建BM25索引"""
# 统计文档频率(DF)
doc_freq = Counter()
for doc_tokens in self.tokenized_docs:
unique_tokens = set(doc_tokens)
for token in unique_tokens:
doc_freq[token] += 1
# 计算所有文档长度
doc_lengths = [len(doc) for doc in self.tokenized_docs]
# 计算平均文档长度
avg_doc_length = sum(doc_lengths) / len(doc_lengths) if doc_lengths else 0
# 返回字典结构包含所有必要统计量
return {
"doc_freq": doc_freq,
"avg_doc_length": avg_doc_length,
"total_docs": len(self.documents)
}
# 计算逆文档频率IDF
def _calculate_idf(self, term: str) -> float:
"""计算逆文档频率(IDF)"""
df = self.bm25_index["doc_freq"].get(term, 0)
if df == 0:
return 0
total_docs = self.bm25_index["total_docs"]
# 按照BM25标准公式计算IDF
return math.log((total_docs - df + 0.5) / (df + 0.5) + 1.0)
# 计算BM25分数(query与某文档之间的得分)
def _calculate_bm25_score(self, query_tokens: List[str], doc_tokens: List[str], doc_idx: int) -> float:
"""计算BM25分数"""
k1 = 1.5 # BM25参数k1
b = 0.75 # BM25参数b
score = 0.0
# 获取文档长度
doc_length = len(doc_tokens)
# 获取语料平均文档长度
avg_doc_length = self.bm25_index["avg_doc_length"]
# 统计题词频
term_freq = Counter(doc_tokens)
# 对每个query中的词,累加BM25得分
for term in query_tokens:
if term not in term_freq:
continue
# 计算IDF
idf = self._calculate_idf(term)
# 计算词频相关系数TF部分
f = term_freq[term]
numerator = f * (k1 + 1)
denominator = f + k1 * (1 - b + b * (doc_length / avg_doc_length))
tf = numerator / denominator if denominator != 0 else 0
# 把该词的BM25贡献值加入总分
score += idf * tf
return score
# 外部接口,执行检索
def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
"""执行BM25检索"""
# 对输入问题分词
query_tokens = self._tokenize(query)
# 依次对所有文档计算BM25分数
scores = []
for idx, doc_tokens in enumerate(self.tokenized_docs):
score = self._calculate_bm25_score(query_tokens, doc_tokens, idx)
# 只保留分数大于0的文档
if score > 0:
scores.append({
"document": self.documents[idx],
"score": score,
"index": idx
})
# 按BM25分数降序排
scores.sort(key=lambda x: x["score"], reverse=True)
# 取前k条返回
return scores[:k]
# 定义密集检索器(向量检索)
class DenseRetriever:
"""密集检索器:基于向量相似度的检索"""
# 初始化传入向量库对象
def __init__(self, vector_store: Any):
self.vector_store = vector_store
# 外部接口,执行向量检索
def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
"""执行密集向量检索"""
# 使用向量库的similarity_search_with_score方法做相关性检索
docs = self.vector_store.similarity_search_with_score(query, k=k)
results = []
# 将检索结果组装为字典形式,包含内容、得分和元信息
for doc, distance in docs:
results.append({
"document": doc.page_content,
"score": float(distance),
"metadata": doc.metadata
})
return results
# 定义混合检索器
class HybridRetriever(BaseRetriever):
"""混合检索器:结合稀疏检索和密集检索"""
# 向量库对象
vector_store: Any
# BM25检索器
bm25_retriever: BM25Retriever
# 密集检索器
dense_retriever: DenseRetriever
# 密集检索权重(默认0.7)
dense_weight: float = 0.7
# 稀疏检索权重(默认0.3)
sparse_weight: float = 0.3
# 导数排名算法的k值(默认60)
rank_k: int = 60
# 检索文档数量
k: int = 4
# 配置模型参数,允许任意类型
model_config = ConfigDict(arbitrary_types_allowed=True)
# 排名变换——根据排名算导数分数,用于后续融合排序
def _calculate_rank_score(self, rank: int) -> float:
"""
计算加权导数排名分数
公式:分数 = 1 / (k + rank)
其中:rank为排名位置(从1开始),k为经验常数(通常取60)
"""
return 1.0 / (self.rank_k + rank)
# 融合稀疏检索和密集检索结果
def _fuse_results(self, sparse_results: List[Dict], dense_results: List[Dict]) -> List[Document]:
"""
融合稀疏检索和密集检索的结果
算法:
1. 对每个检索方法的结果,使用导数排名算法计算分数
2. 按权重加权融合
3. 按融合分数排序
"""
# 建立一个文档内容字符串到其稀疏/密集分数映射
doc_scores: Dict[str, Dict[str, float]] = {}
# 遍历稀疏检索结果,标记其排名分数
for rank, result in enumerate(sparse_results, 1):
doc = result["document"]
rank_score = self._calculate_rank_score(rank)
if doc not in doc_scores:
doc_scores[doc] = {"sparse_score": 0.0, "dense_score": 0.0}
doc_scores[doc]["sparse_score"] = rank_score
# 遍历密集检索结果,标记其排名分数
for rank, result in enumerate(dense_results, 1):
doc = result["document"]
rank_score = self._calculate_rank_score(rank)
if doc not in doc_scores:
doc_scores[doc] = {"sparse_score": 0.0, "dense_score": 0.0}
doc_scores[doc]["dense_score"] = rank_score
# 计算融合分数并组装结果
fused_results = []
for doc, scores in doc_scores.items():
# 融合分数 = 稀疏分数 * 稀疏权重 + 密集分数 * 密集权重
fused_score = scores["sparse_score"] * self.sparse_weight + scores["dense_score"] * self.dense_weight
fused_results.append({
"document": doc,
"fused_score": fused_score,
"sparse_score": scores["sparse_score"],
"dense_score": scores["dense_score"]
})
# 按融合分数排序
fused_results.sort(key=lambda x: x["fused_score"], reverse=True)
# 每个结果转换成Document对象并附带元数据
result_docs = []
for result in fused_results:
result_docs.append(Document(
page_content=result["document"],
metadata={
"fused_score": result["fused_score"],
"sparse_score": result["sparse_score"],
"dense_score": result["dense_score"],
"retrieval_method": "hybrid"
}
))
return result_docs
# 检索入口,返回最相关的k个文档
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
"""执行混合检索"""
# 若外部未传k,则用默认的k
k = kwargs.get("k", self.k)
# 兼容融合,检索稀疏和密集通路各返回2k个(冗余用于融合)
sparse_results = self.bm25_retriever.retrieve(query, k=k * 2)
dense_results = self.dense_retriever.retrieve(query, k=k * 2)
# 融合排序
fused_docs = self._fuse_results(sparse_results, dense_results)
# 返回Top-k
return fused_docs[:k]
# 定义混合Retrieval-Augmented Generation系统
class HybridRetrievalRAG:
"""混合检索RAG系统:结合稀疏检索和密集检索"""
# 初始化传入混合检索器、LLM对象、可选自定义模板
def __init__(self, retriever: HybridRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str] = None):
self.retriever = retriever
self.llm = llm
# 默认答题模板
template = answer_prompt_template or (
"已知以下相关信息(通过混合检索技术获取):\n\n{context}\n\n"
"用户查询:{query}\n\n"
"请根据上述信息,准确全面地回答用户问题。\n\n答案:"
)
# 用PromptTemplate封装模板
self.answer_prompt_template = PromptTemplate(
input_variables=["context", "query"], template=template
)
# 外部接口,根据query生成答案
def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
"""生成答案"""
# 执行混合检索
docs = self.retriever._get_relevant_documents(query, k=k)
# 拼接检索到的文档作为上下文
context = "\n\n".join([f"文档 {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
# 用模板生成prompt
prompt = self.answer_prompt_template.format(context=context, query=query)
# 调用llm生成答案
answer = self.llm.invoke(prompt).content.strip()
# 返回包含查询、答案、检索到的文档等信息的字典
return {
"query": query,
"answer": answer,
"retrieved_documents": docs,
"num_documents": len(docs)
}
# 初始化组件部分
# 创建向量库实例,设置存储目录和集合名
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="hybrid_retrieval"
)
# 示例文档列表,包括AI、区块链、量子计算相关内容
documents = [
"人工智能(AI)是计算机科学的一个分支,旨在创建能够执行通常需要人类智能的任务的系统。"
"机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习,而无需明确编程。"
"深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。",
"区块链技术是一种分布式账本技术,通过密码学方法确保数据的安全性和不可篡改性。"
"比特币是第一个成功应用区块链技术的加密货币,它解决了数字货币的双重支付问题。"
"以太坊是一个支持智能合约的区块链平台,允许开发者在其上构建去中心化应用。",
"量子计算是一种利用量子力学现象进行计算的新兴技术,具有巨大的计算潜力。"
"量子比特(qubit)是量子计算的基本单位,与经典比特不同,它可以同时处于0和1的叠加态。"
"量子纠缠是量子计算的关键特性,允许量子比特之间建立特殊的关联关系。",
]
# 文档添加入向量库,并指定元数据
print("正在初始化知识库...")
vector_store.add_texts(
documents,
metadatas=[{"topic": "人工智能"}, {"topic": "区块链"}, {"topic": "量子计算"}]
)
# 创建BM25稀疏检索器
bm25_retriever = BM25Retriever(documents)
# 创建密集检索器
dense_retriever = DenseRetriever(vector_store)
# 创建混合检索器,并指定各类检索、权重及参数
hybrid_retriever = HybridRetriever(
vector_store=vector_store,
bm25_retriever=bm25_retriever,
dense_retriever=dense_retriever,
dense_weight=0.7,
sparse_weight=0.3,
rank_k=60,
k=3
)
# 创建混合检索RAG系统
rag_system = HybridRetrievalRAG(retriever=hybrid_retriever, llm=llm)
# 打印知识库初始化完毕提示
print("知识库初始化完成!\n")
# 分割线,演示示例流程
print("="*60)
print("混合检索示例")
print("="*60)
# 示例查询
query = "什么是机器学习?"
# 执行混合检索并生成答案
result = rag_system.generate_answer(query, k=3)
# 打印用户查询内容
print(f"\n用户查询:{result['query']}")
# 打印最终生成的答案
print(f"\n生成的答案:\n{result['answer']}")
# 打印检索到的文档数量
print(f"\n检索到的文档数量:{result['num_documents']}")
# 打印检索到的文档详情,包括融合分数
print(f"\n检索到的文档详情(包含融合分数):")
for i, doc in enumerate(result['retrieved_documents'], 1):
print(f"\n文档 {i}:")
print(f" 融合分数:{doc.metadata.get('fused_score', 'N/A'):.6f}")
print(f" 稀疏分数:{doc.metadata.get('sparse_score', 'N/A'):.6f}")
print(f" 密集分数:{doc.metadata.get('dense_score', 'N/A'):.6f}")
print(f" 内容:{doc.page_content[:150]}...")3.文档重排序(Document Reranking) #
文档重排序(Document Reranking)是一种在检索流程中提升最终结果相关性的强有力方法。其基本思想是:
- 先用普通检索算法(如向量检索、关键词检索)选出一批初步候选文档;
- 然后利用更强的相关性判别模型(如Cross-Encoder、基于大模型的匹配评分)对每对“查询-文档”进行相关性精细打分,重新排序,最终取分数最高的文档作为RAG答案生成的依据。
核心过程
- 候选召回(Recall): 首先使用高效的向量检索,从知识库中召回一批与查询初步相关的文档,通常数量较多(如10-20条)。
- 重排序(Rerank): 对每一个“查询-文档”对,输入到更复杂的模型(如跨编码器Cross-Encoder,或用LLM直接相关性评分提示词)中,让模型给出0-1之间的相关性分数。
- 选择最优(Select Top-K): 按分数从高到低排序,只保留最相关的前几条(如Top-3/Top-4)进入大模型生成环节。
实现方式 常见的重排序技术包括:
- Cross-Encoder类模型: 输入“[查询][文档]”对,模型联合判断相关性,效果优于单独编码/余弦计算。
- 大语言模型提示词方式: 用Prompt提示LLM扮演判分员角色,让模型直接打相关性分(见下方代码中的
PromptTemplate与llm.invoke用法)。 - BERT/RoBERTa/SimCSE等模型微调: 针对“查询-文档相关性判别”任务训练得到的判别模型。
优势
- 能有效过滤掉表面相似但实际不相关的文档(如只包含部分关键词但含义不符的干扰文档)。
- 极大改善“向量检索+生成”RAG中的答案准确度、可控性和鲁棒性。
- 通常只需重排序最相关的少量文档,计算代价可控。
应用场景
- 高精度问答/助理型RAG系统
- 法律、医疗、金融等对精准文档匹配强依赖的领域
- 想要摆脱“伪召回”“语义漂移”困扰的知识检索生成场景
DocumentReranking.py
# 导入List、Dict、Any、Optional这些类型,用于类型注解
from typing import List, Dict, Any, Optional
# 导入ConfigDict,Pydantic配置专用
from pydantic import ConfigDict
# 导入基础的检索器基类
from langchain_core.retrievers import BaseRetriever
# 导入Document文档类
from langchain_core.documents import Document
# 导入基础LLM类型
from langchain_core.language_models import BaseLanguageModel
# 导入PromptTemplate提示模版类
from langchain_core.prompts import PromptTemplate
# 导入自定义llm对象
from llm import llm
# 导入自定义的get_vector_store函数
from vector_store import get_vector_store
# 导入自定义的embeddings对象
from embeddings import embeddings
# 定义文档重排序器类,基于Cross-Encoder结构实现
class DocumentReranker:
"""文档重排序器:使用Cross-Encoder架构评估查询-文档相关性"""
# 初始化方法,指定llm和最大长度
def __init__(self, llm: BaseLanguageModel, max_length: int = 512):
# 存储大语言模型对象
self.llm = llm
# 存储最大文本长度
self.max_length = max_length
# 配置相关性分数提示模版
self.relevance_template = PromptTemplate(
input_variables=["query", "document"],
template=(
"请评估以下查询与文档的相关性,给出0-1之间的分数(1表示完全相关,0表示完全不相关)。\n\n"
"查询:{query}\n\n"
"文档:{document}\n\n"
"请只返回一个0-1之间的浮点数分数,不要其他内容:"
)
)
# 文本预处理辅助方法
def _preprocess_text(self, text: str) -> str:
"""文本预处理:标准化处理并控制长度"""
# 移除多余空白
text = " ".join(text.split())
# 限制最大字符数并加省略号
if len(text) > self.max_length:
text = text[:self.max_length] + "..."
return text
# 构建 [query, document] 对辅助方法
def _build_query_document_pairs(self, query: str, documents: List[str]) -> List[List[str]]:
"""构建查询-文档对:将用户查询与每个候选文档组成匹配对"""
# 预处理查询
processed_query = self._preprocess_text(query)
# 预处理所有文档
processed_docs = [self._preprocess_text(doc) for doc in documents]
pairs = []
# 遍历每篇文档,构成一对
for doc in processed_docs:
# 加入[query, document] 格式的对
pairs.append([processed_query, doc])
return pairs
# 计算相关性分数
def _calculate_relevance_score(self, query: str, document: str) -> float:
"""计算查询-文档对的相关性分数(使用LLM模拟Cross-Encoder)"""
# 生成提示
prompt = self.relevance_template.format(query=query, document=document)
# 调用llm获取分数文本
response = self.llm.invoke(prompt).content.strip()
# 尝试用正则提取分数字符串
try:
import re
# 匹配0-1之间的浮点数
score_match = re.search(r'0?\.\d+|1\.0?|0\.0?', response)
if score_match:
score = float(score_match.group())
# 限定在0-1范围之间
score = max(0.0, min(1.0, score))
return score
except:
pass
# 无法解析则返回默认分数0.5
return 0.5
# 主重排序函数
def rerank(self, query: str, documents: List[str], top_k: Optional[int] = None) -> List[Dict[str, Any]]:
"""
执行文档重排序
技术流程:
1. 数据处理阶段:文档对构建、文本预处理、长度控制
2. 模型推理阶段:分词处理、特征提取、相关性计算
3. 结果排序阶段:分数归一化、文档排序、质量筛选
"""
# 步骤1: 构建查询-文档对
query_doc_pairs = self._build_query_document_pairs(query, documents)
# 步骤2: 计算每个对的相关性分数
scored_docs = []
for i, (processed_query, processed_doc) in enumerate(query_doc_pairs):
score = self._calculate_relevance_score(processed_query, processed_doc)
scored_docs.append({
"document": documents[i],
"score": score,
"original_index": i
})
# 步骤3: 根据分数降序排序
scored_docs.sort(key=lambda x: x["score"], reverse=True)
# 如果指定了top_k,只保留前k个
if top_k is not None:
scored_docs = scored_docs[:top_k]
return scored_docs
# 定义带重排序的检索器类
class RerankingRetriever(BaseRetriever):
"""带重排序的检索器:先进行向量检索,再进行重排序"""
# 声明向量库对象
vector_store: Any
# 声明重排序器对象
reranker: DocumentReranker
# 初始化时检索更多文档以便后续重排序
initial_k: int = 10
# 最终返回文档数量
k: int = 4
# pydantic的配置,允许任意类型字段
model_config = ConfigDict(arbitrary_types_allowed=True)
# 检索和重排序的主方法
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
"""执行检索和重排序"""
# 获取最终返回数量k(允许通过kwargs临时修改)
k = kwargs.get("k", self.k)
# 获取初始候选数量initial_k
initial_k = kwargs.get("initial_k", self.initial_k)
# 第一步:向量召回,得到多个候选文档
candidate_docs = self.vector_store.similarity_search_with_score(query, k=initial_k)
# 提取原文档内容
documents = [doc.page_content for doc, _ in candidate_docs]
# 第二步:对所有候选做重排序
reranked_results = self.reranker.rerank(query, documents, top_k=k)
# 转换回Document对象,同时加入重排序得分等元信息
result_docs = []
for result in reranked_results:
# 找到原始文档对象
original_doc = None
for doc, _ in candidate_docs:
if doc.page_content == result["document"]:
original_doc = doc
break
# 新建Document对象,合并元信息
result_docs.append(Document(
page_content=result["document"],
metadata={
**(original_doc.metadata if original_doc else {}),
"rerank_score": result["score"],
"original_index": result["original_index"],
"retrieval_method": "reranking"
}
))
return result_docs
# 定义整体的文档重排序RAG系统类
class DocumentRerankingRAG:
"""文档重排序RAG系统:结合向量检索和Cross-Encoder重排序"""
# 初始化,传入重排序检索器和llm
def __init__(self, retriever: RerankingRetriever, llm: BaseLanguageModel, answer_prompt_template: Optional[str] = None):
# 存储检索器
self.retriever = retriever
# 存储llm
self.llm = llm
# 配置自定义或默认的答案生成提示模版
template = answer_prompt_template or (
"已知以下相关信息(通过文档重排序技术筛选出的最相关文档):\n\n{context}\n\n"
"用户查询:{query}\n\n"
"注意:系统使用了Cross-Encoder重排序技术,能够更准确地评估查询与文档之间的相关性,"
"有效避免相似度陷阱问题。\n\n"
"请根据上述信息,准确全面地回答用户问题。\n\n答案:"
)
self.answer_prompt_template = PromptTemplate(
input_variables=["context", "query"], template=template
)
# 生成答案主函数
def generate_answer(self, query: str, k: int = 4) -> Dict[str, Any]:
"""生成答案"""
# 先检索与重排序
docs = self.retriever._get_relevant_documents(query, k=k)
# 构造上下文,插入每篇文档的内容及分数
context = "\n\n".join([
f"文档 {i+1}(重排序分数:{doc.metadata.get('rerank_score', 'N/A'):.4f}):\n{doc.page_content}"
for i, doc in enumerate(docs)
])
# 构造最终的提示
prompt = self.answer_prompt_template.format(context=context, query=query)
# 调用llm获取答案
answer = self.llm.invoke(prompt).content.strip()
# 返回带文档和答案的结构
return {
"query": query,
"answer": answer,
"retrieved_documents": docs,
"num_documents": len(docs)
}
# ===================== 初始化各组件 =======================
# 创建向量库对象
vector_store = get_vector_store(
persist_directory="chroma_db",
collection_name="document_reranking"
)
# 创建重排序器对象
reranker = DocumentReranker(llm, max_length=512)
# 创建带重排序的检索器对象
reranking_retriever = RerankingRetriever(
vector_store=vector_store,
reranker=reranker,
initial_k=10, # 初始检索10篇文档
k=4 # 重排序后输出4篇文档
)
# 实例化文档重排序RAG系统
rag_system = DocumentRerankingRAG(retriever=reranking_retriever, llm=llm)
# ================== 示例文档及知识库初始化 ===================
# 示例文档列表
documents = [
"人工智能(AI)是计算机科学的一个分支,旨在创建能够执行通常需要人类智能的任务的系统。"
"机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习,而无需明确编程。"
"深度学习是机器学习的一个子集,使用人工神经网络来模拟人脑的工作方式。",
"区块链技术是一种分布式账本技术,通过密码学方法确保数据的安全性和不可篡改性。"
"比特币是第一个成功应用区块链技术的加密货币,它解决了数字货币的双重支付问题。"
"以太坊是一个支持智能合约的区块链平台,允许开发者在其上构建去中心化应用。",
"量子计算是一种利用量子力学现象进行计算的新兴技术,具有巨大的计算潜力。"
"量子比特(qubit)是量子计算的基本单位,与经典比特不同,它可以同时处于0和1的叠加态。"
"量子纠缠是量子计算的关键特性,允许量子比特之间建立特殊的关联关系。",
"自然语言处理(NLP)是人工智能的一个分支,专注于使计算机能够理解、解释和生成人类语言。"
"Transformer架构是NLP领域的重要突破,它使用自注意力机制来处理序列数据。"
"BERT和GPT是Transformer架构的两个重要应用,分别用于理解任务和生成任务。",
"计算机视觉是人工智能的另一个重要分支,专注于使计算机能够理解和分析视觉信息。"
"卷积神经网络(CNN)是计算机视觉的核心技术,特别适合处理图像数据。"
"图像分类、目标检测和图像分割是计算机视觉的三个主要任务。",
]
# 打印初始化提示
print("正在初始化知识库...")
# 向向量库中添加文档及元信息
vector_store.add_texts(
documents,
metadatas=[{"topic": "人工智能"}, {"topic": "区块链"}, {"topic": "量子计算"},
{"topic": "自然语言处理"}, {"topic": "计算机视觉"}]
)
print("知识库初始化完成!\n")
# =============== 测试整体流程 ================
# 输出分隔符及说明
print("="*60)
print("文档重排序示例")
print("="*60)
# 配置示例查询
query = "什么是机器学习?它和深度学习有什么关系?"
# 执行检索、重排序、答案生成
result = rag_system.generate_answer(query, k=4)
# 打印最终结果及详细推理
print(f"\n用户查询:{result['query']}")
print(f"\n生成的答案:\n{result['answer']}")
print(f"\n重排序后选择的文档数量:{result['num_documents']}")
print(f"\n重排序后的文档详情(按相关性分数排序):")
for i, doc in enumerate(result['retrieved_documents'], 1):
print(f"\n文档 {i}(原始索引:{doc.metadata.get('original_index', 'N/A')}):")
print(f" 重排序分数:{doc.metadata.get('rerank_score', 'N/A'):.4f}")
print(f" 内容:{doc.page_content[:150]}...")