导航菜单

  • 1.langchain.intro
  • 2.langchain.chat_models
  • 3.langchain.prompts
  • 4.langchain.example_selectors
  • 5.output_parsers
  • 6.Runnable
  • 7.memory
  • 8.document_loaders
  • 9.text_splitters
  • 10.embeddings
  • 11.tool
  • 12.retrievers
  • 13.optimize
  • 14.项目介绍
  • 15.启动HTTP
  • 16.数据与模型
  • 17.权限管理
  • 18.知识库管理
  • 19.设置
  • 20.文档管理
  • 21.聊天
  • 22.API文档
  • 23.RAG优化
  • 24.索引时优化
  • 25.检索前优化
  • 26.检索后优化
  • 27.系统优化
  • 28.GraphRAG
  • 29.图
  • 30.为什么选择图数据库
  • 31.什么是 Neo4j
  • 32.安装和连接 Neo4j
  • 33.Neo4j核心概念
  • 34.Cypher基础
  • 35.模式匹配
  • 36.数据CRUD操作
  • 37.GraphRAG
  • 38.查询和过滤
  • 39.结果处理和聚合
  • 40.语句组合
  • 41.子查询
  • 42.模式和约束
  • 43.日期时间处理
  • 44.Cypher内置函数
  • 45.Python操作Neo4j
  • 46.neo4j
  • 47.py2neo
  • 48.Streamlit
  • 49.Pandas
  • 50.graphRAG
  • 51.deepdoc
  • 52.deepdoc
  • 53.deepdoc
  • 55.deepdoc
  • 54.deepdoc
  • Pillow
  • 50.as_retriever
    • 50.1. 53.as_retriever.py
    • 50.2. init.py
    • 50.3. base.py
    • 50.4. vector_store.py
    • 50.5. vectorstores.py
    • 50.6. 类
      • 50.6.1 类说明
      • 50.6.2 类图
      • 50.6.3 时序图
      • 50.6.4 调用过程
      • 50.6.5 搜索类型说明
  • 51.TFIDFRetriever
    • 51.1. 51.TFIDFRetriever.py
    • 51.2. tfidf.py
    • 51.3. init.py
    • 51.4. 类
      • 51.4.1 类说明
      • 51.4.2 类图
      • 51.4.3 时序图
      • 51.4.4 调用过程
    • 51.5 TF-IDF 算法
    • 51.6. 余弦相似度
  • 52.BM25Retriever
    • 52.1. 52.BM25Retriever.py
    • 52.2. bm25.py
    • 52.3. init.py
    • 52.4.类
      • 52.4.1 类说明
      • 52.4.2 类图
      • 52.4.3 时序图
      • 52.4.4 调用过程
      • 52.4.5 BM25 算法
        • 52.4.5.1 BM25 公式
        • 52.4.5.2 IDF 计算
  • 53.VectorSimilarityRetriever
    • 53.1. 53.SimilarityRetriever.py
    • 53.2. vector.py
    • 53.3. init.py
    • 53.4.类
      • 53.4.1 类说明
      • 53.4.2 类图
      • 53.4.3 时序图
      • 53.4.4 调用过程
      • 53.4.5 向量嵌入与余弦相似度
      • 53.4.6 设计模式
      • 53.4.7 与 TFIDF/BM25 的区别
  • 54.EnsembleRetriever
    • 54.1. 54.EnsembleRetriever.py
    • 54.2. ensemble.py
    • 54.3. init.py
    • 54.4. 类
      • 54.4.1 类说明
      • 54.4.2 类图
      • 54.4.3 时序图
      • 54.4.4 调用过程
      • 54.4.5 RRF 算法
  • 55.LLMChainExtractor
    • 55.1. 55.LLMChainExtractor.py
    • 55.2. document_compressors.py
    • 55.3. contextual.py
    • 55.4. init.py
    • 55.5. 类
      • 55.5.1 类说明
      • 55.5.2 类图
      • 55.5.3 时序图
      • 55.5.4 调用过程
      • 55.5.5 文档压缩流程
  • 56.EmbeddingsFilter
    • 56.1. 56.EmbeddingsFilter.py
    • 56.2. document_compressors.py
    • 56.3. 类
      • 56.3.1 类说明
      • 56.3.2 类图
      • 56.3.3 时序图
      • 56.3.4 调用过程
      • 56.3.5 EmbeddingsFilter 工作原理
      • 56.3.6 EmbeddingsFilter vs LLMChainExtractor
  • 57.CrossEncoderReranker
    • 57.1. 57.CrossEncoderReranker.py
    • 57.2. cross_encoders.py
    • 57.3. init.py
    • 57.4. base.py
    • 57.5. cross_encoder.py
    • 57.6. embeddings.py
    • 57.7. llm_chain.py
    • 57.8. 类
      • 57.8.1 类说明
      • 57.8.2 类图
      • 57.8.3 时序图
      • 57.8.4 调用过程
      • 57.8.5 CrossEncoder 工作原理
      • 57.8.6 CrossEncoderReranker vs 其他压缩器
  • 58.LongContextReorder
    • 58.1. 58.LongContextReorder.py
    • 58.2. init.py
    • 58.3. base.py
    • 58.4. long_context.py
    • 58.5. 对比
      • 58.5.1 关键区别
      • 58.5.2 工作原理
        • 58.5.2.1. BaseDocumentTransformer(转换文档)
        • 58.5.2.2. BaseDocumentCompressor(压缩/筛选文档)
      • 58.5.3 在RAG流程中的典型协作
      • 58.5.3 总结与选择建议
    • 58.6. 类
      • 58.6.1 类说明
      • 58.6.2 类图
      • 58.6.3 时序图
      • 58.6.4 调用过程
      • 58.6.5 Lost-in-the-middle 问题
      • 58.6.6 与其他组件的对比
  • 59.参考

50.as_retriever #

as_retriever 是 Chroma 向量数据库中的一个方法,用于将向量数据库包装为兼容 LangChain 检索接口(Retriever)的对象。这样就可以方便地与各类LangChain链路、Agent等组件进行拼接,实现如RAG、问答等复杂应用。

常见用法如下:

  • search_type:设定检索方式,如 "similarity"(相似度检索)、"mmr"(最大边际相关性),可根据实际场景选择。
  • search_kwargs:用于传递检索参数,如返回文档条数 k。

调用 as_retriever 后返回的对象支持直接 .get_relevant_documents(query) 方法,从而获取与 query 最相关的文本内容,便于下游调用。

该方法极大地简化了数据与检索逻辑的衔接,实现了与大模型问答等能力的无缝连接。

50.1. 53.as_retriever.py #

53.as_retriever.py

#from langchain_chroma import Chroma
#from langchain_huggingface import HuggingFaceEmbeddings

from smartchain.embeddings import HuggingFaceEmbeddings
from smartchain.vectorstores import Chroma

# 模型路径(本地)
model_path = "C:/Users/Administrator/.cache/modelscope/hub/models/sentence-transformers/all-MiniLM-L6-v2"
# 初始化嵌入模型
embeddings = HuggingFaceEmbeddings(
    model_name=model_path,
    model_kwargs={"device": "cpu"}
)

# 初始化 Chroma 向量数据库
chroma_db = Chroma(
    persist_directory="chroma_database",
    embedding_function=embeddings,
    collection_name="test",
    collection_metadata={"hnsw:space": "cosine"}
)

# 检查数据库是否已包含数据
if not chroma_db._collection.count():
    # 待入库的文本
    texts = [
        "你好,世界!",
        "人工智能非常有趣。",
        "机器学习是人工智能的重要领域。",
        "深度学习通过神经网络模拟人脑。",
        "欢迎使用Chroma向量数据库。",
    ]

    # 每条文本对应的元数据
    metadatas = [
        {"lang": "en", "category": "greeting"},
        {"lang": "en", "category": "tech"},
        {"lang": "zh", "category": "tech"},
        {"lang": "en", "category": "tech"},
        {"lang": "zh", "category": "demo"},
    ]

    # 向数据库中批量添加文本及元信息
    chroma_db.add_texts(texts, metadatas)

# 创建检索器,指定检索类型和返回条数
retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 2})
# 用中文/英文查询检索相关文档
results = retriever.invoke("什么是人工智能?")  # 可替换为英文或中文查询
for i, doc in enumerate(results):
    print(f"检索结果{i}:{doc.page_content}")

50.2. init.py #

smartchain/retrievers/init.py

from .vector_store import VectorStoreRetriever
from .base import BaseRetriever

__all__ = [
    "VectorStoreRetriever",
    "BaseRetriever",
]

50.3. base.py #

smartchain/retrievers/base.py

# 导入抽象基类ABC和抽象方法装饰器abstractmethod
from abc import ABC, abstractmethod

# 定义检索器抽象基类
class BaseRetriever(ABC):
    # 检索器抽象基类说明文档字符串
    """检索器抽象基类

    所有检索器都应该继承此类并实现 _get_relevant_documents 方法。
    """

    # 构造方法,初始化检索器
    def __init__(self, **kwargs):
        # 初始化检索器,这里不做具体实现
        pass

    # 抽象方法,子类必须实现该方法用于获取相关文档
    @abstractmethod
    def _get_relevant_documents(self, query, **kwargs):
        # 获取相关文档(抽象方法,由子类实现)
        # 参数:
        #   query: 查询字符串
        #   **kwargs: 其他参数
        # 返回:
        #   相关文档列表
        pass

    # 对外接口方法,调用检索器获取相关文档
    def invoke(self, query, **kwargs):
        # 调用检索器获取相关文档
        # 参数:
        #   query: 查询字符串
        #   **kwargs: 其他参数
        # 返回:
        #   相关文档列表
        return self._get_relevant_documents(query, **kwargs)

50.4. vector_store.py #

smartchain/retrievers/vector_store.py

from .base import BaseRetriever
# 向量存储检索器类,继承自BaseRetriever
class VectorStoreRetriever(BaseRetriever):
    # 向量存储检索器说明文档字符串
    """向量存储检索器

    用于从向量存储中检索文档,支持多种搜索类型。
    """

    # 允许的搜索类型枚举
    ALLOWED_SEARCH_TYPES = ("similarity", "similarity_score_threshold", "mmr")

    # 构造方法,初始化向量存储检索器
    def __init__(
        self,
        vectorstore,
        search_type="similarity",
        search_kwargs=None,
        **kwargs
    ):
        # 初始化向量存储检索器
        # 参数:
        #   vectorstore: 向量存储实例
        #   search_type: 搜索类型,可选值:'similarity', 'similarity_score_threshold', 'mmr'
        #   search_kwargs: 搜索参数,如 {"k": 4, "score_threshold": 0.8}
        #   **kwargs: 其他参数
        super().__init__(**kwargs)
        # 保存向量存储实例
        self.vectorstore = vectorstore
        # 保存搜索类型
        self.search_type = search_type
        # 保存搜索参数字典,若为None则赋空字典
        self.search_kwargs = search_kwargs or {}

        # 检查搜索类型是否合法
        if search_type not in self.ALLOWED_SEARCH_TYPES:
            # 抛出ValueError异常,提示搜索类型不允许
            raise ValueError(
                f"search_type '{search_type}' 不允许。"
                f"允许的值: {self.ALLOWED_SEARCH_TYPES}"
            )

        # 若选择 similarity_score_threshold,必须提供 score_threshold 参数
        if search_type == "similarity_score_threshold":
            # 获取分数阈值
            score_threshold = self.search_kwargs.get("score_threshold")
            # 判断分数阈值是否为数值类型且非None
            if score_threshold is None or not isinstance(score_threshold, (int, float)):
                # 抛出异常提示必须提供score_threshold参数
                raise ValueError(
                    "使用 'similarity_score_threshold' 搜索类型时,"
                    "必须在 search_kwargs 中提供 'score_threshold' (float, 0~1)"
                )

    # 获取相关文档主方法
    def _get_relevant_documents(self, query, **kwargs):
        # 获取相关文档
        # 参数:
        #   query: 查询字符串
        #   **kwargs: 其他参数,将与search_kwargs合并
        # 返回:
        #   相关文档列表
        # 合并父类构造时的search_kwargs与本次调用传入的kwargs
        search_kwargs = {**self.search_kwargs, **kwargs}

        # 根据当前search_type分支调用不同的检索方法
        if self.search_type == "similarity":
            # 相似度搜索,调用vectorstore的similarity_search方法
            docs = self.vectorstore.similarity_search(query, **search_kwargs)
        elif self.search_type == "similarity_score_threshold":
            # 带分数阈值的相似度搜索,先获取(doc, score)对
            docs_and_scores = self.vectorstore.similarity_search_with_score(
                query, **search_kwargs
            )
            # 获取分数阈值(最大距离阈值)
            score_threshold = search_kwargs.get("score_threshold", 0.0)
            # 注意:Chroma返回的是距离,距离越小越相似
            # 这里直接过滤距离大于阈值的文档,假设score_threshold是最大距离阈值
            docs = [
                doc for doc, score in docs_and_scores
                if score <= score_threshold
            ]
        elif self.search_type == "mmr":
            # 最大边际相关性检索,调用vectorstore的max_marginal_relevance_search方法
            docs = self.vectorstore.max_marginal_relevance_search(
                query, **search_kwargs
            )
        else:
            # 不支持的检索类型,抛出异常
            raise ValueError(f"不支持的搜索类型: {self.search_type}")

        # 返回检索到的相关文档
        return docs

50.5. vectorstores.py #

smartchain/vectorstores.py

import os
import numpy as np
from abc import ABC, abstractmethod
import faiss
import chromadb
import uuid
+from .retrievers import VectorStoreRetriever

# 计算一个向量与多个向量的余弦相似度
def cosine_similarity(from_vec, to_vecs):
    # 将from_vec转换为numpy数组并且强制类型为float
    from_vec = np.array(from_vec, dtype=float)
    to_vecs = np.array(to_vecs, dtype=float)
    # 计算from_vec模长
    norm1 = np.linalg.norm(from_vec)
    similarities = []
    for to_vec in to_vecs:
        dot_product = np.sum(from_vec * to_vec)
        norm_vec = np.linalg.norm(to_vec)
        similarity = dot_product / (norm1 * norm_vec)
        similarities.append(similarity)
    return np.array(similarities)


def mmr_select(query_vector, doc_vectors, k=3, lambda_mult=0.5):
    # 计算每个文档向量与查询向量(Query)的余弦相似度。这代表了文档的“相关性”。
    quer_similarities = cosine_similarity(query_vector, doc_vectors)
    # 选择与查询向量相似度最高的文档:文档 1。
    # 找到与查询向量最相关的文档的下标,作为初始的已选文档 S:选择的结果集=selected=[0]
    selected = [int(np.argmax(quer_similarities))]
    while len(selected) < k:
        # 存放每个候选文档的MMR分数
        mmr_scores = []
        for i in range(len(doc_vectors)):
            if i not in selected:
                # 相关性,指的是i对应的候选文档和查询文档这间的相似性
                relevance = quer_similarities[i]
                # 获取当前已选文档的向量集合
                selected_vecs = doc_vectors[selected]  # - S:选择的结果集
                # 计算当前文档与所有的已选文档的余弦相似度
                sims = cosine_similarity(doc_vectors[i], selected_vecs)
                # 获取对已选中的文档最最大相似度 最不多样性的那个
                # 与已选节点有最大相似度的那个就是最不具有多样性的节点
                max_sim = np.max(sims)
                mmr_score = lambda_mult * relevance - (1 - lambda_mult) * max_sim
                mmr_scores.append((i, mmr_score))
        # 选出MMR分数最高的文档索引
        best_idx, best_score = max(mmr_scores, key=lambda x: x[1])
        # 将选中的文档添加到已选文档中
        selected.append(best_idx)
    return selected


# 定义向量存储的抽象的基类 faiss=chromdb
class VectorStore(ABC):
    # 添加文本到向量存储
    @abstractmethod
    def add_texts(self, texts, metadatas=None):
        pass

    # 最大边际相关性检索
    def max_marginal_relevance_search(self, query, k=3, fetch_k=20):
        pass

    # 从文本批量构建向量存储
    @abstractmethod
    def from_texts(self, texts, embeddings, metadatas=None):
        pass

    @abstractmethod
    def similarity_search(self, query: str, k: int = 4):
        """相似度搜索"""
        pass


class Document:
    def __init__(self, page_content, metadata, embedding_value=None):
        self.page_content = page_content  # 文档内容
        self.metadata = metadata or {}  # 文档的元数据
        self.embedding_value = embedding_value  # 文档对应的嵌入向量


class FAISS(VectorStore):
    def __init__(self, embeddings):
        self.embeddings = embeddings
        # 初始化索引为空
        self.index = None
        # 初始化文档字典,键为文档的ID,值为一个文档Document对象
        self.documents_by_id = {}

    def add_texts(self, texts, metadatas):
        if metadatas is None:
            metadatas = [{}] * len(texts)
        embedding_values = self.embeddings.embed_documents(texts)
        embedding_values = np.array(embedding_values, dtype=np.float32)
        # 如果还没有建立FAISS索引库,则新建它
        if self.index is None:
            dimension = len(embedding_values[0])
            self.index = faiss.IndexFlatL2(dimension)
        # 添加嵌入向量到FAISS索引库中
        self.index.add(embedding_values)
        # 获取已知的文档数量,用于新文档的编号
        start_index = len(self.documents_by_id)
        for i, (text, metadata, embedding_value) in enumerate(
            zip(texts, metadatas, embedding_values)
        ):
            # 构建文档ID
            doc_id = str(start_index + i)
            # 构建文档对象
            doc = Document(
                page_content=text, metadata=metadata, embedding_value=embedding_value
            )
            # 保存到字典里
            self.documents_by_id[doc_id] = doc

    # 类方法,通过文本批量创建FAISS向量数据库的实例
    @classmethod
    def from_texts(cls, texts, embeddings, metadatas=None):
        instance = cls(embeddings=embeddings)
        instance.add_texts(texts, metadatas=metadatas)
        return instance

    def max_marginal_relevance_search(self, query, k, fetch_k, lambda_mult=0.5):
        # 获取查询文本的嵌入向量
        query_embedding = self.embeddings.embed_query(query)
        # 转换为二维numpy数组
        query_vectors = np.array([query_embedding], dtype=np.float32)
        if isinstance(self.index, faiss.Index):
            # 执行索引库中的检索,返回索引及距离
            _, indices = self.index.search(query_vectors, fetch_k)
            # 获取候选文档对应的索引列表
            candidate_indices = indices[0]
        else:
            raise RuntimeError("FAISS 不可用")
        # 如果说候选文档数不足K个,则直接返回这些文档,不需要再走MMR了
        if len(candidate_indices) <= k:
            docs = []
            for idx in candidate_indices:
                doc_id = str(idx)
                if doc_id in self.documents_by_id:
                    docs.append(self.documents_by_id[doc_id])
            return docs
        # 从candidate_indices用MMR算法挑选出K个元素
        # 从字典中提取候选文档的嵌入向量 5个
        candidate_vectors = np.array(
            [self.documents_by_id[str(i)].embedding_value for i in candidate_indices],
            dtype=np.float32,
        )
        # 通过MMR算法获取MMR选出的下标
        selected_indices = mmr_select(
            query_embedding, candidate_vectors, k=k, lambda_mult=lambda_mult
        )
        # 根据下标选出最终的文档对象
        docs = []
        # 遍历选中的索引,这个索引candidate_indices列表中的索引
        for idx in selected_indices:
            # 获取真实的文档索引或者说文档ID
            doc_id = str(candidate_indices[idx])
            # 通过真实的文档ID找到对应的文档对象
            if doc_id in self.documents_by_id:
                docs.append(self.documents_by_id[doc_id])
        return docs

    # 定义相似度检索方法,返回与查询最近的k个文档
    def similarity_search(self, query: str, k: int = 4):
        """
        相似度搜索

        Args:
            query: 查询文本
            k: 返回的文档数量

        Returns:
            List[Document]: 最相似的文档列表
        """
        # 获取查询文本的嵌入向量
        query_embedding = self.embeddings.embed_query(query)
        # 将嵌入向量转换为NumPy二维数组(形状为1行,d维)
        query_vector = np.array([query_embedding], dtype=np.float32)
        # 用FAISS索引执行k近邻检索,得到距离最近的k个索引
        _, indices = self.index.search(query_vector, k)
        # 创建用于存放检索到文档对象的列表
        docs = []
        # 遍历返回的每个文档索引
        for idx in indices[0]:
            # 把数字索引转为字符串形式的文档id
            doc_id = str(idx)
            # 只有字典中存在这个id的文档才加入最终结果
            if doc_id in self.documents_by_id:
                docs.append(self.documents_by_id[doc_id])
        # 返回最终的相似文档列表
        return docs


class Chroma(VectorStore):
    # 设置默认 的集合名称为langchain
    LANCHAIN_DEFAULT_COLLECTION_NAME = "langchain"

    def __init__(
        self,
        collection_name,
        embedding_function,
        persist_directory,
        collection_metadata,
    ):
        self.embedding_function = embedding_function
        self.collection_name = collection_name
        self.collection_metadata = collection_metadata
        if persist_directory is not None:
            # 如果指定了持久化目录,则使用执行化客户端
            self._client = chromadb.PersistentClient(path=persist_directory)
        else:
            # 使用默认的内存客户端
            self._client = chromadb.Client()
        # 使用客户端获取或新建集合
        self._collection = self._client.get_or_create_collection(
            name=self.collection_name, metadata=self.collection_metadata
        )

    def add_texts(self, texts, metadatas, ids=None, **kwargs):
        if ids is None:
            ids = [str(uuid.uuid4()) for _ in texts]
        else:
            ids = [str(_id) if _id is not None else str(uuid.uuid4()) for _id in ids]

        if metadatas is None:
            metadatas = [{}] * len(texts)
        else:
            # 如果元数据数量少于文本数量,则补齐元数据
            length_diff = len(texts) - len(metadatas)
            if length_diff > 0:
                metadatas = metadatas + [{}] * length_diff

        embedding_values = self.embedding_function.embed_documents(texts)
        # 将文本 ID 嵌入 添加到集合中 upsert=update +insert
        self._collection.upsert(
            ids=ids, documents=texts, embeddings=embedding_values, metadatas=metadatas
        )
        return ids

    # 类方法,通过文本批量创建FAISS向量数据库的实例
    @classmethod
    def from_texts(
        cls,
        texts,
        embedding_function,
        metadatas,
        collection_name,
        persist_directory,
        collection_metadata,
    ):
        instance = cls(
            collection_name=collection_name,
            embedding_function=embedding_function,
            persist_directory=persist_directory,
            collection_metadata=collection_metadata,
        )
        instance.add_texts(texts, metadatas=metadatas)
        return instance

    # 定义相似度检索方法,返回与查询最近的k个文档
    def similarity_search(self, query: str, k: int = 4, filter=None, **kwargs):
        docs_and_scores = self.similarity_search_with_score(
            query=query, k=k, filter=filter, **kwargs
        )
        return [doc for doc, _ in docs_and_scores]

    def _results_to_docs_and_scores(self, results):
        docs_and_scores = []
        if not results.get("ids") or not results["ids"][0]:
            return docs_and_scores
        ids = results["ids"][0]
        documents = results["documents"][0]
        metadatas = results["metadatas"][0]
        distances = results["distances"][0]
        for i, doc_id in enumerate(ids):
            if documents[i] is not None:
                doc = Document(
                    page_content=documents[i],
                    metadata=(
                        metadatas[i] if i < len(metadatas) and metadatas[i] else {}
                    ),
                )
                score = distances[i] if i < len(distances) else 0.0
                docs_and_scores.append((doc, score))
        return docs_and_scores

    # 处理chroma检索结果,转为(doc,score)元组列表
    def similarity_search_with_score(
        self, query: str, k: int = 4, filter=None, **kwargs
    ):
        # query_embedding是一个向量,当然一个向量就是一个小数的列表 [0.1,0.2] 一维列表
        query_embedding = self.embedding_function.embed_query(query)
        if query_embedding is not None:
            results = self._collection.query(
                query_embeddings=[query_embedding], n_results=k, where=filter
            )
        else:
            results = self._collection.query(
                query_texts=[query], n_results=k, where=filter
            )
        return self._results_to_docs_and_scores(results)
+   def as_retriever(self, search_type="similarity", search_kwargs=None, **kwargs):
        # """
        # 创建向量存储检索器
        # 
        # 参数:
        #     search_type: 搜索类型,可选值:
        #         - "similarity": 相似度搜索(默认)
        #         - "similarity_score_threshold": 带分数阈值的相似度搜索
        #         - "mmr": 最大边际相关性搜索
        #     search_kwargs: 搜索参数,例如:
        #         - {"k": 4}: 返回4个文档
        #         - {"score_threshold": 0.8}: 相似度阈值(用于 similarity_score_threshold)
        #         - {"fetch_k": 20, "lambda_mult": 0.5}: MMR参数
        #     **kwargs: 其他附加参数
        #
        # 返回:
        #     VectorStoreRetriever: 检索器实例
        # """
        # 如果没有提供search_kwargs则默认使用空字典
+      return VectorStoreRetriever(
+          vectorstore=self,                       # 传入当前vectorstore实例
+          search_type=search_type,                # 搜索类型
+          search_kwargs=search_kwargs or {},      # 搜索参数字典,默认为空字典
+          **kwargs                                # 其他命名参数
+      )    

50.6. 类 #

50.6.1 类说明 #

类名 类型 作用 关键属性/方法
Document 数据类 存储文档内容和元数据 page_content: 文档文本内容
metadata: 元数据字典
HuggingFaceEmbeddings 嵌入模型类 使用 HuggingFace 模型生成文本嵌入向量 model_name: 模型名称
embed_query(): 嵌入单个查询文本
embed_documents(): 批量嵌入文档文本
Chroma 向量存储类 Chroma 向量数据库封装 embedding_function: 嵌入函数
_collection: Chroma 集合
_client: Chroma 客户端
add_texts(): 添加文本到数据库
similarity_search(): 相似度搜索
as_retriever(): 创建检索器
VectorStoreRetriever 检索器类 从向量存储中检索文档 vectorstore: 向量存储实例
search_type: 搜索类型
search_kwargs: 搜索参数
invoke(): 执行检索
_get_relevant_documents(): 核心检索逻辑
BaseRetriever 抽象基类 定义检索器接口规范 invoke(): 对外调用接口
_get_relevant_documents(): 抽象方法(需子类实现)
VectorStore 抽象基类 定义向量存储接口规范 add_texts(): 抽象方法
similarity_search(): 抽象方法

50.6.2 类图 #

classDiagram class Document { +str page_content +Dict[str, Any] metadata +__init__(page_content, metadata) } class BaseRetriever { <<abstract>> +__init__(**kwargs) +invoke(query, **kwargs) List[Document] #_get_relevant_documents(query, **kwargs)* List[Document] } class VectorStoreRetriever { -VectorStore vectorstore -str search_type -dict search_kwargs +__init__(vectorstore, search_type, search_kwargs, **kwargs) +invoke(query, **kwargs) List[Document] -_get_relevant_documents(query, **kwargs) List[Document] } class VectorStore { <<abstract>> +add_texts(texts, metadatas, **kwargs) +similarity_search(query, k, **kwargs)* List[Document] +as_retriever(search_type, search_kwargs)** Retriever } class Chroma { -Embedding embedding_function -Collection _collection -Client _client +__init__(collection_name, embedding_function, persist_directory, collection_metadata) +add_texts(texts, metadatas, ids, **kwargs) List[str] +similarity_search(query, k, filter, **kwargs) List[Document] +similarity_search_with_score(query, k, filter, **kwargs) List[Tuple] +as_retriever(search_type, search_kwargs, **kwargs) VectorStoreRetriever } class Embedding { <<abstract>> +embed_query(text)* List[float] +embed_documents(texts)* List[List[float]] } class HuggingFaceEmbeddings { -str model_name -LangchainHuggingFaceEmbeddings embeddings +__init__(model_name, **kwargs) +embed_query(text) List[float] +embed_documents(texts) List[List[float]] } BaseRetriever <|-- VectorStoreRetriever : 继承实现 VectorStore <|-- Chroma : 继承实现 Embedding <|-- HuggingFaceEmbeddings : 继承实现 VectorStoreRetriever "1" --> "1" VectorStore : 使用向量存储 Chroma "1" --> "1" Embedding : 使用嵌入模型 Chroma "1" --> "1" VectorStoreRetriever : 创建检索器 VectorStoreRetriever "1" *-- "0..*" Document : 返回文档列表 note for BaseRetriever "抽象基类\n定义检索器标准接口\n使用模板方法模式" note for VectorStoreRetriever "向量存储检索器\n支持多种搜索类型\nsimilarity/similarity_score_threshold/mmr" note for Chroma "Chroma向量数据库\n持久化存储向量\n支持相似度搜索" note for HuggingFaceEmbeddings "HuggingFace嵌入模型\n默认使用all-MiniLM-L6-v2模型"

50.6.3 时序图 #

sequenceDiagram participant Main as 53.as_retriever.py participant Embed as HuggingFaceEmbeddings participant Chroma as Chroma participant ChromaDB as ChromaDB(外部) participant VSR as VectorStoreRetriever participant Base as BaseRetriever Note over Main: 阶段1: 初始化嵌入模型 Main->>Embed: HuggingFaceEmbeddings(model_name, model_kwargs) activate Embed Embed-->>Main: embeddings实例 deactivate Embed Note over Main: 阶段2: 初始化向量数据库 Main->>Chroma: Chroma(persist_directory, embedding_function, collection_name, collection_metadata) activate Chroma Chroma->>ChromaDB: PersistentClient(path=persist_directory) activate ChromaDB ChromaDB-->>Chroma: client实例 deactivate ChromaDB Chroma->>ChromaDB: get_or_create_collection(name, metadata) activate ChromaDB ChromaDB-->>Chroma: collection实例 deactivate ChromaDB Chroma-->>Main: chroma_db实例 deactivate Chroma Note over Main: 阶段3: 检查并添加数据 Main->>Chroma: _collection.count() activate Chroma Chroma->>ChromaDB: count() activate ChromaDB ChromaDB-->>Chroma: count结果 deactivate ChromaDB Chroma-->>Main: count结果 deactivate Chroma alt 数据库为空 Main->>Chroma: add_texts(texts, metadatas) activate Chroma Chroma->>Chroma: 生成UUID作为ids Chroma->>Embed: embed_documents(texts) activate Embed Embed-->>Chroma: embedding_values (文档嵌入向量列表) deactivate Embed Chroma->>ChromaDB: upsert(ids, documents, embeddings, metadatas) activate ChromaDB ChromaDB-->>Chroma: 添加成功 deactivate ChromaDB Chroma-->>Main: ids列表 deactivate Chroma end Note over Main: 阶段4: 创建检索器 Main->>Chroma: as_retriever(search_type="similarity", search_kwargs={"k": 2}) activate Chroma Chroma->>VSR: VectorStoreRetriever(vectorstore=self, search_type, search_kwargs) activate VSR VSR->>Base: super().__init__() activate Base Base-->>VSR: 初始化完成 deactivate Base VSR->>VSR: 验证search_type和search_kwargs VSR-->>Chroma: retriever实例 deactivate VSR Chroma-->>Main: retriever实例 deactivate Chroma Note over Main: 阶段5: 执行检索 Main->>VSR: invoke(query="什么是人工智能?") activate VSR VSR->>VSR: _get_relevant_documents(query) activate VSR VSR->>VSR: 合并search_kwargs和kwargs VSR->>Chroma: similarity_search(query, k=2) activate Chroma Chroma->>Chroma: similarity_search_with_score(query, k=2) activate Chroma Chroma->>Embed: embed_query(query) activate Embed Embed-->>Chroma: query_embedding (查询嵌入向量) deactivate Embed Chroma->>ChromaDB: query(query_embeddings=[query_embedding], n_results=2) activate ChromaDB ChromaDB-->>Chroma: results (ids, documents, metadatas, distances) deactivate ChromaDB Chroma->>Chroma: _results_to_docs_and_scores(results) activate Chroma Chroma->>Chroma: 创建Document对象列表 Chroma-->>Chroma: docs_and_scores列表 deactivate Chroma Chroma-->>Chroma: [doc for doc, _ in docs_and_scores] deactivate Chroma Chroma-->>VSR: docs (Document列表) deactivate Chroma VSR-->>VSR: 返回docs deactivate VSR VSR-->>Main: results (Document列表) deactivate VSR Note over Main: 阶段6: 输出结果 Main->>Main: 遍历results并打印

50.6.4 调用过程 #

阶段 1:初始化嵌入模型

embeddings = HuggingFaceEmbeddings(
    model_name=model_path,
    model_kwargs={"device": "cpu"}
)
  • 创建 HuggingFaceEmbeddings 实例
  • 指定本地模型路径和设备(CPU)
  • 用于后续文本向量化

阶段 2:初始化向量数据库

chroma_db = Chroma(
    persist_directory="chroma_database",
    embedding_function=embeddings,
    collection_name="test",
    collection_metadata={"hnsw:space": "cosine"}
)

调用链:

  1. __init__() 方法
    • 保存 embedding_function、collection_name、collection_metadata
    • 创建 ChromaDB 客户端(持久化或内存)
    • 获取或创建集合(collection)

阶段 3:检查并添加数据

if not chroma_db._collection.count():
    chroma_db.add_texts(texts, metadatas)

调用链:

  1. _collection.count() 检查集合是否为空
  2. add_texts() 方法(如果为空)
    • 为每个文本生成 UUID 作为 ID
    • 补齐元数据(如果数量不足)
    • 调用 embedding_function.embed_documents(texts) 批量计算嵌入向量
    • 调用 _collection.upsert() 将文本、嵌入向量和元数据存入数据库

阶段 4:创建检索器

retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 2})

调用链:

  1. as_retriever() 方法
    • 创建 VectorStoreRetriever 实例
    • 传入当前 Chroma 实例、搜索类型和搜索参数
  2. VectorStoreRetriever.__init__() 方法
    • 调用 super().__init__() 初始化基类
    • 保存 vectorstore、search_type、search_kwargs
    • 验证搜索类型是否合法

阶段 5:执行检索

results = retriever.invoke("什么是人工智能?")

调用链:

  1. invoke() 方法(继承自 BaseRetriever)
    • 调用 _get_relevant_documents(query)
  2. _get_relevant_documents() 方法(核心检索逻辑)

    根据 search_type 选择检索方式:

    similarity 模式(本例使用):

    docs = self.vectorstore.similarity_search(query, **search_kwargs)

    内部调用链:

    • similarity_search() → similarity_search_with_score()
    • embed_query(query) 计算查询嵌入向量
    • _collection.query() 在 ChromaDB 中查询相似文档
    • _results_to_docs_and_scores() 将结果转换为 Document 对象列表
    • 返回 Document 列表

阶段 6:输出结果

for i, doc in enumerate(results):
    print(f"检索结果{i}:{doc.page_content}")
  • 遍历 results(Document 列表)
  • 打印每个文档的 page_content

50.6.5 搜索类型说明 #

VectorStoreRetriever 支持三种搜索类型:

搜索类型 说明 参数
similarity 相似度搜索(默认) {"k": 4} - 返回前k个最相似的文档
similarity_score_threshold 带分数阈值的相似度搜索 {"k": 4, "score_threshold": 0.8} - 返回相似度大于阈值的文档
mmr 最大边际相关性搜索 {"k": 4, "fetch_k": 20, "lambda_mult": 0.5} - 平衡相关性和多样性

51.TFIDFRetriever #

  • TF-IDF

TFIDFRetriever 是一种基于传统文本信息检索技术(词频-逆文档频率,Term Frequency-Inverse Document Frequency, 简称 TF-IDF)的检索器,无需依赖深度学习模型或嵌入模型,即可对一批文本文档实现高效的关键词相关性检索。其核心优势在于不需要下载任何模型,速度快,适用于基于关键词匹配的小型场景或无 GPU、无法联网时的轻量化检索需求。

工作机制

  • 对所有文档内容进行分词(中英文均支持),统计每个词的词频(TF)和逆文档频率(IDF)。
  • 输入查询时,将查询内容与所有文档分别计算 tf-idf 向量。
  • 通过计算查询和文档的 tf-idf 相似度(通常为余弦相似度)进行排序,返回得分最高的若干文档。

典型应用场合

  • 数据量较小,或本地调试时的快速检索。
  • 没有嵌入模型,或不适合向量检索的关键词场合。
  • 文本为中文、英文或多语言均可,中文采用简易分词算法。
  • 当 VectorStore、Embeddings 暂不可用时的备选方案。

接口一览

  • TFIDFRetriever.from_documents(docs):根据输入的 Document 列表构建检索器,并自动计算 tf-idf 索引。
  • retriever.invoke(query, *, k=4):根据用户输入的查询内容,返回相关性最高的前 k 条 Document(默认通常返回全部排序后的文档)。
  • 支持自定义分词、去停用词、返回数量等参数(参见 smartchain/retrievers/TFIDFRetriever 代码)。

注意事项

  • 适合结构较为规范、文本内容不长的场景。
  • 对于复杂语义、同义词、语序变化的检索,TFIDF 通常不如深度语义模型,但响应速度极快。
  • 中文文本依赖内置分词算法,无需额外安装分词包。

51.1. 51.TFIDFRetriever.py #

51.TFIDFRetriever.py

#from langchain_community.retrievers import TFIDFRetriever
#from langchain_core.documents import Document

# 导入自定义Document类
from smartchain.documents import Document
# 导入自定义的TFIDFRetriever检索器
from smartchain.retrievers import TFIDFRetriever

# 构造一个包含4条文本的示例文档列表
docs = [
    Document(page_content="深度学习通过神经网络取得突破。"),
    Document(page_content="机器人学结合机械工程与人工智能。"),
    Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
    Document(page_content="人工智能正推动机器人自主学习与创新。"),
]

# 用文档列表创建一个TFIDF检索器实例
retriever = TFIDFRetriever.from_documents(docs)

# 设置检索的查询语句
query = "机器人与人工智能"
# 调用检索器进行查询,返回与query最相关的2条文档
results = retriever.invoke(query,k=2)

# 打印查询内容
print(f"查询:{query}")
# 打印返回的文档数量
print(f"返回 {len(results)} 条:")
# 遍历输出每条检索结果的内容
for i, doc in enumerate(results, 1):
    print(f"{i}. {doc.page_content}")

51.2. tfidf.py #

smartchain/retrievers/tfidf.py

# 导入基础检索器基类
from smartchain.retrievers.base import BaseRetriever
# uv add scikit-learn
# 导入TF-IDF向量化器
from sklearn.feature_extraction.text import TfidfVectorizer
# 导入余弦相似度计算函数
from sklearn.metrics.pairwise import cosine_similarity
# 导入numpy库用于数组处理
import numpy as np
# 导入jieba用于中文分词
import jieba
import re

# 定义中文分词函数,作为tokenizer传递给TfidfVectorizer
def _chinese_tokenizer(text):
    # 使用jieba的搜索引擎模式进行分词,会对长词再次切分,提高召回率
    tokens = list(jieba.cut_for_search(text))
    # 过滤掉标点符号、空白字符和单字符词,只保留有效词
    tokens = [t for t in tokens if re.match(r'^[\u4e00-\u9fa5a-zA-Z0-9]+$', t) and len(t) > 1]
    # 用空格拼接成字符串返回
    return " ".join(tokens)

# 定义TFIDFRetriever检索器,继承自BaseRetriever
class TFIDFRetriever(BaseRetriever):
    # TFIDF 检索器说明文档字符串
    """
    TFIDF 检索器

    使用 TFIDF (Term Frequency-Inverse Document Frequency) 算法进行文档检索。
    通过计算查询文本与文档集合的 TFIDF 向量,使用余弦相似度找到最相关的文档。
    """

    # 初始化方法
    def __init__(self, vectorizer=None, documents=None, **kwargs):
        """
        初始化 TFIDF 检索器

        参数:
            vectorizer: TfidfVectorizer 实例,如果为 None 则创建默认实例
            documents: Document 列表,用于训练向量化器
            **kwargs: 传递给 TfidfVectorizer 的其他参数
        """
        # 调用父类初始化方法,传递可能的关键字参数
        super().__init__(**kwargs)
        # 默认采用中文tokenizer,如果未手动指定,则设置为_chinese_tokenizer
        if 'tokenizer' not in kwargs:
                kwargs['tokenizer'] = _chinese_tokenizer
            # 初始化TfidfVectorizer
        self.vectorizer = TfidfVectorizer(**kwargs)

        # 保存文档列表,如果为None则赋值为空列表
        self.documents = documents or []

        # 如果文档列表不为空,则训练向量化器
        if self.documents:
            self._fit_vectorizer()

    # 训练向量化器并生成所有文档TF-IDF向量
    def _fit_vectorizer(self):
        """训练 TFIDF 向量化器并计算所有文档的向量"""
        # 提取所有文档的内容,组成文本列表
        texts = [doc.page_content for doc in self.documents]
        # 用所有文档内容训练TF-IDF向量化器,像学习字典,先看所有文档,建立词汇表和规则
        self.vectorizer.fit(texts)
        # 将所有文档内容转换为TF-IDF向量矩阵,像查字典,根据词汇表和规则,将文档转为向量
        self.document_vectors = self.vectorizer.transform(texts)

    # 通过类方法根据文档列表创建检索器对象
    @classmethod
    def from_documents(cls, documents, **kwargs):
        """
        从文档列表创建 TFIDFRetriever 实例

        参数:
            documents: Document 列表
            **kwargs: 其他参数,可传递给 TfidfVectorizer

        返回:
            TFIDFRetriever 实例
        """
        # 创建检索器实例,传递文档列表及其他参数
        instance = cls(documents=documents, **kwargs)
        return instance

    # 主体检索方法,返回与查询相关的文档
    def _get_relevant_documents(self, query, k=4, **kwargs):
        """
        获取相关文档

        参数:
            query: 查询字符串
            k: 返回的文档数量,默认 4
            **kwargs: 其他参数

        返回:
            Document 列表,按相似度从高到低排序
        """
        # 如果没有文档,则直接返回空列表
        if not self.documents:
            return []

        # 对查询进行分词,用于检查是否有匹配的词
        query_tokens = set(_chinese_tokenizer(query).split())

        # 将查询转为向量
        query_vector = self.vectorizer.transform([query])

        # 计算查询向量和所有文档向量的余弦相似度
        # 取相似度矩阵的第一行(唯一一行),得到所有文档与查询的相似度分数
        similarities = cosine_similarity(query_vector, self.document_vectors)[0]

        # 检查每个文档是否包含至少一个查询词
        doc_has_query_word = []
        for doc in self.documents:
            doc_tokens = set(_chinese_tokenizer(doc.page_content).split())
            # 文档必须包含至少一个查询词,且相似度大于0
            doc_has_query_word.append(len(query_tokens & doc_tokens) > 0)

        # 获取相似度排序后的索引(降序)
        sorted_indices = np.argsort(similarities)[::-1]

        # 只保留相似度大于0且包含至少一个查询词的文档
        valid_indices = [i for i in sorted_indices if similarities[i] > 0 and doc_has_query_word[i]]

        # 取前k个
        top_indices = valid_indices[:k]

        # 返回对应的Document对象列表
        return [self.documents[i] for i in top_indices]

    # 对外调用的检索接口
    def invoke(self, query, **kwargs):
        """
        调用检索器获取相关文档

        参数:
            query: 查询字符串
            **kwargs: 其他参数,如 k(返回文档数量)

        返回:
            Document 列表
        """
        # 调用内部实际检索函数
        return self._get_relevant_documents(query, **kwargs)

51.3. init.py #

smartchain/retrievers/init.py

from .vector_store import VectorStoreRetriever
from .base import BaseRetriever
+from .tfidf import TFIDFRetriever
__all__ = [
    "VectorStoreRetriever",
    "BaseRetriever",
+   "TFIDFRetriever"
]

51.4. 类 #

51.4.1 类说明 #

类名 类型 作用 关键属性/方法
Document 数据类 存储文档内容和元数据 page_content: 文本内容
metadata: 元数据字典
BaseRetriever 抽象基类 定义检索器接口规范 invoke(): 对外调用接口
_get_relevant_documents(): 抽象方法(需子类实现)
TFIDFRetriever 具体实现类 实现基于 TF-IDF 的文档检索 vectorizer: TfidfVectorizer 向量化器
documents: 文档列表
document_vectors: 文档向量矩阵
from_documents(): 类方法创建实例
invoke(): 执行检索
_get_relevant_documents(): 核心检索逻辑
_fit_vectorizer(): 训练向量化器
_chinese_tokenizer 函数 中文分词预处理 使用 jieba 搜索引擎模式分词
过滤标点和单字符词

51.4.2 类图 #

classDiagram class Document { +str page_content +Dict[str, Any] metadata +__init__(page_content, metadata) } class BaseRetriever { <<abstract>> +__init__(**kwargs) +invoke(query, **kwargs) List[Document] #_get_relevant_documents(query, **kwargs)* List[Document] } class TFIDFRetriever { -TfidfVectorizer vectorizer -List[Document] documents -ndarray document_vectors +__init__(vectorizer, documents, **kwargs) +from_documents(cls, documents, **kwargs)$ TFIDFRetriever +invoke(query, **kwargs) List[Document] -_get_relevant_documents(query, k=4, **kwargs) List[Document] -_fit_vectorizer() } class TfidfVectorizer { <<external>> +fit(texts) +transform(texts) } class cosine_similarity { <<function>> +cosine_similarity(query_vector, document_vectors) } class _chinese_tokenizer { <<function>> +_chinese_tokenizer(text) str } BaseRetriever <|-- TFIDFRetriever : 继承实现 TFIDFRetriever "1" *-- "0..*" Document : 包含文档列表 TFIDFRetriever "1" --> "1" TfidfVectorizer : 使用向量化器 TFIDFRetriever ..> cosine_similarity : 调用 TFIDFRetriever ..> _chinese_tokenizer : 使用分词函数 note for BaseRetriever "抽象基类\n定义检索器标准接口\n使用模板方法模式" note for TFIDFRetriever "具体实现类\n实现TF-IDF检索算法\n使用余弦相似度计算相关性\n过滤不包含查询词的文档" note for Document "数据类\n存储文档内容和元数据\n作为检索的输入和输出" note for _chinese_tokenizer "中文分词函数\n使用jieba搜索引擎模式\n过滤标点和单字符词"

51.4.3 时序图 #

sequenceDiagram participant Main as 51.TFIDFRetriever.py participant Doc as Document participant TFIDF as TFIDFRetriever participant Base as BaseRetriever participant Tokenizer as _chinese_tokenizer participant Vec as TfidfVectorizer participant Cosine as cosine_similarity Note over Main: 阶段1: 创建文档对象 Main->>Doc: Document(page_content="...") Doc-->>Main: doc1, doc2, doc3, doc4 Note over Main: 阶段2: 创建检索器实例 Main->>TFIDF: from_documents(docs) activate TFIDF TFIDF->>TFIDF: __init__(documents=docs) TFIDF->>Base: super().__init__() activate Base Base-->>TFIDF: 初始化完成 deactivate Base TFIDF->>TFIDF: 设置tokenizer=_chinese_tokenizer TFIDF->>Vec: TfidfVectorizer(tokenizer=_chinese_tokenizer) Vec-->>TFIDF: vectorizer实例 TFIDF->>TFIDF: _fit_vectorizer() activate TFIDF TFIDF->>TFIDF: 提取doc.page_content TFIDF->>Vec: fit(texts) activate Vec Vec->>Tokenizer: _chinese_tokenizer(text) for each text activate Tokenizer Tokenizer-->>Vec: 分词后的文本 deactivate Tokenizer Vec-->>TFIDF: 训练完成(建立词汇表) deactivate Vec TFIDF->>Vec: transform(texts) activate Vec Vec->>Tokenizer: _chinese_tokenizer(text) for each text activate Tokenizer Tokenizer-->>Vec: 分词后的文本 deactivate Tokenizer Vec-->>TFIDF: document_vectors矩阵 deactivate Vec deactivate TFIDF TFIDF-->>Main: retriever实例 deactivate TFIDF Note over Main: 阶段3: 执行检索查询 Main->>TFIDF: invoke(query="机器人与人工智能", k=2) activate TFIDF TFIDF->>TFIDF: _get_relevant_documents(query, k=2) activate TFIDF TFIDF->>Tokenizer: _chinese_tokenizer(query) activate Tokenizer Tokenizer-->>TFIDF: query_tokens (分词集合) deactivate Tokenizer TFIDF->>Vec: transform([query]) activate Vec Vec->>Tokenizer: _chinese_tokenizer(query) activate Tokenizer Tokenizer-->>Vec: 分词后的查询文本 deactivate Tokenizer Vec-->>TFIDF: query_vector deactivate Vec TFIDF->>Cosine: cosine_similarity(query_vector, document_vectors) activate Cosine Cosine-->>TFIDF: similarities数组 deactivate Cosine TFIDF->>TFIDF: 检查每个文档是否包含查询词 loop 对每个文档 TFIDF->>Tokenizer: _chinese_tokenizer(doc.page_content) activate Tokenizer Tokenizer-->>TFIDF: doc_tokens deactivate Tokenizer TFIDF->>TFIDF: 检查query_tokens & doc_tokens是否非空 end TFIDF->>TFIDF: np.argsort(similarities)[::-1] TFIDF->>TFIDF: 筛选相似度>0且包含查询词的文档 TFIDF->>TFIDF: 取前k个 TFIDF-->>TFIDF: [self.documents[i] for i in top_indices] deactivate TFIDF TFIDF-->>Main: results (Document列表) deactivate TFIDF Note over Main: 阶段4: 输出结果 Main->>Main: 遍历results并打印

51.4.4 调用过程 #

阶段 1:文档对象创建

docs = [
    Document(page_content="深度学习通过神经网络取得突破。"),
    Document(page_content="机器人学结合机械工程与人工智能。"),
    Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
    Document(page_content="人工智能正推动机器人自主学习与创新。"),
]
  • 创建 4 个 Document 实例
  • 每个对象包含 page_content 和默认的 metadata

阶段 2:检索器初始化

retriever = TFIDFRetriever.from_documents(docs)

调用链:

  1. from_documents() 类方法
    • 调用 cls(documents=docs, **kwargs),即 __init__()
  2. __init__() 方法
    • 调用 super().__init__() 初始化基类
    • 设置默认 tokenizer 为 _chinese_tokenizer
    • 创建 TfidfVectorizer 实例
    • 保存 documents 列表
    • 若 documents 非空,调用 _fit_vectorizer()
  3. _fit_vectorizer() 方法
    • 提取所有文档的 page_content
    • 调用 vectorizer.fit(texts) 训练(构建词汇表和 IDF)
    • 调用 vectorizer.transform(texts) 生成文档向量矩阵

阶段 3:执行检索

results = retriever.invoke(query, k=2)

调用链:

  1. invoke() 方法(继承自 BaseRetriever)
    • 调用 _get_relevant_documents(query, k=2)
  2. _get_relevant_documents() 方法(核心检索逻辑)

    步骤 1:预处理

    • 检查文档列表是否为空
    • 对查询进行分词:query_tokens = set(_chinese_tokenizer(query).split())

    步骤 2:向量化查询

    • 调用 vectorizer.transform([query]) 将查询转换为 TF-IDF 向量

    步骤 3:计算余弦相似度

    • 调用 cosine_similarity(query_vector, self.document_vectors) 计算相似度数组

    步骤 4:检查查询词匹配

    for doc in self.documents:
        doc_tokens = set(_chinese_tokenizer(doc.page_content).split())
        doc_has_query_word.append(len(query_tokens & doc_tokens) > 0)

    确保文档至少包含一个查询词

    步骤 5:排序和筛选

    • 按相似度降序排序:sorted_indices = np.argsort(similarities)[::-1]
    • 筛选有效文档:valid_indices = [i for i in sorted_indices if similarities[i] > 0 and doc_has_query_word[i]]
    • 取前 k 个:top_indices = valid_indices[:k]
    • 返回对应的文档对象列表:return [self.documents[i] for i in top_indices]

阶段 4:结果输出

print(f"查询:{query}")
print(f"返回 {len(results)} 条:")
for i, doc in enumerate(results, 1):
    print(f"{i}. {doc.page_content}")
  • 遍历 results(Document 列表)
  • 打印每个文档的 page_content

51.5 TF-IDF 算法 #

TF-IDF(Term Frequency-Inverse Document Frequency)用于衡量词在文档中的重要性。

TF-IDF 公式

对于词 $t$ 和文档 $d$:

$$\text{TF-IDF}(t, d) = \text{TF}(t, d) \times \text{IDF}(t)$$

其中:

  • $\text{TF}(t, d) = \frac{\text{词t在文档d中的出现次数}}{\text{文档d的总词数}}$(词频)
  • $\text{IDF}(t) = \log\frac{N}{df(t)}$(逆文档频率)
    • $N$:文档总数
    • $df(t)$:包含词 $t$ 的文档数量

51.6. 余弦相似度 #

查询向量 $\vec{q}$ 和文档向量 $\vec{d}$ 的余弦相似度:

$$\cos(\theta) = \frac{\vec{q} \cdot \vec{d}}{|\vec{q}| \times |\vec{d}|} = \frac{\sum_{i} q_i \times d_i}{\sqrt{\sum_{i} q_i^2} \times \sqrt{\sum_{i} d_i^2}}$$

相似度范围:$[-1, 1]$,值越大越相似。

52.BM25Retriever #

  • rank_bm25

BM25 检索器(BM25Retriever)是一种经典的信息检索算法,在不依赖向量嵌入的情况下实现关键词检索。其主要适用于文本检索、FAQ 系统、知识库初步召回等场景。

主要功能

  • 高效关键词检索:BM25 根据 query 与文档的关键词匹配度为文档打分,返回最相关的文本。
  • 免向量化依赖:无需复杂的 embedding,仅依赖分词和排名,部署简单,召回快。
  • 可定制预处理:支持自定义分词、预处理函数,适配不同的语言或业务数据。

52.1. 52.BM25Retriever.py #

52.BM25Retriever.py

# 导入 Document 类(文档数据结构)
from smartchain.documents import Document
# 导入 BM25Retriever 检索器类
from smartchain.retrievers import BM25Retriever

# 构造示例文档列表,每个 Document 对象代表一段文本
docs = [
    # 文档1:关于深度学习的描述
    Document(page_content="深度学习通过神经网络取得突破。"),
    # 文档2:机器人学与人工智能关系
    Document(page_content="机器人学结合机械工程与人工智能。"),
    # 文档3:美食点评内容
    Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
    # 文档4:人工智能正推动机器人自主学习与创新
    Document(page_content="人工智能正推动机器人自主学习与创新。"),
]

# 使用 from_documents 类方法创建 BM25 检索器实例
retriever = BM25Retriever.from_documents(docs)

# 设置检索时的查询字符串
query = "机器人与人工智能"
# 调用 invoke 方法执行检索,k=2 表示返回前2个相关文档
results = retriever.invoke(query, k=2)

# 打印本次检索的查询内容
print(f"查询:{query}")
# 打印检索结果的文档数量
print(f"返回 {len(results)} 条:")
# 遍历检索出的文档,逐条输出其内容
for i, doc in enumerate(results, 1):
    print(f"{i}. {doc.page_content}")

52.2. bm25.py #

smartchain/retrievers/bm25.py

# 导入BM25Okapi用于BM25算法实现
from rank_bm25 import BM25Okapi
# 导入 BaseRetriever 基类
from .base import BaseRetriever
# 导入自定义 Document 类
from ..documents import Document
# 导入jieba用于中文分词
import jieba
import re

# 定义中文分词函数,返回分词列表
def _chinese_tokenizer(text: str):
    # 使用jieba的搜索引擎模式进行分词
    tokens = list(jieba.cut_for_search(text))
    # 过滤掉标点符号、空白字符和单字符词,只保留有效词
    tokens = [t for t in tokens if re.match(r'^[\u4e00-\u9fa5a-zA-Z0-9]+$', t) and len(t) > 1]
    # 返回分词结果列表
    return tokens

# 定义默认的分词预处理函数(按空格切分文本)
def default_preprocessing_func(text: str):
    # 按空格将文本切分为词组成的列表
    return text.split()

# 定义BM25检索器类,继承自BaseRetriever
class BM25Retriever(BaseRetriever):
    # 初始化BM25Retriever类
    def __init__(self, vectorizer=None, docs=None, k=4, preprocess_func=None):
        # 调用父类的初始化方法
        super().__init__()
        # 保存BM25Okapi向量化器对象
        self.vectorizer = vectorizer
        # 保存文档列表,如果未指定则为空列表
        self.docs = docs or []
        # 保存每次检索返回的文档条数
        self.k = k
        # 保存分词预处理函数,默认使用中文分词器
        self.preprocess_func = preprocess_func or _chinese_tokenizer

    # 类方法:通过文本批量创建BM25检索器实例
    @classmethod
    def from_texts(
        cls,
        texts,
        metadatas=None,
        ids=None,
        bm25_params=None,
        preprocess_func=None,
        k=4,
        **kwargs
    ):
        # 如果未指定分词预处理函数,默认使用中文分词器
        preprocess_func = preprocess_func or _chinese_tokenizer
        # 对每个文本应用分词处理,得到分词后的文本列表
        texts_processed = [preprocess_func(t) for t in texts]
        # 如果未指定BM25参数,设置为空字典
        bm25_params = bm25_params or {}
        # 用分词后的文本初始化BM25Okapi向量化器
        vectorizer = BM25Okapi(texts_processed, **bm25_params)
        # 如果未指定元数据,自动生成空字典生成器
        metadatas = metadatas or ({} for _ in texts)
        # 判断是否有ids字段,构建Document对象列表(id目前未使用)
        if ids:
            docs = [
                Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)
            ]
        else:
            docs = [
                Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)
            ]
        # 创建并返回BM25Retriever实例
        return cls(
            vectorizer=vectorizer,
            docs=docs,
            preprocess_func=preprocess_func,
            k=k,
            **kwargs
        )

    # 类方法:通过Document对象列表创建BM25检索器实例
    @classmethod
    def from_documents(
        cls,
        documents,
        bm25_params=None,
        preprocess_func=None,
        k=4,
        **kwargs
    ):
        # 从每个Document对象提取文本内容组成列表
        texts = [doc.page_content for doc in documents]
        # 从每个Document对象提取元数据,没有则使用空字典
        metadatas = [doc.metadata if hasattr(doc, 'metadata') else {} for doc in documents]
        # 调用from_texts构建检索器实例
        return cls.from_texts(
            texts=texts,
            metadatas=metadatas,
            bm25_params=bm25_params,
            preprocess_func=preprocess_func,
            k=k,
            **kwargs
        )

    # 内部方法:检索最相关的文档
    def _get_relevant_documents(self, query, **kwargs):
        # 优先使用参数k,否则用默认的k值
        k = kwargs.get("k", self.k)
        # 对查询字符串进行分词预处理
        processed_query = self.preprocess_func(query)
        # 如果文档列表为空,则直接返回空列表
        if not self.docs:
            return []
        # 对每个文档内容进行分词,得到文档分词列表
        docs_tokenized = [self.preprocess_func(doc.page_content) for doc in self.docs]
        # 文档总数
        N = len(docs_tokenized)
        # 计算文档平均长度
        avgdl = sum(len(doc_tokens) for doc_tokens in docs_tokenized) / N if N > 0 else 0
        # 设置BM25参数k1和b
        k1, b = 1.5, 0.75
        # 计算查询词的IDF(逆文档频率)
        query_word_idf = {}
        for word in processed_query:
            # 统计包含该词的文档数
            df = sum(1 for doc_tokens in docs_tokenized if word in doc_tokens)
            if df > 0:
                # 按BM25公式计算IDF
                query_word_idf[word] = (N - df + 0.5) / (df + 0.5)
            else:
                # 若不存在该词,IDF记为0
                query_word_idf[word] = 0
        # 初始化文档分数列表
        doc_scores = []
        # 逐个文档计算得分
        for i, doc_tokens in enumerate(docs_tokenized):
            score = 0
            # 遍历每个查询词
            for word in processed_query:
                # 仅对文档中存在的词计算分数
                if word in doc_tokens:
                    # 词频统计
                    tf = doc_tokens.count(word)
                    # 查询词IDF
                    idf = query_word_idf[word]
                    # BM25核心公式
                    score += idf * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * len(doc_tokens) / avgdl))
            # 保存本条文档的得分及索引
            doc_scores.append((score, i))
        # 按照分数降序排列,取前k条
        doc_scores.sort(reverse=True, key=lambda x: x[0])
        # 选择分数大于0的前k个文档索引
        top_indices = [idx for score, idx in doc_scores[:k] if score > 0]
        # 返回对应的文档对象列表
        return [self.docs[i] for i in top_indices]

52.3. init.py #

smartchain/retrievers/init.py

from .vector_store import VectorStoreRetriever
from .base import BaseRetriever
from .tfidf import TFIDFRetriever
+from .bm25 import BM25Retriever
__all__ = [
    "VectorStoreRetriever",
    "BaseRetriever",
+   "TFIDFRetriever",
+   "BM25Retriever"
]

52.4.类 #

52.4.1 类说明 #

类名 类型 作用 关键属性/方法
Document 数据类 存储文档内容和元数据 page_content: 文本内容
metadata: 元数据字典
BaseRetriever 抽象基类 定义检索器接口规范 invoke(): 对外调用接口
_get_relevant_documents(): 抽象方法
BM25Retriever 具体实现类 实现基于 BM25 算法的文档检索 vectorizer: BM25Okapi 向量化器
docs: 文档列表
preprocess_func: 分词函数
from_documents(): 类方法
invoke(): 检索
_get_relevant_documents(): BM25分数计算
_chinese_tokenizer 函数 中文分词预处理 使用 jieba 搜索引擎模式分词
过滤标点和单字符词

52.4.2 类图 #

classDiagram class Document { +str page_content +Dict[str, Any] metadata +__init__(page_content, metadata) } class BaseRetriever { <<abstract>> +__init__(**kwargs) +invoke(query, **kwargs) List[Document] #_get_relevant_documents(query, **kwargs)* List[Document] } class BM25Retriever { -BM25Okapi vectorizer -List[Document] docs -int k -Callable preprocess_func +__init__(vectorizer, docs, k, preprocess_func) +from_documents(cls, documents, **kwargs)$ BM25Retriever +from_texts(cls, texts, **kwargs)$ BM25Retriever +invoke(query, **kwargs) List[Document] -_get_relevant_documents(query, **kwargs) List[Document] } class BM25Okapi { <<external>> +__init__(corpus, k1, b) +get_scores(query) +get_top_n(query, corpus, n) } class _chinese_tokenizer { <<function>> +_chinese_tokenizer(text) List[str] } BaseRetriever <|-- BM25Retriever : 继承实现 BM25Retriever "1" *-- "0..*" Document : 包含文档列表 BM25Retriever "1" --> "1" BM25Okapi : 使用向量化器 BM25Retriever ..> _chinese_tokenizer : 使用分词函数 note for BaseRetriever "抽象基类\n定义检索器标准接口\n使用模板方法模式" note for BM25Retriever "具体实现类\n实现BM25检索算法\n手动计算BM25分数并排序" note for Document "数据类\n存储文档内容和元数据\n作为检索的输入和输出" note for _chinese_tokenizer "中文分词函数\n使用jieba搜索引擎模式\n过滤标点和单字符词"

52.4.3 时序图 #

sequenceDiagram participant Main as 52.BM25Retriever.py participant Doc as Document participant BM25 as BM25Retriever participant Base as BaseRetriever participant Tokenizer as _chinese_tokenizer participant BM25Okapi as BM25Okapi Note over Main: 阶段1: 创建文档对象 Main->>Doc: Document(page_content="...") Doc-->>Main: doc1, doc2, doc3, doc4 Note over Main: 阶段2: 创建检索器实例 Main->>BM25: from_documents(docs) activate BM25 BM25->>BM25: from_texts(texts, metadatas) activate BM25 BM25->>Tokenizer: _chinese_tokenizer(text) for each text activate Tokenizer Tokenizer->>Tokenizer: jieba.cut_for_search(text) Tokenizer->>Tokenizer: 过滤标点和单字符词 Tokenizer-->>BM25: texts_processed (分词列表) deactivate Tokenizer BM25->>BM25Okapi: BM25Okapi(texts_processed) activate BM25Okapi BM25Okapi-->>BM25: vectorizer实例 deactivate BM25Okapi BM25->>Doc: Document(page_content, metadata) Doc-->>BM25: docs列表 BM25->>BM25: __init__(vectorizer, docs, k, preprocess_func) BM25->>Base: super().__init__() activate Base Base-->>BM25: 初始化完成 deactivate Base BM25-->>BM25: 返回实例 deactivate BM25 BM25-->>Main: retriever实例 deactivate BM25 Note over Main: 阶段3: 执行检索查询 Main->>BM25: invoke(query="机器人与人工智能", k=2) activate BM25 BM25->>BM25: _get_relevant_documents(query, k=2) activate BM25 BM25->>Tokenizer: preprocess_func(query) activate Tokenizer Tokenizer-->>BM25: processed_query (分词列表) deactivate Tokenizer BM25->>BM25: 检查docs是否为空 BM25->>Tokenizer: preprocess_func(doc.page_content) for each doc activate Tokenizer Tokenizer-->>BM25: docs_tokenized (文档分词列表) deactivate Tokenizer BM25->>BM25: 计算N(文档总数)和avgdl(平均长度) BM25->>BM25: 计算查询词的IDF loop 对每个查询词 BM25->>BM25: 统计包含该词的文档数(df) BM25->>BM25: IDF = (N - df + 0.5) / (df + 0.5) end BM25->>BM25: 计算每个文档的BM25分数 loop 对每个文档 loop 对每个查询词 BM25->>BM25: 计算词频(tf) BM25->>BM25: BM25公式计算分数 Note right of BM25: score += idf * (tf * (k1+1)) /<br/>(tf + k1 * (1-b + b*len/avgdl)) end BM25->>BM25: 保存(score, index) end BM25->>BM25: 按分数降序排序 BM25->>BM25: 取前k个分数>0的文档 BM25-->>BM25: [self.docs[i] for i in top_indices] deactivate BM25 BM25-->>Main: results (Document列表) deactivate BM25 Note over Main: 阶段4: 输出结果 Main->>Main: 遍历results并打印

52.4.4 调用过程 #

阶段 1:文档对象创建

docs = [
    Document(page_content="深度学习通过神经网络取得突破。"),
    Document(page_content="机器人学结合机械工程与人工智能。"),
    Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
    Document(page_content="人工智能技术正推动机器人自主学习与创新。"),
]
  • 创建 4 个 Document 实例
  • 每个对象包含 page_content 和默认的 metadata

阶段 2:检索器初始化

retriever = BM25Retriever.from_documents(docs)

调用链:

  1. from_documents() 类方法
    • 提取所有文档的 page_content 和 metadata
    • 调用 from_texts()
  2. from_texts() 类方法
    • 使用 _chinese_tokenizer 对每个文本分词
    • 创建 BM25Okapi 向量化器(用于后续可能的功能)
    • 创建 Document 对象列表
    • 调用 __init__() 创建实例
  3. __init__() 方法
    • 调用 super().__init__() 初始化基类
    • 保存 vectorizer、docs、k、preprocess_func

阶段 3:执行检索

results = retriever.invoke(query, k=2)

调用链:

  1. invoke() 方法(继承自 BaseRetriever)
    • 调用 _get_relevant_documents(query, k=2)
  2. _get_relevant_documents() 方法(核心检索逻辑)

    步骤 1:预处理

    • 获取 k 值(优先使用参数,否则使用默认值)
    • 对查询进行分词:processed_query = self.preprocess_func(query)
    • 检查文档列表是否为空

    步骤 2:文档分词

    • 对每个文档内容进行分词:docs_tokenized = [self.preprocess_func(doc.page_content) for doc in self.docs]

    步骤 3:计算统计量

    • N = len(docs_tokenized)(文档总数)
    • avgdl = sum(len(doc_tokens) for doc_tokens in docs_tokenized) / N(平均文档长度)
    • 设置 BM25 参数:k1=1.5, b=0.75

    步骤 4:计算查询词的 IDF(逆文档频率)

    for word in processed_query:
        df = sum(1 for doc_tokens in docs_tokenized if word in doc_tokens)
        if df > 0:
            query_word_idf[word] = (N - df + 0.5) / (df + 0.5)

    步骤 5:计算每个文档的 BM25 分数

    for i, doc_tokens in enumerate(docs_tokenized):
        score = 0
        for word in processed_query:
            if word in doc_tokens:
                tf = doc_tokens.count(word)  # 词频
                idf = query_word_idf[word]   # 逆文档频率
                # BM25核心公式
                score += idf * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * len(doc_tokens) / avgdl))
        doc_scores.append((score, i))

    步骤 6:排序和筛选

    • 按分数降序排序:doc_scores.sort(reverse=True, key=lambda x: x[0])
    • 取前 k 个分数 > 0 的文档:top_indices = [idx for score, idx in doc_scores[:k] if score > 0]
    • 返回对应的文档对象列表:return [self.docs[i] for i in top_indices]

阶段 4:结果输出

print(f"查询:{query}")
print(f"返回 {len(results)} 条:")
for i, doc in enumerate(results, 1):
    print(f"{i}. {doc.page_content}")
  • 遍历 results(Document 列表)
  • 打印每个文档的 page_content

52.4.5 BM25 算法 #

BM25(Best Matching 25)是用于信息检索的排序函数。

52.4.5.1 BM25 公式 #

对于查询词 $q$ 和文档 $d$,BM25 分数计算为:

$$\text{BM25}(q, d) = \sum_{i=1}^{n} \text{IDF}(q_i) \cdot \frac{f(q_i, d) \cdot (k_1 + 1)}{f(q_i, d) + k_1 \cdot \left(1 - b + b \cdot \frac{|d|}{\text{avgdl}}\right)}$$

其中:

  • $q_i$:查询中的第 $i$ 个词
  • $f(q_i, d)$:词 $q_i$ 在文档 $d$ 中的词频(TF)
  • $\text{IDF}(q_i)$:词 $q_i$ 的逆文档频率
  • $|d|$:文档 $d$ 的长度(词数)
  • $\text{avgdl}$:所有文档的平均长度
  • $k_1$:词频饱和度参数(通常为 1.5)
  • $b$:长度归一化参数(通常为 0.75)
52.4.5.2 IDF 计算 #

$$\text{IDF}(q_i) = \frac{N - df(q_i) + 0.5}{df(q_i) + 0.5}$$

其中:

  • $N$:文档总数
  • $df(q_i)$:包含词 $q_i$ 的文档数量

53.VectorSimilarityRetriever #

向量相似度检索器(VectorSimilarityRetriever)是一种基于向量化表示和余弦相似度的经典检索工具。它和之前介绍的向量数据库型检索器不同,VectorSimilarityRetriever 常用于本地小规模场景下的内存型向量匹配,特点如下:

  • 无需外部依赖数据库:直接将文档转为嵌入向量后,全部保存在内存中,适用于文档规模不大、运行在轻量系统或教学场景。
  • 原理朴素直观:每当检索时,对查询文本进行向量化,与所有已存文档向量做 1:N 的余弦相似度比较,返回最相似的前 k 个文档。
  • 用法简明:只需提供文档列表,自动完成嵌入、存储和检索,不依赖外部检索服务、数据库或分布式系统。
  • 局限性:由于所有向量均保存在内存,随着文档增多性能和内存消耗会迅速变大,因此主要用于实验、演示、测试或小样本系统。

在 smartchain.retrievers.vector.VectorSimilarityRetriever 中,典型接口为:

  • from_documents(docs):类方法,初始化时传入 Document 列表,自动完成向量化和存储。
  • invoke(query, k=None):执行检索,返回与 query 语义最相似的前 k 个文档,k 可选,默认值如4。

底层实现流程如下:

  1. 把每个文档文本通过预定义/自带的嵌入器转为向量,保存在内存数组中。
  2. 用户输入查询时,同样用嵌入器转向量。
  3. 通过余弦相似度计算 query 向量与所有文档向量的相似度,筛选出最大 top_k 的文档返回。
  4. 可针对 zero-shot、few-shot 场景使用,也可作为向量数据库方案的 lightweight baseline 对比。

53.1. 53.SimilarityRetriever.py #

53.SimilarityRetriever.py

#from langchain_classic.retrievers.bm25 import BM25Retriever
#from langchain_core.documents import Document
# 导入 Document 类(文档表示类)
from smartchain.documents import Document
# 导入 BM25Retriever 检索器
from smartchain.retrievers import VectorSimilarityRetriever

# 准备示例文档,构造 Document 列表,每个包含一段文本
docs = [
    Document(page_content="深度学习通过神经网络取得突破。"),
    Document(page_content="机器人学结合机械工程与人工智能。"),
    Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
    Document(page_content="人工智能技术正推动机器人自主学习与创新。"),
]

# 使用 from_documents 类方法创建 BM25 检索器,传入文档列表
retriever = VectorSimilarityRetriever.from_documents(docs)

# 设置检索的查询字符串
query = "机器人与人工智能"
# 调用 invoke 方法执行检索,获取相关文档
results = retriever.invoke(query)

# 打印查询内容
print(f"查询:{query}")
# 打印返回的相关文档数量
print(f"返回 {len(results)} 条:")
# 遍历检索结果并逐条打印文档内容
for i, doc in enumerate(results, 1):
    print(f"{i}. {doc.page_content}")

53.2. vector.py #

smartchain/retrievers/vector.py

# 导入 BaseRetriever 基类
from .base import BaseRetriever
# 导入自定义 Document 类
from ..documents import Document
# 导入numpy用于数组操作
import numpy as np

# 定义余弦相似度计算函数
def cosine_similarity(query_embedding, doc_embeddings):
    """
    计算查询向量与文档向量的余弦相似度

    参数:
        query_embedding: 查询的嵌入向量(一维数组)
        doc_embeddings: 文档的嵌入向量列表(二维数组)

    返回:
        相似度分数数组
    """
    # 将查询嵌入转换为numpy数组
    query_emb = np.array(query_embedding)
    # 将文档嵌入转换为numpy数组
    doc_embs = np.array(doc_embeddings)
    # 若查询为一维,转换为二维以便向量批量操作
    if query_emb.ndim == 1:
        query_emb = query_emb.reshape(1, -1)
    # 计算查询和所有文档嵌入的点积
    dot_product = np.dot(query_emb, doc_embs.T)
    # 计算查询向量的范数
    query_norm = np.linalg.norm(query_emb, axis=1, keepdims=True)
    # 计算所有文档嵌入的范数,并转置方便后续相乘
    doc_norms = np.linalg.norm(doc_embs, axis=1, keepdims=True).T
    # 使用点积/范数乘积得到余弦相似度
    similarity = dot_product / (query_norm * doc_norms)
    # 将因除零产生的nan、正负无穷转为0
    similarity = np.nan_to_num(similarity, nan=0.0, posinf=0.0, neginf=0.0)
    # 返回第一行的相似度结果(只支持一个query)
    return similarity[0]

# 定义VectorSimilarityRetriever检索器,继承自BaseRetriever
class VectorSimilarityRetriever(BaseRetriever):
    # VectorSimilarity 检索器说明文档字符串
    """
    VectorSimilarity 检索器

    使用向量嵌入进行文档检索。
    通过计算查询文本与文档集合的嵌入向量,使用余弦相似度找到最相关的文档。
    """

    # 初始化方法
    def __init__(self, embeddings, documents=None, k=4, **kwargs):
        """
        初始化 VectorSimilarity 检索器

        参数:
            embeddings: 嵌入模型实例,必须实现 embed_query 和 embed_documents 方法
            documents: Document 列表,用于计算嵌入向量
            k: 默认返回的文档数量,默认 4
            **kwargs: 其他参数
        """
        # 调用父类初始化方法
        super().__init__(**kwargs)
        # 保存嵌入模型
        self.embeddings = embeddings
        # 保存返回文档数量
        self.k = k
        # 保存文档列表,如果为None则赋值为空列表
        self.documents = documents or []
        # 文档嵌入向量列表
        self.doc_embeddings = None

        # 如果文档列表不为空,则计算文档嵌入向量
        if self.documents:
            self._compute_document_embeddings()

    # 计算所有文档的嵌入向量
    def _compute_document_embeddings(self):
        """计算所有文档的嵌入向量"""
        # 提取所有文档的内容,组成文本列表
        texts = [doc.page_content for doc in self.documents]
        # 批量计算所有文档的嵌入向量
        self.doc_embeddings = self.embeddings.embed_documents(texts)

    # 通过类方法根据文档列表创建检索器对象
    @classmethod
    def from_documents(cls, documents, embeddings=None, k=4, **kwargs):
        """
        从文档列表创建 VectorSimilarityRetriever 实例

        参数:
            documents: Document 列表
            embeddings: 嵌入模型实例,如果为None则使用默认的HuggingFaceEmbeddings
            k: 默认返回的文档数量,默认 4
            **kwargs: 其他参数,可传递给嵌入模型

        返回:
            VectorSimilarityRetriever 实例
        """
        # 如果没有提供嵌入模型,使用默认的HuggingFaceEmbeddings
        if embeddings is None:
            from ..embeddings import HuggingFaceEmbeddings
            embeddings = HuggingFaceEmbeddings(**kwargs)

        # 创建检索器实例,传递文档列表及嵌入模型
        instance = cls(embeddings=embeddings, documents=documents, k=k)
        return instance

    # 主体检索方法,返回与查询相关的文档
    def _get_relevant_documents(self, query, k=None, **kwargs):
        """
        获取相关文档

        参数:
            query: 查询字符串
            k: 返回的文档数量,如果为None则使用默认值
            **kwargs: 其他参数

        返回:
            Document 列表,按相似度从高到低排序
        """
        # 如果没有文档,则直接返回空列表
        if not self.documents or self.doc_embeddings is None:
            return []

        # 获取k值,优先使用传入参数,否则使用默认值
        k = k if k is not None else self.k

        # 计算查询的嵌入向量
        query_embedding = self.embeddings.embed_query(query)

        # 计算查询向量和所有文档向量的余弦相似度
        similarities = cosine_similarity(query_embedding, self.doc_embeddings)

        # 获取相似度排序后的索引(降序),并取前k个
        top_indices = np.argsort(similarities)[::-1][:k]

        # 只保留相似度大于0的文档,返回对应的Document对象列表
        return [self.documents[i] for i in top_indices if similarities[i] > 0]

53.3. init.py #

smartchain/retrievers/init.py

from .vector_store import VectorStoreRetriever
from .base import BaseRetriever
from .tfidf import TFIDFRetriever
from .bm25 import BM25Retriever
+from .vector import VectorSimilarityRetriever
__all__ = [
    "VectorStoreRetriever",
    "BaseRetriever",
    "TFIDFRetriever",
+   "BM25Retriever",
+   "VectorSimilarityRetriever"
]

53.4.类 #

53.4.1 类说明 #

类名 类型 作用 关键属性/方法
Document 数据类 存储文档内容和元数据 page_content: 文档文本内容
metadata: 元数据字典
BaseRetriever 抽象基类 定义检索器接口规范 invoke(): 对外调用接口
_get_relevant_documents(): 抽象方法(需子类实现)
VectorSimilarityRetriever 具体实现类 实现基于向量嵌入的文档检索 embeddings: 嵌入模型实例
documents: 文档列表
doc_embeddings: 文档嵌入向量列表
k: 默认返回文档数量
from_documents(): 类方法创建实例
invoke(): 执行检索
_get_relevant_documents(): 核心检索逻辑
_compute_document_embeddings(): 计算文档嵌入向量
HuggingFaceEmbeddings 嵌入模型类 使用 HuggingFace 模型生成文本嵌入向量 model_name: 模型名称
embed_query(): 嵌入单个查询文本
embed_documents(): 批量嵌入文档文本
Embedding 抽象基类 定义嵌入模型接口规范 embed_query(): 抽象方法
embed_documents(): 抽象方法
cosine_similarity 函数 计算余弦相似度 计算查询向量与文档向量的余弦相似度

53.4.2 类图 #

classDiagram class Document { +str page_content +Dict[str, Any] metadata +__init__(page_content, metadata) } class BaseRetriever { <<abstract>> +__init__(**kwargs) +invoke(query, **kwargs) List[Document] #_get_relevant_documents(query, **kwargs)* List[Document] } class VectorSimilarityRetriever { -Embedding embeddings -List[Document] documents -List[List[float]] doc_embeddings -int k +__init__(embeddings, documents, k, **kwargs) +from_documents(cls, documents, embeddings, k, **kwargs)$ VectorSimilarityRetriever +invoke(query, **kwargs) List[Document] -_get_relevant_documents(query, k, **kwargs) List[Document] -_compute_document_embeddings() } class Embedding { <<abstract>> +embed_query(text)* List[float] +embed_documents(texts)* List[List[float]] } class HuggingFaceEmbeddings { -str model_name -LangchainHuggingFaceEmbeddings embeddings +__init__(model_name, **kwargs) +embed_query(text) List[float] +embed_documents(texts) List[List[float]] } class cosine_similarity { <<function>> +cosine_similarity(query_embedding, doc_embeddings) ndarray } BaseRetriever <|-- VectorSimilarityRetriever : 继承实现 Embedding <|-- HuggingFaceEmbeddings : 继承实现 VectorSimilarityRetriever "1" *-- "0..*" Document : 包含文档列表 VectorSimilarityRetriever "1" --> "1" Embedding : 使用嵌入模型 VectorSimilarityRetriever ..> cosine_similarity : 调用 VectorSimilarityRetriever ..> HuggingFaceEmbeddings : 默认创建 note for BaseRetriever "抽象基类\n定义检索器标准接口\n使用模板方法模式" note for VectorSimilarityRetriever "具体实现类\n实现基于向量嵌入的检索\n使用余弦相似度计算相关性" note for Embedding "嵌入模型抽象基类\n定义嵌入接口规范" note for HuggingFaceEmbeddings "HuggingFace嵌入模型\n默认使用all-MiniLM-L6-v2模型"

53.4.3 时序图 #

sequenceDiagram participant Main as 53.SimilarityRetriever.py participant Doc as Document participant VSR as VectorSimilarityRetriever participant Base as BaseRetriever participant Embed as HuggingFaceEmbeddings participant Cosine as cosine_similarity Note over Main: 阶段1: 创建文档对象 Main->>Doc: Document(page_content="...") Doc-->>Main: doc1, doc2, doc3, doc4 Note over Main: 阶段2: 创建检索器实例 Main->>VSR: from_documents(docs) activate VSR VSR->>VSR: 检查embeddings是否为None VSR->>Embed: HuggingFaceEmbeddings() (默认模型) activate Embed Embed-->>VSR: embeddings实例 deactivate Embed VSR->>VSR: __init__(embeddings, documents) VSR->>Base: super().__init__() activate Base Base-->>VSR: 初始化完成 deactivate Base VSR->>VSR: _compute_document_embeddings() activate VSR VSR->>VSR: 提取doc.page_content VSR->>Embed: embed_documents(texts) activate Embed Embed-->>VSR: doc_embeddings (文档嵌入向量列表) deactivate Embed deactivate VSR VSR-->>Main: retriever实例 deactivate VSR Note over Main: 阶段3: 执行检索查询 Main->>VSR: invoke(query="机器人与人工智能") activate VSR VSR->>VSR: _get_relevant_documents(query, k=None) activate VSR VSR->>VSR: 检查documents和doc_embeddings VSR->>VSR: k = k if k is not None else self.k VSR->>Embed: embed_query(query) activate Embed Embed-->>VSR: query_embedding (查询嵌入向量) deactivate Embed VSR->>Cosine: cosine_similarity(query_embedding, doc_embeddings) activate Cosine Cosine->>Cosine: 计算点积和范数 Cosine->>Cosine: similarity = dot_product / (query_norm * doc_norms) Cosine-->>VSR: similarities数组 deactivate Cosine VSR->>VSR: np.argsort(similarities)[::-1][:k] VSR->>VSR: 筛选相似度>0的文档 VSR-->>VSR: [self.documents[i] for i in top_indices] deactivate VSR VSR-->>Main: results (Document列表) deactivate VSR Note over Main: 阶段4: 输出结果 Main->>Main: 遍历results并打印

53.4.4 调用过程 #

阶段 1:文档对象创建

docs = [
    Document(page_content="深度学习通过神经网络取得突破。"),
    Document(page_content="机器人学结合机械工程与人工智能。"),
    Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
    Document(page_content="人工智能技术正推动机器人自主学习与创新。"),
]
  • 创建 4 个 Document 实例
  • 每个对象包含 page_content 和默认的 metadata

阶段 2:检索器初始化

retriever = VectorSimilarityRetriever.from_documents(docs)

调用链:

  1. from_documents() 类方法
    • 检查 embeddings 是否为 None
    • 若为 None,创建默认的 HuggingFaceEmbeddings 实例(使用 all-MiniLM-L6-v2 模型)
    • 调用 cls(embeddings=embeddings, documents=documents, k=k),即 __init__()
  2. __init__() 方法
    • 调用 super().__init__() 初始化基类
    • 保存 embeddings、documents、k
    • 初始化 doc_embeddings = None
    • 若 documents 非空,调用 _compute_document_embeddings()
  3. _compute_document_embeddings() 方法
    • 提取所有文档的 page_content
    • 调用 embeddings.embed_documents(texts) 批量计算文档嵌入向量
    • 保存到 self.doc_embeddings

阶段 3:执行检索

results = retriever.invoke(query)

调用链:

  1. invoke() 方法(继承自 BaseRetriever)
    • 调用 _get_relevant_documents(query, k=None)
  2. _get_relevant_documents() 方法(核心检索逻辑)

    步骤 1:预处理

    • 检查 documents 和 doc_embeddings 是否为空
    • 确定 k 值(优先使用参数,否则使用默认值)

    步骤 2:计算查询嵌入向量

    • 调用 embeddings.embed_query(query) 将查询转换为嵌入向量

    步骤 3:计算余弦相似度

    • 调用 cosine_similarity(query_embedding, self.doc_embeddings) 计算相似度
    • 内部计算:
      dot_product = np.dot(query_emb, doc_embs.T)
      query_norm = np.linalg.norm(query_emb, axis=1, keepdims=True)
      doc_norms = np.linalg.norm(doc_embs, axis=1, keepdims=True).T
      similarity = dot_product / (query_norm * doc_norms)

    步骤 4:排序和筛选

    • 按相似度降序排序:top_indices = np.argsort(similarities)[::-1][:k]
    • 筛选相似度 > 0 的文档:return [self.documents[i] for i in top_indices if similarities[i] > 0]

阶段 4:结果输出

print(f"查询:{query}")
print(f"返回 {len(results)} 条:")
for i, doc in enumerate(results, 1):
    print(f"{i}. {doc.page_content}")
  • 遍历 results(Document 列表)
  • 打印每个文档的 page_content

53.4.5 向量嵌入与余弦相似度 #

向量嵌入(Embedding)

将文本转换为数值向量,捕获语义信息:

  • 语义相似的文本在向量空间中距离更近
  • 使用预训练模型(如 all-MiniLM-L6-v2)生成固定维度的向量

余弦相似度公式

对于查询向量 $\vec{q}$ 和文档向量 $\vec{d}$:

$$\cos(\theta) = \frac{\vec{q} \cdot \vec{d}}{|\vec{q}| \times |\vec{d}|} = \frac{\sum_{i=1}^{n} q_i \times d_i}{\sqrt{\sum_{i=1}^{n} q_i^2} \times \sqrt{\sum_{i=1}^{n} d_i^2}}$$

其中:

  • $\vec{q} \cdot \vec{d}$:点积
  • $|\vec{q}|$、$|\vec{d}|$:向量的 L2 范数
  • 相似度范围:$[-1, 1]$,值越大越相似

优势

  1. 语义理解:能捕获语义相似性,不依赖关键词匹配
  2. 多语言支持:预训练模型支持多种语言
  3. 上下文感知:考虑上下文信息

53.4.6 设计模式 #

  1. 模板方法模式
    • BaseRetriever.invoke() 定义调用流程
    • VectorSimilarityRetriever._get_relevant_documents() 实现具体检索逻辑
  2. 工厂方法模式
    • from_documents() 作为类方法工厂,简化实例创建
  3. 策略模式
    • BaseRetriever 定义接口,VectorSimilarityRetriever 提供具体策略
    • Embedding 定义嵌入接口,HuggingFaceEmbeddings 提供具体实现

该设计便于扩展新的检索器和嵌入模型实现。

53.4.7 与 TFIDF/BM25 的区别 #

特性 VectorSimilarityRetriever TFIDFRetriever BM25Retriever
算法基础 向量嵌入(深度学习) TF-IDF(统计) BM25(统计)
语义理解 ✅ 支持 ❌ 不支持 ❌ 不支持
关键词匹配 ❌ 不依赖 ✅ 依赖 ✅ 依赖
计算复杂度 较高(需要模型推理) 较低 较低
多语言支持 ✅ 好 需分词器 需分词器
上下文理解 ✅ 强 ❌ 弱 ❌ 弱

VectorSimilarityRetriever 更适合需要语义理解的场景,而 TFIDF/BM25 更适合关键词匹配场景。

54.EnsembleRetriever #

  • RRF

EnsembleRetriever 是一种通过集成多个独立检索器(如不同数据源或算法)来提升整体检索效果的高级用法。

EnsembleRetriever 采用加权 Reciprocal Rank Fusion (RRF) 算法,将多个检索器的返回结果进行融合,并根据设定的权重对每个检索源影响力加权排序。其主要特点和典型应用场景如下:

  • 灵活集成:可同时结合多个不同来源或不同算法的检索器,提高召回率和鲁棒性。
  • 权重可调:支持自定义各子检索器的权重,便于根据实际效果调整信息融合策略。
  • 去重与排序:自动对不同检索器的结果按唯一键去重,并结合权重做最终排序,兼顾多样性与准确性。
  • 适用场景:适合需要整合多种索引、异构数据源或不同检索策略的复合应用。

54.1. 54.EnsembleRetriever.py #

54.EnsembleRetriever.py

# 导入 Document 类,用于表示文档对象
from smartchain.documents import Document
# 导入集成检索器(EnsembleRetriever),用于多检索器融合
from smartchain.retrievers import EnsembleRetriever

# 构建示例文档列表,每个元素是一个 Document 对象,包含一段文本内容
docs = [
    Document(page_content="深度学习通过神经网络取得突破。"),
    Document(page_content="机器人学结合机械工程与人工智能。"),
    Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
    Document(page_content="人工智能技术正推动机器人自主学习与创新。"),
]

# 使用类方法 from_documents 创建集成检索器实例,传入文档列表
# weights=[0.5, 0.5] 表示 BM25 检索器和向量相似度检索器权重各占 0.5
retriever = EnsembleRetriever.from_documents(
    docs,
    weights=[0.5, 0.5]  # BM25 和向量检索权重设置
)

# 设置检索的查询语句
query = "机器人与人工智能"
# 调用 invoke 方法执行检索操作,k=2 表示取出前2个相关文档
results = retriever.invoke(query, k=2)

# 打印查询内容
print(f"查询:{query}")
# 打印返回的相关文档数量
print(f"返回 {len(results)} 条:")
# 遍历所有检索结果,依次打印每条文档的内容
for i, doc in enumerate(results, 1):
    print(f"{i}. {doc.page_content}")

54.2. ensemble.py #

smartchain/retrievers/ensemble.py

# 导入基础检索器基类
from .base import BaseRetriever
# 导入BM25检索器
from .bm25 import BM25Retriever
# 导入向量相似度检索器
from .vector import VectorSimilarityRetriever
# 导入HuggingFace嵌入模型
from ..embeddings import HuggingFaceEmbeddings

# 定义集成检索器类(融合多种检索器)
class EnsembleRetriever(BaseRetriever):
    # 初始化方法
    def __init__(
        self,
        retrievers=None,  # 传入的检索器列表
        weights=None,     # 各检索器的权重
        k=60,             # RRF分数计算时的k值
    ):
        # 初始化父类
        super().__init__()
        # 设置检索器列表,如果未指定则为空列表
        self.retrievers = retrievers or []
        # 设置权重
        self.weights = weights
        # 设置RRF算法的k参数
        self.k = k

    # 类方法:根据文档自动构建集成检索器
    @classmethod
    def from_documents(
        cls,
        documents,
        weights=None,
    ):
        # 创建BM25检索器
        bm25_retriever = BM25Retriever.from_documents(documents)
        # 用指定模型创建文本嵌入实例(默认使用CPU)
        embeddings = HuggingFaceEmbeddings(
            model_name="BAAI/bge-base-zh-v1.5",
            model_kwargs={"device": "cpu"}
        )
        # 创建向量相似度检索器
        vector_retriever = VectorSimilarityRetriever.from_documents(
            documents,
            embeddings=embeddings,
        )
        # 如果没有指定权重,则默认均分
        if weights is None:
            weights = [0.5, 0.5]
        # 打包检索器列表
        retrievers = [bm25_retriever, vector_retriever]
        # 返回集成检索器实例
        return cls(retrievers=retrievers, weights=weights, k=60)

    # 获取文档唯一标识(此处直接用内容)
    def _get_doc_id(self, doc):
        return doc.page_content

    # 根据所有检索器得到的排位计算RRF融合分数
    def _calculate_rrf_score(self, doc_ranks):
        # 存储每个文档的RRF分数
        rrf_scores = {}
        # 遍历每个文档的各检索器排名
        for doc_id, ranks in doc_ranks.items():
            score = 0.0
            # 依次遍历各检索器排名
            for i, rank in enumerate(ranks):
                # 如果有排名(未命中为None)
                if rank is not None:
                    # 使用指定权重,如果没有,则为1.0
                    weight = self.weights[i] if self.weights else 1.0
                    # RRF分数累加
                    score += weight / (self.k + rank)
            # 汇总该文档的总分
            rrf_scores[doc_id] = score
        # 返回所有文档分数
        return rrf_scores

    # 获取最终融合后的相关文档
    def _get_relevant_documents(self, query, k=4, **kwargs):
        # 如果检索器列表为空,直接返回空结果
        if not self.retrievers:
            return []
        # 如果权重未设置,则自动平均分配
        if self.weights is None:
            self.weights = [1.0 / len(self.retrievers)] * len(self.retrievers)
        # 权重数量必须等于检索器数量,否则报错
        if len(self.weights) != len(self.retrievers):
            raise ValueError(
                f"权重数量 ({len(self.weights)}) 必须与检索器数量 ({len(self.retrievers)}) 一致"
            )

        # all_results用于存储每个检索器的返回结果
        all_results = []
        # 分别用每个检索器检索,大k值,提高召回
        for retriever in self.retrievers:
            results = retriever.invoke(query, k=k * 2, **kwargs)
            all_results.append(results)

        # 建立从ID到文档对象的映射
        doc_id_to_doc = {}
        # 存储每个文档在各检索器中的排名
        doc_ranks = {}

        # 遍历所有检索器的召回结果
        for retriever_idx, results in enumerate(all_results):
            for rank, doc in enumerate(results, start=1):
                # 获取文档ID
                doc_id = self._get_doc_id(doc)
                # 首次出现该文档,初始化
                if doc_id not in doc_id_to_doc:
                    doc_id_to_doc[doc_id] = doc
                    doc_ranks[doc_id] = [None] * len(self.retrievers)
                # 仅记录第一次出现的rank
                if doc_ranks[doc_id][retriever_idx] is None:
                    doc_ranks[doc_id][retriever_idx] = rank

        # 融合所有排名,返回RRF融合分数
        rrf_scores = self._calculate_rrf_score(doc_ranks)
        # 按分数降序排列
        sorted_docs = sorted(
            rrf_scores.items(),
            key=lambda x: x[1],
            reverse=True
        )
        # 收集分数大于0的top k文档
        top_k_docs = []
        for doc_id, score in sorted_docs[:k]:
            if score > 0:
                top_k_docs.append(doc_id_to_doc[doc_id])
        # 返回最终文档列表
        return top_k_docs

54.3. init.py #

smartchain/retrievers/init.py

from .vector_store import VectorStoreRetriever
from .base import BaseRetriever
from .tfidf import TFIDFRetriever
from .bm25 import BM25Retriever
from .vector import VectorSimilarityRetriever
+from .ensemble import EnsembleRetriever
__all__ = [
    "VectorStoreRetriever",
    "BaseRetriever",
    "TFIDFRetriever",
    "BM25Retriever",
+   "VectorSimilarityRetriever",
+   "EnsembleRetriever"
]

54.4. 类 #

54.4.1 类说明 #

类名 类型 作用 关键属性/方法
Document 数据类 存储文档内容和元数据 page_content: 文档文本内容
metadata: 元数据字典
BaseRetriever 抽象基类 定义检索器接口规范 invoke(): 对外调用接口
_get_relevant_documents(): 抽象方法(需子类实现)
EnsembleRetriever 集成检索器类 融合多个检索器的结果 retrievers: 检索器列表
weights: 各检索器权重
k: RRF算法k值
from_documents(): 类方法创建实例
invoke(): 执行检索
_get_relevant_documents(): 核心检索逻辑
_calculate_rrf_score(): 计算RRF融合分数
_get_doc_id(): 获取文档唯一标识
BM25Retriever 检索器类 基于BM25算法的文档检索 from_documents(): 类方法创建实例
invoke(): 执行检索
VectorSimilarityRetriever 检索器类 基于向量嵌入的文档检索 embeddings: 嵌入模型实例
from_documents(): 类方法创建实例
invoke(): 执行检索
HuggingFaceEmbeddings 嵌入模型类 使用HuggingFace模型生成文本嵌入向量 model_name: 模型名称
embed_query(): 嵌入单个查询文本
embed_documents(): 批量嵌入文档文本

54.4.2 类图 #

classDiagram class Document { +str page_content +Dict[str, Any] metadata +__init__(page_content, metadata) } class BaseRetriever { <<abstract>> +__init__(**kwargs) +invoke(query, **kwargs) List[Document] #_get_relevant_documents(query, **kwargs)* List[Document] } class EnsembleRetriever { -List[BaseRetriever] retrievers -List[float] weights -int k +__init__(retrievers, weights, k) +from_documents(cls, documents, weights)$ EnsembleRetriever +invoke(query, **kwargs) List[Document] -_get_relevant_documents(query, k, **kwargs) List[Document] -_calculate_rrf_score(doc_ranks) Dict[str, float] -_get_doc_id(doc) str } class BM25Retriever { +from_documents(cls, documents, **kwargs)$ BM25Retriever +invoke(query, **kwargs) List[Document] -_get_relevant_documents(query, **kwargs) List[Document] } class VectorSimilarityRetriever { -Embedding embeddings -List[Document] documents +from_documents(cls, documents, embeddings, **kwargs)$ VectorSimilarityRetriever +invoke(query, **kwargs) List[Document] -_get_relevant_documents(query, **kwargs) List[Document] } class HuggingFaceEmbeddings { -str model_name +__init__(model_name, **kwargs) +embed_query(text) List[float] +embed_documents(texts) List[List[float]] } BaseRetriever <|-- EnsembleRetriever : 继承实现 BaseRetriever <|-- BM25Retriever : 继承实现 BaseRetriever <|-- VectorSimilarityRetriever : 继承实现 EnsembleRetriever "1" *-- "2..*" BaseRetriever : 包含多个检索器 EnsembleRetriever "1" --> "1" BM25Retriever : 使用 EnsembleRetriever "1" --> "1" VectorSimilarityRetriever : 使用 VectorSimilarityRetriever "1" --> "1" HuggingFaceEmbeddings : 使用 EnsembleRetriever "1" *-- "0..*" Document : 返回文档列表 note for BaseRetriever "抽象基类\n定义检索器标准接口\n使用模板方法模式" note for EnsembleRetriever "集成检索器\n融合多个检索器结果\n使用RRF算法进行分数融合" note for BM25Retriever "BM25检索器\n基于关键词匹配\n适合精确匹配场景" note for VectorSimilarityRetriever "向量相似度检索器\n基于语义理解\n适合语义匹配场景"

54.4.3 时序图 #

sequenceDiagram participant Main as 54.EnsembleRetriever.py participant Doc as Document participant ER as EnsembleRetriever participant Base as BaseRetriever participant BM25 as BM25Retriever participant VSR as VectorSimilarityRetriever participant Embed as HuggingFaceEmbeddings Note over Main: 阶段1: 创建文档对象 Main->>Doc: Document(page_content="...") Doc-->>Main: doc1, doc2, doc3, doc4 Note over Main: 阶段2: 创建集成检索器 Main->>ER: from_documents(docs, weights=[0.5, 0.5]) activate ER ER->>BM25: from_documents(documents) activate BM25 BM25-->>ER: bm25_retriever实例 deactivate BM25 ER->>Embed: HuggingFaceEmbeddings(model_name="BAAI/bge-base-zh-v1.5") activate Embed Embed-->>ER: embeddings实例 deactivate Embed ER->>VSR: from_documents(documents, embeddings) activate VSR VSR->>VSR: 计算文档嵌入向量 VSR-->>ER: vector_retriever实例 deactivate VSR ER->>ER: cls(retrievers=[bm25, vector], weights=[0.5, 0.5], k=60) ER->>Base: super().__init__() activate Base Base-->>ER: 初始化完成 deactivate Base ER-->>Main: retriever实例 deactivate ER Note over Main: 阶段3: 执行检索查询 Main->>ER: invoke(query="机器人与人工智能", k=2) activate ER ER->>ER: _get_relevant_documents(query, k=2) activate ER ER->>ER: 验证retrievers和weights Note over ER: 并行调用多个检索器 par 调用BM25检索器 ER->>BM25: invoke(query, k=4) activate BM25 BM25-->>ER: results1 (Document列表) deactivate BM25 and 调用向量检索器 ER->>VSR: invoke(query, k=4) activate VSR VSR->>Embed: embed_query(query) activate Embed Embed-->>VSR: query_embedding deactivate Embed VSR->>VSR: 计算余弦相似度 VSR-->>ER: results2 (Document列表) deactivate VSR end ER->>ER: 收集所有检索结果 ER->>ER: 建立doc_id_to_doc映射 ER->>ER: 记录每个文档在各检索器中的排名 loop 遍历所有检索器结果 loop 遍历每个文档 ER->>ER: _get_doc_id(doc) ER->>ER: 记录文档排名 end end ER->>ER: _calculate_rrf_score(doc_ranks) activate ER loop 遍历每个文档 loop 遍历每个检索器 ER->>ER: score += weight / (k + rank) end end ER-->>ER: rrf_scores字典 deactivate ER ER->>ER: 按RRF分数降序排序 ER->>ER: 取前k个分数>0的文档 ER-->>ER: top_k_docs列表 deactivate ER ER-->>Main: results (Document列表) deactivate ER Note over Main: 阶段4: 输出结果 Main->>Main: 遍历results并打印

54.4.4 调用过程 #

阶段 1:文档对象创建

docs = [
    Document(page_content="深度学习通过神经网络取得突破。"),
    Document(page_content="机器人学结合机械工程与人工智能。"),
    Document(page_content="美食点评:这家餐厅的川菜很正宗。"),
    Document(page_content="人工智能技术正推动机器人自主学习与创新。"),
]
  • 创建 4 个 Document 实例
  • 每个对象包含 page_content 和默认的 metadata

阶段 2:创建集成检索器

retriever = EnsembleRetriever.from_documents(
    docs,
    weights=[0.5, 0.5]  # BM25 和向量检索权重设置
)

调用链:

  1. from_documents() 类方法
    • 创建 BM25Retriever 实例:bm25_retriever = BM25Retriever.from_documents(documents)
    • 创建 HuggingFaceEmbeddings 实例(使用 BAAI/bge-base-zh-v1.5 模型)
    • 创建 VectorSimilarityRetriever 实例:vector_retriever = VectorSimilarityRetriever.from_documents(documents, embeddings=embeddings)
    • 如果没有指定权重,默认设置为 [0.5, 0.5]
    • 打包检索器列表:retrievers = [bm25_retriever, vector_retriever]
    • 调用 cls(retrievers=retrievers, weights=weights, k=60) 创建实例
  2. __init__() 方法
    • 调用 super().__init__() 初始化基类
    • 保存 retrievers、weights、k(RRF 算法参数,默认 60)

阶段 3:执行检索

results = retriever.invoke(query, k=2)

调用链:

  1. invoke() 方法(继承自 BaseRetriever)
    • 调用 _get_relevant_documents(query, k=2)
  2. _get_relevant_documents() 方法(核心检索逻辑)

    步骤 1:验证和初始化

    • 检查检索器列表是否为空
    • 如果权重未设置,自动平均分配
    • 验证权重数量与检索器数量一致

    步骤 2:并行调用多个检索器

    for retriever in self.retrievers:
        results = retriever.invoke(query, k=k * 2, **kwargs)  # k*2 提高召回率
        all_results.append(results)
    • 每个检索器返回 k*2 个文档(提高召回率)
    • 收集所有检索器的结果

    步骤 3:建立文档映射和排名记录

    doc_id_to_doc = {}  # ID到文档对象的映射
    doc_ranks = {}      # 每个文档在各检索器中的排名
    
    for retriever_idx, results in enumerate(all_results):
        for rank, doc in enumerate(results, start=1):
            doc_id = self._get_doc_id(doc)  # 获取文档ID(使用page_content)
            if doc_id not in doc_id_to_doc:
                doc_id_to_doc[doc_id] = doc
                doc_ranks[doc_id] = [None] * len(self.retrievers)
            if doc_ranks[doc_id][retriever_idx] is None:
                doc_ranks[doc_id][retriever_idx] = rank

    步骤 4:计算 RRF 融合分数

    rrf_scores = self._calculate_rrf_score(doc_ranks)

    内部计算逻辑:

    for doc_id, ranks in doc_ranks.items():
        score = 0.0
        for i, rank in enumerate(ranks):
            if rank is not None:
                weight = self.weights[i] if self.weights else 1.0
                score += weight / (self.k + rank)  # RRF公式
        rrf_scores[doc_id] = score

    步骤 5:排序和筛选

    • 按 RRF 分数降序排序
    • 取前 k 个分数 > 0 的文档
    • 返回对应的 Document 对象列表

阶段 4:结果输出

print(f"查询:{query}")
print(f"返回 {len(results)} 条:")
for i, doc in enumerate(results, 1):
    print(f"{i}. {doc.page_content}")
  • 遍历 results(Document 列表)
  • 打印每个文档的 page_content

54.4.5 RRF 算法 #

RRF(Reciprocal Rank Fusion)用于融合多个检索器的排名结果。

RRF 公式

对于文档 $d$,其 RRF 分数计算为:

$$\text{RRF}(d) = \sum_{i=1}^{n} \frac{w_i}{k + \text{rank}_i(d)}$$

其中:

  • $n$:检索器数量
  • $w_i$:第 $i$ 个检索器的权重
  • $k$:RRF 算法参数(通常为 60)
  • $\text{rank}_i(d)$:文档 $d$ 在第 $i$ 个检索器中的排名(如果未出现则为 None,不参与计算)

RRF 算法优势

  1. 无需归一化:不同检索器的分数无需归一化即可融合
  2. 处理缺失:某个检索器未返回的文档仍可参与融合
  3. 权重可调:通过 weights 参数调整各检索器的重要性
  4. 提高召回:融合多个检索器,提高召回率

示例计算

假设:

  • 文档 A 在 BM25 中排名 1,在向量检索中排名 2
  • 文档 B 在 BM25 中排名 2,在向量检索中排名 1
  • 权重:[0.5, 0.5],k=60

文档 A 的 RRF 分数: $$\text{RRF}(A) = \frac{0.5}{60 + 1} + \frac{0.5}{60 + 2} = \frac{0.5}{61} + \frac{0.5}{62} \approx 0.0164$$

文档 B 的 RRF 分数: $$\text{RRF}(B) = \frac{0.5}{60 + 2} + \frac{0.5}{60 + 1} = \frac{0.5}{62} + \frac{0.5}{61} \approx 0.0164$$

两者分数接近,但文档 B 在两个检索器中排名更靠前,最终排名可能更高。

55.LLMChainExtractor #

本节介绍如何利用大语言模型(LLM)提升检索问答系统的相关性和答案精度,具体通过文档压缩与上下文压缩检索技术实现。
在传统检索系统中,仅靠相似度召回的文档常常包含冗余内容或部分不相关片段,导致后续大模型生成的答案长且噪声多。
LLM驱动的文档压缩器(如LLMChainExtractor)能够智能地从每个召回文档中,仅抽取与用户问题直接相关的段落,极大提升答案简洁性和相关度。

  1. LLMChainExtractor文档压缩器

    • 该压缩器会对每个检索到的文档,结合问题,通过大模型抽取出最相关的部分。如果判定文档与问题无关,则自动过滤该文档。
    • 适用于中/英文场景,支持按需自定义抽取模板。
    • 能有效剔除冗余、弱相关内容,为问答系统后续推理提供精准上下文。
  2. ContextualCompressionRetriever上下文压缩检索器

    • 这是一个高级检索管道,将“基础检索器”与“文档压缩器”无缝衔接。
    • 流程为:首先使用相似度等方式初步召回一批文档,然后对这些文档应用LLMChainExtractor压缩,仅保留高相关内容或段落。
    • 这样,不仅提升了下游大模型生成答案的准确性,也显著降低无用信息干扰。
  3. 典型应用场景

    • 非结构化知识库问答、企业内部知识检索、FAQ系统等,尤其适用于原始文档较长、信息密度低、但只需部分内容即可作答的问题型场景。
  4. 一体化示例

    • 下方示例代码展示了如何集成向量数据库、嵌入模型、检索器和LLM压缩器,构建高效高相关性的智能问答系统。
    • 其中核心语句为:
      compressor = LLMChainExtractor.from_llm(llm)
      compression_retriever = ContextualCompressionRetriever(
          base_compressor=compressor, base_retriever=base_retriever)
      results = compression_retriever.invoke(query)
    • 只需更换“llm”、“嵌入模型”或检索参数,即可灵活适配不同业务需求。

55.1. 55.LLMChainExtractor.py #

55.LLMChainExtractor.py

#from langchain_chroma import Chroma
#from langchain_huggingface import HuggingFaceEmbeddings
#from langchain_classic.retrievers.contextual_compression import ContextualCompressionRetriever
#from langchain_classic.retrievers.document_compressors.chain_extract import LLMChainExtractor
#from langchain_deepseek import ChatDeepSeek
# 导入 Chroma 向量数据库类
from smartchain.vectorstores import Chroma
# 导入 HuggingFaceEmbeddings 嵌入模型类
from smartchain.embeddings import HuggingFaceEmbeddings
# 导入上下文压缩型检索器
from smartchain.retrievers import ContextualCompressionRetriever
# 导入 LLMChainExtractor 文档压缩器
from smartchain.document_compressors import LLMChainExtractor
# 导入 ChatDeepSeek LLM 模型
from smartchain.chat_models import ChatDeepSeek

# 设置本地嵌入模型的路径
model_path = "C:/Users/Administrator/.cache/modelscope/hub/models/sentence-transformers/all-MiniLM-L6-v2"
# 初始化嵌入模型,指定模型路径与推理设备
embeddings = HuggingFaceEmbeddings(
    model_name=model_path,
    model_kwargs={"device": "cpu"}
)

# 初始化 Chroma 向量数据库实例,指定持久化目录、嵌入函数、集合名称和元数据
chroma_db = Chroma(
    persist_directory="chroma_db",
    embedding_function=embeddings,
    collection_name="test",
    collection_metadata={"hnsw:space": "cosine"}
)

# 检查数据库是否为空,若为空批量插入初始文本及其元数据
if not chroma_db._collection.count():
    # 定义插入的文本列表
    texts = [
        "人工智能(AI)是一种让机器模拟人类智能行为的技术。",
        "深度学习是人工智能的一个重要分支,通过多层神经网络学习数据。",
        "ChatGPT是OpenAI开发的强大自然语言模型。",
        "向量数据库可以高效地存储和检索文本的嵌入向量。",
        "机器人学结合了人工智能和机械工程,推动自动化发展。",
        "AI可以辅助医生进行医学影像分析。",
        "大模型在对话、问答、摘要等领域不断取得突破。",
        "知识库问答系统常用于企业信息检索场景。",
    ]
    # 对应的元数据列表
    metadatas = [
        {"topic": "ai"},
        {"topic": "ai"},
        {"topic": "nlp"},
        {"topic": "vector_db"},
        {"topic": "robotics"},
        {"topic": "healthcare"},
        {"topic": "llm"},
        {"topic": "retrieval"},
    ]
    # 向向量数据库批量插入文本和元数据
    chroma_db.add_texts(texts, metadatas)

# 创建 ChatDeepSeek LLM 对象,使用 deepseek-chat 模型
llm = ChatDeepSeek(model="deepseek-chat")

# 创建基础检索器,指定检索类型为 similarity 和返回数量 k=20
# 建议 k 取较大值,因为压缩器后续会过滤部分不相关的结果
base_retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 20})

# 创建 LLMChainExtractor 文档压缩器,利用 LLM 从文档抽取相关部分
compressor = LLMChainExtractor.from_llm(llm=llm)

# 创建上下文压缩检索器,结合基础检索器与文档压缩器
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=base_retriever
)

# 指定检索查询内容(支持中文或英文)
query = "人工智能"  # 可以更换为任意中英文问题
# 调用上下文压缩检索器进行检索
results = compression_retriever.invoke(query)
# 打印查询内容
print(f"查询:{query}")
# 打印检索到的文档数量
print(f"共检索到 {len(results)} 个文档")
# 遍历输出每条检索结果及其元数据
for i, doc in enumerate(results):
    print(f"检索结果{i}:{doc.page_content}")
    print(f"  元数据:{doc.metadata}\n")

55.2. document_compressors.py #

smartchain/document_compressors.py

# 从abc模块导入抽象基类支持
from abc import ABC, abstractmethod

# 定义一个默认的中文抽取Prompt模板
CHINESE_EXTRACT_PROMPT = """给定以下问题和上下文,提取上下文中与回答问题相关的任何部分(保持原样)。如果上下文都不相关,返回 NO_OUTPUT。

记住,*不要*编辑提取的上下文部分。

> 问题:{question}
> 上下文:
>>>
{context}
>>>
提取的相关部分:"""

# 定义文档压缩器的抽象基类
class BaseDocumentCompressor(ABC):
    # 设置文档压缩器的类说明文档
    """文档压缩器抽象基类
    用于对检索到的文档进行后处理,例如提取相关内容、过滤不相关文档等。
    """

    # 声明子类必须实现的抽象方法
    @abstractmethod
    def compress_documents(
        self,
        documents,
        query: str,
        callbacks=None,
    ):
        # 方法说明文档
        """压缩文档
          参数:
            documents: 检索到的文档列表
            query: 查询字符串
            callbacks: 可选的回调函数
          返回:
            Sequence: 压缩后的文档列表
        """
        # 抽象方法体不实现
        pass

# 定义LLM链文档提取器类,继承自基类
class LLMChainExtractor(BaseDocumentCompressor):
    # 设置LLM链提取器的类说明文档
    """LLM链提取器
    使用 LLM 链从文档中提取与查询相关的部分。
    """
    # 构造函数
    def __init__(
        self,
        llm_chain=None,
        llm=None,
        prompt=None,
        get_input=None,
    ):
        # 如果传入llm_chain则直接赋值
        if llm_chain is not None:
            self.llm_chain = llm_chain
        # 如果传入llm而未传llm_chain,则根据参数生成链
        elif llm is not None:
            # 延迟导入PromptTemplate工具
            from .prompts import PromptTemplate
            # 如果未自定义提示词模板,使用默认中文抽取Prompt
            if prompt is None:
                prompt_template = PromptTemplate.from_template(CHINESE_EXTRACT_PROMPT)
            else:
                prompt_template = prompt
            # 用私有方法构造链
            self.llm_chain = self._create_chain(prompt_template, llm)
        # 两者都没给时报错
        else:
            raise ValueError("必须提供 llm_chain 或 llm 参数")
        # 设置get_input函数(用于将query/doc打包成模型输入),可以自定义
        self.get_input = get_input or self._default_get_input

    # 默认输入转换函数,把query和doc打包成dict
    def _default_get_input(self, query: str, doc) -> dict:
        # 返回一个包含问题和文本内容的字典
        return {"question": query, "context": doc.page_content}

    # 私有方法,实现用prompt_template和llm创建简单链对象
    def _create_chain(self, prompt_template, llm):
        # 定义内部类SimpleChain
        class SimpleChain:
            # 构造函数,保存prompt模板和llm对象
            def __init__(self, prompt_template, llm):
                self.prompt_template = prompt_template
                self.llm = llm
            # 实现invoke方法,将输入dict格式化prompt并调用llm
            def invoke(self, input_dict, **kwargs):
                # 格式化prompt
                formatted_prompt = self.prompt_template.format(**input_dict)
                # 用llm推理,得到响应
                response = self.llm.invoke(formatted_prompt, **kwargs)
                # 如果响应有content属性,则用其内容
                if hasattr(response, 'content'):
                    text = response.content
                # 如果响应直接为字符串
                elif isinstance(response, str):
                    text = response
                # 否则强转为字符串
                else:
                    text = str(response)
                # 去掉文本首尾空格
                text = text.strip()
                # 若内容为NO_OUTPUT或空,则返回空字符串
                if text == "NO_OUTPUT" or text == "":
                    return ""
                # 返回抽取的文本
                return text
        # 返回SimpleChain对象实例
        return SimpleChain(prompt_template, llm)

    # 类方法,支持通过llm快速初始化LLMChainExtractor
    @classmethod
    def from_llm(
        cls,
        llm,
        prompt=None,
        get_input=None,
    ):
        # 返回初始化好的提取器对象
        return cls(
            llm=llm,
            prompt=prompt,
            get_input=get_input
        )

    # 实现文档压缩主接口
    def compress_documents(
        self,
        documents,
        query: str,
        callbacks=None,
    ):
        # 用于存储压缩结果的列表
        compressed_docs = []
        # 遍历所有输入文档
        for doc in documents:
            # 构造当前doc与query的输入结构
            _input = self.get_input(query, doc)
            # 用llm链抽取相关内容
            output = self.llm_chain.invoke(_input)
            # 如果没抽取到任何内容则跳过该文档
            if not output or len(output) == 0:
                continue
            # 动态导入Document类
            from .vectorstores import Document
            # 用抽取得到内容与原有元数据构造新文档
            compressed_doc = Document(
                page_content=output,
                metadata=doc.metadata if hasattr(doc, 'metadata') else {}
            )
            # 加入压缩结果列表
            compressed_docs.append(compressed_doc)
        # 返回压缩结果文档列表
        return compressed_docs

55.3. contextual.py #

smartchain/retrievers/contextual.py

# 导入基础检索器基类
from .base import BaseRetriever
# 导入基础文档压缩器基类
from ..document_compressors import BaseDocumentCompressor

# 定义上下文压缩检索器类,继承自BaseRetriever
class ContextualCompressionRetriever(BaseRetriever):
    # 设置类文档字符串,说明该类包装基础检索器和文档压缩器
    """上下文压缩检索器
    包装基础检索器,使用文档压缩器对检索结果进行压缩。
    """

    # 定义初始化方法
    def __init__(
        self,
        base_compressor: BaseDocumentCompressor,
        base_retriever: BaseRetriever,
    ):
        # 方法文档字符串,说明参数
        """初始化上下文压缩检索器
          参数:
            base_compressor: 文档压缩器实例
            base_retriever: 基础检索器实例
        """
        # 调用父类构造函数进行初始化
        super().__init__()
        # 保存文档压缩器对象
        self.base_compressor = base_compressor
        # 保存基础检索器对象
        self.base_retriever = base_retriever

    # 重写获取相关文档的钩子方法
    def _get_relevant_documents(self, query, **kwargs):
        # 方法文档,阐述参数与返回值
        """获取相关文档
          参数:
            query: 查询字符串
            **kwargs: 其他参数
          返回:
            List: 压缩后的文档列表
        """
        # 使用基础检索器进行初步检索,获取原始文档
        docs = self.base_retriever.invoke(query, **kwargs)
        # 如果没有检索到任何文档,则直接返回空列表
        if not docs:
            return []
        # 使用文档压缩器对检索到的文档进行压缩处理
        compressed_docs = self.base_compressor.compress_documents(
            docs,
            query
        )
        # 返回压缩处理后的文档列表
        return list(compressed_docs)

55.4. init.py #

smartchain/retrievers/init.py

from .vector_store import VectorStoreRetriever
from .base import BaseRetriever
from .tfidf import TFIDFRetriever
from .bm25 import BM25Retriever
from .vector import VectorSimilarityRetriever
from .ensemble import EnsembleRetriever
+from .contextual import ContextualCompressionRetriever
__all__ = [
    "VectorStoreRetriever",
    "BaseRetriever",
    "TFIDFRetriever",
    "BM25Retriever",
    "VectorSimilarityRetriever",
+   "EnsembleRetriever",
+   "ContextualCompressionRetriever"
]

55.5. 类 #

55.5.1 类说明 #

类名 类型 作用 关键属性/方法
Document 数据类 存储文档内容和元数据 page_content: 文本内容
metadata: 元数据字典
HuggingFaceEmbeddings 嵌入模型类 使用 HuggingFace 模型生成文本嵌入向量 model_name: 模型名称
embed_query(): 嵌入单个查询
embed_documents(): 批量嵌入文档
Chroma 向量存储类 Chroma 向量数据库封装 embedding_function: 嵌入函数
_collection: Chroma 集合
add_texts(): 添加文本
similarity_search(): 相似度搜索
as_retriever(): 创建检索器
VectorStoreRetriever 检索器类 从向量存储中检索文档 vectorstore: 向量存储实例
search_type: 搜索类型
invoke(): 执行检索
BaseRetriever 抽象基类 定义检索器接口规范 invoke(): 对外调用接口
_get_relevant_documents(): 抽象方法(需子类实现)
ContextualCompressionRetriever 检索器类 包装基础检索器和文档压缩器 base_retriever: 基础检索器实例
base_compressor: 文档压缩器实例
invoke(): 先检索再压缩
_get_relevant_documents(): 核心逻辑
BaseDocumentCompressor 抽象基类 定义文档压缩器接口规范 compress_documents(): 抽象方法(需子类实现)
LLMChainExtractor 文档压缩器类 使用 LLM 提取文档相关内容 llm_chain: LLM 链对象
get_input: 输入转换
from_llm(): 类方法创建实例
compress_documents(): 压缩文档
ChatDeepSeek LLM 模型类 DeepSeek 大语言模型 model: 模型名称
invoke(): 生成回复
PromptTemplate 提示词模板类 格式化提示词模板 from_template(): 创建实例
format(): 格式化模板

55.5.2 类图 #

classDiagram class Document { +str page_content +Dict[str, Any] metadata +__init__(page_content, metadata) } class BaseRetriever { <<abstract>> +__init__(**kwargs) +invoke(query, **kwargs) List[Document] #_get_relevant_documents(query, **kwargs)* List[Document] } class VectorStoreRetriever { -VectorStore vectorstore -str search_type -dict search_kwargs +__init__(vectorstore, search_type, search_kwargs, **kwargs) +invoke(query, **kwargs) List[Document] -_get_relevant_documents(query, **kwargs) List[Document] } class ContextualCompressionRetriever { -BaseDocumentCompressor base_compressor -BaseRetriever base_retriever +__init__(base_compressor, base_retriever) +invoke(query, **kwargs) List[Document] -_get_relevant_documents(query, **kwargs) List[Document] } class Chroma { -Embedding embedding_function -Collection _collection +__init__(collection_name, embedding_function, persist_directory, collection_metadata) +add_texts(texts, metadatas, **kwargs) List[str] +similarity_search(query, k, **kwargs) List[Document] +as_retriever(search_type, search_kwargs, **kwargs) VectorStoreRetriever } class BaseDocumentCompressor { <<abstract>> +compress_documents(documents, query, callbacks)* List[Document] } class LLMChainExtractor { -SimpleChain llm_chain -Callable get_input +__init__(llm_chain, llm, prompt, get_input) +from_llm(cls, llm, prompt, get_input)$ LLMChainExtractor +compress_documents(documents, query, callbacks) List[Document] -_default_get_input(query, doc) dict -_create_chain(prompt_template, llm) SimpleChain } class ChatDeepSeek { -str model -str api_key -OpenAI client +__init__(model, **kwargs) +invoke(input, **kwargs) AIMessage } class HuggingFaceEmbeddings { -str model_name +__init__(model_name, **kwargs) +embed_query(text) List[float] +embed_documents(texts) List[List[float]] } class SimpleChain { -PromptTemplate prompt_template -ChatDeepSeek llm +__init__(prompt_template, llm) +invoke(input_dict, **kwargs) str } BaseRetriever <|-- VectorStoreRetriever : 继承实现 BaseRetriever <|-- ContextualCompressionRetriever : 继承实现 BaseDocumentCompressor <|-- LLMChainExtractor : 继承实现 ContextualCompressionRetriever "1" --> "1" BaseRetriever : 使用基础检索器 ContextualCompressionRetriever "1" --> "1" BaseDocumentCompressor : 使用文档压缩器 LLMChainExtractor "1" --> "1" SimpleChain : 使用LLM链 SimpleChain "1" --> "1" ChatDeepSeek : 使用LLM模型 Chroma "1" --> "1" HuggingFaceEmbeddings : 使用嵌入模型 Chroma "1" --> "1" VectorStoreRetriever : 创建检索器 ContextualCompressionRetriever "1" *-- "0..*" Document : 返回文档列表 note for BaseRetriever "抽象基类\n定义检索器标准接口\n使用模板方法模式" note for ContextualCompressionRetriever "上下文压缩检索器\n先检索后压缩\n提升结果相关性" note for LLMChainExtractor "LLM链提取器\n使用LLM提取文档中相关内容\n过滤不相关文档" note for SimpleChain "简单链对象\n封装Prompt模板和LLM\n实现链式调用"

55.5.3 时序图 #

sequenceDiagram participant Main as 55.LLMChainExtractor.py participant Embed as HuggingFaceEmbeddings participant Chroma as Chroma participant ChromaDB as ChromaDB(外部) participant VSR as VectorStoreRetriever participant CCR as ContextualCompressionRetriever participant LCE as LLMChainExtractor participant Chain as SimpleChain participant LLM as ChatDeepSeek participant Prompt as PromptTemplate Note over Main: 阶段1: 初始化嵌入模型和向量数据库 Main->>Embed: HuggingFaceEmbeddings(model_name, model_kwargs) activate Embed Embed-->>Main: embeddings实例 deactivate Embed Main->>Chroma: Chroma(persist_directory, embedding_function, collection_name, collection_metadata) activate Chroma Chroma->>ChromaDB: PersistentClient(path=persist_directory) ChromaDB-->>Chroma: client实例 Chroma->>ChromaDB: get_or_create_collection(name, metadata) ChromaDB-->>Chroma: collection实例 Chroma-->>Main: chroma_db实例 deactivate Chroma Note over Main: 阶段2: 检查并添加数据 Main->>Chroma: _collection.count() activate Chroma Chroma->>ChromaDB: count() ChromaDB-->>Chroma: count结果 Chroma-->>Main: count结果 deactivate Chroma alt 数据库为空 Main->>Chroma: add_texts(texts, metadatas) activate Chroma Chroma->>Embed: embed_documents(texts) activate Embed Embed-->>Chroma: embedding_values deactivate Embed Chroma->>ChromaDB: upsert(ids, documents, embeddings, metadatas) ChromaDB-->>Chroma: 添加成功 Chroma-->>Main: ids列表 deactivate Chroma end Note over Main: 阶段3: 创建LLM和压缩器 Main->>LLM: ChatDeepSeek(model="deepseek-chat") activate LLM LLM-->>Main: llm实例 deactivate LLM Main->>LCE: LLMChainExtractor.from_llm(llm=llm) activate LCE LCE->>Prompt: PromptTemplate.from_template(CHINESE_EXTRACT_PROMPT) activate Prompt Prompt-->>LCE: prompt_template实例 deactivate Prompt LCE->>LCE: _create_chain(prompt_template, llm) activate LCE LCE->>Chain: SimpleChain(prompt_template, llm) activate Chain Chain-->>LCE: llm_chain实例 deactivate Chain deactivate LCE LCE-->>Main: compressor实例 deactivate LCE Note over Main: 阶段4: 创建基础检索器和压缩检索器 Main->>Chroma: as_retriever(search_type="similarity", search_kwargs={"k": 20}) activate Chroma Chroma->>VSR: VectorStoreRetriever(vectorstore=self, search_type, search_kwargs) activate VSR VSR-->>Chroma: retriever实例 deactivate VSR Chroma-->>Main: base_retriever实例 deactivate Chroma Main->>CCR: ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever) activate CCR CCR->>CCR: super().__init__() CCR-->>Main: compression_retriever实例 deactivate CCR Note over Main: 阶段5: 执行检索和压缩 Main->>CCR: invoke(query="人工智能") activate CCR CCR->>CCR: _get_relevant_documents(query) activate CCR CCR->>VSR: base_retriever.invoke(query) activate VSR VSR->>VSR: _get_relevant_documents(query) activate VSR VSR->>Chroma: similarity_search(query, k=20) activate Chroma Chroma->>Embed: embed_query(query) activate Embed Embed-->>Chroma: query_embedding deactivate Embed Chroma->>ChromaDB: query(query_embeddings=[query_embedding], n_results=20) activate ChromaDB ChromaDB-->>Chroma: results (ids, documents, metadatas, distances) deactivate ChromaDB Chroma->>Chroma: _results_to_docs_and_scores(results) Chroma-->>VSR: docs (Document列表,最多20个) deactivate Chroma VSR-->>VSR: 返回docs deactivate VSR VSR-->>CCR: docs (Document列表) deactivate VSR CCR->>LCE: base_compressor.compress_documents(docs, query) activate LCE loop 遍历每个文档 LCE->>LCE: _default_get_input(query, doc) LCE->>LCE: {"question": query, "context": doc.page_content} LCE->>Chain: llm_chain.invoke(_input) activate Chain Chain->>Prompt: prompt_template.format(**input_dict) activate Prompt Prompt-->>Chain: formatted_prompt (格式化后的提示词) deactivate Prompt Chain->>LLM: llm.invoke(formatted_prompt) activate LLM LLM-->>Chain: response (AIMessage) deactivate LLM Chain->>Chain: 提取response.content Chain->>Chain: 检查是否为"NO_OUTPUT"或空 Chain-->>LCE: output (提取的相关内容或空字符串) deactivate Chain alt 有输出内容 LCE->>LCE: Document(page_content=output, metadata=doc.metadata) LCE->>LCE: compressed_docs.append(compressed_doc) else 无输出内容 LCE->>LCE: 跳过该文档(continue) end end LCE-->>CCR: compressed_docs (压缩后的Document列表) deactivate LCE CCR-->>CCR: 返回compressed_docs deactivate CCR CCR-->>Main: results (Document列表) deactivate CCR Note over Main: 阶段6: 输出结果 Main->>Main: 遍历results并打印

55.5.4 调用过程 #

阶段 1:初始化嵌入模型和向量数据库

embeddings = HuggingFaceEmbeddings(
    model_name=model_path,
    model_kwargs={"device": "cpu"}
)

chroma_db = Chroma(
    persist_directory="chroma_db",
    embedding_function=embeddings,
    collection_name="test",
    collection_metadata={"hnsw:space": "cosine"}
)
  • 创建 HuggingFaceEmbeddings 实例
  • 创建 Chroma 向量数据库实例,连接持久化存储

阶段 2:检查并添加数据

if not chroma_db._collection.count():
    chroma_db.add_texts(texts, metadatas)
  • 检查数据库是否为空
  • 如果为空,批量添加文本和元数据到数据库

阶段 3:创建 LLM 和压缩器

llm = ChatDeepSeek(model="deepseek-chat")
compressor = LLMChainExtractor.from_llm(llm=llm)

调用链:

  1. ChatDeepSeek.__init__() 创建 LLM 实例
  2. LLMChainExtractor.from_llm() 类方法
    • 调用 cls(llm=llm, prompt=None, get_input=None)
  3. LLMChainExtractor.__init__() 方法
    • 创建默认的 PromptTemplate(使用 CHINESE_EXTRACT_PROMPT)
    • 调用 _create_chain(prompt_template, llm) 创建 SimpleChain 对象
    • 设置 get_input 函数(默认使用 _default_get_input)

阶段 4:创建基础检索器和压缩检索器

base_retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 20})
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=base_retriever
)

调用链:

  1. Chroma.as_retriever() 创建 VectorStoreRetriever 实例
  2. ContextualCompressionRetriever.__init__() 方法
    • 调用 super().__init__() 初始化基类
    • 保存 base_compressor 和 base_retriever

阶段 5:执行检索和压缩

results = compression_retriever.invoke(query)

调用链:

  1. invoke() 方法(继承自 BaseRetriever)
    • 调用 _get_relevant_documents(query)
  2. _get_relevant_documents() 方法(核心检索逻辑)

    步骤 1:基础检索

    docs = self.base_retriever.invoke(query, **kwargs)
    • 调用 VectorStoreRetriever.invoke() 从向量数据库检索文档(返回 k=20 个文档)

    步骤 2:文档压缩

    compressed_docs = self.base_compressor.compress_documents(docs, query)

    内部处理流程(LLMChainExtractor.compress_documents()):

    for doc in documents:
        _input = self.get_input(query, doc)  # {"question": query, "context": doc.page_content}
        output = self.llm_chain.invoke(_input)  # 调用LLM提取相关内容
        if not output or len(output) == 0:
            continue  # 跳过不相关文档
        compressed_doc = Document(page_content=output, metadata=doc.metadata)
        compressed_docs.append(compressed_doc)

    LLM 链调用流程(SimpleChain.invoke()):

    formatted_prompt = self.prompt_template.format(**input_dict)
    response = self.llm.invoke(formatted_prompt)
    text = response.content.strip()
    if text == "NO_OUTPUT" or text == "":
        return ""  # 文档不相关
    return text  # 返回提取的相关内容

    步骤 3:返回压缩后的文档列表

    • 返回过滤和压缩后的 Document 列表

阶段 6:结果输出

print(f"查询:{query}")
print(f"共检索到 {len(results)} 个文档")
for i, doc in enumerate(results):
    print(f"检索结果{i}:{doc.page_content}")
    print(f"  元数据:{doc.metadata}\n")
  • 遍历 results(压缩后的 Document 列表)
  • 打印每个文档的 page_content 和 metadata

55.5.5 文档压缩流程 #

LLMChainExtractor 工作原理

  1. 输入:检索到的文档列表 + 查询字符串
  2. 处理:对每个文档
    • 构造提示词:将查询和文档内容组合成提示词
    • LLM 提取:调用 LLM 提取文档中与查询相关的部分
    • 过滤:如果 LLM 返回 "NO_OUTPUT" 或空字符串,则过滤该文档
  3. 输出:压缩后的文档列表(只包含相关内容)

默认提示词模板

CHINESE_EXTRACT_PROMPT = """给定以下问题和上下文,提取上下文中与回答问题相关的任何部分(保持原样)。如果上下文都不相关,返回 NO_OUTPUT。

记住,*不要*编辑提取的上下文部分。

> 问题:{question}
> 上下文:
>>>
{context}
>>>
提取的相关部分:"""

56.EmbeddingsFilter #

本节将介绍如何使用 EmbeddingsFilter 实现向量检索结果的相关性过滤与压缩。
EmbeddingsFilter 属于文档压缩器的一种,常用于“上下文压缩型检索器”场景 —— 即先使用向量/关键词等手段召回一定数量的候选文档,再利用 EmbeddingsFilter 对这些候选文档根据语义相关性进行二次筛选,提升下游LLM问答的准确性与精简性。

核心思路如下:

  • 先召回一批候选文档(如向量数据库筛选top50)。
  • 使用 EmbeddingFilter 计算每个文档与查询(query)的语义相似度。
  • 可以选择只保留前k个最相关文档,或者设置相似度阈值只保留相关性高于阈值的文档。
  • 支持自定义相似度函数,默认使用余弦相似度。

典型应用流程如下:

  1. 初始化嵌入模型和Chroma向量数据库;
  2. 构建 EmbeddingsFilter,传入 embeddings、top-k 或相关性阈值等参数;
  3. 可结合 ContextualCompressionRetriever,实现“检索 + 嵌入压缩”端到端问答。

具体实现与调优建议:

  • 若侧重精度,建议适当调小k或设置较高阈值,只让最相关文档进入大模型;
  • 若注重召回广度,可放宽阈值或加大k值。
  • 支持自定义 embedding 模型、相似度函数等,方便适配不同业务场景。

56.1. 56.EmbeddingsFilter.py #

56.EmbeddingsFilter.py

#from langchain_chroma import Chroma
#from langchain_huggingface import HuggingFaceEmbeddings
#from langchain_classic.retrievers.contextual_compression import ContextualCompressionRetriever
#from langchain_classic.retrievers.document_compressors.embeddings_filter import EmbeddingsFilter
#from langchain_deepseek import ChatDeepSeek

# 导入 Chroma 向量数据库类
from smartchain.vectorstores import Chroma
# 导入 HuggingFaceEmbeddings 嵌入模型类
from smartchain.embeddings import HuggingFaceEmbeddings
# 导入上下文压缩型检索器
from smartchain.retrievers import ContextualCompressionRetriever
# 导入 LLMChainExtractor 文档压缩器
from smartchain.document_compressors import EmbeddingsFilter
# 导入 ChatDeepSeek LLM 模型
from smartchain.chat_models import ChatDeepSeek

# 设置本地嵌入模型的路径
model_path = "C:/Users/Administrator/.cache/modelscope/hub/models/sentence-transformers/all-MiniLM-L6-v2"
# 初始化嵌入模型,指定模型路径与推理设备
embeddings = HuggingFaceEmbeddings(
    model_name=model_path,
    model_kwargs={"device": "cpu"}
)

# 初始化 Chroma 向量数据库实例,指定持久化目录、嵌入函数、集合名称和元数据
chroma_db = Chroma(
    persist_directory="chroma_db",
    embedding_function=embeddings,
    collection_name="test",
    collection_metadata={"hnsw:space": "cosine"}
)

# 检查数据库是否为空,若为空批量插入初始文本及其元数据
if not chroma_db._collection.count():
    # 定义插入的文本列表
    texts = [
        "人工智能(AI)是一种让机器模拟人类智能行为的技术。",
        "深度学习是人工智能的一个重要分支,通过多层神经网络学习数据。",
        "ChatGPT是OpenAI开发的强大自然语言模型。",
        "向量数据库可以高效地存储和检索文本的嵌入向量。",
        "机器人学结合了人工智能和机械工程,推动自动化发展。",
        "AI可以辅助医生进行医学影像分析。",
        "大模型在对话、问答、摘要等领域不断取得突破。",
        "知识库问答系统常用于企业信息检索场景。",
    ]
    # 对应的元数据列表
    metadatas = [
        {"topic": "ai"},
        {"topic": "ai"},
        {"topic": "nlp"},
        {"topic": "vector_db"},
        {"topic": "robotics"},
        {"topic": "healthcare"},
        {"topic": "llm"},
        {"topic": "retrieval"},
    ]
    # 向向量数据库批量插入文本和元数据
    chroma_db.add_texts(texts, metadatas)

# 创建 ChatDeepSeek LLM 对象,使用 deepseek-chat 模型
llm = ChatDeepSeek(model="deepseek-chat")

# 创建基础检索器,指定检索类型为 similarity 和返回数量 k=20
# 建议 k 取较大值,因为压缩器后续会过滤部分不相关的结果
base_retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 20})

compressor = EmbeddingsFilter(embeddings=embeddings,k=5)

# 创建上下文压缩检索器,结合基础检索器与文档压缩器
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=base_retriever
)

# 指定检索查询内容(支持中文或英文)
query = "人工智能"  # 可以更换为任意中英文问题
# 调用上下文压缩检索器进行检索
results = compression_retriever.invoke(query)
# 打印查询内容
print(f"查询:{query}")
# 打印检索到的文档数量
print(f"共检索到 {len(results)} 个文档")
# 遍历输出每条检索结果及其元数据
for i, doc in enumerate(results):
    print(f"检索结果{i}:{doc.page_content}")
    print(f"  元数据:{doc.metadata}\n")

56.2. document_compressors.py #

smartchain/document_compressors.py

# 从abc模块导入抽象基类支持
from abc import ABC, abstractmethod
+import numpy as np
# 定义一个默认的中文抽取Prompt模板
CHINESE_EXTRACT_PROMPT = """给定以下问题和上下文,提取上下文中与回答问题相关的任何部分(保持原样)。如果上下文都不相关,返回 NO_OUTPUT。

记住,*不要*编辑提取的上下文部分。

> 问题:{question}
> 上下文:
>>>
{context}
>>>
提取的相关部分:"""

# 定义文档压缩器的抽象基类
class BaseDocumentCompressor(ABC):
    # 设置文档压缩器的类说明文档
    """文档压缩器抽象基类
    用于对检索到的文档进行后处理,例如提取相关内容、过滤不相关文档等。
    """

    # 声明子类必须实现的抽象方法
    @abstractmethod
    def compress_documents(
        self,
        documents,
        query: str,
        callbacks=None,
    ):
        # 方法说明文档
        """压缩文档
          参数:
            documents: 检索到的文档列表
            query: 查询字符串
            callbacks: 可选的回调函数
          返回:
            Sequence: 压缩后的文档列表
        """
        # 抽象方法体不实现
        pass

# 定义LLM链文档提取器类,继承自基类
class LLMChainExtractor(BaseDocumentCompressor):
    # 设置LLM链提取器的类说明文档
    """LLM链提取器
    使用 LLM 链从文档中提取与查询相关的部分。
    """
    # 构造函数
    def __init__(
        self,
        llm_chain=None,
        llm=None,
        prompt=None,
        get_input=None,
    ):
        # 如果传入llm_chain则直接赋值
        if llm_chain is not None:
            self.llm_chain = llm_chain
        # 如果传入llm而未传llm_chain,则根据参数生成链
        elif llm is not None:
            # 延迟导入PromptTemplate工具
            from .prompts import PromptTemplate
            # 如果未自定义提示词模板,使用默认中文抽取Prompt
            if prompt is None:
                prompt_template = PromptTemplate.from_template(CHINESE_EXTRACT_PROMPT)
            else:
                prompt_template = prompt
            # 用私有方法构造链
            self.llm_chain = self._create_chain(prompt_template, llm)
        # 两者都没给时报错
        else:
            raise ValueError("必须提供 llm_chain 或 llm 参数")
        # 设置get_input函数(用于将query/doc打包成模型输入),可以自定义
        self.get_input = get_input or self._default_get_input

    # 默认输入转换函数,把query和doc打包成dict
    def _default_get_input(self, query: str, doc) -> dict:
        # 返回一个包含问题和文本内容的字典
        return {"question": query, "context": doc.page_content}

    # 私有方法,实现用prompt_template和llm创建简单链对象
    def _create_chain(self, prompt_template, llm):
        # 定义内部类SimpleChain
        class SimpleChain:
            # 构造函数,保存prompt模板和llm对象
            def __init__(self, prompt_template, llm):
                self.prompt_template = prompt_template
                self.llm = llm
            # 实现invoke方法,将输入dict格式化prompt并调用llm
            def invoke(self, input_dict, **kwargs):
                # 格式化prompt
                formatted_prompt = self.prompt_template.format(**input_dict)
                # 用llm推理,得到响应
                response = self.llm.invoke(formatted_prompt, **kwargs)
                # 如果响应有content属性,则用其内容
                if hasattr(response, 'content'):
                    text = response.content
                # 如果响应直接为字符串
                elif isinstance(response, str):
                    text = response
                # 否则强转为字符串
                else:
                    text = str(response)
                # 去掉文本首尾空格
                text = text.strip()
                # 若内容为NO_OUTPUT或空,则返回空字符串
                if text == "NO_OUTPUT" or text == "":
                    return ""
                # 返回抽取的文本
                return text
        # 返回SimpleChain对象实例
        return SimpleChain(prompt_template, llm)

    # 类方法,支持通过llm快速初始化LLMChainExtractor
    @classmethod
    def from_llm(
        cls,
        llm,
        prompt=None,
        get_input=None,
    ):
        # 返回初始化好的提取器对象
        return cls(
            llm=llm,
            prompt=prompt,
            get_input=get_input
        )

    # 实现文档压缩主接口
    def compress_documents(
        self,
        documents,
        query: str,
        callbacks=None,
    ):
        # 用于存储压缩结果的列表
        compressed_docs = []
        # 遍历所有输入文档
        for doc in documents:
            # 构造当前doc与query的输入结构
            _input = self.get_input(query, doc)
            # 用llm链抽取相关内容
            output = self.llm_chain.invoke(_input)
            # 如果没抽取到任何内容则跳过该文档
            if not output or len(output) == 0:
                continue
            # 动态导入Document类
            from .vectorstores import Document
            # 用抽取得到内容与原有元数据构造新文档
            compressed_doc = Document(
                page_content=output,
                metadata=doc.metadata if hasattr(doc, 'metadata') else {}
            )
            # 加入压缩结果列表
            compressed_docs.append(compressed_doc)
        # 返回压缩结果文档列表
        return compressed_docs


# 定义余弦相似度计算函数
+def cosine_similarity(query_embeddings, doc_embeddings):
    # 将查询嵌入转换为numpy数组
+   query_emb = np.array(query_embeddings)
    # 将文档嵌入转换为numpy数组
+   doc_embs = np.array(doc_embeddings)
    # 若查询为一维,转换为二维以便向量批量操作
+   if query_emb.ndim == 1:
+       query_emb = query_emb.reshape(1, -1)
    # 计算查询和所有文档嵌入的点积
+   dot_product = np.dot(query_emb, doc_embs.T)
    # 计算查询向量的范数
+   query_norm = np.linalg.norm(query_emb, axis=1, keepdims=True)
    # 计算所有文档嵌入的范数,并转置方便后续相乘
+   doc_norms = np.linalg.norm(doc_embs, axis=1, keepdims=True).T
    # 使用点积/范数乘积得到余弦相似度
+   similarity = dot_product / (query_norm * doc_norms)
    # 将因除零产生的nan、正负无穷转为0
+   similarity = np.nan_to_num(similarity, nan=0.0, posinf=0.0, neginf=0.0)
    # 返回第一行的相似度结果(只支持一个query)
+   return similarity[0]

# 定义嵌入过滤器类,继承自BaseDocumentCompressor
+class EmbeddingsFilter(BaseDocumentCompressor):
    # 类文档字符串,说明该类用途
+   """嵌入过滤器

+   使用嵌入模型计算查询与文档的相似度,根据相似度过滤文档。
+   """
    # 构造方法定义
+   def __init__(
+       self,
+       embeddings,
+       similarity_fn=None,
+       k = 20,
+       similarity_threshold = None,
+   ):
        # 参数校验,如k和阈值都没传则报错
+       if k is None and similarity_threshold is None:
+           raise ValueError("必须指定 k 或 similarity_threshold 之一")
        # 存储嵌入对象
+       self.embeddings = embeddings
        # 如果没传相似度函数,默认使用余弦相似度
+       self.similarity_fn = similarity_fn or cosine_similarity
        # 存储k值(返回多少最相似文档)
+       self.k = k
        # 存储相似度阈值
+       self.similarity_threshold = similarity_threshold

    # 文档压缩/过滤主入口
+   def compress_documents(
+       self,
+       documents,
+       query: str,
+       callbacks=None,
+   ):
        # 若输入文档为空,直接返回空列表
+       if not documents:
+           return []
        # 收集每个文档的文本内容
+       doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
        # 计算所有文档嵌入
+       doc_embeddings = self.embeddings.embed_documents(doc_texts)
        # 计算查询嵌入
+       query_embedding = self.embeddings.embed_query(query)
        # 计算查询和所有文档的相似度分数
+       similarities = self.similarity_fn([query_embedding], doc_embeddings)
        # 初始包含所有文档的索引
+       included_idxs = np.arange(len(documents))
        # 如果指定了k值,保留相似度最高的k篇文档
+       if self.k is not None:
+           sorted_indices = np.argsort(similarities)[::-1]
+           included_idxs = sorted_indices[:self.k]
        # 如果指定了相似度阈值,进一步只保留超过阈值的文档
+       if self.similarity_threshold is not None:
+           similar_enough = similarities[included_idxs] > self.similarity_threshold
+           included_idxs = included_idxs[similar_enough]
        # 返回最终保留的文档列表
+       return [documents[i] for i in included_idxs]

56.3. 类 #

56.3.1 类说明 #

类名 类型 作用 关键属性/方法
Document 数据类 存储文档内容和元数据 page_content: 文档文本内容
metadata: 元数据字典
HuggingFaceEmbeddings 嵌入模型类 使用 HuggingFace 模型生成文本嵌入向量 model_name: 模型名称
embed_query(): 嵌入单个查询文本
embed_documents(): 批量嵌入文档文本
Chroma 向量存储类 Chroma 向量数据库封装 embedding_function: 嵌入函数
_collection: Chroma 集合
add_texts(): 添加文本到数据库
similarity_search(): 相似度搜索
as_retriever(): 创建检索器
VectorStoreRetriever 检索器类 从向量存储中检索文档 vectorstore: 向量存储实例
search_type: 搜索类型
invoke(): 执行检索
BaseRetriever 抽象基类 定义检索器接口规范 invoke(): 对外调用接口
_get_relevant_documents(): 抽象方法(需子类实现)
ContextualCompressionRetriever 检索器类 包装基础检索器和文档压缩器 base_retriever: 基础检索器实例
base_compressor: 文档压缩器实例
invoke(): 执行检索(先检索后压缩)
_get_relevant_documents(): 核心检索逻辑
BaseDocumentCompressor 抽象基类 定义文档压缩器接口规范 compress_documents(): 抽象方法(需子类实现)
EmbeddingsFilter 文档压缩器类 使用嵌入相似度过滤文档 embeddings: 嵌入模型实例
similarity_fn: 相似度计算函数
k: 返回文档数量
similarity_threshold: 相似度阈值
compress_documents(): 压缩文档(基于相似度过滤)
cosine_similarity 函数 计算余弦相似度 计算查询向量与文档向量的余弦相似度

56.3.2 类图 #

classDiagram class Document { +str page_content +Dict[str, Any] metadata +__init__(page_content, metadata) } class BaseRetriever { <<abstract>> +__init__(**kwargs) +invoke(query, **kwargs) List[Document] #_get_relevant_documents(query, **kwargs)* List[Document] } class VectorStoreRetriever { -VectorStore vectorstore -str search_type -dict search_kwargs +__init__(vectorstore, search_type, search_kwargs, **kwargs) +invoke(query, **kwargs) List[Document] -_get_relevant_documents(query, **kwargs) List[Document] } class ContextualCompressionRetriever { -BaseDocumentCompressor base_compressor -BaseRetriever base_retriever +__init__(base_compressor, base_retriever) +invoke(query, **kwargs) List[Document] -_get_relevant_documents(query, **kwargs) List[Document] } class Chroma { -Embedding embedding_function -Collection _collection +__init__(collection_name, embedding_function, persist_directory, collection_metadata) +add_texts(texts, metadatas, **kwargs) List[str] +similarity_search(query, k, **kwargs) List[Document] +as_retriever(search_type, search_kwargs, **kwargs) VectorStoreRetriever } class BaseDocumentCompressor { <<abstract>> +compress_documents(documents, query, callbacks)* List[Document] } class EmbeddingsFilter { -Embedding embeddings -Callable similarity_fn -int k -float similarity_threshold +__init__(embeddings, similarity_fn, k, similarity_threshold) +compress_documents(documents, query, callbacks) List[Document] } class HuggingFaceEmbeddings { -str model_name +__init__(model_name, **kwargs) +embed_query(text) List[float] +embed_documents(texts) List[List[float]] } class cosine_similarity { <<function>> +cosine_similarity(query_embeddings, doc_embeddings) ndarray } BaseRetriever <|-- VectorStoreRetriever : 继承实现 BaseRetriever <|-- ContextualCompressionRetriever : 继承实现 BaseDocumentCompressor <|-- EmbeddingsFilter : 继承实现 ContextualCompressionRetriever "1" --> "1" BaseRetriever : 使用基础检索器 ContextualCompressionRetriever "1" --> "1" BaseDocumentCompressor : 使用文档压缩器 EmbeddingsFilter "1" --> "1" HuggingFaceEmbeddings : 使用嵌入模型 EmbeddingsFilter ..> cosine_similarity : 调用 Chroma "1" --> "1" HuggingFaceEmbeddings : 使用嵌入模型 Chroma "1" --> "1" VectorStoreRetriever : 创建检索器 ContextualCompressionRetriever "1" *-- "0..*" Document : 返回文档列表 note for BaseRetriever "抽象基类\n定义检索器标准接口\n使用模板方法模式" note for ContextualCompressionRetriever "上下文压缩检索器\n先检索后压缩\n提升结果相关性" note for EmbeddingsFilter "嵌入过滤器\n基于嵌入相似度过滤文档\n支持k值和阈值两种过滤方式" note for HuggingFaceEmbeddings "HuggingFace嵌入模型\n用于计算文本嵌入向量"

56.3.3 时序图 #

sequenceDiagram participant Main as 56.EmbeddingsFilter.py participant Embed as HuggingFaceEmbeddings participant Chroma as Chroma participant ChromaDB as ChromaDB(外部) participant VSR as VectorStoreRetriever participant CCR as ContextualCompressionRetriever participant EF as EmbeddingsFilter participant Cosine as cosine_similarity Note over Main: 阶段1: 初始化嵌入模型和向量数据库 Main->>Embed: HuggingFaceEmbeddings(model_name, model_kwargs) activate Embed Embed-->>Main: embeddings实例 deactivate Embed Main->>Chroma: Chroma(persist_directory, embedding_function, collection_name, collection_metadata) activate Chroma Chroma->>ChromaDB: PersistentClient(path=persist_directory) ChromaDB-->>Chroma: client实例 Chroma->>ChromaDB: get_or_create_collection(name, metadata) ChromaDB-->>Chroma: collection实例 Chroma-->>Main: chroma_db实例 deactivate Chroma Note over Main: 阶段2: 检查并添加数据 Main->>Chroma: _collection.count() activate Chroma Chroma->>ChromaDB: count() ChromaDB-->>Chroma: count结果 Chroma-->>Main: count结果 deactivate Chroma alt 数据库为空 Main->>Chroma: add_texts(texts, metadatas) activate Chroma Chroma->>Embed: embed_documents(texts) activate Embed Embed-->>Chroma: embedding_values deactivate Embed Chroma->>ChromaDB: upsert(ids, documents, embeddings, metadatas) ChromaDB-->>Chroma: 添加成功 Chroma-->>Main: ids列表 deactivate Chroma end Note over Main: 阶段3: 创建基础检索器和压缩器 Main->>Chroma: as_retriever(search_type="similarity", search_kwargs={"k": 20}) activate Chroma Chroma->>VSR: VectorStoreRetriever(vectorstore=self, search_type, search_kwargs) activate VSR VSR-->>Chroma: retriever实例 deactivate VSR Chroma-->>Main: base_retriever实例 deactivate Chroma Main->>EF: EmbeddingsFilter(embeddings=embeddings, k=5) activate EF EF->>EF: 验证k和similarity_threshold参数 EF-->>Main: compressor实例 deactivate EF Note over Main: 阶段4: 创建压缩检索器 Main->>CCR: ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever) activate CCR CCR->>CCR: super().__init__() CCR-->>Main: compression_retriever实例 deactivate CCR Note over Main: 阶段5: 执行检索和压缩 Main->>CCR: invoke(query="人工智能") activate CCR CCR->>VSR: base_retriever.invoke(query) activate VSR VSR->>Chroma: similarity_search(query, k=20) activate Chroma Chroma->>Embed: embed_query(query) activate Embed Embed-->>Chroma: query_embedding deactivate Embed Chroma->>ChromaDB: query(query_embeddings=[query_embedding], n_results=20) ChromaDB-->>Chroma: results (ids, documents, metadatas, distances) Chroma->>Chroma: _results_to_docs_and_scores(results) Chroma-->>VSR: docs (Document列表,最多20个) deactivate Chroma VSR-->>CCR: docs (Document列表) deactivate VSR CCR->>EF: base_compressor.compress_documents(docs, query) activate EF EF->>EF: 提取doc.page_content EF->>Embed: embed_documents(doc_texts) activate Embed Embed-->>EF: doc_embeddings (文档嵌入向量列表) deactivate Embed EF->>Embed: embed_query(query) activate Embed Embed-->>EF: query_embedding (查询嵌入向量) deactivate Embed EF->>Cosine: similarity_fn([query_embedding], doc_embeddings) activate Cosine Cosine->>Cosine: 计算点积和范数 Cosine->>Cosine: similarity = dot_product / (query_norm * doc_norms) Cosine-->>EF: similarities数组 deactivate Cosine EF->>EF: np.argsort(similarities)[::-1][:k] EF->>EF: 取前k=5个最相似的文档 alt 如果设置了similarity_threshold EF->>EF: 进一步过滤相似度>阈值的文档 end EF->>EF: [documents[i] for i in included_idxs] EF-->>CCR: compressed_docs (过滤后的Document列表) deactivate EF CCR-->>Main: results (Document列表) deactivate CCR Note over Main: 阶段6: 输出结果 Main->>Main: 遍历results并打印

56.3.4 调用过程 #

阶段 1:初始化嵌入模型和向量数据库

embeddings = HuggingFaceEmbeddings(
    model_name=model_path,
    model_kwargs={"device": "cpu"}
)

chroma_db = Chroma(
    persist_directory="chroma_db",
    embedding_function=embeddings,
    collection_name="test",
    collection_metadata={"hnsw:space": "cosine"}
)
  • 创建 HuggingFaceEmbeddings 实例
  • 创建 Chroma 向量数据库实例,连接持久化存储

阶段 2:检查并添加数据

if not chroma_db._collection.count():
    chroma_db.add_texts(texts, metadatas)
  • 检查数据库是否为空
  • 如果为空,批量添加文本和元数据到数据库

阶段 3:创建基础检索器和压缩器

base_retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 20})
compressor = EmbeddingsFilter(embeddings=embeddings, k=5)

调用链:

  1. Chroma.as_retriever() 创建 VectorStoreRetriever 实例
  2. EmbeddingsFilter.__init__() 方法
    • 验证参数:必须指定 k 或 similarity_threshold 之一
    • 保存 embeddings、similarity_fn(默认使用 cosine_similarity)、k、similarity_threshold

阶段 4:创建压缩检索器

compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=base_retriever
)
  • 创建 ContextualCompressionRetriever 实例
  • 保存 base_compressor 和 base_retriever

阶段 5:执行检索和压缩

results = compression_retriever.invoke(query)

调用链:

  1. invoke() 方法(继承自 BaseRetriever)
    • 调用 _get_relevant_documents(query)
  2. _get_relevant_documents() 方法(核心检索逻辑)

    步骤 1:基础检索

    docs = self.base_retriever.invoke(query, **kwargs)
    • 调用 VectorStoreRetriever.invoke() 从向量数据库检索文档(返回 k=20 个文档)

    步骤 2:文档压缩

    compressed_docs = self.base_compressor.compress_documents(docs, query)

    内部处理流程(EmbeddingsFilter.compress_documents()):

    步骤 2.1:计算嵌入向量

    doc_texts = [doc.page_content for doc in documents]
    doc_embeddings = self.embeddings.embed_documents(doc_texts)  # 批量计算文档嵌入
    query_embedding = self.embeddings.embed_query(query)  # 计算查询嵌入

    步骤 2.2:计算相似度

    similarities = self.similarity_fn([query_embedding], doc_embeddings)

    内部调用 cosine_similarity() 函数:

    dot_product = np.dot(query_emb, doc_embs.T)
    query_norm = np.linalg.norm(query_emb, axis=1, keepdims=True)
    doc_norms = np.linalg.norm(doc_embs, axis=1, keepdims=True).T
    similarity = dot_product / (query_norm * doc_norms)

    步骤 2.3:过滤文档

    included_idxs = np.arange(len(documents))
    if self.k is not None:
        sorted_indices = np.argsort(similarities)[::-1]  # 降序排序
        included_idxs = sorted_indices[:self.k]  # 取前k个
    if self.similarity_threshold is not None:
        similar_enough = similarities[included_idxs] > self.similarity_threshold
        included_idxs = included_idxs[similar_enough]  # 进一步过滤

    步骤 2.4:返回过滤后的文档

    return [documents[i] for i in included_idxs]

    步骤 3:返回压缩后的文档列表

    • 返回过滤后的 Document 列表(最多 k=5 个)

阶段 6:结果输出

print(f"查询:{query}")
print(f"共检索到 {len(results)} 个文档")
for i, doc in enumerate(results):
    print(f"检索结果{i}:{doc.page_content}")
    print(f"  元数据:{doc.metadata}\n")
  • 遍历 results(过滤后的 Document 列表)
  • 打印每个文档的 page_content 和 metadata

56.3.5 EmbeddingsFilter 工作原理 #

过滤机制

EmbeddingsFilter 使用嵌入相似度对文档进行二次过滤:

  1. 计算相似度:对检索到的文档重新计算与查询的嵌入相似度
  2. 排序和筛选:
    • 如果指定了 k:保留相似度最高的 k 个文档
    • 如果指定了 similarity_threshold:只保留相似度超过阈值的文档
    • 两者可以同时使用:先按 k 筛选,再按阈值过滤

为什么需要二次过滤?

  1. 向量数据库的相似度计算可能不够精确
  2. 使用不同的嵌入模型可以更准确地评估相关性
  3. 可以设置更严格的相似度阈值,提高结果质量

余弦相似度公式

对于查询向量 $\vec{q}$ 和文档向量 $\vec{d}$:

$$\cos(\theta) = \frac{\vec{q} \cdot \vec{d}}{|\vec{q}| \times |\vec{d}|} = \frac{\sum_{i=1}^{n} q_i \times d_i}{\sqrt{\sum_{i=1}^{n} q_i^2} \times \sqrt{\sum_{i=1}^{n} d_i^2}}$$

相似度范围:$[-1, 1]$,值越大越相似。

56.3.6 EmbeddingsFilter vs LLMChainExtractor #

特性 EmbeddingsFilter LLMChainExtractor
过滤方式 基于嵌入相似度 基于 LLM 语义理解
计算成本 较低(向量计算) 较高(LLM 推理)
速度 快 较慢
准确性 中等(依赖嵌入质量) 高(LLM 理解能力强)
文档修改 不修改文档内容 提取相关内容,可能修改内容
适用场景 快速过滤、大批量文档 需要精确提取、小批量文档

57.CrossEncoderReranker #

  • CrossEncoder

CrossEncoderReranker(交叉编码器重排器)是一种基于大模型或深度学习模型的文档相关性重排序方法。它通常在初步召回了一批候选文档后,对文档与查询的匹配度做更精细的打分和排序,从而过滤掉不相关的内容,仅保留与查询最相关的若干条信息。

典型流程如下:

  1. 基础检索器(如向量数据库)召回一批与查询(query)内容相似的文档,通常数量较多(例如 k=20)。
  2. CrossEncoderReranker 将候选文档与查询两两拼接,输入到交叉编码器模型(如 BAAI/bge-reranker-base)中,获得每个文档与查询的相关性分数。
  3. 按分数降序排列,仅返回 top_n 个最相关的文档。

这种双阶段(先召回再重排)的方法,兼顾了“检索速度”与“结果相关性”。其中,基础召回用较快的向量检索,重排用较精细的神经网络模型。该方法常见于智能问答、知识检索、RAG等自然语言应用。

代码结构要点:

  • CrossEncoderReranker 通常作为 ContextualCompressionRetriever 的 base_compressor 参数,负责文档的重排与压缩。
  • 需要配合如 HuggingFaceCrossEncoder 这样的模型,使用如 BGE-Reranker 等开箱即用的重排权重文件。
  • 可以设置 top_n 控制最终输出的文档数,以防止无关内容干扰下游任务。

使用场景举例:

  • 你检索了20条可能相关的知识库文档,但只希望下游看到与提问“最紧密相关”的3条,此时就应在检索管道中加入CrossEncoderReranker。
  • 特别适用于多轮问答、长文档检索等上下文相关性要求高的场合。

这样可以大幅提升检索相关性的精度和用户体验。

57.1. 57.CrossEncoderReranker.py #

57.CrossEncoderReranker.py

#from langchain_chroma import Chroma
#from langchain_huggingface import HuggingFaceEmbeddings
#from langchain_classic.retrievers.contextual_compression import ContextualCompressionRetriever
#from langchain_classic.retrievers.document_compressors import CrossEncoderReranker
#from langchain_community.cross_encoders import HuggingFaceCrossEncoder
#from langchain_deepseek import ChatDeepSeek

# 导入 Chroma 向量数据库类
from smartchain.vectorstores import Chroma
# 导入 HuggingFaceEmbeddings 嵌入模型类
from smartchain.embeddings import HuggingFaceEmbeddings
# 导入上下文压缩型检索器
from smartchain.retrievers import ContextualCompressionRetriever
# 导入 CrossEncoderReranker 文档压缩器
from smartchain.document_compressors import CrossEncoderReranker
# 导入 HuggingFaceCrossEncoder 交叉编码器
from smartchain.cross_encoders import HuggingFaceCrossEncoder
# 导入 ChatDeepSeek LLM 模型
from smartchain.chat_models import ChatDeepSeek

# 设置本地嵌入模型的路径
model_path = "C:/Users/Administrator/.cache/modelscope/hub/models/sentence-transformers/all-MiniLM-L6-v2"
# 初始化嵌入模型,指定模型路径与推理设备
embeddings = HuggingFaceEmbeddings(
    model_name=model_path,
    model_kwargs={"device": "cpu"}
)

# 初始化 Chroma 向量数据库实例,指定持久化目录、嵌入函数、集合名称和元数据
chroma_db = Chroma(
    persist_directory="chroma_db",
    embedding_function=embeddings,
    collection_name="test",
    collection_metadata={"hnsw:space": "cosine"}
)

# 检查数据库是否为空,若为空批量插入初始文本及其元数据
if not chroma_db._collection.count():
    # 定义插入的文本列表
    texts = [
        "人工智能(AI)是一种让机器模拟人类智能行为的技术。",
        "深度学习是人工智能的一个重要分支,通过多层神经网络学习数据。",
        "ChatGPT是OpenAI开发的强大自然语言模型。",
        "向量数据库可以高效地存储和检索文本的嵌入向量。",
        "机器人学结合了人工智能和机械工程,推动自动化发展。",
        "AI可以辅助医生进行医学影像分析。",
        "大模型在对话、问答、摘要等领域不断取得突破。",
        "知识库问答系统常用于企业信息检索场景。",
    ]
    # 对应的元数据列表
    metadatas = [
        {"topic": "ai"},
        {"topic": "ai"},
        {"topic": "nlp"},
        {"topic": "vector_db"},
        {"topic": "robotics"},
        {"topic": "healthcare"},
        {"topic": "llm"},
        {"topic": "retrieval"},
    ]
    # 向向量数据库批量插入文本和元数据
    chroma_db.add_texts(texts, metadatas)

# 创建 ChatDeepSeek LLM 对象,使用 deepseek-chat 模型
llm = ChatDeepSeek(model="deepseek-chat")

# 创建基础检索器,指定检索类型为 similarity 和返回数量 k=20
# 建议 k 取较大值,因为压缩器后续会过滤部分不相关的结果
base_retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 20})

# 指定 CrossEncoder 重排模型路径
reranker_model_path = "C:/Users/Administrator/.cache/modelscope/hub/models/BAAI/bge-reranker-base"
# 加载 HuggingFaceCrossEncoder 进行交叉编码
cross_encoder = HuggingFaceCrossEncoder(model_name=reranker_model_path)
# 用 CrossEncoderReranker 封装,并指定重排后返回 top 3
compressor = CrossEncoderReranker(model=cross_encoder, top_n=3)

# 创建上下文压缩检索器,结合基础检索器与文档压缩器
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=base_retriever
)

# 指定检索查询内容(支持中文或英文)
query = "人工智能"  # 可以更换为任意中英文问题
# 调用上下文压缩检索器进行检索
results = compression_retriever.invoke(query)
# 打印查询内容
print(f"查询:{query}")
# 打印检索到的文档数量
print(f"共检索到 {len(results)} 个文档")
# 遍历输出每条检索结果及其元数据
for i, doc in enumerate(results):
    print(f"检索结果{i}:{doc.page_content}")
    print(f"  元数据:{doc.metadata}\n")

57.2. cross_encoders.py #

smartchain/cross_encoders.py

# 导入 sentence_transformers 包中的 CrossEncoder 类
from sentence_transformers import CrossEncoder

# 定义一个轻量包装的 HuggingFaceCrossEncoder 类
class HuggingFaceCrossEncoder:
    """轻量封装 sentence_transformers.CrossEncoder"""

    # 初始化方法,接收模型名称、设备和其它模型参数
    def __init__(self, model_name: str, device: str | None = None, **model_kwargs):
        # 加载指定的 CrossEncoder 模型,如果未指定设备则默认用 "cpu"
        self.model = CrossEncoder(model_name_or_path=model_name, device=device or "cpu", **model_kwargs)

    # 定义预测方法,输入为句对组成的列表
    def predict(self, pairs):
        # 调用 CrossEncoder 的 predict 方法,返回分数
        return self.model.predict(pairs)

57.3. init.py #

smartchain/document_compressors/init.py

from .base import BaseDocumentCompressor
from .llm_chain import LLMChainExtractor
from .embeddings import EmbeddingsFilter
from .cross_encoder import CrossEncoderReranker

__all__ = [
    "BaseDocumentCompressor",
    "LLMChainExtractor",
    "EmbeddingsFilter",
    "CrossEncoderReranker",
]

57.4. base.py #

smartchain/document_compressors/base.py

# 从abc模块导入抽象基类支持
from abc import ABC, abstractmethod

# 定义文档压缩器的抽象基类
class BaseDocumentCompressor(ABC):
    # 设置文档压缩器的类说明文档
    """文档压缩器抽象基类
    用于对检索到的文档进行后处理,例如提取相关内容、过滤不相关文档等。
    """

    # 声明子类必须实现的抽象方法
    @abstractmethod
    def compress_documents(
        self,
        documents,
        query: str,
        callbacks=None,
    ):
        # 方法说明文档
        """压缩文档
          参数:
            documents: 检索到的文档列表
            query: 查询字符串
            callbacks: 可选的回调函数
          返回:
            Sequence: 压缩后的文档列表
        """
        # 抽象方法体不实现
        pass

57.5. cross_encoder.py #

smartchain/document_compressors/cross_encoder.py

from ..cross_encoders import HuggingFaceCrossEncoder
from .base import BaseDocumentCompressor

# 定义 CrossEncoderReranker 类,用于通过交叉编码器对检索结果打分并重排序
class CrossEncoderReranker(BaseDocumentCompressor):
    """用 cross-encoder 打分并按分数截断"""

    # 初始化方法,接收封装过的 cross-encoder 模型对象和 top_n 截断值
    def __init__(self, model: HuggingFaceCrossEncoder, top_n: int = 3):
        # 保存 cross-encoder 模型实例
        self.model = model
        # 保存需要返回的文档数量 top_n
        self.top_n = top_n

    # 定义 compress_documents 方法,对输入的文档进行重排序和截断
    def compress_documents(self, documents, query, callbacks=None):
        # 如果传入文档列表为空,直接返回空列表
        if not documents:
            return []
        # 以 (query, 文档内容) 形式为每个文档构造输入对
        pairs = [(query, doc.page_content) for doc in documents]
        # 用 cross-encoder 对所有句对进行评分
        scores = self.model.predict(pairs)
        # 将文档及其分数打包为元组,并按分数从高到低排序
        ranked = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
        # 仅返回 top_n 个分数最高的文档
        return [doc for doc, _ in ranked[: self.top_n]]

57.6. embeddings.py #

smartchain/document_compressors/embeddings.py

import numpy as np
from .base import BaseDocumentCompressor
# 定义余弦相似度计算函数
def cosine_similarity(query_embeddings, doc_embeddings):
    # 将查询嵌入转换为numpy数组
    query_emb = np.array(query_embeddings)
    # 将文档嵌入转换为numpy数组
    doc_embs = np.array(doc_embeddings)
    # 若查询为一维,转换为二维以便向量批量操作
    if query_emb.ndim == 1:
        query_emb = query_emb.reshape(1, -1)
    # 计算查询和所有文档嵌入的点积
    dot_product = np.dot(query_emb, doc_embs.T)
    # 计算查询向量的范数
    query_norm = np.linalg.norm(query_emb, axis=1, keepdims=True)
    # 计算所有文档嵌入的范数,并转置方便后续相乘
    doc_norms = np.linalg.norm(doc_embs, axis=1, keepdims=True).T
    # 使用点积/范数乘积得到余弦相似度
    similarity = dot_product / (query_norm * doc_norms)
    # 将因除零产生的nan、正负无穷转为0
    similarity = np.nan_to_num(similarity, nan=0.0, posinf=0.0, neginf=0.0)
    # 返回第一行的相似度结果(只支持一个query)
    return similarity[0]

# 定义嵌入过滤器类,继承自BaseDocumentCompressor
class EmbeddingsFilter(BaseDocumentCompressor):
    # 类文档字符串,说明该类用途
    """嵌入过滤器

    使用嵌入模型计算查询与文档的相似度,根据相似度过滤文档。
    """
    # 构造方法定义
    def __init__(
        self,
        embeddings,
        similarity_fn=None,
        k = 20,
        similarity_threshold = None,
    ):
        # 参数校验,如k和阈值都没传则报错
        if k is None and similarity_threshold is None:
            raise ValueError("必须指定 k 或 similarity_threshold 之一")
        # 存储嵌入对象
        self.embeddings = embeddings
        # 如果没传相似度函数,默认使用余弦相似度
        self.similarity_fn = similarity_fn or cosine_similarity
        # 存储k值(返回多少最相似文档)
        self.k = k
        # 存储相似度阈值
        self.similarity_threshold = similarity_threshold

    # 文档压缩/过滤主入口
    def compress_documents(
        self,
        documents,
        query: str,
        callbacks=None,
    ):
        # 若输入文档为空,直接返回空列表
        if not documents:
            return []
        # 收集每个文档的文本内容
        doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
        # 计算所有文档嵌入
        doc_embeddings = self.embeddings.embed_documents(doc_texts)
        # 计算查询嵌入
        query_embedding = self.embeddings.embed_query(query)
        # 计算查询和所有文档的相似度分数
        similarities = self.similarity_fn([query_embedding], doc_embeddings)
        # 初始包含所有文档的索引
        included_idxs = np.arange(len(documents))
        # 如果指定了k值,保留相似度最高的k篇文档
        if self.k is not None:
            sorted_indices = np.argsort(similarities)[::-1]
            included_idxs = sorted_indices[:self.k]
        # 如果指定了相似度阈值,进一步只保留超过阈值的文档
        if self.similarity_threshold is not None:
            similar_enough = similarities[included_idxs] > self.similarity_threshold
            included_idxs = included_idxs[similar_enough]
        # 返回最终保留的文档列表
        return [documents[i] for i in included_idxs]

57.7. llm_chain.py #

smartchain/document_compressors/llm_chain.py

from smartchain.document_compressors.base import BaseDocumentCompressor
from ..prompts import PromptTemplate
from ..documents import Document
# 定义一个默认的中文抽取Prompt模板
CHINESE_EXTRACT_PROMPT = """给定以下问题和上下文,提取上下文中与回答问题相关的任何部分(保持原样)。如果上下文都不相关,返回 NO_OUTPUT。

记住,*不要*编辑提取的上下文部分。

> 问题:{question}
> 上下文:
>>>
{context}
>>>
提取的相关部分:"""
# 定义LLM链文档提取器类,继承自基类
class LLMChainExtractor(BaseDocumentCompressor):
    # 设置LLM链提取器的类说明文档
    """LLM链提取器
    使用 LLM 链从文档中提取与查询相关的部分。
    """
    # 构造函数
    def __init__(
        self,
        llm_chain=None,
        llm=None,
        prompt=None,
        get_input=None,
    ):
        # 如果传入llm_chain则直接赋值
        if llm_chain is not None:
            self.llm_chain = llm_chain
        # 如果传入llm而未传llm_chain,则根据参数生成链
        elif llm is not None:

            # 如果未自定义提示词模板,使用默认中文抽取Prompt
            if prompt is None:
                prompt_template = PromptTemplate.from_template(CHINESE_EXTRACT_PROMPT)
            else:
                prompt_template = prompt
            # 用私有方法构造链
            self.llm_chain = self._create_chain(prompt_template, llm)
        # 两者都没给时报错
        else:
            raise ValueError("必须提供 llm_chain 或 llm 参数")
        # 设置get_input函数(用于将query/doc打包成模型输入),可以自定义
        self.get_input = get_input or self._default_get_input

    # 默认输入转换函数,把query和doc打包成dict
    def _default_get_input(self, query: str, doc) -> dict:
        # 返回一个包含问题和文本内容的字典
        return {"question": query, "context": doc.page_content}

    # 私有方法,实现用prompt_template和llm创建简单链对象
    def _create_chain(self, prompt_template, llm):
        # 定义内部类SimpleChain
        class SimpleChain:
            # 构造函数,保存prompt模板和llm对象
            def __init__(self, prompt_template, llm):
                self.prompt_template = prompt_template
                self.llm = llm
            # 实现invoke方法,将输入dict格式化prompt并调用llm
            def invoke(self, input_dict, **kwargs):
                # 格式化prompt
                formatted_prompt = self.prompt_template.format(**input_dict)
                # 用llm推理,得到响应
                response = self.llm.invoke(formatted_prompt, **kwargs)
                # 如果响应有content属性,则用其内容
                if hasattr(response, 'content'):
                    text = response.content
                # 如果响应直接为字符串
                elif isinstance(response, str):
                    text = response
                # 否则强转为字符串
                else:
                    text = str(response)
                # 去掉文本首尾空格
                text = text.strip()
                # 若内容为NO_OUTPUT或空,则返回空字符串
                if text == "NO_OUTPUT" or text == "":
                    return ""
                # 返回抽取的文本
                return text
        # 返回SimpleChain对象实例
        return SimpleChain(prompt_template, llm)

    # 类方法,支持通过llm快速初始化LLMChainExtractor
    @classmethod
    def from_llm(
        cls,
        llm,
        prompt=None,
        get_input=None,
    ):
        # 返回初始化好的提取器对象
        return cls(
            llm=llm,
            prompt=prompt,
            get_input=get_input
        )

    # 实现文档压缩主接口
    def compress_documents(
        self,
        documents,
        query: str,
        callbacks=None,
    ):
        # 用于存储压缩结果的列表
        compressed_docs = []
        # 遍历所有输入文档
        for doc in documents:
            # 构造当前doc与query的输入结构
            _input = self.get_input(query, doc)
            # 用llm链抽取相关内容
            output = self.llm_chain.invoke(_input)
            # 如果没抽取到任何内容则跳过该文档
            if not output or len(output) == 0:
                continue
            # 用抽取得到内容与原有元数据构造新文档
            compressed_doc = Document(
                page_content=output,
                metadata=doc.metadata if hasattr(doc, 'metadata') else {}
            )
            # 加入压缩结果列表
            compressed_docs.append(compressed_doc)
        # 返回压缩结果文档列表
        return compressed_docs

57.8. 类 #

57.8.1 类说明 #

类名 类型 作用 关键属性/方法
Document 数据类 存储文档内容和元数据 page_content: 文本内容
metadata: 元数据字典
HuggingFaceEmbeddings 嵌入模型类 使用 HuggingFace 模型生成文本嵌入向量 model_name: 模型名称
embed_query(): 嵌入查询文本
embed_documents(): 批量嵌入文本
Chroma 向量存储类 Chroma 向量数据库封装 embedding_function: 嵌入函数
_collection: Chroma 集合
add_texts(): 添加文本
similarity_search(): 相似度搜索
as_retriever(): 创建检索器
VectorStoreRetriever 检索器类 从向量存储中检索文档 vectorstore: 向量存储实例
search_type: 搜索类型
invoke(): 执行检索
BaseRetriever 抽象基类 定义检索器接口规范 invoke(): 对外调用接口
_get_relevant_documents(): 抽象方法
ContextualCompressionRetriever 检索器类 包装基础检索器和文档压缩器 base_retriever: 基础检索器
base_compressor: 文档压缩器
invoke(): 执行检索
_get_relevant_documents(): 检索逻辑
BaseDocumentCompressor 抽象基类 定义文档压缩器接口规范 compress_documents(): 抽象方法
CrossEncoderReranker 文档压缩器类 使用交叉编码器对文档重排序 model: 交叉编码器模型
top_n: 返回文档数量
compress_documents(): 重排序
HuggingFaceCrossEncoder 交叉编码器类 封装 sentence-transformers 的 CrossEncoder model: CrossEncoder 实例
predict(): 相关性分数预测

57.8.2 类图 #

classDiagram class Document { +str page_content +Dict[str, Any] metadata +__init__(page_content, metadata) } class BaseRetriever { <<abstract>> +__init__(**kwargs) +invoke(query, **kwargs) List[Document] #_get_relevant_documents(query, **kwargs)* List[Document] } class VectorStoreRetriever { -VectorStore vectorstore -str search_type -dict search_kwargs +__init__(vectorstore, search_type, search_kwargs, **kwargs) +invoke(query, **kwargs) List[Document] -_get_relevant_documents(query, **kwargs) List[Document] } class ContextualCompressionRetriever { -BaseDocumentCompressor base_compressor -BaseRetriever base_retriever +__init__(base_compressor, base_retriever) +invoke(query, **kwargs) List[Document] -_get_relevant_documents(query, **kwargs) List[Document] } class Chroma { -Embedding embedding_function -Collection _collection +__init__(collection_name, embedding_function, persist_directory, collection_metadata) +add_texts(texts, metadatas, **kwargs) List[str] +similarity_search(query, k, **kwargs) List[Document] +as_retriever(search_type, search_kwargs, **kwargs) VectorStoreRetriever } class BaseDocumentCompressor { <<abstract>> +compress_documents(documents, query, callbacks)* List[Document] } class CrossEncoderReranker { -HuggingFaceCrossEncoder model -int top_n +__init__(model, top_n) +compress_documents(documents, query, callbacks) List[Document] } class HuggingFaceCrossEncoder { -CrossEncoder model +__init__(model_name, device, **model_kwargs) +predict(pairs) List[float] } class HuggingFaceEmbeddings { -str model_name +__init__(model_name, **kwargs) +embed_query(text) List[float] +embed_documents(texts) List[List[float]] } BaseRetriever <|-- VectorStoreRetriever : 继承实现 BaseRetriever <|-- ContextualCompressionRetriever : 继承实现 BaseDocumentCompressor <|-- CrossEncoderReranker : 继承实现 ContextualCompressionRetriever "1" --> "1" BaseRetriever : 使用基础检索器 ContextualCompressionRetriever "1" --> "1" BaseDocumentCompressor : 使用文档压缩器 CrossEncoderReranker "1" --> "1" HuggingFaceCrossEncoder : 使用交叉编码器 Chroma "1" --> "1" HuggingFaceEmbeddings : 使用嵌入模型 Chroma "1" --> "1" VectorStoreRetriever : 创建检索器 ContextualCompressionRetriever "1" *-- "0..*" Document : 返回文档列表 note for BaseRetriever "抽象基类\n定义检索器标准接口\n使用模板方法模式" note for ContextualCompressionRetriever "上下文压缩检索器\n先检索后压缩\n提升结果相关性" note for CrossEncoderReranker "交叉编码器重排器\n使用CrossEncoder对文档重排序\n返回top_n个最相关的文档" note for HuggingFaceCrossEncoder "交叉编码器模型\n对查询-文档对进行精确评分\n比双编码器更准确"

57.8.3 时序图 #

sequenceDiagram participant Main as 57.CrossEncoderReranker.py participant Embed as HuggingFaceEmbeddings participant Chroma as Chroma participant ChromaDB as ChromaDB(外部) participant VSR as VectorStoreRetriever participant CCR as ContextualCompressionRetriever participant CER as CrossEncoderReranker participant HCE as HuggingFaceCrossEncoder Note over Main: 阶段1: 初始化嵌入模型和向量数据库 Main->>Embed: HuggingFaceEmbeddings(model_name, model_kwargs) activate Embed Embed-->>Main: embeddings实例 deactivate Embed Main->>Chroma: Chroma(persist_directory, embedding_function, collection_name, collection_metadata) activate Chroma Chroma->>ChromaDB: PersistentClient(path=persist_directory) ChromaDB-->>Chroma: client实例 Chroma->>ChromaDB: get_or_create_collection(name, metadata) ChromaDB-->>Chroma: collection实例 Chroma-->>Main: chroma_db实例 deactivate Chroma Note over Main: 阶段2: 检查并添加数据 Main->>Chroma: _collection.count() activate Chroma Chroma->>ChromaDB: count() ChromaDB-->>Chroma: count结果 Chroma-->>Main: count结果 deactivate Chroma alt 数据库为空 Main->>Chroma: add_texts(texts, metadatas) activate Chroma Chroma->>Embed: embed_documents(texts) activate Embed Embed-->>Chroma: embedding_values deactivate Embed Chroma->>ChromaDB: upsert(ids, documents, embeddings, metadatas) ChromaDB-->>Chroma: 添加成功 Chroma-->>Main: ids列表 deactivate Chroma end Note over Main: 阶段3: 创建基础检索器和交叉编码器 Main->>Chroma: as_retriever(search_type="similarity", search_kwargs={"k": 20}) activate Chroma Chroma->>VSR: VectorStoreRetriever(vectorstore=self, search_type, search_kwargs) activate VSR VSR-->>Chroma: retriever实例 deactivate VSR Chroma-->>Main: base_retriever实例 deactivate Chroma Main->>HCE: HuggingFaceCrossEncoder(model_name=reranker_model_path) activate HCE HCE->>HCE: CrossEncoder(model_name_or_path=model_name, device="cpu") HCE-->>Main: cross_encoder实例 deactivate HCE Main->>CER: CrossEncoderReranker(model=cross_encoder, top_n=3) activate CER CER-->>Main: compressor实例 deactivate CER Note over Main: 阶段4: 创建压缩检索器 Main->>CCR: ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever) activate CCR CCR->>CCR: super().__init__() CCR-->>Main: compression_retriever实例 deactivate CCR Note over Main: 阶段5: 执行检索和重排序 Main->>CCR: invoke(query="人工智能") activate CCR CCR->>VSR: base_retriever.invoke(query) activate VSR VSR->>Chroma: similarity_search(query, k=20) activate Chroma Chroma->>Embed: embed_query(query) activate Embed Embed-->>Chroma: query_embedding deactivate Embed Chroma->>ChromaDB: query(query_embeddings=[query_embedding], n_results=20) ChromaDB-->>Chroma: results (ids, documents, metadatas, distances) Chroma->>Chroma: _results_to_docs_and_scores(results) Chroma-->>VSR: docs (Document列表,最多20个) deactivate Chroma VSR-->>CCR: docs (Document列表) deactivate VSR CCR->>CER: base_compressor.compress_documents(docs, query) activate CER CER->>CER: 检查documents是否为空 CER->>CER: pairs = [(query, doc.page_content) for doc in documents] CER->>HCE: model.predict(pairs) activate HCE loop 对每个(query, doc)对 HCE->>HCE: CrossEncoder.predict([(query, doc)]) HCE->>HCE: 计算相关性分数 end HCE-->>CER: scores (分数数组) deactivate HCE CER->>CER: ranked = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True) CER->>CER: 取前top_n=3个文档 CER->>CER: [doc for doc, _ in ranked[:top_n]] CER-->>CCR: compressed_docs (重排序后的Document列表,最多3个) deactivate CER CCR-->>Main: results (Document列表) deactivate CCR Note over Main: 阶段6: 输出结果 Main->>Main: 遍历results并打印

57.8.4 调用过程 #

阶段 1:初始化嵌入模型和向量数据库

embeddings = HuggingFaceEmbeddings(
    model_name=model_path,
    model_kwargs={"device": "cpu"}
)

chroma_db = Chroma(
    persist_directory="chroma_db",
    embedding_function=embeddings,
    collection_name="test",
    collection_metadata={"hnsw:space": "cosine"}
)
  • 创建 HuggingFaceEmbeddings 实例
  • 创建 Chroma 向量数据库实例,连接持久化存储

阶段 2:检查并添加数据

if not chroma_db._collection.count():
    chroma_db.add_texts(texts, metadatas)
  • 检查数据库是否为空
  • 如果为空,批量添加文本和元数据到数据库

阶段 3:创建基础检索器和交叉编码器

base_retriever = chroma_db.as_retriever(search_type="similarity", search_kwargs={"k": 20})
cross_encoder = HuggingFaceCrossEncoder(model_name=reranker_model_path)
compressor = CrossEncoderReranker(model=cross_encoder, top_n=3)

调用链:

  1. Chroma.as_retriever() 创建 VectorStoreRetriever 实例
  2. HuggingFaceCrossEncoder.__init__() 方法
    • 加载 sentence_transformers.CrossEncoder 模型
    • 保存模型实例
  3. CrossEncoderReranker.__init__() 方法
    • 保存 model(交叉编码器实例)和 top_n(返回文档数量,默认 3)

阶段 4:创建压缩检索器

compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=base_retriever
)
  • 创建 ContextualCompressionRetriever 实例
  • 保存 base_compressor 和 base_retriever

阶段 5:执行检索和重排序

results = compression_retriever.invoke(query)

调用链:

  1. invoke() 方法(继承自 BaseRetriever)
    • 调用 _get_relevant_documents(query)
  2. _get_relevant_documents() 方法(核心检索逻辑)

    步骤 1:基础检索

    docs = self.base_retriever.invoke(query, **kwargs)
    • 调用 VectorStoreRetriever.invoke() 从向量数据库检索文档(返回 k=20 个文档)

    步骤 2:文档重排序

    compressed_docs = self.base_compressor.compress_documents(docs, query)

    内部处理流程(CrossEncoderReranker.compress_documents()):

    步骤 2.1:构造查询-文档对

    pairs = [(query, doc.page_content) for doc in documents]
    • 为每个文档构造 (查询, 文档内容) 对

    步骤 2.2:计算相关性分数

    scores = self.model.predict(pairs)
    • 调用 HuggingFaceCrossEncoder.predict() 对所有句对进行评分
    • 内部调用 sentence_transformers.CrossEncoder.predict(),返回每个文档与查询的相关性分数

    步骤 2.3:排序和截断

    ranked = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
    return [doc for doc, _ in ranked[:self.top_n]]
    • 按分数降序排序
    • 取前 top_n=3 个分数最高的文档

    步骤 3:返回重排序后的文档列表

    • 返回重排序后的 Document 列表(最多 3 个)

阶段 6:结果输出

print(f"查询:{query}")
print(f"共检索到 {len(results)} 个文档")
for i, doc in enumerate(results):
    print(f"检索结果{i}:{doc.page_content}")
    print(f"  元数据:{doc.metadata}\n")
  • 遍历 results(重排序后的 Document 列表)
  • 打印每个文档的 page_content 和 metadata

57.8.5 CrossEncoder 工作原理 #

CrossEncoder vs BiEncoder

特性 CrossEncoder BiEncoder(如向量检索)
计算方式 查询和文档同时输入模型 查询和文档分别编码后计算相似度
准确性 高(考虑交互) 中等(独立编码)
速度 较慢(需要成对计算) 快(可预先编码)
适用场景 重排序、小批量文档 大规模检索、召回阶段

CrossEncoder 优势

  1. 精确评分:同时考虑查询和文档的交互,评分更准确
  2. 上下文理解:能理解查询与文档的语义关系
  3. 适合重排序:在召回阶段后对候选文档进行精确排序

工作流程

  1. 召回阶段:使用向量数据库快速召回 k=20 个候选文档
  2. 重排序阶段:使用 CrossEncoder 对所有候选文档进行精确评分
  3. 截断阶段:按分数排序,返回 top_n=3 个最相关的文档

双阶段检索的优势

阶段 方法 目标 特点
召回阶段 向量检索 快速召回大量候选文档 速度快,覆盖广
重排序阶段 CrossEncoder 精确排序候选文档 准确度高,结果精

这种两阶段方法兼顾了速度和准确性。

57.8.6 CrossEncoderReranker vs 其他压缩器 #

特性 CrossEncoderReranker EmbeddingsFilter LLMChainExtractor
计算方式 CrossEncoder 模型评分 嵌入向量余弦相似度 LLM 语义理解
准确性 高 中等 高
速度 中等 快 慢
文档修改 不修改 不修改 提取相关内容
适用场景 重排序、中等规模 快速过滤、大规模 精确提取、小规模
成本 中等(模型推理) 低(向量计算) 高(LLM API调用)

CrossEncoderReranker 在准确性和速度之间取得平衡,适合需要精确重排序的场景。

58.LongContextReorder #

  • Lost-in-the-middle

本节介绍了 LongContextReorder 的用法和实现原理。LongContextReorder 是一种“长上下文重排器”,用于解决大语言模型处理长文档时信息容易丢失在中间部分(lost-in-the-middle)的问题。其核心思想是将最重要、与查询最相关的文档片段放在序列的开头和结尾,而把相关性较低的内容安排在中间,以提升模型的整体理解能力。
重排流程一般如下:

  1. 先将输入的文档列表反转;
  2. 交替地将文档插入新列表的头部或尾部——偶数位的插到头部,奇数位的加到尾部,实现一种“两头包夹中间”的顺序;
  3. 返回重排后的文档序列。

举个例子,原始文档顺序 [A, B, C, D],经过 LongContextReorder 后,顺序会变为 [B, D, A, C]。

具体如何使用,可参考上方代码示例:初始化后调用 transform_documents(docs) 即可得到重排结果。

这种方法在大模型多文档问答、检索增强生成(RAG)等场景中颇为常见,有助于最大化命中重要信息,提高模型生成答案的准确性。

58.1. 58.LongContextReorder.py #

58.LongContextReorder.py

#from langchain_community.document_transformers  import LongContextReorder
#from langchain_core.documents import Document

# 导入 Document 类,用于创建文档对象
from smartchain.documents import Document
# 导入 LongContextReorder 类,用于实现 Lost-in-the-middle 重排
from smartchain.document_transformers import LongContextReorder

# 构建一个包含 4 个文档的列表,每个文档附带不同的内容和元数据 id
docs = [
    # 第1个文档,介绍部分
    Document(page_content="介绍:本文讨论人工智能的基础。", metadata={"id": 1}),
    # 第2个文档,细节部分
    Document(page_content="细节:机器学习是人工智能的核心方法。", metadata={"id": 2}),
    # 第3个文档,扩展部分
    Document(page_content="扩展:深度学习通过神经网络取得突破。", metadata={"id": 3}),
    # 第4个文档,结论部分
    Document(page_content="结论:人工智能正在改变世界。", metadata={"id": 4}),
]

# 创建长上下文重排器实例
reorder = LongContextReorder()
# 对文档列表进行重排,得到重排后的文档顺序
reordered = reorder.transform_documents(docs)

# 输出重排后文档的顺序和内容
print("重排后顺序:")
# 枚举遍历重排后的文档,i 为顺序编号,从 1 开始
for i, d in enumerate(reordered, 1):
    # 按格式输出:序号. 文档id: 文档内容
    print(f"{i}. {d.metadata.get('id')}: {d.page_content}")

58.2. init.py #

smartchain/document_transformers/init.py

# 从 base 模块导入文档转换器基类
from .base import BaseDocumentTransformer
# 从 long_context 模块导入长上下文重排器
from .long_context import LongContextReorder

# 定义模块对外暴露的类名列表
__all__ = [
    "BaseDocumentTransformer",
    "LongContextReorder",
]

58.3. base.py #

smartchain/document_transformers/base.py

# 导入 Sequence 类型用于类型标注
from typing import Sequence
# 从上级目录导入 Document 类
from ..documents import Document

# 定义文档转换器基类
class BaseDocumentTransformer:
    # 类的文档字符串,说明这是文档转换器基类
    """文档转换器基类"""

    # 定义抽象方法,用于转换文档列表
    def transform_documents(self, documents: Sequence[Document]) -> Sequence[Document]:
        # 方法文档字符串,说明此方法用于转换文档列表
        """转换文档列表"""
        # 抛出未实现异常,子类需实现此方法
        raise NotImplementedError

58.4. long_context.py #

smartchain/document_transformers/long_context.py

# 导入类型注解 Sequence,用于类型标注
from typing import Sequence
# 导入基础文档转换器基类
from .base import BaseDocumentTransformer
# 导入文档对象类
from ..documents import Document


# 定义长上下文重排器,继承自基础文档转换器
class LongContextReorder(BaseDocumentTransformer):
    # 类文档字符串,说明用途及对应论文
    """长上下文重排器

    实现 Lost-in-the-middle 重排算法:
    将最相关的文档放在开头和结尾,不太相关的文档放在中间。
    参考论文: https://arxiv.org/abs/2307.03172
    """

    # 重写 transform_documents 方法,实现文档重排
    def transform_documents(self, documents: Sequence[Document]) -> Sequence[Document]:
        # 将输入的文档序列转换为列表,便于操作
        docs = list(documents)
        # 反转文档列表
        docs.reverse()
        # 初始化重排后的文档列表
        reordered = []
        # 遍历反转后的文档列表,枚举索引和文档
        for i, doc in enumerate(docs):
            # 如果索引为奇数
            if i % 2 == 1:
                # 将当前文档追加到重排列表的末尾
                reordered.append(doc)
            # 如果索引为偶数
            else:
                # 将当前文档插入到重排列表的开头
                reordered.insert(0, doc)
        # 返回重排后的文档列表
        return reordered

58.5. 对比 #

BaseDocumentTransformer 和 BaseDocumentCompressor 都是 LangChain 中处理检索文档的抽象基类,但它们的核心区别在于目的和操作方式:一个负责“变换”文档内容本身,另一个负责“压缩/筛选”文档集。

58.5.1 关键区别 #

维度 BaseDocumentTransformer (文档转换器) BaseDocumentCompressor (文档压缩器)
核心目的 转换文档的内容或格式 减少文档的数量或长度
主要操作对象 单个文档的内容(如文本) 文档集合(多个Document对象)
典型输出 修改后的文档列表(数量通常不变) 过滤或精简后的文档子集(数量减少)
关键方法 transform_documents(documents, **kwargs) compress_documents(documents, query, **kwargs)
类比 文档编辑器:重写、翻译、总结文本 文档过滤器:按相关性筛选、截断文本
常见子类 LongContextReorder (重排顺序)、EmbeddingsRedundantFilter (去重) CrossEncoderReranker (重排序并筛选)、LLMChainExtractor (LLM提取摘要)
是否依赖查询(query) 通常否(对文档进行独立处理) 是(压缩过程依赖于具体的查询)

58.5.2 工作原理 #

58.5.2.1. BaseDocumentTransformer(转换文档) #

核心思想:接收一批文档,对它们逐一进行某种变换,返回数量相同但内容可能已被修改的文档列表。

from langchain.schema import BaseDocumentTransformer
from langchain.schema import Document

class SimpleTextCleaner(BaseDocumentTransformer):
    """一个简单的转换器示例:清理文档文本"""
    def transform_documents(self, documents, **kwargs):
        transformed = []
        for doc in documents:
            # 转换操作:例如,移除多余空格和特殊字符
            cleaned_text = " ".join(doc.page_content.split())
            transformed.append(Document(page_content=cleaned_text, metadata=doc.metadata))
        return transformed

# 使用
original_docs = [Document(page_content="Hello   World!  ")]
cleaner = SimpleTextCleaner()
transformed_docs = cleaner.transform_documents(original_docs)
print(transformed_docs[0].page_content) # 输出: "Hello World!"
58.5.2.2. BaseDocumentCompressor(压缩/筛选文档) #

核心思想:接收一批文档和一个查询,根据查询的相关性来减少文档的数量或长度。

from langchain.retrievers.document_compressors import BaseDocumentCompressor
from langchain.schema import Document

class SimpleTopKCompressor(BaseDocumentCompressor):
    """一个简单的压缩器示例:只保留前K个文档(假设已按相关性排序)"""
    def __init__(self, k=3):
        self.k = k

    def compress_documents(self, documents, query, **kwargs):
        # 压缩操作:这里简单截取前K个
        return documents[:self.k]

# 使用
original_docs = [Document(page_content=f"Doc {i}") for i in range(10)] # 10个文档
query = "test query"
compressor = SimpleTopKCompressor(k=2)
compressed_docs = compressor.compress_documents(original_docs, query)
print(len(compressed_docs)) # 输出: 2

58.5.3 在RAG流程中的典型协作 #

在实际的RAG系统中,它们常被串联使用,形成一个处理流水线(Pipeline):

# 假设在一个完整的检索增强流程中:
1. **检索**:从向量数据库获取100个相关文档(粗召回)
2. **转换**:使用 `EmbeddingsRedundantFilter` (转换器) 去除重复文档 -> 剩下80个
3. **压缩**:使用 `CrossEncoderReranker` (压缩器) 根据查询进行精排并取Top-5 -> 剩下5个
4. **再次转换**:使用 `LongContextReorder` (转换器) 重新排列这5个文档的顺序以优化LLM注意力
5. **生成**:将最终处理后的文档和查询一起交给LLM生成答案

在这个流水线中,压缩器 (BaseDocumentCompressor) 的核心职责是“做减法”,根据查询筛选出最核心的文档子集;而转换器 (BaseDocumentTransformer) 的核心职责是“做调整”,在不改变文档集合大小的前提下优化文档的内容或顺序。

58.5.3 总结与选择建议 #

当你想... 应该选择
去除重复、翻译文本、重新排序文档(不依赖查询) BaseDocumentTransformer 的子类
根据特定查询,筛选最相关的文档、提取关键片段 BaseDocumentCompressor 的子类
既要去重又要筛选,且优化顺序 组合使用两者,通常顺序是:先转换器去重,再压缩器筛选

简单来说,可以这样记忆:BaseDocumentTransformer 像是文档的“化妆师”或“整理师”,专注于美化或重组内容;而 BaseDocumentCompressor 则像是严格的“面试官”,根据你的问题(query)来决定哪些文档可以进入下一轮。

58.6. 类 #

58.6.1 类说明 #

类名 类型 作用 关键属性/方法
Document 数据类 存储文档内容和元数据 page_content: 文档文本内容
metadata: 元数据字典
BaseDocumentTransformer 抽象基类 定义文档转换器接口规范 transform_documents(): 抽象方法(需子类实现)
LongContextReorder 文档转换器类 实现 Lost-in-the-middle 重排算法 transform_documents(): 重排文档顺序
将最相关文档放在开头和结尾,不太相关的放在中间

58.6.2 类图 #

classDiagram class Document { +str page_content +Dict[str, Any] metadata +__init__(page_content, metadata) } class BaseDocumentTransformer { <<abstract>> +transform_documents(documents)* Sequence[Document] } class LongContextReorder { +__init__() +transform_documents(documents) Sequence[Document] } BaseDocumentTransformer <|-- LongContextReorder : 继承实现 LongContextReorder "1" *-- "0..*" Document : 转换文档列表 note for BaseDocumentTransformer "文档转换器基类\n定义转换接口规范\n使用模板方法模式" note for LongContextReorder "长上下文重排器\n实现Lost-in-the-middle算法\n解决LLM长上下文中的中间信息丢失问题\n参考论文: https://arxiv.org/abs/2307.03172" note for Document "文档数据类\n存储文本内容和元数据"

58.6.3 时序图 #

sequenceDiagram participant Main as 58.LongContextReorder.py participant Doc as Document participant LCR as LongContextReorder participant Base as BaseDocumentTransformer Note over Main: 阶段1: 创建文档对象列表 Main->>Doc: Document(page_content="介绍...", metadata={"id": 1}) Doc-->>Main: doc1 Main->>Doc: Document(page_content="细节...", metadata={"id": 2}) Doc-->>Main: doc2 Main->>Doc: Document(page_content="扩展...", metadata={"id": 3}) Doc-->>Main: doc3 Main->>Doc: Document(page_content="结论...", metadata={"id": 4}) Doc-->>Main: doc4 Main->>Main: docs = [doc1, doc2, doc3, doc4] Note over Main: 阶段2: 创建重排器实例 Main->>LCR: LongContextReorder() activate LCR LCR->>Base: 继承BaseDocumentTransformer activate Base Base-->>LCR: 初始化完成 deactivate Base LCR-->>Main: reorder实例 deactivate LCR Note over Main: 阶段3: 执行文档重排 Main->>LCR: transform_documents(docs) activate LCR LCR->>LCR: docs = list(documents) # 转换为列表 LCR->>LCR: docs.reverse() # 反转列表: [doc4, doc3, doc2, doc1] LCR->>LCR: reordered = [] # 初始化重排列表 loop 遍历反转后的文档列表 alt i=0 (偶数索引) LCR->>LCR: reordered.insert(0, doc4) # 插入开头 Note over LCR: reordered = [doc4] else i=1 (奇数索引) LCR->>LCR: reordered.append(doc3) # 追加末尾 Note over LCR: reordered = [doc4, doc3] else i=2 (偶数索引) LCR->>LCR: reordered.insert(0, doc2) # 插入开头 Note over LCR: reordered = [doc2, doc4, doc3] else i=3 (奇数索引) LCR->>LCR: reordered.append(doc1) # 追加末尾 Note over LCR: reordered = [doc2, doc4, doc3, doc1] end end LCR-->>Main: reordered (重排后的Document列表) deactivate LCR Note over Main: 阶段4: 输出重排结果 Main->>Main: 遍历reordered并打印 loop 遍历每个文档 Main->>Main: print(f"{i}. {d.metadata.get('id')}: {d.page_content}") end

58.6.4 调用过程 #

阶段 1:创建文档对象列表

docs = [
    Document(page_content="介绍:本文讨论人工智能的基础。", metadata={"id": 1}),
    Document(page_content="细节:机器学习是人工智能的核心方法。", metadata={"id": 2}),
    Document(page_content="扩展:深度学习通过神经网络取得突破。", metadata={"id": 3}),
    Document(page_content="结论:人工智能正在改变世界。", metadata={"id": 4}),
]
  • 创建 4 个 Document 对象
  • 每个文档包含 page_content 和 metadata(包含 id)

阶段 2:创建重排器实例

reorder = LongContextReorder()
  • 创建 LongContextReorder 实例
  • 该类继承自 BaseDocumentTransformer

阶段 3:执行文档重排

reordered = reorder.transform_documents(docs)

调用链:

  1. transform_documents() 方法(核心重排逻辑)

    步骤 1:转换为列表并反转

    docs = list(documents)  # [doc1, doc2, doc3, doc4]
    docs.reverse()           # [doc4, doc3, doc2, doc1]

    步骤 2:交替插入和追加

    reordered = []
    for i, doc in enumerate(docs):
        if i % 2 == 1:  # 奇数索引
            reordered.append(doc)      # 追加到末尾
        else:            # 偶数索引
            reordered.insert(0, doc)   # 插入到开头

    重排过程示例(4 个文档):

    • 初始:[doc1, doc2, doc3, doc4]
    • 反转:[doc4, doc3, doc2, doc1]
    • i=0(偶数):reordered.insert(0, doc4) → [doc4]
    • i=1(奇数):reordered.append(doc3) → [doc4, doc3]
    • i=2(偶数):reordered.insert(0, doc2) → [doc2, doc4, doc3]
    • i=3(奇数):reordered.append(doc1) → [doc2, doc4, doc3, doc1]
    • 最终:[doc2, doc4, doc3, doc1]

    步骤 3:返回重排后的文档列表

    • 返回重排后的 Document 列表

阶段 4:输出重排结果

print("重排后顺序:")
for i, d in enumerate(reordered, 1):
    print(f"{i}. {d.metadata.get('id')}: {d.page_content}")
  • 遍历重排后的文档列表
  • 打印每个文档的序号、id 和内容

58.6.5 Lost-in-the-middle 问题 #

问题背景

LLM 在处理长上下文时,容易出现“中间信息丢失”:

  • 开头和结尾的信息更容易被关注
  • 中间部分的信息容易被忽略

解决方案

LongContextReorder 实现的重排策略:

  1. 将最相关的文档放在开头和结尾
  2. 将不太相关的文档放在中间

重排算法详解

对于 n 个文档(按相关性从高到低排序):

  1. 反转列表:[doc_n, doc_{n-1}, ..., doc_2, doc_1]
  2. 交替插入:
    • 偶数索引(0, 2, 4, ...):插入到开头
    • 奇数索引(1, 3, 5, ...):追加到末尾

重排效果示例

假设有 4 个文档,按相关性排序为:[doc1, doc2, doc3, doc4](doc1 最相关)

步骤 操作 结果
初始 [doc1, doc2, doc3, doc4] 原始顺序
反转 [doc4, doc3, doc2, doc1] 反转后
i=0 insert(0, doc4) [doc4]
i=1 append(doc3) [doc4, doc3]
i=2 insert(0, doc2) [doc2, doc4, doc3]
i=3 append(doc1) [doc2, doc4, doc3, doc1]

最终顺序:[doc2, doc4, doc3, doc1]

  • 开头:doc2(次相关)
  • 中间:doc4(不太相关)、doc3(不太相关)
  • 结尾:doc1(最相关)

算法优势

  1. 重要信息在两端:最相关文档在开头和结尾
  2. 中间信息不丢失:不太相关的文档放在中间,仍能被处理
  3. 简单高效:时间复杂度 O(n),空间复杂度 O(n)

适用场景

  • RAG 系统:将检索到的文档重排后输入 LLM
  • 长文档处理:处理超长上下文时优化信息分布
  • 问答系统:确保关键信息不被忽略

58.6.6 与其他组件的对比 #

特性 LongContextReorder Document Compressor
作用 重排文档顺序 过滤或提取文档内容
文档修改 不修改内容,只改变顺序 可能修改或过滤内容
适用阶段 检索后、输入 LLM 前 检索后、输入 LLM 前
目标 优化信息分布 提高相关性
输入 文档列表 文档列表 + 查询
输出 重排后的文档列表 压缩后的文档列表

LongContextReorder 专注于解决长上下文中的信息分布问题,不依赖查询,适用于所有需要优化文档顺序的场景。

59.参考 #

  • 余弦相似度
  • TF-IDF
  • rank_bm25
  • RRF
  • CrossEncoder
  • Lost-in-the-middle
  • Flask
  • SqlAlchemy
  • asyncio
  • Starlette
  • FastAPI
  • uvicorn
  • argparse
← 上一节 11.tool 下一节 13.optimize →

访问验证

请输入访问令牌

Token不正确,请重新输入