50.as_retriever #
as_retriever 是 Chroma 向量数据库中的一个方法,用于将向量数据库包装为兼容 LangChain 检索接口(Retriever)的对象。这样就可以方便地与各类LangChain链路、Agent等组件进行拼接,实现如RAG、问答等复杂应用。
常见用法如下:
search_type:设定检索方式,如 "similarity"(相似度检索)、"mmr"(最大边际相关性),可根据实际场景选择。search_kwargs:用于传递检索参数,如返回文档条数k。
调用 as_retriever 后返回的对象支持直接 .get_relevant_documents(query) 方法,从而获取与 query 最相关的文本内容,便于下游调用。
该方法极大地简化了数据与检索逻辑的衔接,实现了与大模型问答等能力的无缝连接。
50.1. 53.as_retriever.py #
53.as_retriever.py
#from langchain_chroma import Chroma
#from langchain_huggingface import HuggingFaceEmbeddings
from smartchain.embeddings import HuggingFaceEmbeddings
from smartchain.vectorstores import Chroma
# 模型路径(本地)
model_path = "C:/Users/Administrator/.cache/modelscope/hub/models/sentence-transformers/all-MiniLM-L6-v2"
# 初始化嵌入模型
embeddings = HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={"device": "cpu"}
)
# 初始化 Chroma 向量数据库
chroma_db = Chroma(
persist_directory="chroma_database",
embedding_function=embeddings,
collection_name="test",
collection_metadata={"hnsw:space": "cosine"}
)
# 检查数据库是否已包含数据
if not chroma_db._collection.count():
# 待入库的文本
texts = [
"你好,世界!",
"人工智能非常有趣。",
"机器学习是人工智能的重要领域。",
"深度学习通过神经网络模拟人脑。",
"欢迎使用Chroma向量数据库。",
]
# 每条文本对应的元数据
metadatas = [
{"lang": "en", "category": "greeting"},
{"lang": "en", "category": "tech"},
{"lang": "zh", "category": "tech"},
{"lang": "en", "category": "tech"},
{"lang": "zh", "category": "demo"},
]
# 向数据库中批量添加文本及元信息
chroma_db.add_texts(texts, metadatas)
# 创建检索器,指定检索类型和返回条数
retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 2})
# 用中文/英文查询检索相关文档
results = retriever.invoke("什么是人工智能?") # 可替换为英文或中文查询
for i, doc in enumerate(results):
print(f"检索结果{i}:{doc.page_content}")50.2. init.py #
smartchain/retrievers/init.py
from .vector_store import VectorStoreRetriever
from .base import BaseRetriever
__all__ = [
"VectorStoreRetriever",
"BaseRetriever",
]50.3. base.py #
smartchain/retrievers/base.py
# 导入抽象基类ABC和抽象方法装饰器abstractmethod
from abc import ABC, abstractmethod
# 定义检索器抽象基类
class BaseRetriever(ABC):
# 检索器抽象基类说明文档字符串
"""检索器抽象基类
所有检索器都应该继承此类并实现 _get_relevant_documents 方法。
"""
# 构造方法,初始化检索器
def __init__(self, **kwargs):
# 初始化检索器,这里不做具体实现
pass
# 抽象方法,子类必须实现该方法用于获取相关文档
@abstractmethod
def _get_relevant_documents(self, query, **kwargs):
# 获取相关文档(抽象方法,由子类实现)
# 参数:
# query: 查询字符串
# **kwargs: 其他参数
# 返回:
# 相关文档列表
pass
# 对外接口方法,调用检索器获取相关文档
def invoke(self, query, **kwargs):
# 调用检索器获取相关文档
# 参数:
# query: 查询字符串
# **kwargs: 其他参数
# 返回:
# 相关文档列表
return self._get_relevant_documents(query, **kwargs)
50.4. vector_store.py #
smartchain/retrievers/vector_store.py
from .base import BaseRetriever
# 向量存储检索器类,继承自BaseRetriever
class VectorStoreRetriever(BaseRetriever):
# 向量存储检索器说明文档字符串
"""向量存储检索器
用于从向量存储中检索文档,支持多种搜索类型。
"""
# 允许的搜索类型枚举
ALLOWED_SEARCH_TYPES = ("similarity", "similarity_score_threshold", "mmr")
# 构造方法,初始化向量存储检索器
def __init__(
self,
vectorstore,
search_type="similarity",
search_kwargs=None,
**kwargs
):
# 初始化向量存储检索器
# 参数:
# vectorstore: 向量存储实例
# search_type: 搜索类型,可选值:'similarity', 'similarity_score_threshold', 'mmr'
# search_kwargs: 搜索参数,如 {"k": 4, "score_threshold": 0.8}
# **kwargs: 其他参数
super().__init__(**kwargs)
# 保存向量存储实例
self.vectorstore = vectorstore
# 保存搜索类型
self.search_type = search_type
# 保存搜索参数字典,若为None则赋空字典
self.search_kwargs = search_kwargs or {}
# 检查搜索类型是否合法
if search_type not in self.ALLOWED_SEARCH_TYPES:
# 抛出ValueError异常,提示搜索类型不允许
raise ValueError(
f"search_type '{search_type}' 不允许。"
f"允许的值: {self.ALLOWED_SEARCH_TYPES}"
)
# 若选择 similarity_score_threshold,必须提供 score_threshold 参数
if search_type == "similarity_score_threshold":
# 获取分数阈值
score_threshold = self.search_kwargs.get("score_threshold")
# 判断分数阈值是否为数值类型且非None
if score_threshold is None or not isinstance(score_threshold, (int, float)):
# 抛出异常提示必须提供score_threshold参数
raise ValueError(
"使用 'similarity_score_threshold' 搜索类型时,"
"必须在 search_kwargs 中提供 'score_threshold' (float, 0~1)"
)
# 获取相关文档主方法
def _get_relevant_documents(self, query, **kwargs):
# 获取相关文档
# 参数:
# query: 查询字符串
# **kwargs: 其他参数,将与search_kwargs合并
# 返回:
# 相关文档列表
# 合并父类构造时的search_kwargs与本次调用传入的kwargs
search_kwargs = {**self.search_kwargs, **kwargs}
# 根据当前search_type分支调用不同的检索方法
if self.search_type == "similarity":
# 相似度搜索,调用vectorstore的similarity_search方法
docs = self.vectorstore.similarity_search(query, **search_kwargs)
elif self.search_type == "similarity_score_threshold":
# 带分数阈值的相似度搜索,先获取(doc, score)对
docs_and_scores = self.vectorstore.similarity_search_with_score(
query, **search_kwargs
)
# 获取分数阈值(最大距离阈值)
score_threshold = search_kwargs.get("score_threshold", 0.0)
# 注意:Chroma返回的是距离,距离越小越相似
# 这里直接过滤距离大于阈值的文档,假设score_threshold是最大距离阈值
docs = [
doc for doc, score in docs_and_scores
if score <= score_threshold
]
elif self.search_type == "mmr":
# 最大边际相关性检索,调用vectorstore的max_marginal_relevance_search方法
docs = self.vectorstore.max_marginal_relevance_search(
query, **search_kwargs
)
else:
# 不支持的检索类型,抛出异常
raise ValueError(f"不支持的搜索类型: {self.search_type}")
# 返回检索到的相关文档
return docs50.5. vectorstores.py #
smartchain/vectorstores.py
import os
import numpy as np
from abc import ABC, abstractmethod
import faiss
import chromadb
import uuid
+from .retrievers import VectorStoreRetriever
# 计算一个向量与多个向量的余弦相似度
def cosine_similarity(from_vec, to_vecs):
# 将from_vec转换为numpy数组并且强制类型为float
from_vec = np.array(from_vec, dtype=float)
to_vecs = np.array(to_vecs, dtype=float)
# 计算from_vec模长
norm1 = np.linalg.norm(from_vec)
similarities = []
for to_vec in to_vecs:
dot_product = np.sum(from_vec * to_vec)
norm_vec = np.linalg.norm(to_vec)
similarity = dot_product / (norm1 * norm_vec)
similarities.append(similarity)
return np.array(similarities)
def mmr_select(query_vector, doc_vectors, k=3, lambda_mult=0.5):
# 计算每个文档向量与查询向量(Query)的余弦相似度。这代表了文档的“相关性”。
quer_similarities = cosine_similarity(query_vector, doc_vectors)
# 选择与查询向量相似度最高的文档:文档 1。
# 找到与查询向量最相关的文档的下标,作为初始的已选文档 S:选择的结果集=selected=[0]
selected = [int(np.argmax(quer_similarities))]
while len(selected) < k:
# 存放每个候选文档的MMR分数
mmr_scores = []
for i in range(len(doc_vectors)):
if i not in selected:
# 相关性,指的是i对应的候选文档和查询文档这间的相似性
relevance = quer_similarities[i]
# 获取当前已选文档的向量集合
selected_vecs = doc_vectors[selected] # - S:选择的结果集
# 计算当前文档与所有的已选文档的余弦相似度
sims = cosine_similarity(doc_vectors[i], selected_vecs)
# 获取对已选中的文档最最大相似度 最不多样性的那个
# 与已选节点有最大相似度的那个就是最不具有多样性的节点
max_sim = np.max(sims)
mmr_score = lambda_mult * relevance - (1 - lambda_mult) * max_sim
mmr_scores.append((i, mmr_score))
# 选出MMR分数最高的文档索引
best_idx, best_score = max(mmr_scores, key=lambda x: x[1])
# 将选中的文档添加到已选文档中
selected.append(best_idx)
return selected
# 定义向量存储的抽象的基类 faiss=chromdb
class VectorStore(ABC):
# 添加文本到向量存储
@abstractmethod
def add_texts(self, texts, metadatas=None):
pass
# 最大边际相关性检索
def max_marginal_relevance_search(self, query, k=3, fetch_k=20):
pass
# 从文本批量构建向量存储
@abstractmethod
def from_texts(self, texts, embeddings, metadatas=None):
pass
@abstractmethod
def similarity_search(self, query: str, k: int = 4):
"""相似度搜索"""
pass
class Document:
def __init__(self, page_content, metadata, embedding_value=None):
self.page_content = page_content # 文档内容
self.metadata = metadata or {} # 文档的元数据
self.embedding_value = embedding_value # 文档对应的嵌入向量
class FAISS(VectorStore):
def __init__(self, embeddings):
self.embeddings = embeddings
# 初始化索引为空
self.index = None
# 初始化文档字典,键为文档的ID,值为一个文档Document对象
self.documents_by_id = {}
def add_texts(self, texts, metadatas):
if metadatas is None:
metadatas = [{}] * len(texts)
embedding_values = self.embeddings.embed_documents(texts)
embedding_values = np.array(embedding_values, dtype=np.float32)
# 如果还没有建立FAISS索引库,则新建它
if self.index is None:
dimension = len(embedding_values[0])
self.index = faiss.IndexFlatL2(dimension)
# 添加嵌入向量到FAISS索引库中
self.index.add(embedding_values)
# 获取已知的文档数量,用于新文档的编号
start_index = len(self.documents_by_id)
for i, (text, metadata, embedding_value) in enumerate(
zip(texts, metadatas, embedding_values)
):
# 构建文档ID
doc_id = str(start_index + i)
# 构建文档对象
doc = Document(
page_content=text, metadata=metadata, embedding_value=embedding_value
)
# 保存到字典里
self.documents_by_id[doc_id] = doc
# 类方法,通过文本批量创建FAISS向量数据库的实例
@classmethod
def from_texts(cls, texts, embeddings, metadatas=None):
instance = cls(embeddings=embeddings)
instance.add_texts(texts, metadatas=metadatas)
return instance
def max_marginal_relevance_search(self, query, k, fetch_k, lambda_mult=0.5):
# 获取查询文本的嵌入向量
query_embedding = self.embeddings.embed_query(query)
# 转换为二维numpy数组
query_vectors = np.array([query_embedding], dtype=np.float32)
if isinstance(self.index, faiss.Index):
# 执行索引库中的检索,返回索引及距离
_, indices = self.index.search(query_vectors, fetch_k)
# 获取候选文档对应的索引列表
candidate_indices = indices[0]
else:
raise RuntimeError("FAISS 不可用")
# 如果说候选文档数不足K个,则直接返回这些文档,不需要再走MMR了
if len(candidate_indices) <= k:
docs = []
for idx in candidate_indices:
doc_id = str(idx)
if doc_id in self.documents_by_id:
docs.append(self.documents_by_id[doc_id])
return docs
# 从candidate_indices用MMR算法挑选出K个元素
# 从字典中提取候选文档的嵌入向量 5个
candidate_vectors = np.array(
[self.documents_by_id[str(i)].embedding_value for i in candidate_indices],
dtype=np.float32,
)
# 通过MMR算法获取MMR选出的下标
selected_indices = mmr_select(
query_embedding, candidate_vectors, k=k, lambda_mult=lambda_mult
)
# 根据下标选出最终的文档对象
docs = []
# 遍历选中的索引,这个索引candidate_indices列表中的索引
for idx in selected_indices:
# 获取真实的文档索引或者说文档ID
doc_id = str(candidate_indices[idx])
# 通过真实的文档ID找到对应的文档对象
if doc_id in self.documents_by_id:
docs.append(self.documents_by_id[doc_id])
return docs
# 定义相似度检索方法,返回与查询最近的k个文档
def similarity_search(self, query: str, k: int = 4):
"""
相似度搜索
Args:
query: 查询文本
k: 返回的文档数量
Returns:
List[Document]: 最相似的文档列表
"""
# 获取查询文本的嵌入向量
query_embedding = self.embeddings.embed_query(query)
# 将嵌入向量转换为NumPy二维数组(形状为1行,d维)
query_vector = np.array([query_embedding], dtype=np.float32)
# 用FAISS索引执行k近邻检索,得到距离最近的k个索引
_, indices = self.index.search(query_vector, k)
# 创建用于存放检索到文档对象的列表
docs = []
# 遍历返回的每个文档索引
for idx in indices[0]:
# 把数字索引转为字符串形式的文档id
doc_id = str(idx)
# 只有字典中存在这个id的文档才加入最终结果
if doc_id in self.documents_by_id:
docs.append(self.documents_by_id[doc_id])
# 返回最终的相似文档列表
return docs
class Chroma(VectorStore):
# 设置默认 的集合名称为langchain
LANCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
def __init__(
self,
collection_name,
embedding_function,
persist_directory,
collection_metadata,
):
self.embedding_function = embedding_function
self.collection_name = collection_name
self.collection_metadata = collection_metadata
if persist_directory is not None:
# 如果指定了持久化目录,则使用执行化客户端
self._client = chromadb.PersistentClient(path=persist_directory)
else:
# 使用默认的内存客户端
self._client = chromadb.Client()
# 使用客户端获取或新建集合
self._collection = self._client.get_or_create_collection(
name=self.collection_name, metadata=self.collection_metadata
)
def add_texts(self, texts, metadatas, ids=None, **kwargs):
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
else:
ids = [str(_id) if _id is not None else str(uuid.uuid4()) for _id in ids]
if metadatas is None:
metadatas = [{}] * len(texts)
else:
# 如果元数据数量少于文本数量,则补齐元数据
length_diff = len(texts) - len(metadatas)
if length_diff > 0:
metadatas = metadatas + [{}] * length_diff
embedding_values = self.embedding_function.embed_documents(texts)
# 将文本 ID 嵌入 添加到集合中 upsert=update +insert
self._collection.upsert(
ids=ids, documents=texts, embeddings=embedding_values, metadatas=metadatas
)
return ids
# 类方法,通过文本批量创建FAISS向量数据库的实例
@classmethod
def from_texts(
cls,
texts,
embedding_function,
metadatas,
collection_name,
persist_directory,
collection_metadata,
):
instance = cls(
collection_name=collection_name,
embedding_function=embedding_function,
persist_directory=persist_directory,
collection_metadata=collection_metadata,
)
instance.add_texts(texts, metadatas=metadatas)
return instance
# 定义相似度检索方法,返回与查询最近的k个文档
def similarity_search(self, query: str, k: int = 4, filter=None, **kwargs):
docs_and_scores = self.similarity_search_with_score(
query=query, k=k, filter=filter, **kwargs
)
return [doc for doc, _ in docs_and_scores]
def _results_to_docs_and_scores(self, results):
docs_and_scores = []
if not results.get("ids") or not results["ids"][0]:
return docs_and_scores
ids = results["ids"][0]
documents = results["documents"][0]
metadatas = results["metadatas"][0]
distances = results["distances"][0]
for i, doc_id in enumerate(ids):
if documents[i] is not None:
doc = Document(
page_content=documents[i],
metadata=(
metadatas[i] if i < len(metadatas) and metadatas[i] else {}
),
)
score = distances[i] if i < len(distances) else 0.0
docs_and_scores.append((doc, score))
return docs_and_scores
# 处理chroma检索结果,转为(doc,score)元组列表
def similarity_search_with_score(
self, query: str, k: int = 4, filter=None, **kwargs
):
# query_embedding是一个向量,当然一个向量就是一个小数的列表 [0.1,0.2] 一维列表
query_embedding = self.embedding_function.embed_query(query)
if query_embedding is not None:
results = self._collection.query(
query_embeddings=[query_embedding], n_results=k, where=filter
)
else:
results = self._collection.query(
query_texts=[query], n_results=k, where=filter
)
return self._results_to_docs_and_scores(results)
+ def as_retriever(self, search_type="similarity", search_kwargs=None, **kwargs):
# """
# 创建向量存储检索器
#
# 参数:
# search_type: 搜索类型,可选值:
# - "similarity": 相似度搜索(默认)
# - "similarity_score_threshold": 带分数阈值的相似度搜索
# - "mmr": 最大边际相关性搜索
# search_kwargs: 搜索参数,例如:
# - {"k": 4}: 返回4个文档
# - {"score_threshold": 0.8}: 相似度阈值(用于 similarity_score_threshold)
# - {"fetch_k": 20, "lambda_mult": 0.5}: MMR参数
# **kwargs: 其他附加参数
#
# 返回:
# VectorStoreRetriever: 检索器实例
# """
# 如果没有提供search_kwargs则默认使用空字典
+ return VectorStoreRetriever(
+ vectorstore=self, # 传入当前vectorstore实例
+ search_type=search_type, # 搜索类型
+ search_kwargs=search_kwargs or {}, # 搜索参数字典,默认为空字典
+ **kwargs # 其他命名参数
+ )
50.6. 类 #
50.6.1 类说明 #
| 类名 | 类型 | 作用 | 关键属性/方法 |
|---|---|---|---|
| Document | 数据类 | 存储文档内容和元数据 | page_content: 文档文本内容metadata: 元数据字典 |
| HuggingFaceEmbeddings | 嵌入模型类 | 使用 HuggingFace 模型生成文本嵌入向量 | model_name: 模型名称embed_query(): 嵌入单个查询文本embed_documents(): 批量嵌入文档文本 |
| Chroma | 向量存储类 | Chroma 向量数据库封装 | embedding_function: 嵌入函数_collection: Chroma 集合_client: Chroma 客户端add_texts(): 添加文本到数据库similarity_search(): 相似度搜索as_retriever(): 创建检索器 |
| VectorStoreRetriever | 检索器类 | 从向量存储中检索文档 | vectorstore: 向量存储实例search_type: 搜索类型search_kwargs: 搜索参数invoke(): 执行检索_get_relevant_documents(): 核心检索逻辑 |
| BaseRetriever | 抽象基类 | 定义检索器接口规范 | invoke(): 对外调用接口_get_relevant_documents(): 抽象方法(需子类实现) |
| VectorStore | 抽象基类 | 定义向量存储接口规范 | add_texts(): 抽象方法similarity_search(): 抽象方法 |
50.6.2 类图 #
50.6.3 时序图 #
50.6.4 调用过程 #
阶段 1:初始化嵌入模型
embeddings = HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={"device": "cpu"}
)- 创建
HuggingFaceEmbeddings实例 - 指定本地模型路径和设备(CPU)
- 用于后续文本向量化
阶段 2:初始化向量数据库
chroma_db = Chroma(
persist_directory="chroma_database",
embedding_function=embeddings,
collection_name="test",
collection_metadata={"hnsw:space": "cosine"}
)调用链:
__init__()方法- 保存
embedding_function、collection_name、collection_metadata - 创建 ChromaDB 客户端(持久化或内存)
- 获取或创建集合(collection)
- 保存
阶段 3:检查并添加数据
if not chroma_db._collection.count():
chroma_db.add_texts(texts, metadatas)调用链:
_collection.count()检查集合是否为空add_texts()方法(如果为空)- 为每个文本生成 UUID 作为 ID
- 补齐元数据(如果数量不足)
- 调用
embedding_function.embed_documents(texts)批量计算嵌入向量 - 调用
_collection.upsert()将文本、嵌入向量和元数据存入数据库
阶段 4:创建检索器
retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 2})调用链:
as_retriever()方法- 创建
VectorStoreRetriever实例 - 传入当前
Chroma实例、搜索类型和搜索参数
- 创建
VectorStoreRetriever.__init__()方法- 调用
super().__init__()初始化基类 - 保存
vectorstore、search_type、search_kwargs - 验证搜索类型是否合法
- 调用
阶段 5:执行检索
results = retriever.invoke("什么是人工智能?")调用链:
invoke()方法(继承自BaseRetriever)- 调用
_get_relevant_documents(query)
- 调用
_get_relevant_documents()方法(核心检索逻辑)根据
search_type选择检索方式:similarity 模式(本例使用):
docs = self.vectorstore.similarity_search(query, **search_kwargs)内部调用链:
similarity_search()→similarity_search_with_score()embed_query(query)计算查询嵌入向量_collection.query()在 ChromaDB 中查询相似文档_results_to_docs_and_scores()将结果转换为 Document 对象列表- 返回 Document 列表
阶段 6:输出结果
for i, doc in enumerate(results):
print(f"检索结果{i}:{doc.page_content}")- 遍历
results(Document列表) - 打印每个文档的
page_content
50.6.5 搜索类型说明 #
VectorStoreRetriever 支持三种搜索类型:
| 搜索类型 | 说明 | 参数 |
|---|---|---|
| similarity | 相似度搜索(默认) | {"k": 4} - 返回前k个最相似的文档 |
| similarity_score_threshold | 带分数阈值的相似度搜索 | {"k": 4, "score_threshold": 0.8} - 返回相似度大于阈值的文档 |
| mmr | 最大边际相关性搜索 | {"k": 4, "fetch_k": 20, "lambda_mult": 0.5} - 平衡相关性和多样性 |
51.TFIDFRetriever #
TFIDFRetriever 是一种基于传统文本信息检索技术(词频-逆文档频率,Term Frequency-Inverse Document Frequency, 简称 TF-IDF)的检索器,无需依赖深度学习模型或嵌入模型,即可对一批文本文档实现高效的关键词相关性检索。其核心优势在于不需要下载任何模型,速度快,适用于基于关键词匹配的小型场景或无 GPU、无法联网时的轻量化检索需求。
工作机制
- 对所有文档内容进行分词(中英文均支持),统计每个词的词频(TF)和逆文档频率(IDF)。
- 输入查询时,将查询内容与所有文档分别计算 tf-idf 向量。
- 通过计算查询和文档的 tf-idf 相似度(通常为余弦相似度)进行排序,返回得分最高的若干文档。
典型应用场合
- 数据量较小,或本地调试时的快速检索。
- 没有嵌入模型,或不适合向量检索的关键词场合。
- 文本为中文、英文或多语言均可,中文采用简易分词算法。
- 当 VectorStore、Embeddings 暂不可用时的备选方案。
接口一览
TFIDFRetriever.from_documents(docs):根据输入的 Document 列表构建检索器,并自动计算 tf-idf 索引。retriever.invoke(query, *, k=4):根据用户输入的查询内容,返回相关性最高的前 k 条 Document(默认通常返回全部排序后的文档)。- 支持自定义分词、去停用词、返回数量等参数(参见 smartchain/retrievers/TFIDFRetriever 代码)。
注意事项
- 适合结构较为规范、文本内容不长的场景。
- 对于复杂语义、同义词、语序变化的检索,TFIDF 通常不如深度语义模型,但响应速度极快。
- 中文文本依赖内置分词算法,无需额外安装分词包。
51.1. 51.TFIDFRetriever.py #
51.TFIDFRetriever.py
#from langchain_community.retrievers import TFIDFRetriever
#from langchain_core.documents import Document
# 导入自定义Document类
from smartchain.documents import Document
# 导入自定义的TFIDFRetriever检索器
from smartchain.retrievers import TFIDFRetriever
# 构造一个包含4条文本的示例文档列表
docs = [
Document(page_content="深度学习通过神经网络取得突破。"),
Document(page_content="机器人学结合机械工程与人工智能。"),
Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
Document(page_content="人工智能正推动机器人自主学习与创新。"),
]
# 用文档列表创建一个TFIDF检索器实例
retriever = TFIDFRetriever.from_documents(docs)
# 设置检索的查询语句
query = "机器人与人工智能"
# 调用检索器进行查询,返回与query最相关的2条文档
results = retriever.invoke(query,k=2)
# 打印查询内容
print(f"查询:{query}")
# 打印返回的文档数量
print(f"返回 {len(results)} 条:")
# 遍历输出每条检索结果的内容
for i, doc in enumerate(results, 1):
print(f"{i}. {doc.page_content}")51.2. tfidf.py #
smartchain/retrievers/tfidf.py
# 导入基础检索器基类
from smartchain.retrievers.base import BaseRetriever
# uv add scikit-learn
# 导入TF-IDF向量化器
from sklearn.feature_extraction.text import TfidfVectorizer
# 导入余弦相似度计算函数
from sklearn.metrics.pairwise import cosine_similarity
# 导入numpy库用于数组处理
import numpy as np
# 导入jieba用于中文分词
import jieba
import re
# 定义中文分词函数,作为tokenizer传递给TfidfVectorizer
def _chinese_tokenizer(text):
# 使用jieba的搜索引擎模式进行分词,会对长词再次切分,提高召回率
tokens = list(jieba.cut_for_search(text))
# 过滤掉标点符号、空白字符和单字符词,只保留有效词
tokens = [t for t in tokens if re.match(r'^[\u4e00-\u9fa5a-zA-Z0-9]+$', t) and len(t) > 1]
# 用空格拼接成字符串返回
return " ".join(tokens)
# 定义TFIDFRetriever检索器,继承自BaseRetriever
class TFIDFRetriever(BaseRetriever):
# TFIDF 检索器说明文档字符串
"""
TFIDF 检索器
使用 TFIDF (Term Frequency-Inverse Document Frequency) 算法进行文档检索。
通过计算查询文本与文档集合的 TFIDF 向量,使用余弦相似度找到最相关的文档。
"""
# 初始化方法
def __init__(self, vectorizer=None, documents=None, **kwargs):
"""
初始化 TFIDF 检索器
参数:
vectorizer: TfidfVectorizer 实例,如果为 None 则创建默认实例
documents: Document 列表,用于训练向量化器
**kwargs: 传递给 TfidfVectorizer 的其他参数
"""
# 调用父类初始化方法,传递可能的关键字参数
super().__init__(**kwargs)
# 默认采用中文tokenizer,如果未手动指定,则设置为_chinese_tokenizer
if 'tokenizer' not in kwargs:
kwargs['tokenizer'] = _chinese_tokenizer
# 初始化TfidfVectorizer
self.vectorizer = TfidfVectorizer(**kwargs)
# 保存文档列表,如果为None则赋值为空列表
self.documents = documents or []
# 如果文档列表不为空,则训练向量化器
if self.documents:
self._fit_vectorizer()
# 训练向量化器并生成所有文档TF-IDF向量
def _fit_vectorizer(self):
"""训练 TFIDF 向量化器并计算所有文档的向量"""
# 提取所有文档的内容,组成文本列表
texts = [doc.page_content for doc in self.documents]
# 用所有文档内容训练TF-IDF向量化器,像学习字典,先看所有文档,建立词汇表和规则
self.vectorizer.fit(texts)
# 将所有文档内容转换为TF-IDF向量矩阵,像查字典,根据词汇表和规则,将文档转为向量
self.document_vectors = self.vectorizer.transform(texts)
# 通过类方法根据文档列表创建检索器对象
@classmethod
def from_documents(cls, documents, **kwargs):
"""
从文档列表创建 TFIDFRetriever 实例
参数:
documents: Document 列表
**kwargs: 其他参数,可传递给 TfidfVectorizer
返回:
TFIDFRetriever 实例
"""
# 创建检索器实例,传递文档列表及其他参数
instance = cls(documents=documents, **kwargs)
return instance
# 主体检索方法,返回与查询相关的文档
def _get_relevant_documents(self, query, k=4, **kwargs):
"""
获取相关文档
参数:
query: 查询字符串
k: 返回的文档数量,默认 4
**kwargs: 其他参数
返回:
Document 列表,按相似度从高到低排序
"""
# 如果没有文档,则直接返回空列表
if not self.documents:
return []
# 对查询进行分词,用于检查是否有匹配的词
query_tokens = set(_chinese_tokenizer(query).split())
# 将查询转为向量
query_vector = self.vectorizer.transform([query])
# 计算查询向量和所有文档向量的余弦相似度
# 取相似度矩阵的第一行(唯一一行),得到所有文档与查询的相似度分数
similarities = cosine_similarity(query_vector, self.document_vectors)[0]
# 检查每个文档是否包含至少一个查询词
doc_has_query_word = []
for doc in self.documents:
doc_tokens = set(_chinese_tokenizer(doc.page_content).split())
# 文档必须包含至少一个查询词,且相似度大于0
doc_has_query_word.append(len(query_tokens & doc_tokens) > 0)
# 获取相似度排序后的索引(降序)
sorted_indices = np.argsort(similarities)[::-1]
# 只保留相似度大于0且包含至少一个查询词的文档
valid_indices = [i for i in sorted_indices if similarities[i] > 0 and doc_has_query_word[i]]
# 取前k个
top_indices = valid_indices[:k]
# 返回对应的Document对象列表
return [self.documents[i] for i in top_indices]
# 对外调用的检索接口
def invoke(self, query, **kwargs):
"""
调用检索器获取相关文档
参数:
query: 查询字符串
**kwargs: 其他参数,如 k(返回文档数量)
返回:
Document 列表
"""
# 调用内部实际检索函数
return self._get_relevant_documents(query, **kwargs)51.3. init.py #
smartchain/retrievers/init.py
from .vector_store import VectorStoreRetriever
from .base import BaseRetriever
+from .tfidf import TFIDFRetriever
__all__ = [
"VectorStoreRetriever",
"BaseRetriever",
+ "TFIDFRetriever"
]51.4. 类 #
51.4.1 类说明 #
| 类名 | 类型 | 作用 | 关键属性/方法 |
|---|---|---|---|
| Document | 数据类 | 存储文档内容和元数据 | page_content: 文本内容metadata: 元数据字典 |
| BaseRetriever | 抽象基类 | 定义检索器接口规范 | invoke(): 对外调用接口_get_relevant_documents(): 抽象方法(需子类实现) |
| TFIDFRetriever | 具体实现类 | 实现基于 TF-IDF 的文档检索 | vectorizer: TfidfVectorizer 向量化器documents: 文档列表document_vectors: 文档向量矩阵from_documents(): 类方法创建实例invoke(): 执行检索_get_relevant_documents(): 核心检索逻辑_fit_vectorizer(): 训练向量化器 |
| _chinese_tokenizer | 函数 | 中文分词预处理 | 使用 jieba 搜索引擎模式分词 过滤标点和单字符词 |
51.4.2 类图 #
51.4.3 时序图 #
51.4.4 调用过程 #
阶段 1:文档对象创建
docs = [
Document(page_content="深度学习通过神经网络取得突破。"),
Document(page_content="机器人学结合机械工程与人工智能。"),
Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
Document(page_content="人工智能正推动机器人自主学习与创新。"),
]- 创建 4 个
Document实例 - 每个对象包含
page_content和默认的metadata
阶段 2:检索器初始化
retriever = TFIDFRetriever.from_documents(docs)调用链:
from_documents()类方法- 调用
cls(documents=docs, **kwargs),即__init__()
- 调用
__init__()方法- 调用
super().__init__()初始化基类 - 设置默认 tokenizer 为
_chinese_tokenizer - 创建
TfidfVectorizer实例 - 保存
documents列表 - 若
documents非空,调用_fit_vectorizer()
- 调用
_fit_vectorizer()方法- 提取所有文档的
page_content - 调用
vectorizer.fit(texts)训练(构建词汇表和 IDF) - 调用
vectorizer.transform(texts)生成文档向量矩阵
- 提取所有文档的
阶段 3:执行检索
results = retriever.invoke(query, k=2)调用链:
invoke()方法(继承自BaseRetriever)- 调用
_get_relevant_documents(query, k=2)
- 调用
_get_relevant_documents()方法(核心检索逻辑)步骤 1:预处理
- 检查文档列表是否为空
- 对查询进行分词:
query_tokens = set(_chinese_tokenizer(query).split())
步骤 2:向量化查询
- 调用
vectorizer.transform([query])将查询转换为 TF-IDF 向量
步骤 3:计算余弦相似度
- 调用
cosine_similarity(query_vector, self.document_vectors)计算相似度数组
步骤 4:检查查询词匹配
for doc in self.documents: doc_tokens = set(_chinese_tokenizer(doc.page_content).split()) doc_has_query_word.append(len(query_tokens & doc_tokens) > 0)确保文档至少包含一个查询词
步骤 5:排序和筛选
- 按相似度降序排序:
sorted_indices = np.argsort(similarities)[::-1] - 筛选有效文档:
valid_indices = [i for i in sorted_indices if similarities[i] > 0 and doc_has_query_word[i]] - 取前 k 个:
top_indices = valid_indices[:k] - 返回对应的文档对象列表:
return [self.documents[i] for i in top_indices]
阶段 4:结果输出
print(f"查询:{query}")
print(f"返回 {len(results)} 条:")
for i, doc in enumerate(results, 1):
print(f"{i}. {doc.page_content}")- 遍历
results(Document列表) - 打印每个文档的
page_content
51.5 TF-IDF 算法 #
TF-IDF(Term Frequency-Inverse Document Frequency)用于衡量词在文档中的重要性。
TF-IDF 公式
对于词 $t$ 和文档 $d$:
$$\text{TF-IDF}(t, d) = \text{TF}(t, d) \times \text{IDF}(t)$$
其中:
- $\text{TF}(t, d) = \frac{\text{词t在文档d中的出现次数}}{\text{文档d的总词数}}$(词频)
- $\text{IDF}(t) = \log\frac{N}{df(t)}$(逆文档频率)
- $N$:文档总数
- $df(t)$:包含词 $t$ 的文档数量
51.6. 余弦相似度 #
查询向量 $\vec{q}$ 和文档向量 $\vec{d}$ 的余弦相似度:
$$\cos(\theta) = \frac{\vec{q} \cdot \vec{d}}{|\vec{q}| \times |\vec{d}|} = \frac{\sum_{i} q_i \times d_i}{\sqrt{\sum_{i} q_i^2} \times \sqrt{\sum_{i} d_i^2}}$$
相似度范围:$[-1, 1]$,值越大越相似。
52.BM25Retriever #
BM25 检索器(BM25Retriever)是一种经典的信息检索算法,在不依赖向量嵌入的情况下实现关键词检索。其主要适用于文本检索、FAQ 系统、知识库初步召回等场景。
主要功能
- 高效关键词检索:BM25 根据 query 与文档的关键词匹配度为文档打分,返回最相关的文本。
- 免向量化依赖:无需复杂的 embedding,仅依赖分词和排名,部署简单,召回快。
- 可定制预处理:支持自定义分词、预处理函数,适配不同的语言或业务数据。
52.1. 52.BM25Retriever.py #
52.BM25Retriever.py
# 导入 Document 类(文档数据结构)
from smartchain.documents import Document
# 导入 BM25Retriever 检索器类
from smartchain.retrievers import BM25Retriever
# 构造示例文档列表,每个 Document 对象代表一段文本
docs = [
# 文档1:关于深度学习的描述
Document(page_content="深度学习通过神经网络取得突破。"),
# 文档2:机器人学与人工智能关系
Document(page_content="机器人学结合机械工程与人工智能。"),
# 文档3:美食点评内容
Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
# 文档4:人工智能正推动机器人自主学习与创新
Document(page_content="人工智能正推动机器人自主学习与创新。"),
]
# 使用 from_documents 类方法创建 BM25 检索器实例
retriever = BM25Retriever.from_documents(docs)
# 设置检索时的查询字符串
query = "机器人与人工智能"
# 调用 invoke 方法执行检索,k=2 表示返回前2个相关文档
results = retriever.invoke(query, k=2)
# 打印本次检索的查询内容
print(f"查询:{query}")
# 打印检索结果的文档数量
print(f"返回 {len(results)} 条:")
# 遍历检索出的文档,逐条输出其内容
for i, doc in enumerate(results, 1):
print(f"{i}. {doc.page_content}")52.2. bm25.py #
smartchain/retrievers/bm25.py
# 导入BM25Okapi用于BM25算法实现
from rank_bm25 import BM25Okapi
# 导入 BaseRetriever 基类
from .base import BaseRetriever
# 导入自定义 Document 类
from ..documents import Document
# 导入jieba用于中文分词
import jieba
import re
# 定义中文分词函数,返回分词列表
def _chinese_tokenizer(text: str):
# 使用jieba的搜索引擎模式进行分词
tokens = list(jieba.cut_for_search(text))
# 过滤掉标点符号、空白字符和单字符词,只保留有效词
tokens = [t for t in tokens if re.match(r'^[\u4e00-\u9fa5a-zA-Z0-9]+$', t) and len(t) > 1]
# 返回分词结果列表
return tokens
# 定义默认的分词预处理函数(按空格切分文本)
def default_preprocessing_func(text: str):
# 按空格将文本切分为词组成的列表
return text.split()
# 定义BM25检索器类,继承自BaseRetriever
class BM25Retriever(BaseRetriever):
# 初始化BM25Retriever类
def __init__(self, vectorizer=None, docs=None, k=4, preprocess_func=None):
# 调用父类的初始化方法
super().__init__()
# 保存BM25Okapi向量化器对象
self.vectorizer = vectorizer
# 保存文档列表,如果未指定则为空列表
self.docs = docs or []
# 保存每次检索返回的文档条数
self.k = k
# 保存分词预处理函数,默认使用中文分词器
self.preprocess_func = preprocess_func or _chinese_tokenizer
# 类方法:通过文本批量创建BM25检索器实例
@classmethod
def from_texts(
cls,
texts,
metadatas=None,
ids=None,
bm25_params=None,
preprocess_func=None,
k=4,
**kwargs
):
# 如果未指定分词预处理函数,默认使用中文分词器
preprocess_func = preprocess_func or _chinese_tokenizer
# 对每个文本应用分词处理,得到分词后的文本列表
texts_processed = [preprocess_func(t) for t in texts]
# 如果未指定BM25参数,设置为空字典
bm25_params = bm25_params or {}
# 用分词后的文本初始化BM25Okapi向量化器
vectorizer = BM25Okapi(texts_processed, **bm25_params)
# 如果未指定元数据,自动生成空字典生成器
metadatas = metadatas or ({} for _ in texts)
# 判断是否有ids字段,构建Document对象列表(id目前未使用)
if ids:
docs = [
Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)
]
else:
docs = [
Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)
]
# 创建并返回BM25Retriever实例
return cls(
vectorizer=vectorizer,
docs=docs,
preprocess_func=preprocess_func,
k=k,
**kwargs
)
# 类方法:通过Document对象列表创建BM25检索器实例
@classmethod
def from_documents(
cls,
documents,
bm25_params=None,
preprocess_func=None,
k=4,
**kwargs
):
# 从每个Document对象提取文本内容组成列表
texts = [doc.page_content for doc in documents]
# 从每个Document对象提取元数据,没有则使用空字典
metadatas = [doc.metadata if hasattr(doc, 'metadata') else {} for doc in documents]
# 调用from_texts构建检索器实例
return cls.from_texts(
texts=texts,
metadatas=metadatas,
bm25_params=bm25_params,
preprocess_func=preprocess_func,
k=k,
**kwargs
)
# 内部方法:检索最相关的文档
def _get_relevant_documents(self, query, **kwargs):
# 优先使用参数k,否则用默认的k值
k = kwargs.get("k", self.k)
# 对查询字符串进行分词预处理
processed_query = self.preprocess_func(query)
# 如果文档列表为空,则直接返回空列表
if not self.docs:
return []
# 对每个文档内容进行分词,得到文档分词列表
docs_tokenized = [self.preprocess_func(doc.page_content) for doc in self.docs]
# 文档总数
N = len(docs_tokenized)
# 计算文档平均长度
avgdl = sum(len(doc_tokens) for doc_tokens in docs_tokenized) / N if N > 0 else 0
# 设置BM25参数k1和b
k1, b = 1.5, 0.75
# 计算查询词的IDF(逆文档频率)
query_word_idf = {}
for word in processed_query:
# 统计包含该词的文档数
df = sum(1 for doc_tokens in docs_tokenized if word in doc_tokens)
if df > 0:
# 按BM25公式计算IDF
query_word_idf[word] = (N - df + 0.5) / (df + 0.5)
else:
# 若不存在该词,IDF记为0
query_word_idf[word] = 0
# 初始化文档分数列表
doc_scores = []
# 逐个文档计算得分
for i, doc_tokens in enumerate(docs_tokenized):
score = 0
# 遍历每个查询词
for word in processed_query:
# 仅对文档中存在的词计算分数
if word in doc_tokens:
# 词频统计
tf = doc_tokens.count(word)
# 查询词IDF
idf = query_word_idf[word]
# BM25核心公式
score += idf * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * len(doc_tokens) / avgdl))
# 保存本条文档的得分及索引
doc_scores.append((score, i))
# 按照分数降序排列,取前k条
doc_scores.sort(reverse=True, key=lambda x: x[0])
# 选择分数大于0的前k个文档索引
top_indices = [idx for score, idx in doc_scores[:k] if score > 0]
# 返回对应的文档对象列表
return [self.docs[i] for i in top_indices]52.3. init.py #
smartchain/retrievers/init.py
from .vector_store import VectorStoreRetriever
from .base import BaseRetriever
from .tfidf import TFIDFRetriever
+from .bm25 import BM25Retriever
__all__ = [
"VectorStoreRetriever",
"BaseRetriever",
+ "TFIDFRetriever",
+ "BM25Retriever"
]52.4.类 #
52.4.1 类说明 #
| 类名 | 类型 | 作用 | 关键属性/方法 |
|---|---|---|---|
| Document | 数据类 | 存储文档内容和元数据 | page_content: 文本内容metadata: 元数据字典 |
| BaseRetriever | 抽象基类 | 定义检索器接口规范 | invoke(): 对外调用接口_get_relevant_documents(): 抽象方法 |
| BM25Retriever | 具体实现类 | 实现基于 BM25 算法的文档检索 | vectorizer: BM25Okapi 向量化器docs: 文档列表preprocess_func: 分词函数from_documents(): 类方法invoke(): 检索_get_relevant_documents(): BM25分数计算 |
| _chinese_tokenizer | 函数 | 中文分词预处理 | 使用 jieba 搜索引擎模式分词 过滤标点和单字符词 |
52.4.2 类图 #
52.4.3 时序图 #
52.4.4 调用过程 #
阶段 1:文档对象创建
docs = [
Document(page_content="深度学习通过神经网络取得突破。"),
Document(page_content="机器人学结合机械工程与人工智能。"),
Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
Document(page_content="人工智能技术正推动机器人自主学习与创新。"),
]- 创建 4 个
Document实例 - 每个对象包含
page_content和默认的metadata
阶段 2:检索器初始化
retriever = BM25Retriever.from_documents(docs)调用链:
from_documents()类方法- 提取所有文档的
page_content和metadata - 调用
from_texts()
- 提取所有文档的
from_texts()类方法- 使用
_chinese_tokenizer对每个文本分词 - 创建
BM25Okapi向量化器(用于后续可能的功能) - 创建
Document对象列表 - 调用
__init__()创建实例
- 使用
__init__()方法- 调用
super().__init__()初始化基类 - 保存
vectorizer、docs、k、preprocess_func
- 调用
阶段 3:执行检索
results = retriever.invoke(query, k=2)调用链:
invoke()方法(继承自BaseRetriever)- 调用
_get_relevant_documents(query, k=2)
- 调用
_get_relevant_documents()方法(核心检索逻辑)步骤 1:预处理
- 获取
k值(优先使用参数,否则使用默认值) - 对查询进行分词:
processed_query = self.preprocess_func(query) - 检查文档列表是否为空
步骤 2:文档分词
- 对每个文档内容进行分词:
docs_tokenized = [self.preprocess_func(doc.page_content) for doc in self.docs]
步骤 3:计算统计量
N = len(docs_tokenized)(文档总数)avgdl = sum(len(doc_tokens) for doc_tokens in docs_tokenized) / N(平均文档长度)- 设置 BM25 参数:
k1=1.5, b=0.75
步骤 4:计算查询词的 IDF(逆文档频率)
for word in processed_query: df = sum(1 for doc_tokens in docs_tokenized if word in doc_tokens) if df > 0: query_word_idf[word] = (N - df + 0.5) / (df + 0.5)步骤 5:计算每个文档的 BM25 分数
for i, doc_tokens in enumerate(docs_tokenized): score = 0 for word in processed_query: if word in doc_tokens: tf = doc_tokens.count(word) # 词频 idf = query_word_idf[word] # 逆文档频率 # BM25核心公式 score += idf * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * len(doc_tokens) / avgdl)) doc_scores.append((score, i))步骤 6:排序和筛选
- 按分数降序排序:
doc_scores.sort(reverse=True, key=lambda x: x[0]) - 取前 k 个分数 > 0 的文档:
top_indices = [idx for score, idx in doc_scores[:k] if score > 0] - 返回对应的文档对象列表:
return [self.docs[i] for i in top_indices]
- 获取
阶段 4:结果输出
print(f"查询:{query}")
print(f"返回 {len(results)} 条:")
for i, doc in enumerate(results, 1):
print(f"{i}. {doc.page_content}")- 遍历
results(Document列表) - 打印每个文档的
page_content
52.4.5 BM25 算法 #
BM25(Best Matching 25)是用于信息检索的排序函数。
52.4.5.1 BM25 公式 #
对于查询词 $q$ 和文档 $d$,BM25 分数计算为:
$$\text{BM25}(q, d) = \sum_{i=1}^{n} \text{IDF}(q_i) \cdot \frac{f(q_i, d) \cdot (k_1 + 1)}{f(q_i, d) + k_1 \cdot \left(1 - b + b \cdot \frac{|d|}{\text{avgdl}}\right)}$$
其中:
- $q_i$:查询中的第 $i$ 个词
- $f(q_i, d)$:词 $q_i$ 在文档 $d$ 中的词频(TF)
- $\text{IDF}(q_i)$:词 $q_i$ 的逆文档频率
- $|d|$:文档 $d$ 的长度(词数)
- $\text{avgdl}$:所有文档的平均长度
- $k_1$:词频饱和度参数(通常为 1.5)
- $b$:长度归一化参数(通常为 0.75)
52.4.5.2 IDF 计算 #
$$\text{IDF}(q_i) = \frac{N - df(q_i) + 0.5}{df(q_i) + 0.5}$$
其中:
- $N$:文档总数
- $df(q_i)$:包含词 $q_i$ 的文档数量
53.VectorSimilarityRetriever #
向量相似度检索器(VectorSimilarityRetriever)是一种基于向量化表示和余弦相似度的经典检索工具。它和之前介绍的向量数据库型检索器不同,VectorSimilarityRetriever 常用于本地小规模场景下的内存型向量匹配,特点如下:
- 无需外部依赖数据库:直接将文档转为嵌入向量后,全部保存在内存中,适用于文档规模不大、运行在轻量系统或教学场景。
- 原理朴素直观:每当检索时,对查询文本进行向量化,与所有已存文档向量做 1:N 的余弦相似度比较,返回最相似的前 k 个文档。
- 用法简明:只需提供文档列表,自动完成嵌入、存储和检索,不依赖外部检索服务、数据库或分布式系统。
- 局限性:由于所有向量均保存在内存,随着文档增多性能和内存消耗会迅速变大,因此主要用于实验、演示、测试或小样本系统。
在 smartchain.retrievers.vector.VectorSimilarityRetriever 中,典型接口为:
from_documents(docs):类方法,初始化时传入 Document 列表,自动完成向量化和存储。invoke(query, k=None):执行检索,返回与 query 语义最相似的前 k 个文档,k 可选,默认值如4。
底层实现流程如下:
- 把每个文档文本通过预定义/自带的嵌入器转为向量,保存在内存数组中。
- 用户输入查询时,同样用嵌入器转向量。
- 通过余弦相似度计算 query 向量与所有文档向量的相似度,筛选出最大 top_k 的文档返回。
- 可针对 zero-shot、few-shot 场景使用,也可作为向量数据库方案的 lightweight baseline 对比。
53.1. 53.SimilarityRetriever.py #
53.SimilarityRetriever.py
#from langchain_classic.retrievers.bm25 import BM25Retriever
#from langchain_core.documents import Document
# 导入 Document 类(文档表示类)
from smartchain.documents import Document
# 导入 BM25Retriever 检索器
from smartchain.retrievers import VectorSimilarityRetriever
# 准备示例文档,构造 Document 列表,每个包含一段文本
docs = [
Document(page_content="深度学习通过神经网络取得突破。"),
Document(page_content="机器人学结合机械工程与人工智能。"),
Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
Document(page_content="人工智能技术正推动机器人自主学习与创新。"),
]
# 使用 from_documents 类方法创建 BM25 检索器,传入文档列表
retriever = VectorSimilarityRetriever.from_documents(docs)
# 设置检索的查询字符串
query = "机器人与人工智能"
# 调用 invoke 方法执行检索,获取相关文档
results = retriever.invoke(query)
# 打印查询内容
print(f"查询:{query}")
# 打印返回的相关文档数量
print(f"返回 {len(results)} 条:")
# 遍历检索结果并逐条打印文档内容
for i, doc in enumerate(results, 1):
print(f"{i}. {doc.page_content}")53.2. vector.py #
smartchain/retrievers/vector.py
# 导入 BaseRetriever 基类
from .base import BaseRetriever
# 导入自定义 Document 类
from ..documents import Document
# 导入numpy用于数组操作
import numpy as np
# 定义余弦相似度计算函数
def cosine_similarity(query_embedding, doc_embeddings):
"""
计算查询向量与文档向量的余弦相似度
参数:
query_embedding: 查询的嵌入向量(一维数组)
doc_embeddings: 文档的嵌入向量列表(二维数组)
返回:
相似度分数数组
"""
# 将查询嵌入转换为numpy数组
query_emb = np.array(query_embedding)
# 将文档嵌入转换为numpy数组
doc_embs = np.array(doc_embeddings)
# 若查询为一维,转换为二维以便向量批量操作
if query_emb.ndim == 1:
query_emb = query_emb.reshape(1, -1)
# 计算查询和所有文档嵌入的点积
dot_product = np.dot(query_emb, doc_embs.T)
# 计算查询向量的范数
query_norm = np.linalg.norm(query_emb, axis=1, keepdims=True)
# 计算所有文档嵌入的范数,并转置方便后续相乘
doc_norms = np.linalg.norm(doc_embs, axis=1, keepdims=True).T
# 使用点积/范数乘积得到余弦相似度
similarity = dot_product / (query_norm * doc_norms)
# 将因除零产生的nan、正负无穷转为0
similarity = np.nan_to_num(similarity, nan=0.0, posinf=0.0, neginf=0.0)
# 返回第一行的相似度结果(只支持一个query)
return similarity[0]
# 定义VectorSimilarityRetriever检索器,继承自BaseRetriever
class VectorSimilarityRetriever(BaseRetriever):
# VectorSimilarity 检索器说明文档字符串
"""
VectorSimilarity 检索器
使用向量嵌入进行文档检索。
通过计算查询文本与文档集合的嵌入向量,使用余弦相似度找到最相关的文档。
"""
# 初始化方法
def __init__(self, embeddings, documents=None, k=4, **kwargs):
"""
初始化 VectorSimilarity 检索器
参数:
embeddings: 嵌入模型实例,必须实现 embed_query 和 embed_documents 方法
documents: Document 列表,用于计算嵌入向量
k: 默认返回的文档数量,默认 4
**kwargs: 其他参数
"""
# 调用父类初始化方法
super().__init__(**kwargs)
# 保存嵌入模型
self.embeddings = embeddings
# 保存返回文档数量
self.k = k
# 保存文档列表,如果为None则赋值为空列表
self.documents = documents or []
# 文档嵌入向量列表
self.doc_embeddings = None
# 如果文档列表不为空,则计算文档嵌入向量
if self.documents:
self._compute_document_embeddings()
# 计算所有文档的嵌入向量
def _compute_document_embeddings(self):
"""计算所有文档的嵌入向量"""
# 提取所有文档的内容,组成文本列表
texts = [doc.page_content for doc in self.documents]
# 批量计算所有文档的嵌入向量
self.doc_embeddings = self.embeddings.embed_documents(texts)
# 通过类方法根据文档列表创建检索器对象
@classmethod
def from_documents(cls, documents, embeddings=None, k=4, **kwargs):
"""
从文档列表创建 VectorSimilarityRetriever 实例
参数:
documents: Document 列表
embeddings: 嵌入模型实例,如果为None则使用默认的HuggingFaceEmbeddings
k: 默认返回的文档数量,默认 4
**kwargs: 其他参数,可传递给嵌入模型
返回:
VectorSimilarityRetriever 实例
"""
# 如果没有提供嵌入模型,使用默认的HuggingFaceEmbeddings
if embeddings is None:
from ..embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(**kwargs)
# 创建检索器实例,传递文档列表及嵌入模型
instance = cls(embeddings=embeddings, documents=documents, k=k)
return instance
# 主体检索方法,返回与查询相关的文档
def _get_relevant_documents(self, query, k=None, **kwargs):
"""
获取相关文档
参数:
query: 查询字符串
k: 返回的文档数量,如果为None则使用默认值
**kwargs: 其他参数
返回:
Document 列表,按相似度从高到低排序
"""
# 如果没有文档,则直接返回空列表
if not self.documents or self.doc_embeddings is None:
return []
# 获取k值,优先使用传入参数,否则使用默认值
k = k if k is not None else self.k
# 计算查询的嵌入向量
query_embedding = self.embeddings.embed_query(query)
# 计算查询向量和所有文档向量的余弦相似度
similarities = cosine_similarity(query_embedding, self.doc_embeddings)
# 获取相似度排序后的索引(降序),并取前k个
top_indices = np.argsort(similarities)[::-1][:k]
# 只保留相似度大于0的文档,返回对应的Document对象列表
return [self.documents[i] for i in top_indices if similarities[i] > 0]53.3. init.py #
smartchain/retrievers/init.py
from .vector_store import VectorStoreRetriever
from .base import BaseRetriever
from .tfidf import TFIDFRetriever
from .bm25 import BM25Retriever
+from .vector import VectorSimilarityRetriever
__all__ = [
"VectorStoreRetriever",
"BaseRetriever",
"TFIDFRetriever",
+ "BM25Retriever",
+ "VectorSimilarityRetriever"
]53.4.类 #
53.4.1 类说明 #
| 类名 | 类型 | 作用 | 关键属性/方法 |
|---|---|---|---|
| Document | 数据类 | 存储文档内容和元数据 | page_content: 文档文本内容metadata: 元数据字典 |
| BaseRetriever | 抽象基类 | 定义检索器接口规范 | invoke(): 对外调用接口_get_relevant_documents(): 抽象方法(需子类实现) |
| VectorSimilarityRetriever | 具体实现类 | 实现基于向量嵌入的文档检索 | embeddings: 嵌入模型实例documents: 文档列表doc_embeddings: 文档嵌入向量列表k: 默认返回文档数量from_documents(): 类方法创建实例invoke(): 执行检索_get_relevant_documents(): 核心检索逻辑_compute_document_embeddings(): 计算文档嵌入向量 |
| HuggingFaceEmbeddings | 嵌入模型类 | 使用 HuggingFace 模型生成文本嵌入向量 | model_name: 模型名称embed_query(): 嵌入单个查询文本embed_documents(): 批量嵌入文档文本 |
| Embedding | 抽象基类 | 定义嵌入模型接口规范 | embed_query(): 抽象方法embed_documents(): 抽象方法 |
| cosine_similarity | 函数 | 计算余弦相似度 | 计算查询向量与文档向量的余弦相似度 |
53.4.2 类图 #
53.4.3 时序图 #
53.4.4 调用过程 #
阶段 1:文档对象创建
docs = [
Document(page_content="深度学习通过神经网络取得突破。"),
Document(page_content="机器人学结合机械工程与人工智能。"),
Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
Document(page_content="人工智能技术正推动机器人自主学习与创新。"),
]- 创建 4 个
Document实例 - 每个对象包含
page_content和默认的metadata
阶段 2:检索器初始化
retriever = VectorSimilarityRetriever.from_documents(docs)调用链:
from_documents()类方法- 检查
embeddings是否为None - 若为
None,创建默认的HuggingFaceEmbeddings实例(使用all-MiniLM-L6-v2模型) - 调用
cls(embeddings=embeddings, documents=documents, k=k),即__init__()
- 检查
__init__()方法- 调用
super().__init__()初始化基类 - 保存
embeddings、documents、k - 初始化
doc_embeddings = None - 若
documents非空,调用_compute_document_embeddings()
- 调用
_compute_document_embeddings()方法- 提取所有文档的
page_content - 调用
embeddings.embed_documents(texts)批量计算文档嵌入向量 - 保存到
self.doc_embeddings
- 提取所有文档的
阶段 3:执行检索
results = retriever.invoke(query)调用链:
invoke()方法(继承自BaseRetriever)- 调用
_get_relevant_documents(query, k=None)
- 调用
_get_relevant_documents()方法(核心检索逻辑)步骤 1:预处理
- 检查
documents和doc_embeddings是否为空 - 确定
k值(优先使用参数,否则使用默认值)
步骤 2:计算查询嵌入向量
- 调用
embeddings.embed_query(query)将查询转换为嵌入向量
步骤 3:计算余弦相似度
- 调用
cosine_similarity(query_embedding, self.doc_embeddings)计算相似度 - 内部计算:
dot_product = np.dot(query_emb, doc_embs.T) query_norm = np.linalg.norm(query_emb, axis=1, keepdims=True) doc_norms = np.linalg.norm(doc_embs, axis=1, keepdims=True).T similarity = dot_product / (query_norm * doc_norms)
步骤 4:排序和筛选
- 按相似度降序排序:
top_indices = np.argsort(similarities)[::-1][:k] - 筛选相似度 > 0 的文档:
return [self.documents[i] for i in top_indices if similarities[i] > 0]
- 检查
阶段 4:结果输出
print(f"查询:{query}")
print(f"返回 {len(results)} 条:")
for i, doc in enumerate(results, 1):
print(f"{i}. {doc.page_content}")- 遍历
results(Document列表) - 打印每个文档的
page_content
53.4.5 向量嵌入与余弦相似度 #
向量嵌入(Embedding)
将文本转换为数值向量,捕获语义信息:
- 语义相似的文本在向量空间中距离更近
- 使用预训练模型(如
all-MiniLM-L6-v2)生成固定维度的向量
余弦相似度公式
对于查询向量 $\vec{q}$ 和文档向量 $\vec{d}$:
$$\cos(\theta) = \frac{\vec{q} \cdot \vec{d}}{|\vec{q}| \times |\vec{d}|} = \frac{\sum_{i=1}^{n} q_i \times d_i}{\sqrt{\sum_{i=1}^{n} q_i^2} \times \sqrt{\sum_{i=1}^{n} d_i^2}}$$
其中:
- $\vec{q} \cdot \vec{d}$:点积
- $|\vec{q}|$、$|\vec{d}|$:向量的 L2 范数
- 相似度范围:$[-1, 1]$,值越大越相似
优势
- 语义理解:能捕获语义相似性,不依赖关键词匹配
- 多语言支持:预训练模型支持多种语言
- 上下文感知:考虑上下文信息
53.4.6 设计模式 #
- 模板方法模式
BaseRetriever.invoke()定义调用流程VectorSimilarityRetriever._get_relevant_documents()实现具体检索逻辑
- 工厂方法模式
from_documents()作为类方法工厂,简化实例创建
- 策略模式
BaseRetriever定义接口,VectorSimilarityRetriever提供具体策略Embedding定义嵌入接口,HuggingFaceEmbeddings提供具体实现
该设计便于扩展新的检索器和嵌入模型实现。
53.4.7 与 TFIDF/BM25 的区别 #
| 特性 | VectorSimilarityRetriever | TFIDFRetriever | BM25Retriever |
|---|---|---|---|
| 算法基础 | 向量嵌入(深度学习) | TF-IDF(统计) | BM25(统计) |
| 语义理解 | ✅ 支持 | ❌ 不支持 | ❌ 不支持 |
| 关键词匹配 | ❌ 不依赖 | ✅ 依赖 | ✅ 依赖 |
| 计算复杂度 | 较高(需要模型推理) | 较低 | 较低 |
| 多语言支持 | ✅ 好 | 需分词器 | 需分词器 |
| 上下文理解 | ✅ 强 | ❌ 弱 | ❌ 弱 |
VectorSimilarityRetriever 更适合需要语义理解的场景,而 TFIDF/BM25 更适合关键词匹配场景。
54.EnsembleRetriever #
EnsembleRetriever 是一种通过集成多个独立检索器(如不同数据源或算法)来提升整体检索效果的高级用法。
EnsembleRetriever 采用加权 Reciprocal Rank Fusion (RRF) 算法,将多个检索器的返回结果进行融合,并根据设定的权重对每个检索源影响力加权排序。其主要特点和典型应用场景如下:
- 灵活集成:可同时结合多个不同来源或不同算法的检索器,提高召回率和鲁棒性。
- 权重可调:支持自定义各子检索器的权重,便于根据实际效果调整信息融合策略。
- 去重与排序:自动对不同检索器的结果按唯一键去重,并结合权重做最终排序,兼顾多样性与准确性。
- 适用场景:适合需要整合多种索引、异构数据源或不同检索策略的复合应用。
54.1. 54.EnsembleRetriever.py #
54.EnsembleRetriever.py
# 导入 Document 类,用于表示文档对象
from smartchain.documents import Document
# 导入集成检索器(EnsembleRetriever),用于多检索器融合
from smartchain.retrievers import EnsembleRetriever
# 构建示例文档列表,每个元素是一个 Document 对象,包含一段文本内容
docs = [
Document(page_content="深度学习通过神经网络取得突破。"),
Document(page_content="机器人学结合机械工程与人工智能。"),
Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
Document(page_content="人工智能技术正推动机器人自主学习与创新。"),
]
# 使用类方法 from_documents 创建集成检索器实例,传入文档列表
# weights=[0.5, 0.5] 表示 BM25 检索器和向量相似度检索器权重各占 0.5
retriever = EnsembleRetriever.from_documents(
docs,
weights=[0.5, 0.5] # BM25 和向量检索权重设置
)
# 设置检索的查询语句
query = "机器人与人工智能"
# 调用 invoke 方法执行检索操作,k=2 表示取出前2个相关文档
results = retriever.invoke(query, k=2)
# 打印查询内容
print(f"查询:{query}")
# 打印返回的相关文档数量
print(f"返回 {len(results)} 条:")
# 遍历所有检索结果,依次打印每条文档的内容
for i, doc in enumerate(results, 1):
print(f"{i}. {doc.page_content}")54.2. ensemble.py #
smartchain/retrievers/ensemble.py
# 导入基础检索器基类
from .base import BaseRetriever
# 导入BM25检索器
from .bm25 import BM25Retriever
# 导入向量相似度检索器
from .vector import VectorSimilarityRetriever
# 导入HuggingFace嵌入模型
from ..embeddings import HuggingFaceEmbeddings
# 定义集成检索器类(融合多种检索器)
class EnsembleRetriever(BaseRetriever):
# 初始化方法
def __init__(
self,
retrievers=None, # 传入的检索器列表
weights=None, # 各检索器的权重
k=60, # RRF分数计算时的k值
):
# 初始化父类
super().__init__()
# 设置检索器列表,如果未指定则为空列表
self.retrievers = retrievers or []
# 设置权重
self.weights = weights
# 设置RRF算法的k参数
self.k = k
# 类方法:根据文档自动构建集成检索器
@classmethod
def from_documents(
cls,
documents,
weights=None,
):
# 创建BM25检索器
bm25_retriever = BM25Retriever.from_documents(documents)
# 用指定模型创建文本嵌入实例(默认使用CPU)
embeddings = HuggingFaceEmbeddings(
model_name="BAAI/bge-base-zh-v1.5",
model_kwargs={"device": "cpu"}
)
# 创建向量相似度检索器
vector_retriever = VectorSimilarityRetriever.from_documents(
documents,
embeddings=embeddings,
)
# 如果没有指定权重,则默认均分
if weights is None:
weights = [0.5, 0.5]
# 打包检索器列表
retrievers = [bm25_retriever, vector_retriever]
# 返回集成检索器实例
return cls(retrievers=retrievers, weights=weights, k=60)
# 获取文档唯一标识(此处直接用内容)
def _get_doc_id(self, doc):
return doc.page_content
# 根据所有检索器得到的排位计算RRF融合分数
def _calculate_rrf_score(self, doc_ranks):
# 存储每个文档的RRF分数
rrf_scores = {}
# 遍历每个文档的各检索器排名
for doc_id, ranks in doc_ranks.items():
score = 0.0
# 依次遍历各检索器排名
for i, rank in enumerate(ranks):
# 如果有排名(未命中为None)
if rank is not None:
# 使用指定权重,如果没有,则为1.0
weight = self.weights[i] if self.weights else 1.0
# RRF分数累加
score += weight / (self.k + rank)
# 汇总该文档的总分
rrf_scores[doc_id] = score
# 返回所有文档分数
return rrf_scores
# 获取最终融合后的相关文档
def _get_relevant_documents(self, query, k=4, **kwargs):
# 如果检索器列表为空,直接返回空结果
if not self.retrievers:
return []
# 如果权重未设置,则自动平均分配
if self.weights is None:
self.weights = [1.0 / len(self.retrievers)] * len(self.retrievers)
# 权重数量必须等于检索器数量,否则报错
if len(self.weights) != len(self.retrievers):
raise ValueError(
f"权重数量 ({len(self.weights)}) 必须与检索器数量 ({len(self.retrievers)}) 一致"
)
# all_results用于存储每个检索器的返回结果
all_results = []
# 分别用每个检索器检索,大k值,提高召回
for retriever in self.retrievers:
results = retriever.invoke(query, k=k * 2, **kwargs)
all_results.append(results)
# 建立从ID到文档对象的映射
doc_id_to_doc = {}
# 存储每个文档在各检索器中的排名
doc_ranks = {}
# 遍历所有检索器的召回结果
for retriever_idx, results in enumerate(all_results):
for rank, doc in enumerate(results, start=1):
# 获取文档ID
doc_id = self._get_doc_id(doc)
# 首次出现该文档,初始化
if doc_id not in doc_id_to_doc:
doc_id_to_doc[doc_id] = doc
doc_ranks[doc_id] = [None] * len(self.retrievers)
# 仅记录第一次出现的rank
if doc_ranks[doc_id][retriever_idx] is None:
doc_ranks[doc_id][retriever_idx] = rank
# 融合所有排名,返回RRF融合分数
rrf_scores = self._calculate_rrf_score(doc_ranks)
# 按分数降序排列
sorted_docs = sorted(
rrf_scores.items(),
key=lambda x: x[1],
reverse=True
)
# 收集分数大于0的top k文档
top_k_docs = []
for doc_id, score in sorted_docs[:k]:
if score > 0:
top_k_docs.append(doc_id_to_doc[doc_id])
# 返回最终文档列表
return top_k_docs54.3. init.py #
smartchain/retrievers/init.py
from .vector_store import VectorStoreRetriever
from .base import BaseRetriever
from .tfidf import TFIDFRetriever
from .bm25 import BM25Retriever
from .vector import VectorSimilarityRetriever
+from .ensemble import EnsembleRetriever
__all__ = [
"VectorStoreRetriever",
"BaseRetriever",
"TFIDFRetriever",
"BM25Retriever",
+ "VectorSimilarityRetriever",
+ "EnsembleRetriever"
]54.4. 类 #
54.4.1 类说明 #
| 类名 | 类型 | 作用 | 关键属性/方法 |
|---|---|---|---|
| Document | 数据类 | 存储文档内容和元数据 | page_content: 文档文本内容metadata: 元数据字典 |
| BaseRetriever | 抽象基类 | 定义检索器接口规范 | invoke(): 对外调用接口_get_relevant_documents(): 抽象方法(需子类实现) |
| EnsembleRetriever | 集成检索器类 | 融合多个检索器的结果 | retrievers: 检索器列表weights: 各检索器权重k: RRF算法k值from_documents(): 类方法创建实例invoke(): 执行检索_get_relevant_documents(): 核心检索逻辑_calculate_rrf_score(): 计算RRF融合分数_get_doc_id(): 获取文档唯一标识 |
| BM25Retriever | 检索器类 | 基于BM25算法的文档检索 | from_documents(): 类方法创建实例invoke(): 执行检索 |
| VectorSimilarityRetriever | 检索器类 | 基于向量嵌入的文档检索 | embeddings: 嵌入模型实例from_documents(): 类方法创建实例invoke(): 执行检索 |
| HuggingFaceEmbeddings | 嵌入模型类 | 使用HuggingFace模型生成文本嵌入向量 | model_name: 模型名称embed_query(): 嵌入单个查询文本embed_documents(): 批量嵌入文档文本 |
54.4.2 类图 #
54.4.3 时序图 #
54.4.4 调用过程 #
阶段 1:文档对象创建
docs = [
Document(page_content="深度学习通过神经网络取得突破。"),
Document(page_content="机器人学结合机械工程与人工智能。"),
Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
Document(page_content="人工智能技术正推动机器人自主学习与创新。"),
]- 创建 4 个
Document实例 - 每个对象包含
page_content和默认的metadata
阶段 2:创建集成检索器
retriever = EnsembleRetriever.from_documents(
docs,
weights=[0.5, 0.5] # BM25 和向量检索权重设置
)调用链:
from_documents()类方法- 创建
BM25Retriever实例:bm25_retriever = BM25Retriever.from_documents(documents) - 创建
HuggingFaceEmbeddings实例(使用BAAI/bge-base-zh-v1.5模型) - 创建
VectorSimilarityRetriever实例:vector_retriever = VectorSimilarityRetriever.from_documents(documents, embeddings=embeddings) - 如果没有指定权重,默认设置为
[0.5, 0.5] - 打包检索器列表:
retrievers = [bm25_retriever, vector_retriever] - 调用
cls(retrievers=retrievers, weights=weights, k=60)创建实例
- 创建
__init__()方法- 调用
super().__init__()初始化基类 - 保存
retrievers、weights、k(RRF 算法参数,默认 60)
- 调用
阶段 3:执行检索
results = retriever.invoke(query, k=2)调用链:
invoke()方法(继承自BaseRetriever)- 调用
_get_relevant_documents(query, k=2)
- 调用
_get_relevant_documents()方法(核心检索逻辑)步骤 1:验证和初始化
- 检查检索器列表是否为空
- 如果权重未设置,自动平均分配
- 验证权重数量与检索器数量一致
步骤 2:并行调用多个检索器
for retriever in self.retrievers: results = retriever.invoke(query, k=k * 2, **kwargs) # k*2 提高召回率 all_results.append(results)- 每个检索器返回
k*2个文档(提高召回率) - 收集所有检索器的结果
步骤 3:建立文档映射和排名记录
doc_id_to_doc = {} # ID到文档对象的映射 doc_ranks = {} # 每个文档在各检索器中的排名 for retriever_idx, results in enumerate(all_results): for rank, doc in enumerate(results, start=1): doc_id = self._get_doc_id(doc) # 获取文档ID(使用page_content) if doc_id not in doc_id_to_doc: doc_id_to_doc[doc_id] = doc doc_ranks[doc_id] = [None] * len(self.retrievers) if doc_ranks[doc_id][retriever_idx] is None: doc_ranks[doc_id][retriever_idx] = rank步骤 4:计算 RRF 融合分数
rrf_scores = self._calculate_rrf_score(doc_ranks)内部计算逻辑:
for doc_id, ranks in doc_ranks.items(): score = 0.0 for i, rank in enumerate(ranks): if rank is not None: weight = self.weights[i] if self.weights else 1.0 score += weight / (self.k + rank) # RRF公式 rrf_scores[doc_id] = score步骤 5:排序和筛选
- 按 RRF 分数降序排序
- 取前 k 个分数 > 0 的文档
- 返回对应的 Document 对象列表
阶段 4:结果输出
print(f"查询:{query}")
print(f"返回 {len(results)} 条:")
for i, doc in enumerate(results, 1):
print(f"{i}. {doc.page_content}")- 遍历
results(Document列表) - 打印每个文档的
page_content
54.4.5 RRF 算法 #
RRF(Reciprocal Rank Fusion)用于融合多个检索器的排名结果。
RRF 公式
对于文档 $d$,其 RRF 分数计算为:
$$\text{RRF}(d) = \sum_{i=1}^{n} \frac{w_i}{k + \text{rank}_i(d)}$$
其中:
- $n$:检索器数量
- $w_i$:第 $i$ 个检索器的权重
- $k$:RRF 算法参数(通常为 60)
- $\text{rank}_i(d)$:文档 $d$ 在第 $i$ 个检索器中的排名(如果未出现则为 None,不参与计算)
RRF 算法优势
- 无需归一化:不同检索器的分数无需归一化即可融合
- 处理缺失:某个检索器未返回的文档仍可参与融合
- 权重可调:通过
weights参数调整各检索器的重要性 - 提高召回:融合多个检索器,提高召回率
示例计算
假设:
- 文档 A 在 BM25 中排名 1,在向量检索中排名 2
- 文档 B 在 BM25 中排名 2,在向量检索中排名 1
- 权重:
[0.5, 0.5],k=60
文档 A 的 RRF 分数: $$\text{RRF}(A) = \frac{0.5}{60 + 1} + \frac{0.5}{60 + 2} = \frac{0.5}{61} + \frac{0.5}{62} \approx 0.0164$$
文档 B 的 RRF 分数: $$\text{RRF}(B) = \frac{0.5}{60 + 2} + \frac{0.5}{60 + 1} = \frac{0.5}{62} + \frac{0.5}{61} \approx 0.0164$$
两者分数接近,但文档 B 在两个检索器中排名更靠前,最终排名可能更高。
55.LLMChainExtractor #
本节介绍如何利用大语言模型(LLM)提升检索问答系统的相关性和答案精度,具体通过文档压缩与上下文压缩检索技术实现。
在传统检索系统中,仅靠相似度召回的文档常常包含冗余内容或部分不相关片段,导致后续大模型生成的答案长且噪声多。
LLM驱动的文档压缩器(如LLMChainExtractor)能够智能地从每个召回文档中,仅抽取与用户问题直接相关的段落,极大提升答案简洁性和相关度。
LLMChainExtractor文档压缩器
- 该压缩器会对每个检索到的文档,结合问题,通过大模型抽取出最相关的部分。如果判定文档与问题无关,则自动过滤该文档。
- 适用于中/英文场景,支持按需自定义抽取模板。
- 能有效剔除冗余、弱相关内容,为问答系统后续推理提供精准上下文。
ContextualCompressionRetriever上下文压缩检索器
- 这是一个高级检索管道,将“基础检索器”与“文档压缩器”无缝衔接。
- 流程为:首先使用相似度等方式初步召回一批文档,然后对这些文档应用LLMChainExtractor压缩,仅保留高相关内容或段落。
- 这样,不仅提升了下游大模型生成答案的准确性,也显著降低无用信息干扰。
典型应用场景
- 非结构化知识库问答、企业内部知识检索、FAQ系统等,尤其适用于原始文档较长、信息密度低、但只需部分内容即可作答的问题型场景。
一体化示例
- 下方示例代码展示了如何集成向量数据库、嵌入模型、检索器和LLM压缩器,构建高效高相关性的智能问答系统。
- 其中核心语句为:
compressor = LLMChainExtractor.from_llm(llm) compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=base_retriever) results = compression_retriever.invoke(query) - 只需更换“llm”、“嵌入模型”或检索参数,即可灵活适配不同业务需求。
55.1. 55.LLMChainExtractor.py #
55.LLMChainExtractor.py
#from langchain_chroma import Chroma
#from langchain_huggingface import HuggingFaceEmbeddings
#from langchain_classic.retrievers.contextual_compression import ContextualCompressionRetriever
#from langchain_classic.retrievers.document_compressors.chain_extract import LLMChainExtractor
#from langchain_deepseek import ChatDeepSeek
# 导入 Chroma 向量数据库类
from smartchain.vectorstores import Chroma
# 导入 HuggingFaceEmbeddings 嵌入模型类
from smartchain.embeddings import HuggingFaceEmbeddings
# 导入上下文压缩型检索器
from smartchain.retrievers import ContextualCompressionRetriever
# 导入 LLMChainExtractor 文档压缩器
from smartchain.document_compressors import LLMChainExtractor
# 导入 ChatDeepSeek LLM 模型
from smartchain.chat_models import ChatDeepSeek
# 设置本地嵌入模型的路径
model_path = "C:/Users/Administrator/.cache/modelscope/hub/models/sentence-transformers/all-MiniLM-L6-v2"
# 初始化嵌入模型,指定模型路径与推理设备
embeddings = HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={"device": "cpu"}
)
# 初始化 Chroma 向量数据库实例,指定持久化目录、嵌入函数、集合名称和元数据
chroma_db = Chroma(
persist_directory="chroma_db",
embedding_function=embeddings,
collection_name="test",
collection_metadata={"hnsw:space": "cosine"}
)
# 检查数据库是否为空,若为空批量插入初始文本及其元数据
if not chroma_db._collection.count():
# 定义插入的文本列表
texts = [
"人工智能(AI)是一种让机器模拟人类智能行为的技术。",
"深度学习是人工智能的一个重要分支,通过多层神经网络学习数据。",
"ChatGPT是OpenAI开发的强大自然语言模型。",
"向量数据库可以高效地存储和检索文本的嵌入向量。",
"机器人学结合了人工智能和机械工程,推动自动化发展。",
"AI可以辅助医生进行医学影像分析。",
"大模型在对话、问答、摘要等领域不断取得突破。",
"知识库问答系统常用于企业信息检索场景。",
]
# 对应的元数据列表
metadatas = [
{"topic": "ai"},
{"topic": "ai"},
{"topic": "nlp"},
{"topic": "vector_db"},
{"topic": "robotics"},
{"topic": "healthcare"},
{"topic": "llm"},
{"topic": "retrieval"},
]
# 向向量数据库批量插入文本和元数据
chroma_db.add_texts(texts, metadatas)
# 创建 ChatDeepSeek LLM 对象,使用 deepseek-chat 模型
llm = ChatDeepSeek(model="deepseek-chat")
# 创建基础检索器,指定检索类型为 similarity 和返回数量 k=20
# 建议 k 取较大值,因为压缩器后续会过滤部分不相关的结果
base_retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 20})
# 创建 LLMChainExtractor 文档压缩器,利用 LLM 从文档抽取相关部分
compressor = LLMChainExtractor.from_llm(llm=llm)
# 创建上下文压缩检索器,结合基础检索器与文档压缩器
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=base_retriever
)
# 指定检索查询内容(支持中文或英文)
query = "人工智能" # 可以更换为任意中英文问题
# 调用上下文压缩检索器进行检索
results = compression_retriever.invoke(query)
# 打印查询内容
print(f"查询:{query}")
# 打印检索到的文档数量
print(f"共检索到 {len(results)} 个文档")
# 遍历输出每条检索结果及其元数据
for i, doc in enumerate(results):
print(f"检索结果{i}:{doc.page_content}")
print(f" 元数据:{doc.metadata}\n")55.2. document_compressors.py #
smartchain/document_compressors.py
# 从abc模块导入抽象基类支持
from abc import ABC, abstractmethod
# 定义一个默认的中文抽取Prompt模板
CHINESE_EXTRACT_PROMPT = """给定以下问题和上下文,提取上下文中与回答问题相关的任何部分(保持原样)。如果上下文都不相关,返回 NO_OUTPUT。
记住,*不要*编辑提取的上下文部分。
> 问题:{question}
> 上下文:
>>>
{context}
>>>
提取的相关部分:"""
# 定义文档压缩器的抽象基类
class BaseDocumentCompressor(ABC):
# 设置文档压缩器的类说明文档
"""文档压缩器抽象基类
用于对检索到的文档进行后处理,例如提取相关内容、过滤不相关文档等。
"""
# 声明子类必须实现的抽象方法
@abstractmethod
def compress_documents(
self,
documents,
query: str,
callbacks=None,
):
# 方法说明文档
"""压缩文档
参数:
documents: 检索到的文档列表
query: 查询字符串
callbacks: 可选的回调函数
返回:
Sequence: 压缩后的文档列表
"""
# 抽象方法体不实现
pass
# 定义LLM链文档提取器类,继承自基类
class LLMChainExtractor(BaseDocumentCompressor):
# 设置LLM链提取器的类说明文档
"""LLM链提取器
使用 LLM 链从文档中提取与查询相关的部分。
"""
# 构造函数
def __init__(
self,
llm_chain=None,
llm=None,
prompt=None,
get_input=None,
):
# 如果传入llm_chain则直接赋值
if llm_chain is not None:
self.llm_chain = llm_chain
# 如果传入llm而未传llm_chain,则根据参数生成链
elif llm is not None:
# 延迟导入PromptTemplate工具
from .prompts import PromptTemplate
# 如果未自定义提示词模板,使用默认中文抽取Prompt
if prompt is None:
prompt_template = PromptTemplate.from_template(CHINESE_EXTRACT_PROMPT)
else:
prompt_template = prompt
# 用私有方法构造链
self.llm_chain = self._create_chain(prompt_template, llm)
# 两者都没给时报错
else:
raise ValueError("必须提供 llm_chain 或 llm 参数")
# 设置get_input函数(用于将query/doc打包成模型输入),可以自定义
self.get_input = get_input or self._default_get_input
# 默认输入转换函数,把query和doc打包成dict
def _default_get_input(self, query: str, doc) -> dict:
# 返回一个包含问题和文本内容的字典
return {"question": query, "context": doc.page_content}
# 私有方法,实现用prompt_template和llm创建简单链对象
def _create_chain(self, prompt_template, llm):
# 定义内部类SimpleChain
class SimpleChain:
# 构造函数,保存prompt模板和llm对象
def __init__(self, prompt_template, llm):
self.prompt_template = prompt_template
self.llm = llm
# 实现invoke方法,将输入dict格式化prompt并调用llm
def invoke(self, input_dict, **kwargs):
# 格式化prompt
formatted_prompt = self.prompt_template.format(**input_dict)
# 用llm推理,得到响应
response = self.llm.invoke(formatted_prompt, **kwargs)
# 如果响应有content属性,则用其内容
if hasattr(response, 'content'):
text = response.content
# 如果响应直接为字符串
elif isinstance(response, str):
text = response
# 否则强转为字符串
else:
text = str(response)
# 去掉文本首尾空格
text = text.strip()
# 若内容为NO_OUTPUT或空,则返回空字符串
if text == "NO_OUTPUT" or text == "":
return ""
# 返回抽取的文本
return text
# 返回SimpleChain对象实例
return SimpleChain(prompt_template, llm)
# 类方法,支持通过llm快速初始化LLMChainExtractor
@classmethod
def from_llm(
cls,
llm,
prompt=None,
get_input=None,
):
# 返回初始化好的提取器对象
return cls(
llm=llm,
prompt=prompt,
get_input=get_input
)
# 实现文档压缩主接口
def compress_documents(
self,
documents,
query: str,
callbacks=None,
):
# 用于存储压缩结果的列表
compressed_docs = []
# 遍历所有输入文档
for doc in documents:
# 构造当前doc与query的输入结构
_input = self.get_input(query, doc)
# 用llm链抽取相关内容
output = self.llm_chain.invoke(_input)
# 如果没抽取到任何内容则跳过该文档
if not output or len(output) == 0:
continue
# 动态导入Document类
from .vectorstores import Document
# 用抽取得到内容与原有元数据构造新文档
compressed_doc = Document(
page_content=output,
metadata=doc.metadata if hasattr(doc, 'metadata') else {}
)
# 加入压缩结果列表
compressed_docs.append(compressed_doc)
# 返回压缩结果文档列表
return compressed_docs55.3. contextual.py #
smartchain/retrievers/contextual.py
# 导入基础检索器基类
from .base import BaseRetriever
# 导入基础文档压缩器基类
from ..document_compressors import BaseDocumentCompressor
# 定义上下文压缩检索器类,继承自BaseRetriever
class ContextualCompressionRetriever(BaseRetriever):
# 设置类文档字符串,说明该类包装基础检索器和文档压缩器
"""上下文压缩检索器
包装基础检索器,使用文档压缩器对检索结果进行压缩。
"""
# 定义初始化方法
def __init__(
self,
base_compressor: BaseDocumentCompressor,
base_retriever: BaseRetriever,
):
# 方法文档字符串,说明参数
"""初始化上下文压缩检索器
参数:
base_compressor: 文档压缩器实例
base_retriever: 基础检索器实例
"""
# 调用父类构造函数进行初始化
super().__init__()
# 保存文档压缩器对象
self.base_compressor = base_compressor
# 保存基础检索器对象
self.base_retriever = base_retriever
# 重写获取相关文档的钩子方法
def _get_relevant_documents(self, query, **kwargs):
# 方法文档,阐述参数与返回值
"""获取相关文档
参数:
query: 查询字符串
**kwargs: 其他参数
返回:
List: 压缩后的文档列表
"""
# 使用基础检索器进行初步检索,获取原始文档
docs = self.base_retriever.invoke(query, **kwargs)
# 如果没有检索到任何文档,则直接返回空列表
if not docs:
return []
# 使用文档压缩器对检索到的文档进行压缩处理
compressed_docs = self.base_compressor.compress_documents(
docs,
query
)
# 返回压缩处理后的文档列表
return list(compressed_docs)55.4. init.py #
smartchain/retrievers/init.py
from .vector_store import VectorStoreRetriever
from .base import BaseRetriever
from .tfidf import TFIDFRetriever
from .bm25 import BM25Retriever
from .vector import VectorSimilarityRetriever
from .ensemble import EnsembleRetriever
+from .contextual import ContextualCompressionRetriever
__all__ = [
"VectorStoreRetriever",
"BaseRetriever",
"TFIDFRetriever",
"BM25Retriever",
"VectorSimilarityRetriever",
+ "EnsembleRetriever",
+ "ContextualCompressionRetriever"
]55.5. 类 #
55.5.1 类说明 #
| 类名 | 类型 | 作用 | 关键属性/方法 |
|---|---|---|---|
| Document | 数据类 | 存储文档内容和元数据 | page_content: 文本内容metadata: 元数据字典 |
| HuggingFaceEmbeddings | 嵌入模型类 | 使用 HuggingFace 模型生成文本嵌入向量 | model_name: 模型名称embed_query(): 嵌入单个查询embed_documents(): 批量嵌入文档 |
| Chroma | 向量存储类 | Chroma 向量数据库封装 | embedding_function: 嵌入函数_collection: Chroma 集合add_texts(): 添加文本similarity_search(): 相似度搜索as_retriever(): 创建检索器 |
| VectorStoreRetriever | 检索器类 | 从向量存储中检索文档 | vectorstore: 向量存储实例search_type: 搜索类型invoke(): 执行检索 |
| BaseRetriever | 抽象基类 | 定义检索器接口规范 | invoke(): 对外调用接口_get_relevant_documents(): 抽象方法(需子类实现) |
| ContextualCompressionRetriever | 检索器类 | 包装基础检索器和文档压缩器 | base_retriever: 基础检索器实例base_compressor: 文档压缩器实例invoke(): 先检索再压缩_get_relevant_documents(): 核心逻辑 |
| BaseDocumentCompressor | 抽象基类 | 定义文档压缩器接口规范 | compress_documents(): 抽象方法(需子类实现) |
| LLMChainExtractor | 文档压缩器类 | 使用 LLM 提取文档相关内容 | llm_chain: LLM 链对象get_input: 输入转换from_llm(): 类方法创建实例compress_documents(): 压缩文档 |
| ChatDeepSeek | LLM 模型类 | DeepSeek 大语言模型 | model: 模型名称invoke(): 生成回复 |
| PromptTemplate | 提示词模板类 | 格式化提示词模板 | from_template(): 创建实例format(): 格式化模板 |
55.5.2 类图 #
55.5.3 时序图 #
55.5.4 调用过程 #
阶段 1:初始化嵌入模型和向量数据库
embeddings = HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={"device": "cpu"}
)
chroma_db = Chroma(
persist_directory="chroma_db",
embedding_function=embeddings,
collection_name="test",
collection_metadata={"hnsw:space": "cosine"}
)- 创建
HuggingFaceEmbeddings实例 - 创建
Chroma向量数据库实例,连接持久化存储
阶段 2:检查并添加数据
if not chroma_db._collection.count():
chroma_db.add_texts(texts, metadatas)- 检查数据库是否为空
- 如果为空,批量添加文本和元数据到数据库
阶段 3:创建 LLM 和压缩器
llm = ChatDeepSeek(model="deepseek-chat")
compressor = LLMChainExtractor.from_llm(llm=llm)调用链:
ChatDeepSeek.__init__()创建 LLM 实例LLMChainExtractor.from_llm()类方法- 调用
cls(llm=llm, prompt=None, get_input=None)
- 调用
LLMChainExtractor.__init__()方法- 创建默认的
PromptTemplate(使用CHINESE_EXTRACT_PROMPT) - 调用
_create_chain(prompt_template, llm)创建SimpleChain对象 - 设置
get_input函数(默认使用_default_get_input)
- 创建默认的
阶段 4:创建基础检索器和压缩检索器
base_retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 20})
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=base_retriever
)调用链:
Chroma.as_retriever()创建VectorStoreRetriever实例ContextualCompressionRetriever.__init__()方法- 调用
super().__init__()初始化基类 - 保存
base_compressor和base_retriever
- 调用
阶段 5:执行检索和压缩
results = compression_retriever.invoke(query)调用链:
invoke()方法(继承自BaseRetriever)- 调用
_get_relevant_documents(query)
- 调用
_get_relevant_documents()方法(核心检索逻辑)步骤 1:基础检索
docs = self.base_retriever.invoke(query, **kwargs)- 调用
VectorStoreRetriever.invoke()从向量数据库检索文档(返回 k=20 个文档)
步骤 2:文档压缩
compressed_docs = self.base_compressor.compress_documents(docs, query)内部处理流程(
LLMChainExtractor.compress_documents()):for doc in documents: _input = self.get_input(query, doc) # {"question": query, "context": doc.page_content} output = self.llm_chain.invoke(_input) # 调用LLM提取相关内容 if not output or len(output) == 0: continue # 跳过不相关文档 compressed_doc = Document(page_content=output, metadata=doc.metadata) compressed_docs.append(compressed_doc)LLM 链调用流程(
SimpleChain.invoke()):formatted_prompt = self.prompt_template.format(**input_dict) response = self.llm.invoke(formatted_prompt) text = response.content.strip() if text == "NO_OUTPUT" or text == "": return "" # 文档不相关 return text # 返回提取的相关内容步骤 3:返回压缩后的文档列表
- 返回过滤和压缩后的
Document列表
- 调用
阶段 6:结果输出
print(f"查询:{query}")
print(f"共检索到 {len(results)} 个文档")
for i, doc in enumerate(results):
print(f"检索结果{i}:{doc.page_content}")
print(f" 元数据:{doc.metadata}\n")- 遍历
results(压缩后的Document列表) - 打印每个文档的
page_content和metadata
55.5.5 文档压缩流程 #
LLMChainExtractor 工作原理
- 输入:检索到的文档列表 + 查询字符串
- 处理:对每个文档
- 构造提示词:将查询和文档内容组合成提示词
- LLM 提取:调用 LLM 提取文档中与查询相关的部分
- 过滤:如果 LLM 返回 "NO_OUTPUT" 或空字符串,则过滤该文档
- 输出:压缩后的文档列表(只包含相关内容)
默认提示词模板
CHINESE_EXTRACT_PROMPT = """给定以下问题和上下文,提取上下文中与回答问题相关的任何部分(保持原样)。如果上下文都不相关,返回 NO_OUTPUT。
记住,*不要*编辑提取的上下文部分。
> 问题:{question}
> 上下文:
>>>
{context}
>>>
提取的相关部分:"""56.EmbeddingsFilter #
本节将介绍如何使用 EmbeddingsFilter 实现向量检索结果的相关性过滤与压缩。EmbeddingsFilter 属于文档压缩器的一种,常用于“上下文压缩型检索器”场景 —— 即先使用向量/关键词等手段召回一定数量的候选文档,再利用 EmbeddingsFilter 对这些候选文档根据语义相关性进行二次筛选,提升下游LLM问答的准确性与精简性。
核心思路如下:
- 先召回一批候选文档(如向量数据库筛选top50)。
- 使用
EmbeddingFilter计算每个文档与查询(query)的语义相似度。 - 可以选择只保留前k个最相关文档,或者设置相似度阈值只保留相关性高于阈值的文档。
- 支持自定义相似度函数,默认使用余弦相似度。
典型应用流程如下:
- 初始化嵌入模型和
Chroma向量数据库; - 构建
EmbeddingsFilter,传入 embeddings、top-k 或相关性阈值等参数; - 可结合 ContextualCompressionRetriever,实现“检索 + 嵌入压缩”端到端问答。
具体实现与调优建议:
- 若侧重精度,建议适当调小k或设置较高阈值,只让最相关文档进入大模型;
- 若注重召回广度,可放宽阈值或加大k值。
- 支持自定义 embedding 模型、相似度函数等,方便适配不同业务场景。
56.1. 56.EmbeddingsFilter.py #
56.EmbeddingsFilter.py
#from langchain_chroma import Chroma
#from langchain_huggingface import HuggingFaceEmbeddings
#from langchain_classic.retrievers.contextual_compression import ContextualCompressionRetriever
#from langchain_classic.retrievers.document_compressors.embeddings_filter import EmbeddingsFilter
#from langchain_deepseek import ChatDeepSeek
# 导入 Chroma 向量数据库类
from smartchain.vectorstores import Chroma
# 导入 HuggingFaceEmbeddings 嵌入模型类
from smartchain.embeddings import HuggingFaceEmbeddings
# 导入上下文压缩型检索器
from smartchain.retrievers import ContextualCompressionRetriever
# 导入 LLMChainExtractor 文档压缩器
from smartchain.document_compressors import EmbeddingsFilter
# 导入 ChatDeepSeek LLM 模型
from smartchain.chat_models import ChatDeepSeek
# 设置本地嵌入模型的路径
model_path = "C:/Users/Administrator/.cache/modelscope/hub/models/sentence-transformers/all-MiniLM-L6-v2"
# 初始化嵌入模型,指定模型路径与推理设备
embeddings = HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={"device": "cpu"}
)
# 初始化 Chroma 向量数据库实例,指定持久化目录、嵌入函数、集合名称和元数据
chroma_db = Chroma(
persist_directory="chroma_db",
embedding_function=embeddings,
collection_name="test",
collection_metadata={"hnsw:space": "cosine"}
)
# 检查数据库是否为空,若为空批量插入初始文本及其元数据
if not chroma_db._collection.count():
# 定义插入的文本列表
texts = [
"人工智能(AI)是一种让机器模拟人类智能行为的技术。",
"深度学习是人工智能的一个重要分支,通过多层神经网络学习数据。",
"ChatGPT是OpenAI开发的强大自然语言模型。",
"向量数据库可以高效地存储和检索文本的嵌入向量。",
"机器人学结合了人工智能和机械工程,推动自动化发展。",
"AI可以辅助医生进行医学影像分析。",
"大模型在对话、问答、摘要等领域不断取得突破。",
"知识库问答系统常用于企业信息检索场景。",
]
# 对应的元数据列表
metadatas = [
{"topic": "ai"},
{"topic": "ai"},
{"topic": "nlp"},
{"topic": "vector_db"},
{"topic": "robotics"},
{"topic": "healthcare"},
{"topic": "llm"},
{"topic": "retrieval"},
]
# 向向量数据库批量插入文本和元数据
chroma_db.add_texts(texts, metadatas)
# 创建 ChatDeepSeek LLM 对象,使用 deepseek-chat 模型
llm = ChatDeepSeek(model="deepseek-chat")
# 创建基础检索器,指定检索类型为 similarity 和返回数量 k=20
# 建议 k 取较大值,因为压缩器后续会过滤部分不相关的结果
base_retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 20})
compressor = EmbeddingsFilter(embeddings=embeddings,k=5)
# 创建上下文压缩检索器,结合基础检索器与文档压缩器
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=base_retriever
)
# 指定检索查询内容(支持中文或英文)
query = "人工智能" # 可以更换为任意中英文问题
# 调用上下文压缩检索器进行检索
results = compression_retriever.invoke(query)
# 打印查询内容
print(f"查询:{query}")
# 打印检索到的文档数量
print(f"共检索到 {len(results)} 个文档")
# 遍历输出每条检索结果及其元数据
for i, doc in enumerate(results):
print(f"检索结果{i}:{doc.page_content}")
print(f" 元数据:{doc.metadata}\n")56.2. document_compressors.py #
smartchain/document_compressors.py
# 从abc模块导入抽象基类支持
from abc import ABC, abstractmethod
+import numpy as np
# 定义一个默认的中文抽取Prompt模板
CHINESE_EXTRACT_PROMPT = """给定以下问题和上下文,提取上下文中与回答问题相关的任何部分(保持原样)。如果上下文都不相关,返回 NO_OUTPUT。
记住,*不要*编辑提取的上下文部分。
> 问题:{question}
> 上下文:
>>>
{context}
>>>
提取的相关部分:"""
# 定义文档压缩器的抽象基类
class BaseDocumentCompressor(ABC):
# 设置文档压缩器的类说明文档
"""文档压缩器抽象基类
用于对检索到的文档进行后处理,例如提取相关内容、过滤不相关文档等。
"""
# 声明子类必须实现的抽象方法
@abstractmethod
def compress_documents(
self,
documents,
query: str,
callbacks=None,
):
# 方法说明文档
"""压缩文档
参数:
documents: 检索到的文档列表
query: 查询字符串
callbacks: 可选的回调函数
返回:
Sequence: 压缩后的文档列表
"""
# 抽象方法体不实现
pass
# 定义LLM链文档提取器类,继承自基类
class LLMChainExtractor(BaseDocumentCompressor):
# 设置LLM链提取器的类说明文档
"""LLM链提取器
使用 LLM 链从文档中提取与查询相关的部分。
"""
# 构造函数
def __init__(
self,
llm_chain=None,
llm=None,
prompt=None,
get_input=None,
):
# 如果传入llm_chain则直接赋值
if llm_chain is not None:
self.llm_chain = llm_chain
# 如果传入llm而未传llm_chain,则根据参数生成链
elif llm is not None:
# 延迟导入PromptTemplate工具
from .prompts import PromptTemplate
# 如果未自定义提示词模板,使用默认中文抽取Prompt
if prompt is None:
prompt_template = PromptTemplate.from_template(CHINESE_EXTRACT_PROMPT)
else:
prompt_template = prompt
# 用私有方法构造链
self.llm_chain = self._create_chain(prompt_template, llm)
# 两者都没给时报错
else:
raise ValueError("必须提供 llm_chain 或 llm 参数")
# 设置get_input函数(用于将query/doc打包成模型输入),可以自定义
self.get_input = get_input or self._default_get_input
# 默认输入转换函数,把query和doc打包成dict
def _default_get_input(self, query: str, doc) -> dict:
# 返回一个包含问题和文本内容的字典
return {"question": query, "context": doc.page_content}
# 私有方法,实现用prompt_template和llm创建简单链对象
def _create_chain(self, prompt_template, llm):
# 定义内部类SimpleChain
class SimpleChain:
# 构造函数,保存prompt模板和llm对象
def __init__(self, prompt_template, llm):
self.prompt_template = prompt_template
self.llm = llm
# 实现invoke方法,将输入dict格式化prompt并调用llm
def invoke(self, input_dict, **kwargs):
# 格式化prompt
formatted_prompt = self.prompt_template.format(**input_dict)
# 用llm推理,得到响应
response = self.llm.invoke(formatted_prompt, **kwargs)
# 如果响应有content属性,则用其内容
if hasattr(response, 'content'):
text = response.content
# 如果响应直接为字符串
elif isinstance(response, str):
text = response
# 否则强转为字符串
else:
text = str(response)
# 去掉文本首尾空格
text = text.strip()
# 若内容为NO_OUTPUT或空,则返回空字符串
if text == "NO_OUTPUT" or text == "":
return ""
# 返回抽取的文本
return text
# 返回SimpleChain对象实例
return SimpleChain(prompt_template, llm)
# 类方法,支持通过llm快速初始化LLMChainExtractor
@classmethod
def from_llm(
cls,
llm,
prompt=None,
get_input=None,
):
# 返回初始化好的提取器对象
return cls(
llm=llm,
prompt=prompt,
get_input=get_input
)
# 实现文档压缩主接口
def compress_documents(
self,
documents,
query: str,
callbacks=None,
):
# 用于存储压缩结果的列表
compressed_docs = []
# 遍历所有输入文档
for doc in documents:
# 构造当前doc与query的输入结构
_input = self.get_input(query, doc)
# 用llm链抽取相关内容
output = self.llm_chain.invoke(_input)
# 如果没抽取到任何内容则跳过该文档
if not output or len(output) == 0:
continue
# 动态导入Document类
from .vectorstores import Document
# 用抽取得到内容与原有元数据构造新文档
compressed_doc = Document(
page_content=output,
metadata=doc.metadata if hasattr(doc, 'metadata') else {}
)
# 加入压缩结果列表
compressed_docs.append(compressed_doc)
# 返回压缩结果文档列表
return compressed_docs
# 定义余弦相似度计算函数
+def cosine_similarity(query_embeddings, doc_embeddings):
# 将查询嵌入转换为numpy数组
+ query_emb = np.array(query_embeddings)
# 将文档嵌入转换为numpy数组
+ doc_embs = np.array(doc_embeddings)
# 若查询为一维,转换为二维以便向量批量操作
+ if query_emb.ndim == 1:
+ query_emb = query_emb.reshape(1, -1)
# 计算查询和所有文档嵌入的点积
+ dot_product = np.dot(query_emb, doc_embs.T)
# 计算查询向量的范数
+ query_norm = np.linalg.norm(query_emb, axis=1, keepdims=True)
# 计算所有文档嵌入的范数,并转置方便后续相乘
+ doc_norms = np.linalg.norm(doc_embs, axis=1, keepdims=True).T
# 使用点积/范数乘积得到余弦相似度
+ similarity = dot_product / (query_norm * doc_norms)
# 将因除零产生的nan、正负无穷转为0
+ similarity = np.nan_to_num(similarity, nan=0.0, posinf=0.0, neginf=0.0)
# 返回第一行的相似度结果(只支持一个query)
+ return similarity[0]
# 定义嵌入过滤器类,继承自BaseDocumentCompressor
+class EmbeddingsFilter(BaseDocumentCompressor):
# 类文档字符串,说明该类用途
+ """嵌入过滤器
+ 使用嵌入模型计算查询与文档的相似度,根据相似度过滤文档。
+ """
# 构造方法定义
+ def __init__(
+ self,
+ embeddings,
+ similarity_fn=None,
+ k = 20,
+ similarity_threshold = None,
+ ):
# 参数校验,如k和阈值都没传则报错
+ if k is None and similarity_threshold is None:
+ raise ValueError("必须指定 k 或 similarity_threshold 之一")
# 存储嵌入对象
+ self.embeddings = embeddings
# 如果没传相似度函数,默认使用余弦相似度
+ self.similarity_fn = similarity_fn or cosine_similarity
# 存储k值(返回多少最相似文档)
+ self.k = k
# 存储相似度阈值
+ self.similarity_threshold = similarity_threshold
# 文档压缩/过滤主入口
+ def compress_documents(
+ self,
+ documents,
+ query: str,
+ callbacks=None,
+ ):
# 若输入文档为空,直接返回空列表
+ if not documents:
+ return []
# 收集每个文档的文本内容
+ doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
# 计算所有文档嵌入
+ doc_embeddings = self.embeddings.embed_documents(doc_texts)
# 计算查询嵌入
+ query_embedding = self.embeddings.embed_query(query)
# 计算查询和所有文档的相似度分数
+ similarities = self.similarity_fn([query_embedding], doc_embeddings)
# 初始包含所有文档的索引
+ included_idxs = np.arange(len(documents))
# 如果指定了k值,保留相似度最高的k篇文档
+ if self.k is not None:
+ sorted_indices = np.argsort(similarities)[::-1]
+ included_idxs = sorted_indices[:self.k]
# 如果指定了相似度阈值,进一步只保留超过阈值的文档
+ if self.similarity_threshold is not None:
+ similar_enough = similarities[included_idxs] > self.similarity_threshold
+ included_idxs = included_idxs[similar_enough]
# 返回最终保留的文档列表
+ return [documents[i] for i in included_idxs]
56.3. 类 #
56.3.1 类说明 #
| 类名 | 类型 | 作用 | 关键属性/方法 |
|---|---|---|---|
| Document | 数据类 | 存储文档内容和元数据 | page_content: 文档文本内容metadata: 元数据字典 |
| HuggingFaceEmbeddings | 嵌入模型类 | 使用 HuggingFace 模型生成文本嵌入向量 | model_name: 模型名称embed_query(): 嵌入单个查询文本embed_documents(): 批量嵌入文档文本 |
| Chroma | 向量存储类 | Chroma 向量数据库封装 | embedding_function: 嵌入函数_collection: Chroma 集合add_texts(): 添加文本到数据库similarity_search(): 相似度搜索as_retriever(): 创建检索器 |
| VectorStoreRetriever | 检索器类 | 从向量存储中检索文档 | vectorstore: 向量存储实例search_type: 搜索类型invoke(): 执行检索 |
| BaseRetriever | 抽象基类 | 定义检索器接口规范 | invoke(): 对外调用接口_get_relevant_documents(): 抽象方法(需子类实现) |
| ContextualCompressionRetriever | 检索器类 | 包装基础检索器和文档压缩器 | base_retriever: 基础检索器实例base_compressor: 文档压缩器实例invoke(): 执行检索(先检索后压缩)_get_relevant_documents(): 核心检索逻辑 |
| BaseDocumentCompressor | 抽象基类 | 定义文档压缩器接口规范 | compress_documents(): 抽象方法(需子类实现) |
| EmbeddingsFilter | 文档压缩器类 | 使用嵌入相似度过滤文档 | embeddings: 嵌入模型实例similarity_fn: 相似度计算函数k: 返回文档数量similarity_threshold: 相似度阈值compress_documents(): 压缩文档(基于相似度过滤) |
| cosine_similarity | 函数 | 计算余弦相似度 | 计算查询向量与文档向量的余弦相似度 |
56.3.2 类图 #
56.3.3 时序图 #
56.3.4 调用过程 #
阶段 1:初始化嵌入模型和向量数据库
embeddings = HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={"device": "cpu"}
)
chroma_db = Chroma(
persist_directory="chroma_db",
embedding_function=embeddings,
collection_name="test",
collection_metadata={"hnsw:space": "cosine"}
)- 创建
HuggingFaceEmbeddings实例 - 创建
Chroma向量数据库实例,连接持久化存储
阶段 2:检查并添加数据
if not chroma_db._collection.count():
chroma_db.add_texts(texts, metadatas)- 检查数据库是否为空
- 如果为空,批量添加文本和元数据到数据库
阶段 3:创建基础检索器和压缩器
base_retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 20})
compressor = EmbeddingsFilter(embeddings=embeddings, k=5)调用链:
Chroma.as_retriever()创建VectorStoreRetriever实例EmbeddingsFilter.__init__()方法- 验证参数:必须指定
k或similarity_threshold之一 - 保存
embeddings、similarity_fn(默认使用cosine_similarity)、k、similarity_threshold
- 验证参数:必须指定
阶段 4:创建压缩检索器
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=base_retriever
)- 创建
ContextualCompressionRetriever实例 - 保存
base_compressor和base_retriever
阶段 5:执行检索和压缩
results = compression_retriever.invoke(query)调用链:
invoke()方法(继承自BaseRetriever)- 调用
_get_relevant_documents(query)
- 调用
_get_relevant_documents()方法(核心检索逻辑)步骤 1:基础检索
docs = self.base_retriever.invoke(query, **kwargs)- 调用
VectorStoreRetriever.invoke()从向量数据库检索文档(返回 k=20 个文档)
步骤 2:文档压缩
compressed_docs = self.base_compressor.compress_documents(docs, query)内部处理流程(
EmbeddingsFilter.compress_documents()):步骤 2.1:计算嵌入向量
doc_texts = [doc.page_content for doc in documents] doc_embeddings = self.embeddings.embed_documents(doc_texts) # 批量计算文档嵌入 query_embedding = self.embeddings.embed_query(query) # 计算查询嵌入步骤 2.2:计算相似度
similarities = self.similarity_fn([query_embedding], doc_embeddings)内部调用
cosine_similarity()函数:dot_product = np.dot(query_emb, doc_embs.T) query_norm = np.linalg.norm(query_emb, axis=1, keepdims=True) doc_norms = np.linalg.norm(doc_embs, axis=1, keepdims=True).T similarity = dot_product / (query_norm * doc_norms)步骤 2.3:过滤文档
included_idxs = np.arange(len(documents)) if self.k is not None: sorted_indices = np.argsort(similarities)[::-1] # 降序排序 included_idxs = sorted_indices[:self.k] # 取前k个 if self.similarity_threshold is not None: similar_enough = similarities[included_idxs] > self.similarity_threshold included_idxs = included_idxs[similar_enough] # 进一步过滤步骤 2.4:返回过滤后的文档
return [documents[i] for i in included_idxs]步骤 3:返回压缩后的文档列表
- 返回过滤后的
Document列表(最多 k=5 个)
- 调用
阶段 6:结果输出
print(f"查询:{query}")
print(f"共检索到 {len(results)} 个文档")
for i, doc in enumerate(results):
print(f"检索结果{i}:{doc.page_content}")
print(f" 元数据:{doc.metadata}\n")- 遍历
results(过滤后的Document列表) - 打印每个文档的
page_content和metadata
56.3.5 EmbeddingsFilter 工作原理 #
过滤机制
EmbeddingsFilter 使用嵌入相似度对文档进行二次过滤:
- 计算相似度:对检索到的文档重新计算与查询的嵌入相似度
- 排序和筛选:
- 如果指定了
k:保留相似度最高的 k 个文档 - 如果指定了
similarity_threshold:只保留相似度超过阈值的文档 - 两者可以同时使用:先按 k 筛选,再按阈值过滤
- 如果指定了
为什么需要二次过滤?
- 向量数据库的相似度计算可能不够精确
- 使用不同的嵌入模型可以更准确地评估相关性
- 可以设置更严格的相似度阈值,提高结果质量
余弦相似度公式
对于查询向量 $\vec{q}$ 和文档向量 $\vec{d}$:
$$\cos(\theta) = \frac{\vec{q} \cdot \vec{d}}{|\vec{q}| \times |\vec{d}|} = \frac{\sum_{i=1}^{n} q_i \times d_i}{\sqrt{\sum_{i=1}^{n} q_i^2} \times \sqrt{\sum_{i=1}^{n} d_i^2}}$$
相似度范围:$[-1, 1]$,值越大越相似。
56.3.6 EmbeddingsFilter vs LLMChainExtractor #
| 特性 | EmbeddingsFilter | LLMChainExtractor |
|---|---|---|
| 过滤方式 | 基于嵌入相似度 | 基于 LLM 语义理解 |
| 计算成本 | 较低(向量计算) | 较高(LLM 推理) |
| 速度 | 快 | 较慢 |
| 准确性 | 中等(依赖嵌入质量) | 高(LLM 理解能力强) |
| 文档修改 | 不修改文档内容 | 提取相关内容,可能修改内容 |
| 适用场景 | 快速过滤、大批量文档 | 需要精确提取、小批量文档 |
57.CrossEncoderReranker #
CrossEncoderReranker(交叉编码器重排器)是一种基于大模型或深度学习模型的文档相关性重排序方法。它通常在初步召回了一批候选文档后,对文档与查询的匹配度做更精细的打分和排序,从而过滤掉不相关的内容,仅保留与查询最相关的若干条信息。
典型流程如下:
- 基础检索器(如向量数据库)召回一批与查询(query)内容相似的文档,通常数量较多(例如 k=20)。
- CrossEncoderReranker 将候选文档与查询两两拼接,输入到交叉编码器模型(如 BAAI/bge-reranker-base)中,获得每个文档与查询的相关性分数。
- 按分数降序排列,仅返回 top_n 个最相关的文档。
这种双阶段(先召回再重排)的方法,兼顾了“检索速度”与“结果相关性”。其中,基础召回用较快的向量检索,重排用较精细的神经网络模型。该方法常见于智能问答、知识检索、RAG等自然语言应用。
代码结构要点:
CrossEncoderReranker通常作为ContextualCompressionRetriever的base_compressor参数,负责文档的重排与压缩。- 需要配合如 HuggingFaceCrossEncoder 这样的模型,使用如 BGE-Reranker 等开箱即用的重排权重文件。
- 可以设置
top_n控制最终输出的文档数,以防止无关内容干扰下游任务。
使用场景举例:
- 你检索了20条可能相关的知识库文档,但只希望下游看到与提问“最紧密相关”的3条,此时就应在检索管道中加入
CrossEncoderReranker。 - 特别适用于多轮问答、长文档检索等上下文相关性要求高的场合。
这样可以大幅提升检索相关性的精度和用户体验。
57.1. 57.CrossEncoderReranker.py #
57.CrossEncoderReranker.py
#from langchain_chroma import Chroma
#from langchain_huggingface import HuggingFaceEmbeddings
#from langchain_classic.retrievers.contextual_compression import ContextualCompressionRetriever
#from langchain_classic.retrievers.document_compressors import CrossEncoderReranker
#from langchain_community.cross_encoders import HuggingFaceCrossEncoder
#from langchain_deepseek import ChatDeepSeek
# 导入 Chroma 向量数据库类
from smartchain.vectorstores import Chroma
# 导入 HuggingFaceEmbeddings 嵌入模型类
from smartchain.embeddings import HuggingFaceEmbeddings
# 导入上下文压缩型检索器
from smartchain.retrievers import ContextualCompressionRetriever
# 导入 CrossEncoderReranker 文档压缩器
from smartchain.document_compressors import CrossEncoderReranker
# 导入 HuggingFaceCrossEncoder 交叉编码器
from smartchain.cross_encoders import HuggingFaceCrossEncoder
# 导入 ChatDeepSeek LLM 模型
from smartchain.chat_models import ChatDeepSeek
# 设置本地嵌入模型的路径
model_path = "C:/Users/Administrator/.cache/modelscope/hub/models/sentence-transformers/all-MiniLM-L6-v2"
# 初始化嵌入模型,指定模型路径与推理设备
embeddings = HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={"device": "cpu"}
)
# 初始化 Chroma 向量数据库实例,指定持久化目录、嵌入函数、集合名称和元数据
chroma_db = Chroma(
persist_directory="chroma_db",
embedding_function=embeddings,
collection_name="test",
collection_metadata={"hnsw:space": "cosine"}
)
# 检查数据库是否为空,若为空批量插入初始文本及其元数据
if not chroma_db._collection.count():
# 定义插入的文本列表
texts = [
"人工智能(AI)是一种让机器模拟人类智能行为的技术。",
"深度学习是人工智能的一个重要分支,通过多层神经网络学习数据。",
"ChatGPT是OpenAI开发的强大自然语言模型。",
"向量数据库可以高效地存储和检索文本的嵌入向量。",
"机器人学结合了人工智能和机械工程,推动自动化发展。",
"AI可以辅助医生进行医学影像分析。",
"大模型在对话、问答、摘要等领域不断取得突破。",
"知识库问答系统常用于企业信息检索场景。",
]
# 对应的元数据列表
metadatas = [
{"topic": "ai"},
{"topic": "ai"},
{"topic": "nlp"},
{"topic": "vector_db"},
{"topic": "robotics"},
{"topic": "healthcare"},
{"topic": "llm"},
{"topic": "retrieval"},
]
# 向向量数据库批量插入文本和元数据
chroma_db.add_texts(texts, metadatas)
# 创建 ChatDeepSeek LLM 对象,使用 deepseek-chat 模型
llm = ChatDeepSeek(model="deepseek-chat")
# 创建基础检索器,指定检索类型为 similarity 和返回数量 k=20
# 建议 k 取较大值,因为压缩器后续会过滤部分不相关的结果
base_retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 20})
# 指定 CrossEncoder 重排模型路径
reranker_model_path = "C:/Users/Administrator/.cache/modelscope/hub/models/BAAI/bge-reranker-base"
# 加载 HuggingFaceCrossEncoder 进行交叉编码
cross_encoder = HuggingFaceCrossEncoder(model_name=reranker_model_path)
# 用 CrossEncoderReranker 封装,并指定重排后返回 top 3
compressor = CrossEncoderReranker(model=cross_encoder, top_n=3)
# 创建上下文压缩检索器,结合基础检索器与文档压缩器
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=base_retriever
)
# 指定检索查询内容(支持中文或英文)
query = "人工智能" # 可以更换为任意中英文问题
# 调用上下文压缩检索器进行检索
results = compression_retriever.invoke(query)
# 打印查询内容
print(f"查询:{query}")
# 打印检索到的文档数量
print(f"共检索到 {len(results)} 个文档")
# 遍历输出每条检索结果及其元数据
for i, doc in enumerate(results):
print(f"检索结果{i}:{doc.page_content}")
print(f" 元数据:{doc.metadata}\n")57.2. cross_encoders.py #
smartchain/cross_encoders.py
# 导入 sentence_transformers 包中的 CrossEncoder 类
from sentence_transformers import CrossEncoder
# 定义一个轻量包装的 HuggingFaceCrossEncoder 类
class HuggingFaceCrossEncoder:
"""轻量封装 sentence_transformers.CrossEncoder"""
# 初始化方法,接收模型名称、设备和其它模型参数
def __init__(self, model_name: str, device: str | None = None, **model_kwargs):
# 加载指定的 CrossEncoder 模型,如果未指定设备则默认用 "cpu"
self.model = CrossEncoder(model_name_or_path=model_name, device=device or "cpu", **model_kwargs)
# 定义预测方法,输入为句对组成的列表
def predict(self, pairs):
# 调用 CrossEncoder 的 predict 方法,返回分数
return self.model.predict(pairs)57.3. init.py #
smartchain/document_compressors/init.py
from .base import BaseDocumentCompressor
from .llm_chain import LLMChainExtractor
from .embeddings import EmbeddingsFilter
from .cross_encoder import CrossEncoderReranker
__all__ = [
"BaseDocumentCompressor",
"LLMChainExtractor",
"EmbeddingsFilter",
"CrossEncoderReranker",
]
57.4. base.py #
smartchain/document_compressors/base.py
# 从abc模块导入抽象基类支持
from abc import ABC, abstractmethod
# 定义文档压缩器的抽象基类
class BaseDocumentCompressor(ABC):
# 设置文档压缩器的类说明文档
"""文档压缩器抽象基类
用于对检索到的文档进行后处理,例如提取相关内容、过滤不相关文档等。
"""
# 声明子类必须实现的抽象方法
@abstractmethod
def compress_documents(
self,
documents,
query: str,
callbacks=None,
):
# 方法说明文档
"""压缩文档
参数:
documents: 检索到的文档列表
query: 查询字符串
callbacks: 可选的回调函数
返回:
Sequence: 压缩后的文档列表
"""
# 抽象方法体不实现
pass57.5. cross_encoder.py #
smartchain/document_compressors/cross_encoder.py
from ..cross_encoders import HuggingFaceCrossEncoder
from .base import BaseDocumentCompressor
# 定义 CrossEncoderReranker 类,用于通过交叉编码器对检索结果打分并重排序
class CrossEncoderReranker(BaseDocumentCompressor):
"""用 cross-encoder 打分并按分数截断"""
# 初始化方法,接收封装过的 cross-encoder 模型对象和 top_n 截断值
def __init__(self, model: HuggingFaceCrossEncoder, top_n: int = 3):
# 保存 cross-encoder 模型实例
self.model = model
# 保存需要返回的文档数量 top_n
self.top_n = top_n
# 定义 compress_documents 方法,对输入的文档进行重排序和截断
def compress_documents(self, documents, query, callbacks=None):
# 如果传入文档列表为空,直接返回空列表
if not documents:
return []
# 以 (query, 文档内容) 形式为每个文档构造输入对
pairs = [(query, doc.page_content) for doc in documents]
# 用 cross-encoder 对所有句对进行评分
scores = self.model.predict(pairs)
# 将文档及其分数打包为元组,并按分数从高到低排序
ranked = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
# 仅返回 top_n 个分数最高的文档
return [doc for doc, _ in ranked[: self.top_n]]57.6. embeddings.py #
smartchain/document_compressors/embeddings.py
import numpy as np
from .base import BaseDocumentCompressor
# 定义余弦相似度计算函数
def cosine_similarity(query_embeddings, doc_embeddings):
# 将查询嵌入转换为numpy数组
query_emb = np.array(query_embeddings)
# 将文档嵌入转换为numpy数组
doc_embs = np.array(doc_embeddings)
# 若查询为一维,转换为二维以便向量批量操作
if query_emb.ndim == 1:
query_emb = query_emb.reshape(1, -1)
# 计算查询和所有文档嵌入的点积
dot_product = np.dot(query_emb, doc_embs.T)
# 计算查询向量的范数
query_norm = np.linalg.norm(query_emb, axis=1, keepdims=True)
# 计算所有文档嵌入的范数,并转置方便后续相乘
doc_norms = np.linalg.norm(doc_embs, axis=1, keepdims=True).T
# 使用点积/范数乘积得到余弦相似度
similarity = dot_product / (query_norm * doc_norms)
# 将因除零产生的nan、正负无穷转为0
similarity = np.nan_to_num(similarity, nan=0.0, posinf=0.0, neginf=0.0)
# 返回第一行的相似度结果(只支持一个query)
return similarity[0]
# 定义嵌入过滤器类,继承自BaseDocumentCompressor
class EmbeddingsFilter(BaseDocumentCompressor):
# 类文档字符串,说明该类用途
"""嵌入过滤器
使用嵌入模型计算查询与文档的相似度,根据相似度过滤文档。
"""
# 构造方法定义
def __init__(
self,
embeddings,
similarity_fn=None,
k = 20,
similarity_threshold = None,
):
# 参数校验,如k和阈值都没传则报错
if k is None and similarity_threshold is None:
raise ValueError("必须指定 k 或 similarity_threshold 之一")
# 存储嵌入对象
self.embeddings = embeddings
# 如果没传相似度函数,默认使用余弦相似度
self.similarity_fn = similarity_fn or cosine_similarity
# 存储k值(返回多少最相似文档)
self.k = k
# 存储相似度阈值
self.similarity_threshold = similarity_threshold
# 文档压缩/过滤主入口
def compress_documents(
self,
documents,
query: str,
callbacks=None,
):
# 若输入文档为空,直接返回空列表
if not documents:
return []
# 收集每个文档的文本内容
doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
# 计算所有文档嵌入
doc_embeddings = self.embeddings.embed_documents(doc_texts)
# 计算查询嵌入
query_embedding = self.embeddings.embed_query(query)
# 计算查询和所有文档的相似度分数
similarities = self.similarity_fn([query_embedding], doc_embeddings)
# 初始包含所有文档的索引
included_idxs = np.arange(len(documents))
# 如果指定了k值,保留相似度最高的k篇文档
if self.k is not None:
sorted_indices = np.argsort(similarities)[::-1]
included_idxs = sorted_indices[:self.k]
# 如果指定了相似度阈值,进一步只保留超过阈值的文档
if self.similarity_threshold is not None:
similar_enough = similarities[included_idxs] > self.similarity_threshold
included_idxs = included_idxs[similar_enough]
# 返回最终保留的文档列表
return [documents[i] for i in included_idxs]
57.7. llm_chain.py #
smartchain/document_compressors/llm_chain.py
from smartchain.document_compressors.base import BaseDocumentCompressor
from ..prompts import PromptTemplate
from ..documents import Document
# 定义一个默认的中文抽取Prompt模板
CHINESE_EXTRACT_PROMPT = """给定以下问题和上下文,提取上下文中与回答问题相关的任何部分(保持原样)。如果上下文都不相关,返回 NO_OUTPUT。
记住,*不要*编辑提取的上下文部分。
> 问题:{question}
> 上下文:
>>>
{context}
>>>
提取的相关部分:"""
# 定义LLM链文档提取器类,继承自基类
class LLMChainExtractor(BaseDocumentCompressor):
# 设置LLM链提取器的类说明文档
"""LLM链提取器
使用 LLM 链从文档中提取与查询相关的部分。
"""
# 构造函数
def __init__(
self,
llm_chain=None,
llm=None,
prompt=None,
get_input=None,
):
# 如果传入llm_chain则直接赋值
if llm_chain is not None:
self.llm_chain = llm_chain
# 如果传入llm而未传llm_chain,则根据参数生成链
elif llm is not None:
# 如果未自定义提示词模板,使用默认中文抽取Prompt
if prompt is None:
prompt_template = PromptTemplate.from_template(CHINESE_EXTRACT_PROMPT)
else:
prompt_template = prompt
# 用私有方法构造链
self.llm_chain = self._create_chain(prompt_template, llm)
# 两者都没给时报错
else:
raise ValueError("必须提供 llm_chain 或 llm 参数")
# 设置get_input函数(用于将query/doc打包成模型输入),可以自定义
self.get_input = get_input or self._default_get_input
# 默认输入转换函数,把query和doc打包成dict
def _default_get_input(self, query: str, doc) -> dict:
# 返回一个包含问题和文本内容的字典
return {"question": query, "context": doc.page_content}
# 私有方法,实现用prompt_template和llm创建简单链对象
def _create_chain(self, prompt_template, llm):
# 定义内部类SimpleChain
class SimpleChain:
# 构造函数,保存prompt模板和llm对象
def __init__(self, prompt_template, llm):
self.prompt_template = prompt_template
self.llm = llm
# 实现invoke方法,将输入dict格式化prompt并调用llm
def invoke(self, input_dict, **kwargs):
# 格式化prompt
formatted_prompt = self.prompt_template.format(**input_dict)
# 用llm推理,得到响应
response = self.llm.invoke(formatted_prompt, **kwargs)
# 如果响应有content属性,则用其内容
if hasattr(response, 'content'):
text = response.content
# 如果响应直接为字符串
elif isinstance(response, str):
text = response
# 否则强转为字符串
else:
text = str(response)
# 去掉文本首尾空格
text = text.strip()
# 若内容为NO_OUTPUT或空,则返回空字符串
if text == "NO_OUTPUT" or text == "":
return ""
# 返回抽取的文本
return text
# 返回SimpleChain对象实例
return SimpleChain(prompt_template, llm)
# 类方法,支持通过llm快速初始化LLMChainExtractor
@classmethod
def from_llm(
cls,
llm,
prompt=None,
get_input=None,
):
# 返回初始化好的提取器对象
return cls(
llm=llm,
prompt=prompt,
get_input=get_input
)
# 实现文档压缩主接口
def compress_documents(
self,
documents,
query: str,
callbacks=None,
):
# 用于存储压缩结果的列表
compressed_docs = []
# 遍历所有输入文档
for doc in documents:
# 构造当前doc与query的输入结构
_input = self.get_input(query, doc)
# 用llm链抽取相关内容
output = self.llm_chain.invoke(_input)
# 如果没抽取到任何内容则跳过该文档
if not output or len(output) == 0:
continue
# 用抽取得到内容与原有元数据构造新文档
compressed_doc = Document(
page_content=output,
metadata=doc.metadata if hasattr(doc, 'metadata') else {}
)
# 加入压缩结果列表
compressed_docs.append(compressed_doc)
# 返回压缩结果文档列表
return compressed_docs
57.8. 类 #
57.8.1 类说明 #
| 类名 | 类型 | 作用 | 关键属性/方法 |
|---|---|---|---|
| Document | 数据类 | 存储文档内容和元数据 | page_content: 文本内容metadata: 元数据字典 |
| HuggingFaceEmbeddings | 嵌入模型类 | 使用 HuggingFace 模型生成文本嵌入向量 | model_name: 模型名称embed_query(): 嵌入查询文本embed_documents(): 批量嵌入文本 |
| Chroma | 向量存储类 | Chroma 向量数据库封装 | embedding_function: 嵌入函数_collection: Chroma 集合add_texts(): 添加文本similarity_search(): 相似度搜索as_retriever(): 创建检索器 |
| VectorStoreRetriever | 检索器类 | 从向量存储中检索文档 | vectorstore: 向量存储实例search_type: 搜索类型invoke(): 执行检索 |
| BaseRetriever | 抽象基类 | 定义检索器接口规范 | invoke(): 对外调用接口_get_relevant_documents(): 抽象方法 |
| ContextualCompressionRetriever | 检索器类 | 包装基础检索器和文档压缩器 | base_retriever: 基础检索器base_compressor: 文档压缩器invoke(): 执行检索_get_relevant_documents(): 检索逻辑 |
| BaseDocumentCompressor | 抽象基类 | 定义文档压缩器接口规范 | compress_documents(): 抽象方法 |
| CrossEncoderReranker | 文档压缩器类 | 使用交叉编码器对文档重排序 | model: 交叉编码器模型top_n: 返回文档数量compress_documents(): 重排序 |
| HuggingFaceCrossEncoder | 交叉编码器类 | 封装 sentence-transformers 的 CrossEncoder | model: CrossEncoder 实例predict(): 相关性分数预测 |
57.8.2 类图 #
57.8.3 时序图 #
57.8.4 调用过程 #
阶段 1:初始化嵌入模型和向量数据库
embeddings = HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={"device": "cpu"}
)
chroma_db = Chroma(
persist_directory="chroma_db",
embedding_function=embeddings,
collection_name="test",
collection_metadata={"hnsw:space": "cosine"}
)- 创建
HuggingFaceEmbeddings实例 - 创建
Chroma向量数据库实例,连接持久化存储
阶段 2:检查并添加数据
if not chroma_db._collection.count():
chroma_db.add_texts(texts, metadatas)- 检查数据库是否为空
- 如果为空,批量添加文本和元数据到数据库
阶段 3:创建基础检索器和交叉编码器
base_retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 20})
cross_encoder = HuggingFaceCrossEncoder(model_name=reranker_model_path)
compressor = CrossEncoderReranker(model=cross_encoder, top_n=3)调用链:
Chroma.as_retriever()创建VectorStoreRetriever实例HuggingFaceCrossEncoder.__init__()方法- 加载
sentence_transformers.CrossEncoder模型 - 保存模型实例
- 加载
CrossEncoderReranker.__init__()方法- 保存
model(交叉编码器实例)和top_n(返回文档数量,默认 3)
- 保存
阶段 4:创建压缩检索器
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=base_retriever
)- 创建
ContextualCompressionRetriever实例 - 保存
base_compressor和base_retriever
阶段 5:执行检索和重排序
results = compression_retriever.invoke(query)调用链:
invoke()方法(继承自BaseRetriever)- 调用
_get_relevant_documents(query)
- 调用
_get_relevant_documents()方法(核心检索逻辑)步骤 1:基础检索
docs = self.base_retriever.invoke(query, **kwargs)- 调用
VectorStoreRetriever.invoke()从向量数据库检索文档(返回 k=20 个文档)
步骤 2:文档重排序
compressed_docs = self.base_compressor.compress_documents(docs, query)内部处理流程(
CrossEncoderReranker.compress_documents()):步骤 2.1:构造查询-文档对
pairs = [(query, doc.page_content) for doc in documents]- 为每个文档构造 (查询, 文档内容) 对
步骤 2.2:计算相关性分数
scores = self.model.predict(pairs)- 调用
HuggingFaceCrossEncoder.predict()对所有句对进行评分 - 内部调用
sentence_transformers.CrossEncoder.predict(),返回每个文档与查询的相关性分数
步骤 2.3:排序和截断
ranked = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True) return [doc for doc, _ in ranked[:self.top_n]]- 按分数降序排序
- 取前
top_n=3个分数最高的文档
步骤 3:返回重排序后的文档列表
- 返回重排序后的
Document列表(最多 3 个)
- 调用
阶段 6:结果输出
print(f"查询:{query}")
print(f"共检索到 {len(results)} 个文档")
for i, doc in enumerate(results):
print(f"检索结果{i}:{doc.page_content}")
print(f" 元数据:{doc.metadata}\n")- 遍历
results(重排序后的Document列表) - 打印每个文档的
page_content和metadata
57.8.5 CrossEncoder 工作原理 #
CrossEncoder vs BiEncoder
| 特性 | CrossEncoder | BiEncoder(如向量检索) |
|---|---|---|
| 计算方式 | 查询和文档同时输入模型 | 查询和文档分别编码后计算相似度 |
| 准确性 | 高(考虑交互) | 中等(独立编码) |
| 速度 | 较慢(需要成对计算) | 快(可预先编码) |
| 适用场景 | 重排序、小批量文档 | 大规模检索、召回阶段 |
CrossEncoder 优势
- 精确评分:同时考虑查询和文档的交互,评分更准确
- 上下文理解:能理解查询与文档的语义关系
- 适合重排序:在召回阶段后对候选文档进行精确排序
工作流程
- 召回阶段:使用向量数据库快速召回 k=20 个候选文档
- 重排序阶段:使用 CrossEncoder 对所有候选文档进行精确评分
- 截断阶段:按分数排序,返回 top_n=3 个最相关的文档
双阶段检索的优势
| 阶段 | 方法 | 目标 | 特点 |
|---|---|---|---|
| 召回阶段 | 向量检索 | 快速召回大量候选文档 | 速度快,覆盖广 |
| 重排序阶段 | CrossEncoder | 精确排序候选文档 | 准确度高,结果精 |
这种两阶段方法兼顾了速度和准确性。
57.8.6 CrossEncoderReranker vs 其他压缩器 #
| 特性 | CrossEncoderReranker | EmbeddingsFilter | LLMChainExtractor |
|---|---|---|---|
| 计算方式 | CrossEncoder 模型评分 | 嵌入向量余弦相似度 | LLM 语义理解 |
| 准确性 | 高 | 中等 | 高 |
| 速度 | 中等 | 快 | 慢 |
| 文档修改 | 不修改 | 不修改 | 提取相关内容 |
| 适用场景 | 重排序、中等规模 | 快速过滤、大规模 | 精确提取、小规模 |
| 成本 | 中等(模型推理) | 低(向量计算) | 高(LLM API调用) |
CrossEncoderReranker 在准确性和速度之间取得平衡,适合需要精确重排序的场景。
58.LongContextReorder #
本节介绍了 LongContextReorder 的用法和实现原理。LongContextReorder 是一种“长上下文重排器”,用于解决大语言模型处理长文档时信息容易丢失在中间部分(lost-in-the-middle)的问题。其核心思想是将最重要、与查询最相关的文档片段放在序列的开头和结尾,而把相关性较低的内容安排在中间,以提升模型的整体理解能力。
重排流程一般如下:
- 先将输入的文档列表反转;
- 交替地将文档插入新列表的头部或尾部——偶数位的插到头部,奇数位的加到尾部,实现一种“两头包夹中间”的顺序;
- 返回重排后的文档序列。
举个例子,原始文档顺序 [A, B, C, D],经过 LongContextReorder 后,顺序会变为 [B, D, A, C]。
具体如何使用,可参考上方代码示例:初始化后调用 transform_documents(docs) 即可得到重排结果。
这种方法在大模型多文档问答、检索增强生成(RAG)等场景中颇为常见,有助于最大化命中重要信息,提高模型生成答案的准确性。
58.1. 58.LongContextReorder.py #
58.LongContextReorder.py
#from langchain_community.document_transformers import LongContextReorder
#from langchain_core.documents import Document
# 导入 Document 类,用于创建文档对象
from smartchain.documents import Document
# 导入 LongContextReorder 类,用于实现 Lost-in-the-middle 重排
from smartchain.document_transformers import LongContextReorder
# 构建一个包含 4 个文档的列表,每个文档附带不同的内容和元数据 id
docs = [
# 第1个文档,介绍部分
Document(page_content="介绍:本文讨论人工智能的基础。", metadata={"id": 1}),
# 第2个文档,细节部分
Document(page_content="细节:机器学习是人工智能的核心方法。", metadata={"id": 2}),
# 第3个文档,扩展部分
Document(page_content="扩展:深度学习通过神经网络取得突破。", metadata={"id": 3}),
# 第4个文档,结论部分
Document(page_content="结论:人工智能正在改变世界。", metadata={"id": 4}),
]
# 创建长上下文重排器实例
reorder = LongContextReorder()
# 对文档列表进行重排,得到重排后的文档顺序
reordered = reorder.transform_documents(docs)
# 输出重排后文档的顺序和内容
print("重排后顺序:")
# 枚举遍历重排后的文档,i 为顺序编号,从 1 开始
for i, d in enumerate(reordered, 1):
# 按格式输出:序号. 文档id: 文档内容
print(f"{i}. {d.metadata.get('id')}: {d.page_content}")58.2. init.py #
smartchain/document_transformers/init.py
# 从 base 模块导入文档转换器基类
from .base import BaseDocumentTransformer
# 从 long_context 模块导入长上下文重排器
from .long_context import LongContextReorder
# 定义模块对外暴露的类名列表
__all__ = [
"BaseDocumentTransformer",
"LongContextReorder",
]58.3. base.py #
smartchain/document_transformers/base.py
# 导入 Sequence 类型用于类型标注
from typing import Sequence
# 从上级目录导入 Document 类
from ..documents import Document
# 定义文档转换器基类
class BaseDocumentTransformer:
# 类的文档字符串,说明这是文档转换器基类
"""文档转换器基类"""
# 定义抽象方法,用于转换文档列表
def transform_documents(self, documents: Sequence[Document]) -> Sequence[Document]:
# 方法文档字符串,说明此方法用于转换文档列表
"""转换文档列表"""
# 抛出未实现异常,子类需实现此方法
raise NotImplementedError58.4. long_context.py #
smartchain/document_transformers/long_context.py
# 导入类型注解 Sequence,用于类型标注
from typing import Sequence
# 导入基础文档转换器基类
from .base import BaseDocumentTransformer
# 导入文档对象类
from ..documents import Document
# 定义长上下文重排器,继承自基础文档转换器
class LongContextReorder(BaseDocumentTransformer):
# 类文档字符串,说明用途及对应论文
"""长上下文重排器
实现 Lost-in-the-middle 重排算法:
将最相关的文档放在开头和结尾,不太相关的文档放在中间。
参考论文: https://arxiv.org/abs/2307.03172
"""
# 重写 transform_documents 方法,实现文档重排
def transform_documents(self, documents: Sequence[Document]) -> Sequence[Document]:
# 将输入的文档序列转换为列表,便于操作
docs = list(documents)
# 反转文档列表
docs.reverse()
# 初始化重排后的文档列表
reordered = []
# 遍历反转后的文档列表,枚举索引和文档
for i, doc in enumerate(docs):
# 如果索引为奇数
if i % 2 == 1:
# 将当前文档追加到重排列表的末尾
reordered.append(doc)
# 如果索引为偶数
else:
# 将当前文档插入到重排列表的开头
reordered.insert(0, doc)
# 返回重排后的文档列表
return reordered58.5. 对比 #
BaseDocumentTransformer 和 BaseDocumentCompressor 都是 LangChain 中处理检索文档的抽象基类,但它们的核心区别在于目的和操作方式:一个负责“变换”文档内容本身,另一个负责“压缩/筛选”文档集。
58.5.1 关键区别 #
| 维度 | BaseDocumentTransformer (文档转换器) | BaseDocumentCompressor (文档压缩器) |
|---|---|---|
| 核心目的 | 转换文档的内容或格式 | 减少文档的数量或长度 |
| 主要操作对象 | 单个文档的内容(如文本) | 文档集合(多个Document对象) |
| 典型输出 | 修改后的文档列表(数量通常不变) | 过滤或精简后的文档子集(数量减少) |
| 关键方法 | transform_documents(documents, **kwargs) |
compress_documents(documents, query, **kwargs) |
| 类比 | 文档编辑器:重写、翻译、总结文本 | 文档过滤器:按相关性筛选、截断文本 |
| 常见子类 | LongContextReorder (重排顺序)、EmbeddingsRedundantFilter (去重) |
CrossEncoderReranker (重排序并筛选)、LLMChainExtractor (LLM提取摘要) |
| 是否依赖查询(query) | 通常否(对文档进行独立处理) | 是(压缩过程依赖于具体的查询) |
58.5.2 工作原理 #
58.5.2.1. BaseDocumentTransformer(转换文档) #
核心思想:接收一批文档,对它们逐一进行某种变换,返回数量相同但内容可能已被修改的文档列表。
from langchain.schema import BaseDocumentTransformer
from langchain.schema import Document
class SimpleTextCleaner(BaseDocumentTransformer):
"""一个简单的转换器示例:清理文档文本"""
def transform_documents(self, documents, **kwargs):
transformed = []
for doc in documents:
# 转换操作:例如,移除多余空格和特殊字符
cleaned_text = " ".join(doc.page_content.split())
transformed.append(Document(page_content=cleaned_text, metadata=doc.metadata))
return transformed
# 使用
original_docs = [Document(page_content="Hello World! ")]
cleaner = SimpleTextCleaner()
transformed_docs = cleaner.transform_documents(original_docs)
print(transformed_docs[0].page_content) # 输出: "Hello World!"58.5.2.2. BaseDocumentCompressor(压缩/筛选文档) #
核心思想:接收一批文档和一个查询,根据查询的相关性来减少文档的数量或长度。
from langchain.retrievers.document_compressors import BaseDocumentCompressor
from langchain.schema import Document
class SimpleTopKCompressor(BaseDocumentCompressor):
"""一个简单的压缩器示例:只保留前K个文档(假设已按相关性排序)"""
def __init__(self, k=3):
self.k = k
def compress_documents(self, documents, query, **kwargs):
# 压缩操作:这里简单截取前K个
return documents[:self.k]
# 使用
original_docs = [Document(page_content=f"Doc {i}") for i in range(10)] # 10个文档
query = "test query"
compressor = SimpleTopKCompressor(k=2)
compressed_docs = compressor.compress_documents(original_docs, query)
print(len(compressed_docs)) # 输出: 258.5.3 在RAG流程中的典型协作 #
在实际的RAG系统中,它们常被串联使用,形成一个处理流水线(Pipeline):
# 假设在一个完整的检索增强流程中:
1. **检索**:从向量数据库获取100个相关文档(粗召回)
2. **转换**:使用 `EmbeddingsRedundantFilter` (转换器) 去除重复文档 -> 剩下80个
3. **压缩**:使用 `CrossEncoderReranker` (压缩器) 根据查询进行精排并取Top-5 -> 剩下5个
4. **再次转换**:使用 `LongContextReorder` (转换器) 重新排列这5个文档的顺序以优化LLM注意力
5. **生成**:将最终处理后的文档和查询一起交给LLM生成答案在这个流水线中,压缩器 (BaseDocumentCompressor) 的核心职责是“做减法”,根据查询筛选出最核心的文档子集;而转换器 (BaseDocumentTransformer) 的核心职责是“做调整”,在不改变文档集合大小的前提下优化文档的内容或顺序。
58.5.3 总结与选择建议 #
| 当你想... | 应该选择 |
|---|---|
| 去除重复、翻译文本、重新排序文档(不依赖查询) | BaseDocumentTransformer 的子类 |
| 根据特定查询,筛选最相关的文档、提取关键片段 | BaseDocumentCompressor 的子类 |
| 既要去重又要筛选,且优化顺序 | 组合使用两者,通常顺序是:先转换器去重,再压缩器筛选 |
简单来说,可以这样记忆:BaseDocumentTransformer 像是文档的“化妆师”或“整理师”,专注于美化或重组内容;而 BaseDocumentCompressor 则像是严格的“面试官”,根据你的问题(query)来决定哪些文档可以进入下一轮。
58.6. 类 #
58.6.1 类说明 #
| 类名 | 类型 | 作用 | 关键属性/方法 |
|---|---|---|---|
| Document | 数据类 | 存储文档内容和元数据 | page_content: 文档文本内容metadata: 元数据字典 |
| BaseDocumentTransformer | 抽象基类 | 定义文档转换器接口规范 | transform_documents(): 抽象方法(需子类实现) |
| LongContextReorder | 文档转换器类 | 实现 Lost-in-the-middle 重排算法 | transform_documents(): 重排文档顺序将最相关文档放在开头和结尾,不太相关的放在中间 |
58.6.2 类图 #
58.6.3 时序图 #
58.6.4 调用过程 #
阶段 1:创建文档对象列表
docs = [
Document(page_content="介绍:本文讨论人工智能的基础。", metadata={"id": 1}),
Document(page_content="细节:机器学习是人工智能的核心方法。", metadata={"id": 2}),
Document(page_content="扩展:深度学习通过神经网络取得突破。", metadata={"id": 3}),
Document(page_content="结论:人工智能正在改变世界。", metadata={"id": 4}),
]- 创建 4 个
Document对象 - 每个文档包含
page_content和metadata(包含id)
阶段 2:创建重排器实例
reorder = LongContextReorder()- 创建
LongContextReorder实例 - 该类继承自
BaseDocumentTransformer
阶段 3:执行文档重排
reordered = reorder.transform_documents(docs)调用链:
transform_documents()方法(核心重排逻辑)步骤 1:转换为列表并反转
docs = list(documents) # [doc1, doc2, doc3, doc4] docs.reverse() # [doc4, doc3, doc2, doc1]步骤 2:交替插入和追加
reordered = [] for i, doc in enumerate(docs): if i % 2 == 1: # 奇数索引 reordered.append(doc) # 追加到末尾 else: # 偶数索引 reordered.insert(0, doc) # 插入到开头重排过程示例(4 个文档):
- 初始:
[doc1, doc2, doc3, doc4] - 反转:
[doc4, doc3, doc2, doc1] - i=0(偶数):
reordered.insert(0, doc4)→[doc4] - i=1(奇数):
reordered.append(doc3)→[doc4, doc3] - i=2(偶数):
reordered.insert(0, doc2)→[doc2, doc4, doc3] - i=3(奇数):
reordered.append(doc1)→[doc2, doc4, doc3, doc1] - 最终:
[doc2, doc4, doc3, doc1]
步骤 3:返回重排后的文档列表
- 返回重排后的
Document列表
- 初始:
阶段 4:输出重排结果
print("重排后顺序:")
for i, d in enumerate(reordered, 1):
print(f"{i}. {d.metadata.get('id')}: {d.page_content}")- 遍历重排后的文档列表
- 打印每个文档的序号、id 和内容
58.6.5 Lost-in-the-middle 问题 #
问题背景
LLM 在处理长上下文时,容易出现“中间信息丢失”:
- 开头和结尾的信息更容易被关注
- 中间部分的信息容易被忽略
解决方案
LongContextReorder 实现的重排策略:
- 将最相关的文档放在开头和结尾
- 将不太相关的文档放在中间
重排算法详解
对于 n 个文档(按相关性从高到低排序):
- 反转列表:
[doc_n, doc_{n-1}, ..., doc_2, doc_1] - 交替插入:
- 偶数索引(0, 2, 4, ...):插入到开头
- 奇数索引(1, 3, 5, ...):追加到末尾
重排效果示例
假设有 4 个文档,按相关性排序为:[doc1, doc2, doc3, doc4](doc1 最相关)
| 步骤 | 操作 | 结果 |
|---|---|---|
| 初始 | [doc1, doc2, doc3, doc4] |
原始顺序 |
| 反转 | [doc4, doc3, doc2, doc1] |
反转后 |
| i=0 | insert(0, doc4) |
[doc4] |
| i=1 | append(doc3) |
[doc4, doc3] |
| i=2 | insert(0, doc2) |
[doc2, doc4, doc3] |
| i=3 | append(doc1) |
[doc2, doc4, doc3, doc1] |
最终顺序:[doc2, doc4, doc3, doc1]
- 开头:doc2(次相关)
- 中间:doc4(不太相关)、doc3(不太相关)
- 结尾:doc1(最相关)
算法优势
- 重要信息在两端:最相关文档在开头和结尾
- 中间信息不丢失:不太相关的文档放在中间,仍能被处理
- 简单高效:时间复杂度 O(n),空间复杂度 O(n)
适用场景
- RAG 系统:将检索到的文档重排后输入 LLM
- 长文档处理:处理超长上下文时优化信息分布
- 问答系统:确保关键信息不被忽略
58.6.6 与其他组件的对比 #
| 特性 | LongContextReorder | Document Compressor |
|---|---|---|
| 作用 | 重排文档顺序 | 过滤或提取文档内容 |
| 文档修改 | 不修改内容,只改变顺序 | 可能修改或过滤内容 |
| 适用阶段 | 检索后、输入 LLM 前 | 检索后、输入 LLM 前 |
| 目标 | 优化信息分布 | 提高相关性 |
| 输入 | 文档列表 | 文档列表 + 查询 |
| 输出 | 重排后的文档列表 | 压缩后的文档列表 |
LongContextReorder 专注于解决长上下文中的信息分布问题,不依赖查询,适用于所有需要优化文档顺序的场景。