导航菜单

  • 1.langchain.intro
  • 2.langchain.chat_models
  • 3.langchain.prompts
  • 4.langchain.example_selectors
  • 5.output_parsers
  • 6.Runnable
  • 7.memory
  • 8.document_loaders
  • 9.text_splitters
  • 10.embeddings
  • 11.tool
  • 12.retrievers
  • 13.optimize
  • 14.项目介绍
  • 15.启动HTTP
  • 16.数据与模型
  • 17.权限管理
  • 18.知识库管理
  • 19.设置
  • 20.文档管理
  • 21.聊天
  • 22.API文档
  • 23.RAG优化
  • 24.索引时优化
  • 25.检索前优化
  • 26.检索后优化
  • 27.系统优化
  • 28.GraphRAG
  • 29.图
  • 30.为什么选择图数据库
  • 31.什么是 Neo4j
  • 32.安装和连接 Neo4j
  • 33.Neo4j核心概念
  • 34.Cypher基础
  • 35.模式匹配
  • 36.数据CRUD操作
  • 37.GraphRAG
  • 38.查询和过滤
  • 39.结果处理和聚合
  • 40.语句组合
  • 41.子查询
  • 42.模式和约束
  • 43.日期时间处理
  • 44.Cypher内置函数
  • 45.Python操作Neo4j
  • 46.neo4j
  • 47.py2neo
  • 48.Streamlit
  • 49.Pandas
  • 50.graphRAG
  • 51.deepdoc
  • 52.deepdoc
  • 53.deepdoc
  • 55.deepdoc
  • 54.deepdoc
  • Pillow
  • 1.本章目标
  • 2.目录结构
  • 3.聊天页面
    • 3.1. chat.py
    • 3.2. chat.html
    • 3.3. init.py
    • 3.4. init.py
  • 4.聊天对话
    • 4.1. chat_service.py
    • 4.2. llm_factory.py
    • 4.3. chat.py
    • 4.4. chat.html
  • 5.会话管理
    • 5.1. chat_session_service.py
    • 5.2. chat.py
    • 5.3. chat.html
  • 6.选择知识库
    • 6.1. chat.py
    • 6.2. chat.html
  • 7.知识库对话
    • 7.1. rag_service.py
    • 7.2. chat.py
    • 7.3. chat_service.py
    • 7.4. chat.html
  • 8.向量检索
    • 8.1. retrieval_service.py
    • 8.2. rag_service.py
  • 9.关键字检索
    • 9.1. rag_service.py
    • 9.2. retrieval_service.py
  • 10.混合检索
    • 10.1. rag_service.py
    • 10.2. retrieval_service.py
  • 11.重排序
    • 11.1. rerank_factory.py
    • 11.2. retrieval_service.py
  • 12.显示来源
    • 12.1. chat.py
    • 12.2. chat_session_service.py
    • 12.3. rag_service.py
    • 12.4. chat.html

1.本章目标 #

本章将介绍聊天模块的实现目标与功能,包括用户与AI的对话流程、消息存储与检索、会话管理机制以及与知识库等模块的基础集成方式。通过本章的学习,读者能够理解聊天功能的架构设计思路、关键模块结构、以及如何扩展和自定义聊天体验,为构建智能对话应用奠定基础。

2.目录结构 #

// 项目根目录
rag-lite/
// 应用主目录
├── app/
// 蓝图视图目录(用于组织不同功能的视图)
│   ├── blueprints/
// 蓝图包初始化
│   │   ├── __init__.py
// 用户认证相关视图
│   │   ├── auth.py
// 聊天相关视图
│   │   ├── chat.py
// 文档管理相关视图
│   │   ├── document.py
// 知识库相关视图
│   │   ├── knowledgebase.py
// 系统设置相关视图
│   │   ├── settings.py
// 工具函数蓝图
│   │   └── utils.py
// ORM数据模型目录
│   ├── models/
// 模型包初始化
│   │   ├── __init__.py
// 基础模型定义
│   │   ├── base.py
// 聊天消息模型
│   │   ├── chat_message.py
// 聊天会话模型
│   │   ├── chat_session.py
// 文档模型
│   │   ├── document.py
// 知识库模型
│   │   ├── knowledgebase.py
// 系统设置模型
│   │   ├── settings.py
// 用户模型
│   │   └── user.py
// 服务层目录
│   ├── services/
// 存储相关服务子目录
│   │   ├── storage/
// 存储包初始化
│   │   │   ├── __init__.py
// 存储基类
│   │   │   ├── base.py
// 存储工厂类
│   │   │   ├── factory.py
// 本地存储实现
│   │   │   ├── local_storage.py
// minio存储实现
│   │   │   └── minio_storage.py
// 向量数据库服务子目录
│   │   ├── vectordb/
// 向量数据库包初始化
│   │   │   ├── __init__.py
// 向量数据库基类
│   │   │   ├── base.py
// Chroma向量数据库实现
│   │   │   ├── chroma.py
// 向量数据库工厂
│   │   │   ├── factory.py
// Milvus向量数据库实现
│   │   │   └── milvus.py
// 基础服务类
│   │   ├── base_service.py
// 聊天服务
│   │   ├── chat_service.py
// 聊天会话管理服务
│   │   ├── chat_session_service.py
// 文档服务
│   │   ├── document_service.py
// 知识库服务
│   │   ├── knowledgebase_service.py
// 文档解析服务
│   │   ├── parser_service.py
// RAG服务
│   │   ├── rag_service.py
// 检索服务
│   │   ├── retrieval_service.py
// 系统设置服务
│   │   ├── settings_service.py
// 存储服务(入口)
│   │   ├── storage_service.py
// 用户服务
│   │   ├── user_service.py
// 向量化服务
│   │   └── vector_service.py
// 静态资源目录
│   ├── static/
// 前端模板文件目录
│   ├── templates/
// 基础模板
│   │   ├── base.html
// 聊天页面模板
│   │   ├── chat.html
// 首页模板
│   │   ├── home.html
// 知识库详情页面模板
│   │   ├── kb_detail.html
// 知识库列表页面模板
│   │   ├── kb_list.html
// 登录页面模板
│   │   ├── login.html
// 注册页面模板
│   │   ├── register.html
// 设置页面模板
│   │   └── settings.html
// 工具函数模块目录
│   ├── utils/
// 认证相关工具
│   │   ├── auth.py
// 数据库工具
│   │   ├── db.py
// 文档加载工具
│   │   ├── document_loader.py
// 嵌入模型工厂
│   │   ├── embedding_factory.py
// LLM大模型工厂
│   │   ├── llm_factory.py
// 日志工具
│   │   ├── logger.py
// 模型配置工具
│   │   ├── models_config.py
// 重排序(rerank)工厂
│   │   ├── rerank_factory.py
// 文本分割工具
│   │   └── text_splitter.py
// 应用初始化文件
│   ├── __init__.py
// 应用配置文件
│   └── config.py
// Chroma数据库目录
├── chroma_db/
// 日志文件目录
├── logs/
// RAG Lite 主日志文件
│   └── rag_lite.log
// 存储文件目录
├── storages/
// 持久化卷目录
├── volumes/
// Milvus向量数据库的数据目录
│   ├── milvus/
// Minio对象存储的数据目录
│   └── minio/
// Docker Compose服务编排配置文件
├── docker-compose.yml
// 应用启动主程序
├── main.py
// Python项目配置文件
├── pyproject.toml

3.聊天页面 #

3.1. chat.py #

app/blueprints/chat.py

# 聊天相关路由(视图 + API)
"""
聊天相关路由(视图 + API)
"""

# 导入 Flask 的 Blueprint 和模板渲染函数
from flask import Blueprint, render_template

# 导入知识库服务,用于后续业务逻辑
from app.services.knowledgebase_service import kb_service

# 导入登录保护装饰器和获取当前用户辅助方法
from app.utils.auth import login_required, get_current_user

# 导入日志模块
import logging

# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)

# 创建名为 'chat' 的蓝图对象
bp = Blueprint('chat', __name__)

# 注册 /chat 路由,访问该路由需要先登录
@bp.route('/chat')
@login_required
def chat_view():
    # 智能问答页面视图函数
    """智能问答页面"""
    # 渲染 chat.html 模板并传递空知识库列表
    return render_template('chat.html', knowledgebases=[])

3.2. chat.html #

app/templates/chat.html

{% extends "base.html" %}

{% block title %}智能问答 - RAG Lite{% endblock %}

{% block extra_css %}
<style>
    .chat-container {
        height: calc(100vh - 200px);
        display: flex;
        gap: 1rem;
    }
    .chat-sidebar {
        width: 280px;
        background: #fff;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        display: flex;
        flex-direction: column;
    }
    .chat-main {
        flex: 1;
        background: #fff;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        display: flex;
        flex-direction: column;
    }
    .session-list {
        flex: 1;
        overflow-y: auto;
        padding: 0.5rem;
    }
    .session-item {
        padding: 0.75rem;
        margin-bottom: 0.5rem;
        border-radius: 0.5rem;
        cursor: pointer;
        transition: background-color 0.2s;
        position: relative;
    }
    .session-item:hover {
        background-color: #f8f9fa;
    }
    .session-item.active {
        background-color: #e3f2fd;
        border-left: 3px solid #0d6efd;
    }
    .session-item .session-title {
        font-weight: 500;
        margin-bottom: 0.25rem;
        overflow: hidden;
        text-overflow: ellipsis;
        white-space: nowrap;
    }
    .session-item .session-time {
        font-size: 0.75rem;
        color: #6c757d;
    }
    .session-item .session-delete {
        position: absolute;
        top: 0.5rem;
        right: 0.5rem;
        opacity: 0;
        transition: opacity 0.2s;
    }
    .session-item:hover .session-delete {
        opacity: 1;
    }
    .chat-messages {
        flex: 1;
        overflow-y: auto;
        padding: 1rem;
        scroll-behavior: smooth;
    }
    .chat-message {
        padding: 1rem;
        margin-bottom: 1rem;
        border-radius: 0.5rem;
    }
    .chat-question {
        background-color: #e3f2fd;
    }
    .chat-answer {
        background-color: #f5f5f5;
    }
    .chat-input-area {
        padding: 1rem;
        border-top: 1px solid #dee2e6;
    }
    .empty-state {
        display: flex;
        flex-direction: column;
        align-items: center;
        justify-content: center;
        height: 100%;
        color: #6c757d;
    }
</style>
{% endblock %}

{% block content %}
<div class="chat-container">
    <!-- 左侧:会话管理 -->
    <div class="chat-sidebar">
        <div class="p-3 border-bottom">
            <div class="d-flex justify-content-between align-items-center mb-3">
                <h6 class="mb-0"><i class="bi bi-chat-left-text"></i> 聊天会话</h6>
                <button class="btn btn-sm btn-primary" onclick="createNewSession()">
                    <i class="bi bi-plus"></i> 新建
                </button>
            </div>
            <button class="btn btn-sm btn-outline-danger w-100" onclick="clearAllSessions()">
                <i class="bi bi-trash"></i> 清空所有
            </button>
        </div>
        <div class="session-list" id="sessionList">
            <div class="text-center text-muted py-5">
                <i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
                <p class="mt-2 small">暂无会话</p>
            </div>
        </div>
    </div>

    <!-- 右侧:对话页面 -->
    <div class="chat-main">
        <div class="p-3 border-bottom">
            <div class="d-flex justify-content-between align-items-center">
                <h6 class="mb-0"><i class="bi bi-chat-dots"></i> 对话</h6>
                <select class="form-select form-select-sm" id="kbSelect" style="width: 200px;" onchange="onKbChange()">
                    <option value="">-- 选择知识库 --</option>
                    {% for kb in knowledgebases %}
                    <option value="{{ kb.id }}">{{ kb.name }}</option>
                    {% endfor %}
                </select>
            </div>
        </div>

        <div class="chat-messages" id="chatMessages">
            <div class="empty-state">
                <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
                <p class="mt-2">开始提问吧!</p>
            </div>
        </div>

        <div class="chat-input-area">
           <form id="chatForm" onsubmit="askQuestion(event)">
                <div class="mb-2">
                    <textarea class="form-control" id="questionInput" rows="2" 
                              placeholder="输入您的问题..." required></textarea>
                </div>
                <div class="d-flex justify-content-end">
                    <button type="submit" class="btn btn-primary" id="submitBtn">
                        <i class="bi bi-send"></i> 发送
                    </button>
                </div>
            </form>
        </div>
    </div>
</div>
{% endblock %}

{% block extra_js %}
<script>

</script>
{% endblock %}

3.3. init.py #

app/init.py

# RAG Lite 应用模块说明
"""
RAG Lite Application
"""

# 导入操作系统相关模块
import os
# 从 Flask 包导入 Flask 应用对象
from flask import Flask
# 导入 Flask 跨域资源共享支持
from flask_cors import CORS
# 导入应用配置类
from app.config import Config
# 导入日志工具,用于获取日志记录器
from app.utils.logger import get_logger
# 导入数据库初始化函数
from app.utils.db import init_db
# 导入蓝图模块
+from app.blueprints import auth,knowledgebase,settings,document,chat
# 导入获取当前用户信息函数
from app.utils.auth import get_current_user
# 定义创建 Flask 应用的工厂函数
def create_app(config_class=Config):
    # 获取日志记录器,名称为当前模块名
    logger = get_logger(__name__)
    # 尝试初始化数据库
    try:
        # 输出日志,表示即将初始化数据库
        logger.info("初始化数据库...")
        # 执行数据库初始化函数
        init_db()
        # 输出日志,表示数据库初始化成功
        logger.info("数据库初始化成功")
    # 捕获任意异常
    except Exception as e:
        # 输出警告日志,提示数据库初始化失败,并输出异常信息
        logger.warning(f"数据库初始化失败: {e}")
        # 输出警告日志,提示检查数据库是否已存在,并建议手动创建数据表
        logger.warning("请确认数据库已存在,或手动创建数据表")

    # 创建 Flask 应用对象,并指定模板和静态文件目录
    base_dir = os.path.abspath(os.path.dirname(__file__))
    # 创建 Flask 应用对象,并指定模板和静态文件目录
    app = Flask(
        __name__,
        # 指定模板文件目录
        template_folder=os.path.join(base_dir, 'templates'),
        # 指定静态文件目录
        static_folder=os.path.join(base_dir, 'static')
    )
    # 从给定配置类加载配置信息到应用
    app.config.from_object(config_class)

    # 启用跨域请求支持
    CORS(app)

    # 记录应用创建日志信息
    logger.info("Flask 应用已创建")

    # 注册上下文处理器,使 current_user 在所有模板中可用
    @app.context_processor
    def inject_user():
        # 返回当前用户信息字典
        # 使用 get_current_user 获取当前用户信息,并将其添加到上下文字典中
        # 这样在模板中可以直接使用 current_user 变量
        return dict(current_user=get_current_user())

    # 注册蓝图
    app.register_blueprint(auth.bp)
    # 注册知识库蓝图
    app.register_blueprint(knowledgebase.bp)
    # 注册设置蓝图
    app.register_blueprint(settings.bp)
    # 注册文档蓝图
    app.register_blueprint(document.bp)
    # 注册聊天蓝图``
+   app.register_blueprint(chat.bp)
    # 定义首页路由
    @app.route('/')
    def index():
        return "Hello, World!"

    # 返回已配置的 Flask 应用对象
    return app

3.4. init.py #

app/blueprints/init.py

# 定义此文件的描述信息:蓝图模块
"""
蓝图模块
"""

# 从 app.blueprints 包中分别导入 auth、settings、document、chat 模块
+from app.blueprints import auth, settings, document, chat

# 设置 __all__,用于声明对外可导出的模块名称列表
+__all__ = ['auth', 'settings', 'document', 'chat']

4.聊天对话 #

4.1. chat_service.py #

app/services/chat_service.py

"""
问答服务
支持普通聊天和知识库聊天(RAG)
"""
# 导入日志模块
import logging
# 导入可选类型和迭代器类型注解
from typing import Optional, Iterator
# 导入 LLM 工厂,用于创建大语言模型实例
from app.utils.llm_factory import LLMFactory
# 导入 LangChain 的对话模板
from langchain_core.prompts import ChatPromptTemplate
# 导入设置服务,用于获取当前系统设置
from app.services.settings_service import settings_service
# 初始化日志记录器
logger = logging.getLogger(__name__)

# 定义问答服务类
class ChatService:
    # 类的初始化方法
    def __init__(self):
        """初始化问答服务"""
        # 获取并保存当前的系统设置
        self.settings = settings_service.get()
    """问答服务(支持普通聊天和RAG)"""
    # 定义流式普通聊天方法,不使用知识库
    def chat_stream(self, question: str, temperature: Optional[float] = None,
                   max_tokens: int = 1000, history: Optional[list] = None) -> Iterator[dict]:
        """
        流式普通聊天接口(不使用知识库)

        Args:
            question: 问题
            temperature: LLM 温度参数(如果为 None,则从设置中读取)
            max_tokens: 最大生成 token 数
            history: 历史对话记录(可选)

        Yields:
            流式数据块
        """
        # 如果没有指定温度,则从设置中获取(默认为 0.7),并限制在 0-2 之间
        if temperature is None:
            temperature = float(self.settings.get('llm_temperature', '0.7'))
            temperature = max(0.0, min(temperature, 2.0))  # 限制在 0-2 之间

        # 获取用于普通聊天的系统提示词
        chat_prompt_text = self.settings.get('chat_system_prompt')
        # 如果系统提示词不存在,则使用默认的提示词
        if not chat_prompt_text:
            chat_prompt_text = '你是一个专业的AI助手。请友好、准确地回答用户的问题。'

        # 创建支持流式输出的 LLM 实例
        llm = LLMFactory.create_llm(self.settings, temperature=temperature, max_tokens=max_tokens, streaming=True)

        # 构造单轮对话消息格式,包含 system 提示和用户问题
        messages = [
                ("system", chat_prompt_text),
                ("human", question)
        ]

        # 从消息创建对话提示模板
        prompt = ChatPromptTemplate.from_messages(messages)
        # 组装 prompt 和 llm,形成链式调用
        chain = prompt | llm

        # 发送流式开头信号
        yield {
            "type": "start",
            "content": ""
        }

        # 初始化完整答案内容
        full_answer = ""
        try:
            # 遍历模型生成的每一段内容
            for chunk in chain.stream({}):
                # 如果 chunk 有内容,提取内容并累加到full_answer
                if hasattr(chunk, 'content') and chunk.content:
                    content = chunk.content
                    full_answer += content
                    # 输出内容块
                    yield {
                        "type": "content",
                        "content": content
                    }
        # 捕获生成过程中的异常,记录日志并产出错误类型的数据块
        except Exception as e:
            logger.error(f"流式生成时出错: {e}")
            yield {
                "type": "error",
                "content": f"生成答案时出错: {str(e)}"
            }
            return

        # 发送流式结束信号,附带元数据(此处无知识库相关内容)
        yield {
            "type": "done",
            "content": "",
            "sources": [],
            "metadata": {
                'question': question,
                'retrieved_chunks': 0,
                'used_chunks': 0
            }
        }

# 创建全局单例 chat_service 实例
chat_service=ChatService()

4.2. llm_factory.py #

app/utils/llm_factory.py

# LLM 模型工厂
# 根据设置动态创建 LLM 模型,支持扩展
"""
LLM 模型工厂
根据设置动态创建 LLM 模型,支持扩展
"""
# 导入日志模块
import logging
# 导入类型注解:可选、字典、可调用、任意类型
from typing import Optional, Dict, Callable, Any
# 导入设置服务,用于获取当前设置
from app.services.settings_service import settings_service
# 导入配置类
from app.config import Config

# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)


# 定义 LLM 工厂类
class LLMFactory:
    # LLM 模型工厂(支持扩展)
    """LLM 模型工厂(支持扩展)"""

    # 注册的LLM提供者,用于存储各Provider的构造函数
    _providers: Dict[str, Callable] = {}

    # 注册新的LLM提供者方法
    @classmethod
    def register_provider(cls, provider_name: str, provider_func: Callable):
        # 注册新的LLM提供者
        """
        注册新的LLM提供者

        Args:
            provider_name: 提供者名称
            provider_func: 创建LLM的函数,签名应为:
                func(settings: dict, temperature: float, max_tokens: int, streaming: bool) -> LLM
        """
        # 将提供者函数存入_providers字典,键为小写名称
        cls._providers[provider_name.lower()] = provider_func
        # 日志打印已注册信息
        logger.info(f"已注册 LLM 提供商: {provider_name}")

    # 创建LLM实例方法
    @classmethod
    def create_llm(cls, settings: Optional[dict] = None, temperature: float = 0.7, 
                   max_tokens: int = 1000, streaming: bool = False):
        # 根据设置创建 LLM 模型
        """
        根据设置创建 LLM 模型

        Args:
            settings: 设置字典,如果为 None 则从数据库读取
            temperature: 温度参数
            max_tokens: 最大 token 数
            streaming: 是否启用流式输出

        Returns:
            LLM 对象
        """
        # 如果未传入settings,从setting_service获取全局设置
        if settings is None:
            settings = settings_service.get()

        # 获取llm_provider的名称(默认为deepseek)
        provider = settings.get('llm_provider', 'deepseek').lower()

        # 优先检查是否用户注册的 Provider
        if provider in cls._providers:
            # 使用自定义注册的Provider创建llm对象
            return cls._providers[provider](settings, temperature, max_tokens, streaming)

        # 若非自定义Provider,则按内置Provider处理
        if provider == 'deepseek':
            # 创建 DeepSeek LLM
            return cls._create_deepseek(settings, temperature, max_tokens, streaming)
        elif provider == 'openai':
            # 创建 OpenAI LLM
            return cls._create_openai(settings, temperature, max_tokens, streaming)
        elif provider == 'ollama':
            # 创建 Ollama LLM
            return cls._create_ollama(settings, temperature, max_tokens, streaming)
        else:
            # 不支持的Provider,抛出错误
            raise ValueError(
                f"Unsupported LLM provider: {provider}. "
                f"Available providers: {list(cls._providers.keys()) + ['deepseek', 'openai', 'ollama']}"
            )

    # DeepSeek LLM 创建方法
    @classmethod
    def _create_deepseek(cls, settings: dict, temperature: float, max_tokens: int, streaming: bool):
        # 创建 DeepSeek LLM
        """创建 DeepSeek LLM"""
        # 导入 DeepSeek LLM 的类
        from langchain_deepseek import ChatDeepSeek

        # 获取模型名,优先用settings里的值,否则用默认配置
        model_name = settings.get('llm_model_name') or Config.DEEPSEEK_CHAT_MODEL
        # 获取API Key,优先用settings里的值,否则用默认配置
        api_key = settings.get('llm_api_key') or Config.DEEPSEEK_API_KEY
        # 获取Base URL,优先用settings里的值,否则用默认配置
        base_url = settings.get('llm_base_url') or Config.DEEPSEEK_BASE_URL

        # 实例化 DeepSeek LLM 对象
        llm = ChatDeepSeek(
            model=model_name,
            api_key=api_key,
            base_url=base_url,
            temperature=temperature,
            max_tokens=max_tokens,
            streaming=streaming
        )
        # 日志打印创建成功的信息
        logger.info(f"已创建 DeepSeek LLM: {model_name}")
        # 返回 DeepSeek LLM 实例
        return llm

    # OpenAI LLM 创建方法
    @classmethod
    def _create_openai(cls, settings: dict, temperature: float, max_tokens: int, streaming: bool):
        # 创建 OpenAI LLM
        """创建 OpenAI LLM"""
        # 导入OpenAI LLM的类
        from langchain_openai import ChatOpenAI

        # 从settings中获取API Key
        api_key = settings.get('llm_api_key')
        # 如果未设置API Key则报错
        if not api_key:
            raise ValueError("OpenAI API key is required")

        # 获取模型名称,优先用用户配置,否则用默认gpt-4o
        model_name = settings.get('llm_model_name') or 'gpt-4o'
        # 实例化OpenAI LLM对象
        llm = ChatOpenAI(
            model=model_name,
            api_key=api_key,
            temperature=temperature,
            max_tokens=max_tokens,
            streaming=streaming
        )
        # 日志打印创建成功的信息
        logger.info(f"已创建 OpenAI LLM: {model_name}")
        # 返回OpenAI LLM实例
        return llm

    # Ollama LLM 创建方法
    @classmethod
    def _create_ollama(cls, settings: dict, temperature: float, max_tokens: int, streaming: bool):
        # 创建 Ollama LLM
        """创建 Ollama LLM"""
        # 导入Ollama LLM相关类
        from langchain_community.chat_models import ChatOllama

        # 获取基础URL(优先用setting,否则用默认本地地址)
        base_url = settings.get('llm_base_url') or 'http://localhost:11434'
        # 获取模型名,优先用用户配置
        model_name = settings.get('llm_model_name') or 'llama2'

        # 实例化Ollama LLM对象
        llm = ChatOllama(
            model=model_name,
            base_url=base_url,
            temperature=temperature,
            num_predict=max_tokens
        )
        # 日志打印Ollama创建信息
        logger.info(f"已创建 Ollama LLM: {model_name}, 地址: {base_url}")
        # 返回 Ollama LLM 实例
        return llm


# 注册内置提供者(可选,用于统一管理)
def _register_builtin_providers():
    # 注册内置提供者
    """注册内置提供者"""
    # 注册deepseek provider
    LLMFactory.register_provider('deepseek', LLMFactory._create_deepseek)
    # 注册openai provider
    LLMFactory.register_provider('openai', LLMFactory._create_openai)
    # 注册ollama provider
    LLMFactory.register_provider('ollama', LLMFactory._create_ollama)

# 自动注册内置提供者
_register_builtin_providers()

4.3. chat.py #

app/blueprints/chat.py

# 聊天相关路由(视图 + API)
"""
聊天相关路由(视图 + API)
"""

# 导入 Flask 的 Blueprint 和模板渲染函数
+from flask import Blueprint, render_template, request, stream_with_context, Response
+import json
# 导入日志模块
+import logging
# 导入知识库服务,用于后续业务逻辑
from app.services.knowledgebase_service import kb_service
# 导入登录保护装饰器和获取当前用户辅助方法
+from app.utils.auth import login_required, get_current_user, api_login_required
# 导入自定义工具函数:成功响应、错误响应、获取分页参数、获取当前用户或错误、异常处理装饰器、检查所有权
+from app.blueprints.utils import (
+   success_response, error_response,
+   get_current_user_or_error, handle_api_error
+)
+from app.services.chat_service import chat_service

# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)

# 创建名为 'chat' 的蓝图对象
bp = Blueprint('chat', __name__)

# 注册 /chat 路由,访问该路由需要先登录
@bp.route('/chat')
@login_required
def chat_view():
    # 智能问答页面视图函数
    """智能问答页面"""
    # 渲染 chat.html 模板并传递空知识库列表
    return render_template('chat.html', knowledgebases=[])

# 注册 API 路由,处理聊天接口 POST 请求
+@bp.route('/api/v1/knowledgebases/chat', methods=['POST'])
+@api_login_required
+@handle_api_error
+def api_chat():
    # 普通聊天接口(不支持知识库,支持流式输出)
+   """普通聊天接口(不支持知识库,支持流式输出)"""
    # 获取当前用户和错误信息
+   current_user, err = get_current_user_or_error()
    # 如果有错误,直接返回错误响应
+   if err:
+       return err

    # 从请求体获取 JSON 数据
+   data = request.get_json()
    # 如果数据为空或不存在 'question' 字段,返回错误
+   if not data or 'question' not in data:
+       return error_response("question is required", 400)

    # 去除问题文本首尾空格
+   question = data['question'].strip()
    # 如果问题内容为空,返回错误
+   if not question:
+       return error_response("question cannot be empty", 400)

    # 获取 max_tokens 参数,默认 1000
+   max_tokens = int(data.get('max_tokens', 1000))
    # 限制最大和最小值在 1~10000 之间
+   max_tokens = max(1, min(max_tokens, 10000))  # 限制在 1-10000 之间

    # 声明用于流式输出的生成器
+   @stream_with_context
+   def generate():
+       try:
            # 用于缓存完整答案内容
+           full_answer = ''
            # 调用服务进行流式对话
+           for chunk in chat_service.chat_stream(
+               question=question,
+               temperature=None,  # 使用设置中的值
+               max_tokens=max_tokens
+           ):
                # 如果是内容块,则拼接内容到 full_answer
+               if chunk.get('type') == 'content':
+                   full_answer += chunk.get('content', '')
                # 以 SSE 协议格式输出数据
+               yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
            # 输出对话完成信号
+           yield "data: [DONE]\n\n"
+       except Exception as e:
            # 发生异常记录日志
+           logger.error(f"流式输出时出错: {e}")
            # 构造错误数据块
+           error_chunk = {
+               "type": "error",
+               "content": str(e)
+           }
            # 输出错误数据块
+           yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"

    # 创建 Response 对象,设置必要的 SSE 响应头部
+   response = Response(
+       generate(),
+       mimetype='text/event-stream',
+       headers={
+           'Cache-Control': 'no-cache',
+           'Connection': 'keep-alive',
+           'X-Accel-Buffering': 'no',
+           'Content-Type': 'text/event-stream; charset=utf-8'
+       }
+   )
    # 返回响应
+   return response

4.4. chat.html #

app/templates/chat.html

{% extends "base.html" %}

{% block title %}智能问答 - RAG Lite{% endblock %}

{% block extra_css %}
<style>
    .chat-container {
        height: calc(100vh - 200px);
        display: flex;
        gap: 1rem;
    }
    .chat-sidebar {
        width: 280px;
        background: #fff;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        display: flex;
        flex-direction: column;
    }
    .chat-main {
        flex: 1;
        background: #fff;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        display: flex;
        flex-direction: column;
    }
    .session-list {
        flex: 1;
        overflow-y: auto;
        padding: 0.5rem;
    }
    .session-item {
        padding: 0.75rem;
        margin-bottom: 0.5rem;
        border-radius: 0.5rem;
        cursor: pointer;
        transition: background-color 0.2s;
        position: relative;
    }
    .session-item:hover {
        background-color: #f8f9fa;
    }
    .session-item.active {
        background-color: #e3f2fd;
        border-left: 3px solid #0d6efd;
    }
    .session-item .session-title {
        font-weight: 500;
        margin-bottom: 0.25rem;
        overflow: hidden;
        text-overflow: ellipsis;
        white-space: nowrap;
    }
    .session-item .session-time {
        font-size: 0.75rem;
        color: #6c757d;
    }
    .session-item .session-delete {
        position: absolute;
        top: 0.5rem;
        right: 0.5rem;
        opacity: 0;
        transition: opacity 0.2s;
    }
    .session-item:hover .session-delete {
        opacity: 1;
    }
    .chat-messages {
        flex: 1;
        overflow-y: auto;
        padding: 1rem;
        scroll-behavior: smooth;
    }
    .chat-message {
        padding: 1rem;
        margin-bottom: 1rem;
        border-radius: 0.5rem;
    }
    .chat-question {
        background-color: #e3f2fd;
    }
    .chat-answer {
        background-color: #f5f5f5;
    }
    .chat-input-area {
        padding: 1rem;
        border-top: 1px solid #dee2e6;
    }
    .empty-state {
        display: flex;
        flex-direction: column;
        align-items: center;
        justify-content: center;
        height: 100%;
        color: #6c757d;
    }
</style>
{% endblock %}

{% block content %}
<div class="chat-container">
    <!-- 左侧:会话管理 -->
    <div class="chat-sidebar">
        <div class="p-3 border-bottom">
            <div class="d-flex justify-content-between align-items-center mb-3">
                <h6 class="mb-0"><i class="bi bi-chat-left-text"></i> 聊天会话</h6>
                <button class="btn btn-sm btn-primary" onclick="createNewSession()">
                    <i class="bi bi-plus"></i> 新建
                </button>
            </div>
            <button class="btn btn-sm btn-outline-danger w-100" onclick="clearAllSessions()">
                <i class="bi bi-trash"></i> 清空所有
            </button>
        </div>
        <div class="session-list" id="sessionList">
            <div class="text-center text-muted py-5">
                <i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
                <p class="mt-2 small">暂无会话</p>
            </div>
        </div>
    </div>

    <!-- 右侧:对话页面 -->
    <div class="chat-main">
        <div class="p-3 border-bottom">
            <div class="d-flex justify-content-between align-items-center">
                <h6 class="mb-0"><i class="bi bi-chat-dots"></i> 对话</h6>
            </div>
        </div>

        <div class="chat-messages" id="chatMessages">
            <div class="empty-state">
                <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
                <p class="mt-2">开始提问吧!</p>
            </div>
        </div>

        <div class="chat-input-area">
            <form id="chatForm" onsubmit="askQuestion(event)">
                <div class="mb-2">
                    <textarea class="form-control" id="questionInput" rows="2" 
                              placeholder="输入您的问题..." required></textarea>
                </div>
                <div class="d-flex justify-content-end">
                    <button type="submit" class="btn btn-primary" id="submitBtn">
                        <i class="bi bi-send"></i> 发送
                    </button>
                </div>
            </form>
        </div>
    </div>
</div>
{% endblock %}

{% block extra_js %}
<script>
+// 为字符串进行HTML转义,防止XSS攻击
+function escapeHtml(text) {
+   // 创建一个div元素作为容器
+   const div = document.createElement('div');
+   // 将待转义文本设置为div的textContent(自动完成转义)
+   div.textContent = text;
+   // 返回转义后的HTML内容
+   return div.innerHTML;
+}

+// 滚动消息框到底部
+function scrollToBottom() {
+   // 获取聊天消息区域元素
+   const chatMessages = document.getElementById('chatMessages');
+   // 设置滚动条位置到最底部
+   chatMessages.scrollTop = chatMessages.scrollHeight;
+}

+// 将Markdown内容渲染到指定元素(支持降级为普通文本)
+function renderMarkdownToElement(element, text) {
+   // 若未提供内容,显示思考中图标
+   if (!text) {
+       element.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
+       return;
+   }
+   // 优先判断marked库是否可用(渲染markdown)
+   if (typeof marked !== 'undefined' && marked.parse) {
+       try {
+           // 使用marked进行markdown转html
+           element.innerHTML = marked.parse(text);
+       } catch (e) {
+           // 渲染失败则退化为转义+换行
+           element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
+       }
+   } else {
+       // 没有marked库则直接转义+换行显示
+       element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
+   }
+}

+// 主函数:处理用户提交问题事件
+async function askQuestion(event) {
+   // 阻止表单默认提交行为(防止页面刷新)
+   event.preventDefault();
+   // 获取输入框的用户问题并去除首尾空白
+   const question = document.getElementById('questionInput').value.trim();
+   // 获取消息显示区域元素
+   const chatMessages = document.getElementById('chatMessages');
+   // 若问题为空则直接返回
+   if (!question) return;
+   // 检查并移除初始空白提示(如有)
+   if (chatMessages.querySelector('.empty-state')) {
+       chatMessages.innerHTML = '';
+   }
+   // 创建用于展示问题的div元素
+   const questionDiv = document.createElement('div');
+   // 加上样式: 用户问题
+   questionDiv.className = 'chat-message chat-question';
+   // 构建用户气泡内容含图标、文本和时间
+   questionDiv.innerHTML = `
+       <div class="d-flex justify-content-between align-items-start">
+           <div class="flex-grow-1">
+               <strong><i class="bi bi-person-circle"></i> 问题:</strong>
+               <div class="mt-1">${escapeHtml(question)}</div>
+           </div>
+           <small class="text-muted">${new Date().toLocaleTimeString()}</small>
+       </div>
+   `;
+   // 显示到对话窗口
+   chatMessages.appendChild(questionDiv);

+   // 创建用于显示答案的div元素
+   const answerDiv = document.createElement('div');
+   // 答案样式
+   answerDiv.className = 'chat-message chat-answer';
+   // 动态生成唯一的答案内容div id(用于唯一标记)
+   const answerContentId = 'answerContent_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
+   // 创建存放答案内容的div
+   const answerContent = document.createElement('div');
+   // 为markdown渲染容器设置样式和id
+   answerContent.className = 'mt-2 markdown-content';
+   answerContent.id = answerContentId;
+   // 答案div显示机器人图标和时间
+   answerDiv.innerHTML = `
+       <div class="d-flex justify-content-between align-items-start">
+           <div class="flex-grow-1">
+               <strong><i class="bi bi-robot"></i> 答案:</strong>
+           </div>
+           <small class="text-muted">${new Date().toLocaleTimeString()}</small>
+       </div>
+   `;
+   // 将答案内容div插入到机器人气泡div的内容区
+   const flexGrowDiv = answerDiv.querySelector('.flex-grow-1');
+   flexGrowDiv.appendChild(answerContent);
+   // 插入答案div到消息区域
+   chatMessages.appendChild(answerDiv);

+   // 变量记录完整的回答内容
+   let fullAnswer = '';
+   // 标记渲染任务是否挂起(防止重复)
+   let pendingUpdate = false;
+   // 记录定时器id(去抖动用)
+   let updateTimer = null;

+   // 清空输入框内容
+   document.getElementById('questionInput').value = '';
+   // 滚动到底
+   scrollToBottom();

+   // 定义scheduleRender:将markdown渲染插入到队列合适时机执行
+   function scheduleRender() {
+       // 若当前无待渲染任务才安排渲染
+       if (!pendingUpdate) {
+           pendingUpdate = true;
+           // 下一帧渲染
+           requestAnimationFrame(() => {
+               // 将答案作为markdown渲染进dom
+               renderMarkdownToElement(answerContent, fullAnswer);
+               // 清理pending标识
+               pendingUpdate = false;
+               // 渲染后滚动到底部
+               scrollToBottom();
+           });
+       }
+   }

+   try {
+       // 组装API接口地址
+       const url = `/api/v1/knowledgebases/chat`;
+       // 请求后端,发起流式POST请求
+       const response = await fetch(url, {
+           method: 'POST',
+           headers: {'Content-Type': 'application/json'},
+           body: JSON.stringify({
+               question: question,
+               stream: true
+           })
+       });
+       // 非200响应时,抛出异常
+       if (!response.ok) throw new Error('请求失败');
+       // 使用ReadableStream reader获取response流
+       const reader = response.body.getReader();
+       // 用TextDecoder解码二进制数据
+       const decoder = new TextDecoder();
+       // 数据缓冲区(字符串)
+       let buffer = '';
+       // 初始展示“思考中...”提示
+       answerContent.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';

+       // 不断循环读取服务端推送流内容
+       while (true) {
+           // 逐步读取一段数据
+           const { done, value } = await reader.read();
+           // 若读完则结束循环
+           if (done) break;
+           // 本块新数据解码成字符串,并追加到buffer
+           buffer += decoder.decode(value, { stream: true });
+           // 按行切分(多行分别处理)
+           const lines = buffer.split('\n');
+           // 最后一行一般为数据残留,不处理,下次拼
+           buffer = lines.pop() || '';
+           // 处理每一行数据
+           for (const line of lines) {
+               // 筛选有效SSE协议"data: "数据包
+               if (line.startsWith('data: ')) {
+                   // 去掉前缀,留下json内容
+                   const data = line.slice(6);
+                   // [DONE]信号作为流结束,仅跳过
+                   if (data === '[DONE]') continue;
+                   try {
+                       // 解析JSON数据
+                       const chunk = JSON.parse(data);
+                       // 类型为流起始,清空内容
+                       if (chunk.type === 'start') {
+                           fullAnswer = '';
+                           answerContent.innerHTML = '';
+                       // 类型为内容,追加文本并安排去抖渲染
+                       } else if (chunk.type === 'content') {
+                           fullAnswer += chunk.content;
+                           scheduleRender();
+                       // 类型为done(流传输结束),直接最终渲染所有答案内容
+                       } else if (chunk.type === 'done') {
+                           renderMarkdownToElement(answerContent, fullAnswer);
+                       // 错误类型,alert显示报错内容
+                       } else if (chunk.type === 'error') {
+                           answerContent.innerHTML = `<div class="alert alert-danger">${chunk.content}</div>`;
+                       }
+                   } catch (e) {
+                       // JSON解析失败时报console
+                       console.error('解析流数据失败:', e);
+                   }
+               }
+           }
+       }
+   } catch (error) {
+       // 通信异常、Fetch错误等: 显示错误气泡
+       answerContent.innerHTML = `<div class="alert alert-danger"><strong>错误:</strong> ${error.message}</div>`;
+   }
+   // 结束后确保界面滚动到底
+   scrollToBottom();
+}

</script>
{% endblock %}

5.会话管理 #

5.1. chat_session_service.py #

app/services/chat_session_service.py

# 导入倒序排序工具
from sqlalchemy import desc
# 导入日期时间
from datetime import datetime
# 导入聊天会话ORM模型
from app.models.chat_session import ChatSession
# 导入聊天消息ORM模型
from app.models.chat_message import ChatMessage
# 导入基础服务类
from app.services.base_service import BaseService

# 聊天会话服务类,继承自基础服务
class ChatSessionService(BaseService[ChatSession]):
    # 聊天会话服务说明文档
    """聊天会话服务"""

    # 创建新的聊天会话
    def create_session(self, user_id: str, kb_id: str = None, title: str = None) -> dict:
        """
        创建新的聊天会话

        Args:
            user_id: 用户ID
            kb_id: 知识库ID(可选)
            title: 会话标题(可选,如果不提供则使用默认标题)

        Returns:
            会话信息字典
        """
        # 启动数据库事务
        with self.transaction() as session:
            # 如果没有传标题则用默认标题
            if not title:
                title = "新对话"
            # 构造会话对象
            chat_session = ChatSession(
                user_id=user_id,
                title=title
            )
            # 新会话入库
            session.add(chat_session)
            # 刷新以拿到自增ID
            session.flush()
            # 刷新会话对象,便于获取ID等字段
            session.refresh(chat_session)
            # 记录日志
            self.logger.info(f"已创建聊天会话: {chat_session.id}, 用户: {user_id}")
            # 返回会话字典格式
            return chat_session.to_dict()

    # 根据ID获取会话
    def get_session_by_id(self, session_id: str, user_id: str = None) -> dict:
        """
        根据ID获取会话

        Args:
            session_id: 会话ID
            user_id: 用户ID(可选,用于验证权限)

        Returns:
            会话信息字典,如果不存在或无权访问则返回 None
        """
        # 打开数据库只读session
        with self.session() as session:
            # 查询指定ID的会话
            query = session.query(ChatSession).filter_by(id=session_id)
            # 如果提供了user_id则额外限定归属
            if user_id:
                query = query.filter_by(user_id=user_id)
            # 拿到第一个会话记录
            chat_session = query.first()
            # 有则返回字典信息,没有返回None
            if chat_session:
                return chat_session.to_dict()
            return None

    # 获取用户的所有会话列表(分页)
    def list_sessions(self, user_id: str, page: int = 1, page_size: int = 100) -> dict:
        """
        获取用户的会话列表

        Args:
            user_id: 用户ID
            page: 页码
            page_size: 每页数量

        Returns:
            包含总数和会话列表的字典
        """
        # 打开数据库只读session
        with self.session() as session:
            # 查询当前用户的所有会话
            query = session.query(ChatSession).filter_by(user_id=user_id)
            # 用基类的分页方法返回结构化内容,按更新时间倒序
            return self.paginate_query(query, page=page, page_size=page_size,
                                      order_by=desc(ChatSession.updated_at))

    # 会话标题修改
    def update_session_title(self, session_id: str, user_id: str, title: str) -> dict:
        """
        更新会话标题

        Args:
            session_id: 会话ID
            user_id: 用户ID(用于验证权限)
            title: 新标题

        Returns:
            更新后的会话信息字典
        """
        # 启动数据库事务
        with self.transaction() as session:
            # 查询拥有该会话的用户
            chat_session = session.query(ChatSession).filter_by(id=session_id, user_id=user_id).first()
            # 会话不存在则抛异常
            if not chat_session:
                raise ValueError("Session not found or access denied")
            # 更新标题
            chat_session.title = title
            # 刷新会话对象(update不会提交,refresh刷新对象)
            session.refresh(chat_session)
            # 返回最新数据
            return chat_session.to_dict()

    # 删除指定会话
    def delete_session(self, session_id: str, user_id: str) -> bool:
        """
        删除会话(级联删除消息)

        Args:
            session_id: 会话ID
            user_id: 用户ID(用于验证权限)

        Returns:
            是否删除成功
        """
        # 开启数据库事务
        with self.transaction() as session:
            # 查询该用户的指定会话
            chat_session = session.query(ChatSession).filter_by(id=session_id, user_id=user_id).first()
            # 未找到会话则返回False
            if not chat_session:
                return False
            # 删除会话(DB应有级联消息)
            session.delete(chat_session)
            # 记录日志
            self.logger.info(f"已删除聊天会话: {session_id}")
            # 返回True表示删除成功
            return True

    # 删除当前用户的所有会话
    def delete_all_sessions(self, user_id: str) -> int:
        """
        删除用户的所有会话

        Args:
            user_id: 用户ID

        Returns:
            删除的会话数量
        """
        # 开启数据库事务
        with self.transaction() as session:
            # 批量删除本用户所有会话
            count = session.query(ChatSession).filter_by(user_id=user_id).delete()
            # 记录日志
            self.logger.info(f"已删除用户 {user_id} 的 {count} 个聊天会话")
            # 返回删除数量
            return count

    # 添加消息到会话
    def add_message(self, session_id: str, role: str, content: str) -> dict:
        """
        添加消息到会话

        Args:
            session_id: 会话ID
            role: 角色('user' 或 'assistant')
            content: 消息内容
            sources: 引用来源列表(可选)

        Returns:
            消息信息字典
        """
        # 开启数据库事务
        with self.transaction() as session:
            # 构造消息对象
            message = ChatMessage(
                session_id=session_id,
                role=role,
                content=content
            )
            # 添加消息到数据库
            session.add(message)
            # 查询会话对象,用于更新时间/自动生成标题
            chat_session = session.query(ChatSession).filter_by(id=session_id).first()
            # 如果存在会话对象
            if chat_session:
                # 更新会话更新时间
                chat_session.updated_at = datetime.now()
                # 如果是用户发的第一条消息,并且还没标题,则用内容自动命名
                if role == 'user' and (not chat_session.title or chat_session.title == "新对话"):
                    # 会话标题截取前30字符,超长加省略号
                    title = content[:30] + ('...' if len(content) > 30 else '')
                    chat_session.title = title
            # 刷新确保message有ID
            session.flush()
            # 刷新消息对象
            session.refresh(message)
            # 返回消息字典
            return message.to_dict()

    # 获取会话的全部消息
    def get_messages(self, session_id: str, user_id: str = None) -> list:
        """
        获取会话的所有消息

        Args:
            session_id: 会话ID
            user_id: 用户ID(可选,用于验证权限)

        Returns:
            消息列表
        """
        # 打开只读session
        with self.session() as session:
            # 如指定user_id,须先验证此会话是否属于该用户,不属则不给查
            if user_id:
                chat_session = session.query(ChatSession).filter_by(id=session_id, user_id=user_id).first()
                # 如果会话不存在则返回空列表
                if not chat_session:
                    return []
            # 查询该会话下所有消息,按创建时间升序排序
            messages = session.query(ChatMessage).filter_by(session_id=session_id).order_by(ChatMessage.created_at).all()
            # 返回所有消息的字典列表
            return [m.to_dict() for m in messages]

# 单例: 聊天会话服务对象
session_service = ChatSessionService()

5.2. chat.py #

app/blueprints/chat.py

# 聊天相关路由(视图 + API)
"""
聊天相关路由(视图 + API)
"""

# 导入 Flask 的 Blueprint 和模板渲染函数
from flask import Blueprint, render_template, request, stream_with_context, Response
import json
# 导入日志模块
import logging
# 导入登录保护装饰器和获取当前用户辅助方法
+from app.utils.auth import login_required, api_login_required
# 导入自定义工具函数:成功响应、错误响应、获取分页参数、获取当前用户或错误、异常处理装饰器、检查所有权
from app.blueprints.utils import (
    success_response, error_response,
+   get_current_user_or_error, handle_api_error, get_pagination_params
)
from app.services.chat_service import chat_service
+from app.services.chat_session_service import session_service

# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)

# 创建名为 'chat' 的蓝图对象
bp = Blueprint('chat', __name__)

# 注册 /chat 路由,访问该路由需要先登录
@bp.route('/chat')
@login_required
def chat_view():
    # 智能问答页面视图函数
    """智能问答页面"""
    # 渲染 chat.html 模板并传递空知识库列表
    return render_template('chat.html', knowledgebases=[])

# 注册 API 路由,处理聊天接口 POST 请求
@bp.route('/api/v1/knowledgebases/chat', methods=['POST'])
@api_login_required
@handle_api_error
def api_chat():
    # 普通聊天接口(不支持知识库,支持流式输出)
    """普通聊天接口(不支持知识库,支持流式输出)"""
    # 获取当前用户和错误信息
    current_user, err = get_current_user_or_error()
    # 如果有错误,直接返回错误响应
    if err:
        return err

    # 从请求体获取 JSON 数据
    data = request.get_json()
    # 如果数据为空或不存在 'question' 字段,返回错误
    if not data or 'question' not in data:
        return error_response("question is required", 400)

    # 去除问题文本首尾空格
    question = data['question'].strip()
    # 如果问题内容为空,返回错误
    if not question:
        return error_response("question cannot be empty", 400)
+   session_id = data.get('session_id')  # 会话ID(可以为空,表示普通聊天)
    # 获取 max_tokens 参数,默认 1000
    max_tokens = int(data.get('max_tokens', 1000))
    # 限制最大和最小值在 1~10000 之间
    max_tokens = max(1, min(max_tokens, 10000))  # 限制在 1-10000 之间
    # 从请求数据中获取'stream'字段,默认为True,表示启用流式输出
+   stream = data.get('stream', True)  # 默认启用流式输出

    # 初始化历史消息为None
+   history = None
    # 如果请求中带有session_id,说明有现有会话
+   if session_id:
        # 根据session_id和当前用户ID获取历史消息列表
+       history_messages = session_service.get_messages(session_id, current_user['id'])
        # 将历史消息转换为对话格式,仅保留最近10条
+       history = [
+           {'role': msg.get('role'), 'content': msg.get('content')}
+           for msg in history_messages[-10:]  # 只取最近10条
+       ]

    # 如果请求中没有session_id,说明是新对话,需要新建会话
+   if not session_id:
        # 创建新会话,kb_id设为None表示普通聊天
+       chat_session = session_service.create_session(
+           user_id=current_user['id']
+       )
        # 使用新创建会话的ID作为本次会话ID
+       session_id = chat_session['id']

    # 将用户的问题消息保存到当前会话中
+   session_service.add_message(session_id, 'user', question)

    # 声明用于流式输出的生成器
    @stream_with_context
    def generate():
        try:
            # 用于缓存完整答案内容
            full_answer = ''
            # 调用服务进行流式对话
            for chunk in chat_service.chat_stream(
                question=question,
                temperature=None,  # 使用设置中的值
+               max_tokens=max_tokens,
+               history=history
            ):
                # 如果是内容块,则拼接内容到 full_answer
                if chunk.get('type') == 'content':
                    full_answer += chunk.get('content', '')
                # 以 SSE 协议格式输出数据
                yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
            # 输出对话完成信号
            yield "data: [DONE]\n\n"
            # 保存助手回复
+           if full_answer:
+               session_service.add_message(session_id, 'assistant', full_answer)
        except Exception as e:
            # 发生异常记录日志
            logger.error(f"流式输出时出错: {e}")
            # 构造错误数据块
            error_chunk = {
                "type": "error",
                "content": str(e)
            }
            # 输出错误数据块
            yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"

    # 创建 Response 对象,设置必要的 SSE 响应头部
    response = Response(
        generate(),
        mimetype='text/event-stream',
        headers={
            'Cache-Control': 'no-cache',
            'Connection': 'keep-alive',
            'X-Accel-Buffering': 'no',
            'Content-Type': 'text/event-stream; charset=utf-8'
        }
    )
    # 返回响应
    return response
# 路由装饰器,定义 GET 方法获取会话列表的接口
+@bp.route('/api/v1/knowledgebases/sessions', methods=['GET'])
# API 登录校验装饰器,确保用户已登录
+@api_login_required
# 错误处理装饰器,统一处理接口异常
+@handle_api_error
+def api_list_sessions():
    # 接口描述:获取当前用户的会话列表
+   """获取当前用户的会话列表"""
    # 获取当前用户,如有错误直接返回错误响应
+   current_user, err = get_current_user_or_error()
+   if err:
+       return err

    # 获取分页参数(页码和每页数量),最大单页1000
+   page, page_size = get_pagination_params(max_page_size=1000)
    # 调用会话服务获取当前用户的会话列表
+   result = session_service.list_sessions(current_user['id'], page=page, page_size=page_size)
    # 以统一成功响应格式返回会话列表
+   return success_response(result)    


# 路由装饰器,定义 POST 方法创建会话的接口
+@bp.route('/api/v1/knowledgebases/sessions', methods=['POST'])
+@api_login_required
+@handle_api_error
+def api_create_session():
    # 接口描述:创建新的聊天会话
+   """创建新的聊天会话"""
    # 获取当前用户,如果有错误直接返回
+   current_user, err = get_current_user_or_error()
+   if err:
+       return err

    # 获取请求体中的 JSON 数据,若无返回空字典
+   data = request.get_json() or {}
    # 获取会话标题
+   title = data.get('title')

    # 调用服务创建会话,传入当前用户ID、知识库ID与标题
+   session_obj = session_service.create_session(
+       user_id=current_user['id'],
+       title=title
+   )
    # 返回成功响应及会话对象
+   return success_response(session_obj)


# 路由装饰器,定义 GET 方法获取单个会话详情的接口(带 session_id)
+@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['GET'])
+@api_login_required
+@handle_api_error
+def api_get_session(session_id):
    # 接口描述:获取会话详情和消息
+   """获取会话详情和消息"""
    # 获取当前用户,如有错误直接返回
+   current_user, err = get_current_user_or_error()
+   if err:
+       return err

    # 根据 session_id 获取会话对象,校验所属当前用户
+   session_obj = session_service.get_session_by_id(session_id, current_user['id'])
    # 如果没有找到会话,返回 404 错误
+   if not session_obj:
+       return error_response("Session not found", 404)

    # 获取该会话下的所有消息
+   messages = session_service.get_messages(session_id, current_user['id'])

    # 返回会话详情及消息列表
+   return success_response({
+       'session': session_obj,
+       'messages': messages
+   })


# 路由装饰器,定义 DELETE 方法删除单个会话接口
+@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['DELETE'])
+@api_login_required
+@handle_api_error
+def api_delete_session(session_id):
    # 接口描述:删除会话
+   """删除会话"""
    # 获取当前用户,如有错误直接返回
+   current_user, err = get_current_user_or_error()
+   if err:
+       return err

    # 调用服务删除会话,校验归属当前用户
+   success = session_service.delete_session(session_id, current_user['id'])
    # 若删除成功,返回成功响应,否则返回 404
+   if success:
+       return success_response(None, "Session deleted")
+   else:
+       return error_response("Session not found", 404)


# 路由装饰器,定义 DELETE 方法清空所有会话的接口
+@bp.route('/api/v1/knowledgebases/sessions', methods=['DELETE'])
+@api_login_required
+@handle_api_error
+def api_delete_all_sessions():
    # 接口描述:清空所有会话
+   """清空所有会话"""
    # 获取当前用户,如果有错误直接返回
+   current_user, err = get_current_user_or_error()
+   if err:
+       return err

    # 调用服务删除所有属于当前用户的会话,返回删除数量
+   count = session_service.delete_all_sessions(current_user['id'])
    # 返回成功响应及被删除会话数
+   return success_response({'deleted_count': count}, f"Deleted {count} sessions")

5.3. chat.html #

app/templates/chat.html

{% extends "base.html" %}

{% block title %}智能问答 - RAG Lite{% endblock %}

{% block extra_css %}
<style>
    .chat-container {
        height: calc(100vh - 200px);
        display: flex;
        gap: 1rem;
    }
    .chat-sidebar {
        width: 280px;
        background: #fff;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        display: flex;
        flex-direction: column;
    }
    .chat-main {
        flex: 1;
        background: #fff;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        display: flex;
        flex-direction: column;
    }
    .session-list {
        flex: 1;
        overflow-y: auto;
        padding: 0.5rem;
    }
    .session-item {
        padding: 0.75rem;
        margin-bottom: 0.5rem;
        border-radius: 0.5rem;
        cursor: pointer;
        transition: background-color 0.2s;
        position: relative;
    }
    .session-item:hover {
        background-color: #f8f9fa;
    }
    .session-item.active {
        background-color: #e3f2fd;
        border-left: 3px solid #0d6efd;
    }
    .session-item .session-title {
        font-weight: 500;
        margin-bottom: 0.25rem;
        overflow: hidden;
        text-overflow: ellipsis;
        white-space: nowrap;
    }
    .session-item .session-time {
        font-size: 0.75rem;
        color: #6c757d;
    }
    .session-item .session-delete {
        position: absolute;
        top: 0.5rem;
        right: 0.5rem;
        opacity: 0;
        transition: opacity 0.2s;
    }
    .session-item:hover .session-delete {
        opacity: 1;
    }
    .chat-messages {
        flex: 1;
        overflow-y: auto;
        padding: 1rem;
        scroll-behavior: smooth;
    }
    .chat-message {
        padding: 1rem;
        margin-bottom: 1rem;
        border-radius: 0.5rem;
    }
    .chat-question {
        background-color: #e3f2fd;
    }
    .chat-answer {
        background-color: #f5f5f5;
    }
    .chat-input-area {
        padding: 1rem;
        border-top: 1px solid #dee2e6;
    }
    .empty-state {
        display: flex;
        flex-direction: column;
        align-items: center;
        justify-content: center;
        height: 100%;
        color: #6c757d;
    }
</style>
{% endblock %}

{% block content %}
<div class="chat-container">
    <!-- 左侧:会话管理 -->
    <div class="chat-sidebar">
        <div class="p-3 border-bottom">
            <div class="d-flex justify-content-between align-items-center mb-3">
                <h6 class="mb-0"><i class="bi bi-chat-left-text"></i> 聊天会话</h6>
                <button class="btn btn-sm btn-primary" onclick="createNewSession()">
                    <i class="bi bi-plus"></i> 新建
                </button>
            </div>
            <button class="btn btn-sm btn-outline-danger w-100" onclick="clearAllSessions()">
                <i class="bi bi-trash"></i> 清空所有
            </button>
        </div>
        <div class="session-list" id="sessionList">
            <div class="text-center text-muted py-5">
                <i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
                <p class="mt-2 small">暂无会话</p>
            </div>
        </div>
    </div>

    <!-- 右侧:对话页面 -->
    <div class="chat-main">
        <div class="p-3 border-bottom">
            <div class="d-flex justify-content-between align-items-center">
                <h6 class="mb-0"><i class="bi bi-chat-dots"></i> 对话</h6>
            </div>
        </div>

        <div class="chat-messages" id="chatMessages">
            <div class="empty-state">
                <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
                <p class="mt-2">开始提问吧!</p>
            </div>
        </div>

        <div class="chat-input-area">
            <form id="chatForm" onsubmit="askQuestion(event)">
                <div class="mb-2">
                    <textarea class="form-control" id="questionInput" rows="2" 
                              placeholder="输入您的问题..." required></textarea>
                </div>
                <div class="d-flex justify-content-end">
                    <button type="submit" class="btn btn-primary" id="submitBtn">
                        <i class="bi bi-send"></i> 发送
                    </button>
                </div>
            </form>
        </div>
    </div>
</div>
{% endblock %}

{% block extra_js %}
<script>
+// 当前会话的ID,初始为null,表示暂未选择任何会话
+let currentSessionId = null;
+// 会话列表,初始为空数组,用于存储所有的会话对象
+let sessions = [];
// 为字符串进行HTML转义,防止XSS攻击
function escapeHtml(text) {
    // 创建一个div元素作为容器
    const div = document.createElement('div');
    // 将待转义文本设置为div的textContent(自动完成转义)
    div.textContent = text;
    // 返回转义后的HTML内容
    return div.innerHTML;
}

// 滚动消息框到底部
function scrollToBottom() {
    // 获取聊天消息区域元素
    const chatMessages = document.getElementById('chatMessages');
    // 设置滚动条位置到最底部
    chatMessages.scrollTop = chatMessages.scrollHeight;
}
+// 定义一个用于渲染Markdown文本为HTML的函数
+function renderMarkdown(text) {
+   // 判断marked库是否已加载且能使用解析方法
+   if (typeof marked !== 'undefined' && marked.parse) {
+       try {
+           // 尝试用marked库将Markdown文本解析成HTML
+           return marked.parse(text);
+       } catch (e) {
+           // 如果解析出错,则进行HTML转义并换行
+           return escapeHtml(text).replace(/\n/g, '<br>');
+       }
+   }
+   // 如果marked库不可用,直接做HTML转义并处理换行
+   return escapeHtml(text).replace(/\n/g, '<br>');
+}
// 将Markdown内容渲染到指定元素(支持降级为普通文本)
function renderMarkdownToElement(element, text) {
    // 若未提供内容,显示思考中图标
    if (!text) {
        element.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
        return;
    }
    // 优先判断marked库是否可用(渲染markdown)
    if (typeof marked !== 'undefined' && marked.parse) {
        try {
            // 使用marked进行markdown转html
            element.innerHTML = marked.parse(text);
        } catch (e) {
            // 渲染失败则退化为转义+换行
            element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
        }
    } else {
        // 没有marked库则直接转义+换行显示
        element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
    }
}

// 主函数:处理用户提交问题事件
async function askQuestion(event) {
    // 阻止表单默认提交行为(防止页面刷新)
    event.preventDefault();
    // 获取输入框的用户问题并去除首尾空白
    const question = document.getElementById('questionInput').value.trim();
    // 获取消息显示区域元素
    const chatMessages = document.getElementById('chatMessages');
    // 若问题为空则直接返回
    if (!question) return;
+    // 如果没有会话,创建新会话
+    if (!currentSessionId) {
+       await createNewSession();
+   }
    // 检查并移除初始空白提示(如有)
    if (chatMessages.querySelector('.empty-state')) {
        chatMessages.innerHTML = '';
    }
    // 创建用于展示问题的div元素
    const questionDiv = document.createElement('div');
    // 加上样式: 用户问题
    questionDiv.className = 'chat-message chat-question';
    // 构建用户气泡内容含图标、文本和时间
    questionDiv.innerHTML = `
        <div class="d-flex justify-content-between align-items-start">
            <div class="flex-grow-1">
                <strong><i class="bi bi-person-circle"></i> 问题:</strong>
                <div class="mt-1">${escapeHtml(question)}</div>
            </div>
            <small class="text-muted">${new Date().toLocaleTimeString()}</small>
        </div>
    `;
    // 显示到对话窗口
    chatMessages.appendChild(questionDiv);

    // 创建用于显示答案的div元素
    const answerDiv = document.createElement('div');
    // 答案样式
    answerDiv.className = 'chat-message chat-answer';
    // 动态生成唯一的答案内容div id(用于唯一标记)
    const answerContentId = 'answerContent_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
    // 创建存放答案内容的div
    const answerContent = document.createElement('div');
    // 为markdown渲染容器设置样式和id
    answerContent.className = 'mt-2 markdown-content';
    answerContent.id = answerContentId;
    // 答案div显示机器人图标和时间
    answerDiv.innerHTML = `
        <div class="d-flex justify-content-between align-items-start">
            <div class="flex-grow-1">
                <strong><i class="bi bi-robot"></i> 答案:</strong>
            </div>
            <small class="text-muted">${new Date().toLocaleTimeString()}</small>
        </div>
    `;
    // 将答案内容div插入到机器人气泡div的内容区
    const flexGrowDiv = answerDiv.querySelector('.flex-grow-1');
    flexGrowDiv.appendChild(answerContent);
    // 插入答案div到消息区域
    chatMessages.appendChild(answerDiv);

    // 变量记录完整的回答内容
    let fullAnswer = '';
    // 标记渲染任务是否挂起(防止重复)
    let pendingUpdate = false;
    // 记录定时器id(去抖动用)
    let updateTimer = null;

    // 清空输入框内容
    document.getElementById('questionInput').value = '';
    // 滚动到底
    scrollToBottom();

    // 定义scheduleRender:将markdown渲染插入到队列合适时机执行
    function scheduleRender() {
        // 若当前无待渲染任务才安排渲染
        if (!pendingUpdate) {
            pendingUpdate = true;
            // 下一帧渲染
            requestAnimationFrame(() => {
                // 将答案作为markdown渲染进dom
                renderMarkdownToElement(answerContent, fullAnswer);
                // 清理pending标识
                pendingUpdate = false;
                // 渲染后滚动到底部
                scrollToBottom();
            });
        }
    }


    try {
        // 组装API接口地址
        const url = `/api/v1/knowledgebases/chat`;
        // 请求后端,发起流式POST请求
        const response = await fetch(url, {
            method: 'POST',
            headers: {'Content-Type': 'application/json'},
            body: JSON.stringify({
                question: question,
+               session_id: currentSessionId,
                stream: true
            })
        });
        // 非200响应时,抛出异常
        if (!response.ok) throw new Error('请求失败');
        // 使用ReadableStream reader获取response流
        const reader = response.body.getReader();
        // 用TextDecoder解码二进制数据
        const decoder = new TextDecoder();
        // 数据缓冲区(字符串)
        let buffer = '';
        // 初始展示“思考中...”提示
        answerContent.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
+       // 重新加载会话列表以更新标题
+       await loadSessions();
        // 不断循环读取服务端推送流内容
        while (true) {
            // 逐步读取一段数据
            const { done, value } = await reader.read();
            // 若读完则结束循环
            if (done) break;
            // 本块新数据解码成字符串,并追加到buffer
            buffer += decoder.decode(value, { stream: true });
            // 按行切分(多行分别处理)
            const lines = buffer.split('\n');
            // 最后一行一般为数据残留,不处理,下次拼
            buffer = lines.pop() || '';
            // 处理每一行数据
            for (const line of lines) {
                // 筛选有效SSE协议"data: "数据包
                if (line.startsWith('data: ')) {
                    // 去掉前缀,留下json内容
                    const data = line.slice(6);
                    // [DONE]信号作为流结束,仅跳过
                    if (data === '[DONE]') continue;
                    try {
                        // 解析JSON数据
                        const chunk = JSON.parse(data);
                        // 类型为流起始,清空内容
                        if (chunk.type === 'start') {
                            fullAnswer = '';
                            answerContent.innerHTML = '';
                        // 类型为内容,追加文本并安排去抖渲染
                        } else if (chunk.type === 'content') {
                            fullAnswer += chunk.content;
                            scheduleRender();
                        // 类型为done(流传输结束),直接最终渲染所有答案内容
                        } else if (chunk.type === 'done') {
                            renderMarkdownToElement(answerContent, fullAnswer);
                        // 错误类型,alert显示报错内容
                        } else if (chunk.type === 'error') {
                            answerContent.innerHTML = `<div class="alert alert-danger">${chunk.content}</div>`;
                        }
                    } catch (e) {
                        // JSON解析失败时报console
                        console.error('解析流数据失败:', e);
                    }
                }
            }

        }
    } catch (error) {
        // 通信异常、Fetch错误等: 显示错误气泡
        answerContent.innerHTML = `<div class="alert alert-danger"><strong>错误:</strong> ${error.message}</div>`;
    }
    // 结束后确保界面滚动到底
    scrollToBottom();
}
+// 清空所有会话
+// 异步函数,用于清空所有会话
+async function clearAllSessions() {
+   // 弹窗确认操作,若取消则直接返回
+   if (!confirm('确定要清空所有会话吗?此操作不可恢复!')) return;

+   try {
+       // 发送DELETE请求到服务器,删除所有会话
+       const response = await fetch('/api/v1/knowledgebases/sessions', {
+           method: 'DELETE'
+       });
+       // 获取服务器返回的JSON结果
+       const result = await response.json();
+       // 如果接口返回成功
+       if (result.code === 200) {
+           // 当前会话ID置空
+           currentSessionId = null;
+           // 清空聊天消息内容
+           clearChatMessages();
+           // 重新加载会话列表
+           await loadSessions();
+       }
+   } catch (error) {
+       // 捕获异常并弹窗提示失败原因
+       alert('清空会话失败: ' + error.message);
+   }
+}
+   // 删除会话
+   // 异步函数,用于删除指定的会话
+   async function deleteSession(sessionId) {
+       // 弹窗确认操作,若取消则直接返回
+       if (!confirm('确定要删除这个会话吗?此操作不可恢复!')) return;
+       try {
+           // 发送DELETE请求到服务器,删除指定会话
+           const response = await fetch(`/api/v1/knowledgebases/sessions/${sessionId}`, {
+               method: 'DELETE'
+           });
+           // 获取服务器返回的JSON结果
+           const result = await response.json();
+           // 如果接口返回成功
+           if (result.code === 200) {
+               // 如果删除的是当前会话,清空当前会话ID和聊天消息
+               if (currentSessionId === sessionId) {
+                   currentSessionId = null;
+                   clearChatMessages();
+               }
+               // 重新加载会话列表
+               await loadSessions();
+           } else {
+               // 如果删除失败,显示错误信息
+               alert('删除会话失败: ' + (result.message || '未知错误'));
+           }
+       } catch (error) {
+           // 捕获异常并弹窗提示失败原因
+           alert('删除会话失败: ' + error.message);
+       }
+   }
+// 格式化时间
+// 用于格式化时间字符串为友好显示
+function formatTime(timeStr) {
+   // 如果无有效时间,直接返回空字符串
+   if (!timeStr) return '';
+   // 将字符串转换为Date对象
+   const date = new Date(timeStr);
+   // 获取当前时间
+   const now = new Date();
+   // 计算时间差(毫秒)
+   const diff = now - date;
+   // 换算为分钟数
+   const minutes = Math.floor(diff / 60000);
+   // 换算为小时数
+   const hours = Math.floor(diff / 3600000);
+   // 换算为天数
+   const days = Math.floor(diff / 86400000);
+   // 1分钟内显示“刚刚”
+   if (minutes < 1) return '刚刚';
+   // 1小时内显示“x分钟前”
+   if (minutes < 60) return `${minutes}分钟前`;
+   // 24小时内显示“x小时前”
+   if (hours < 24) return `${hours}小时前`;
+   // 7天内显示“x天前”
+   if (days < 7) return `${days}天前`;
+   // 超过7天显示具体日期
+   return date.toLocaleDateString();
+}

+// 渲染消息
+function renderMessages(messages) {
+   const chatMessages = document.getElementById('chatMessages');

+   if (messages.length === 0) {
+       chatMessages.innerHTML = `
+           <div class="empty-state">
+               <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
+               <p class="mt-2">开始提问吧!</p>
+           </div>
+       `;
+       return;
+   }

+   chatMessages.innerHTML = messages.map(msg => {
+       if (msg.role === 'user') {
+           return `
+               <div class="chat-message chat-question">
+                   <div class="d-flex justify-content-between align-items-start">
+                       <div class="flex-grow-1">
+                           <strong><i class="bi bi-person-circle"></i> 问题:</strong>
+                           <div class="mt-1">${escapeHtml(msg.content)}</div>
+                       </div>
+                       <small class="text-muted">${formatTime(msg.created_at)}</small>
+                   </div>
+               </div>
+           `;
+       } else {
+           return `
+               <div class="chat-message chat-answer">
+                   <div class="d-flex justify-content-between align-items-start">
+                       <div class="flex-grow-1">
+                           <strong><i class="bi bi-robot"></i> 答案:</strong>
+                           <div class="mt-2 markdown-content">${renderMarkdown(msg.content)}</div>
+                       </div>
+                       <small class="text-muted">${formatTime(msg.created_at)}</small>
+                   </div>
+               </div>
+           `;
+       }
+   }).join('');

+   scrollToBottom();
+}
+// 加载会话
+// 定义一个异步函数 loadSession,用于加载指定 sessionId 的会话
+async function loadSession(sessionId) {
+   // 异常捕获,防止请求或处理过程抛错
+   try {
+       // 发送 GET 请求,获取指定会话的详情数据
+       const response = await fetch(`/api/v1/knowledgebases/sessions/${sessionId}`);
+       // 将返回的响应数据解析为 JSON 格式
+       const result = await response.json();

+       // 如果接口请求成功,即 code 等于 200
+       if (result.code === 200) {
+           // 设置当前会话 ID 为传入的 sessionId
+           currentSessionId = sessionId;
+           // 获取会话详情数据
+           const session = result.data.session;
+           // 获取该会话包含的消息列表,如果没有则为 []
+           const messages = result.data.messages || [];
+           // 调用方法将消息渲染到页面
+           renderMessages(messages);
+           // 重新加载全部会话列表,刷新左侧会话栏
+           await loadSessions();
+       }
+   } catch (error) {
+       // 捕获异常并弹窗提示加载失败的原因
+       alert('加载会话失败: ' + error.message);
+   }
+}
+// 渲染会话列表
+// 将sessions数组渲染到界面左侧会话列表
+function renderSessions() {
+   // 获取会话列表的DOM节点
+   const sessionList = document.getElementById('sessionList');
+   // 如果没有任何会话,显示为空状态
+   if (sessions.length === 0) {
+       sessionList.innerHTML = `
+           <div class="text-center text-muted py-5">
+               <i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
+               <p class="mt-2 small">暂无会话</p>
+           </div>
+       `;
+       return;
+   }
+   // 遍历sessions并渲染为会话项
+   sessionList.innerHTML = sessions.map(session => `
+       <div class="session-item ${session.id === currentSessionId ? 'active' : ''}" 
+            onclick="loadSession('${session.id}')">
+           <button class="btn btn-sm btn-link text-danger p-0 session-delete" 
+                   onclick="event.stopPropagation(); deleteSession('${session.id}')">
+               <i class="bi bi-x-lg"></i>
+           </button>
+           <div class="session-title">${session.title || '新对话'}</div>
+           <div class="session-time">${formatTime(session.updated_at)}</div>
+       </div>
+   `).join('');
+}

+// 清空聊天消息
+// 将聊天内容区重置为初始状态
+function clearChatMessages() {
+   document.getElementById('chatMessages').innerHTML = `
+       <div class="empty-state">
+           <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
+           <p class="mt-2">开始提问吧!</p>
+       </div>
+   `;
+}

+// 加载会话列表
+// 异步加载会话列表并渲染到界面
+async function loadSessions() {
+   try {
+       // 发送GET请求获取会话列表
+       const response = await fetch('/api/v1/knowledgebases/sessions');
+       // 解析JSON结果
+       const result = await response.json();
+       // 如果接口返回成功
+       if (result.code === 200) {
+           // 保存会话数组
+           sessions = result.data.items || [];
+           // 渲染会话列表
+           renderSessions();
+       }
+   } catch (error) {
+       // 捕捉异常并在控制台打印
+       console.error('加载会话列表失败:', error);
+   }
+}

+// 创建新会话
+// 异步请求服务端创建新会话
+async function createNewSession() {
+   try {
+       // 发送POST请求创建新会话
+       const response = await fetch('/api/v1/knowledgebases/sessions', {
+           method: 'POST',
+           headers: {'Content-Type': 'application/json'},
+           body: JSON.stringify({})
+       });
+       // 解析返回结果
+       const result = await response.json();
+       // 如果创建成功
+       if (result.code === 200) {
+           // 切换到新会话
+           currentSessionId = result.data.id;
+           // 重新加载会话列表
+           await loadSessions();
+           // 清空聊天内容
+           clearChatMessages();
+       }
+   } catch (error) {
+       // 捕获异常并弹窗提示
+       alert('创建会话失败: ' + error.message);
+   }
+}
+// 监听 DOMContentLoaded 事件,确保页面元素加载完成后执行
+document.addEventListener('DOMContentLoaded', function() {
+   // 加载会话列表
+   loadSessions();
+   // 判断 marked.js 是否已经加载
+   if (typeof marked !== 'undefined') {
+       // 配置 marked.js 的渲染选项
+       marked.setOptions({
+           // 启用软换行
+           breaks: true,
+           // 启用 Github Flavored Markdown
+           gfm: true,
+           // 禁用标题自动生成 id
+           headerIds: false,
+           // 禁用混淆处理
+           mangle: false
+       });
+   }
+});
</script>
{% endblock %}

6.选择知识库 #

6.1. chat.py #

app/blueprints/chat.py

# 聊天相关路由(视图 + API)
"""
聊天相关路由(视图 + API)
"""

# 导入 Flask 的 Blueprint 和模板渲染函数
from flask import Blueprint, render_template, request, stream_with_context, Response
import json
# 导入日志模块
import logging
# 导入登录保护装饰器和获取当前用户辅助方法
+from app.utils.auth import login_required, api_login_required,get_current_user
# 导入知识库服务
+from app.services.knowledgebase_service import kb_service
# 导入自定义工具函数:成功响应、错误响应、获取分页参数、获取当前用户或错误、异常处理装饰器、检查所有权
from app.blueprints.utils import (
    success_response, error_response,
    get_current_user_or_error, handle_api_error, get_pagination_params
)
from app.services.chat_service import chat_service
from app.services.chat_session_service import session_service

# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)

# 创建名为 'chat' 的蓝图对象
bp = Blueprint('chat', __name__)

# 注册 /chat 路由,访问该路由需要先登录
@bp.route('/chat')
@login_required
def chat_view():
    # 智能问答页面视图函数
    """智能问答页面"""
+   current_user = get_current_user()
    # 获取所有知识库(通常用户不会有太多知识库,不需要分页)
+   result = kb_service.list(user_id=current_user['id'], page=1, page_size=1000)
    # 渲染 chat.html 模板并传递空知识库列表
+   return render_template('chat.html', knowledgebases=result['items'])

# 注册 API 路由,处理聊天接口 POST 请求
@bp.route('/api/v1/knowledgebases/chat', methods=['POST'])
@api_login_required
@handle_api_error
def api_chat():
    # 普通聊天接口(不支持知识库,支持流式输出)
    """普通聊天接口(不支持知识库,支持流式输出)"""
    # 获取当前用户和错误信息
    current_user, err = get_current_user_or_error()
    # 如果有错误,直接返回错误响应
    if err:
        return err

    # 从请求体获取 JSON 数据
    data = request.get_json()
    # 如果数据为空或不存在 'question' 字段,返回错误
    if not data or 'question' not in data:
        return error_response("question is required", 400)

    # 去除问题文本首尾空格
    question = data['question'].strip()
    # 如果问题内容为空,返回错误
    if not question:
        return error_response("question cannot be empty", 400)
    session_id = data.get('session_id')  # 会话ID(可以为空,表示普通聊天)
    # 获取 max_tokens 参数,默认 1000
    max_tokens = int(data.get('max_tokens', 1000))
    # 限制最大和最小值在 1~10000 之间
    max_tokens = max(1, min(max_tokens, 10000))  # 限制在 1-10000 之间
    # 从请求数据中获取'stream'字段,默认为True,表示启用流式输出
    stream = data.get('stream', True)  # 默认启用流式输出

    # 初始化历史消息为None
    history = None
    # 如果请求中带有session_id,说明有现有会话
    if session_id:
        # 根据session_id和当前用户ID获取历史消息列表
        history_messages = session_service.get_messages(session_id, current_user['id'])
        # 将历史消息转换为对话格式,仅保留最近10条
        history = [
            {'role': msg.get('role'), 'content': msg.get('content')}
            for msg in history_messages[-10:]  # 只取最近10条
        ]

    # 如果请求中没有session_id,说明是新对话,需要新建会话
    if not session_id:
        # 创建新会话,kb_id设为None表示普通聊天
        chat_session = session_service.create_session(
            user_id=current_user['id']
        )
        # 使用新创建会话的ID作为本次会话ID
        session_id = chat_session['id']

    # 将用户的问题消息保存到当前会话中
    session_service.add_message(session_id, 'user', question)

    # 声明用于流式输出的生成器
    @stream_with_context
    def generate():
        try:
            # 用于缓存完整答案内容
            full_answer = ''
            # 调用服务进行流式对话
            for chunk in chat_service.chat_stream(
                question=question,
                temperature=None,  # 使用设置中的值
                max_tokens=max_tokens,
                history=history
            ):
                # 如果是内容块,则拼接内容到 full_answer
                if chunk.get('type') == 'content':
                    full_answer += chunk.get('content', '')
                # 以 SSE 协议格式输出数据
                yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
            # 输出对话完成信号
            yield "data: [DONE]\n\n"
            # 保存助手回复
            if full_answer:
                session_service.add_message(session_id, 'assistant', full_answer)
        except Exception as e:
            # 发生异常记录日志
            logger.error(f"流式输出时出错: {e}")
            # 构造错误数据块
            error_chunk = {
                "type": "error",
                "content": str(e)
            }
            # 输出错误数据块
            yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"

    # 创建 Response 对象,设置必要的 SSE 响应头部
    response = Response(
        generate(),
        mimetype='text/event-stream',
        headers={
            'Cache-Control': 'no-cache',
            'Connection': 'keep-alive',
            'X-Accel-Buffering': 'no',
            'Content-Type': 'text/event-stream; charset=utf-8'
        }
    )
    # 返回响应
    return response
# 路由装饰器,定义 GET 方法获取会话列表的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['GET'])
# API 登录校验装饰器,确保用户已登录
@api_login_required
# 错误处理装饰器,统一处理接口异常
@handle_api_error
def api_list_sessions():
    # 接口描述:获取当前用户的会话列表
    """获取当前用户的会话列表"""
    # 获取当前用户,如有错误直接返回错误响应
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 获取分页参数(页码和每页数量),最大单页1000
    page, page_size = get_pagination_params(max_page_size=1000)
    # 调用会话服务获取当前用户的会话列表
    result = session_service.list_sessions(current_user['id'], page=page, page_size=page_size)
    # 以统一成功响应格式返回会话列表
    return success_response(result)    


# 路由装饰器,定义 POST 方法创建会话的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['POST'])
@api_login_required
@handle_api_error
def api_create_session():
    # 接口描述:创建新的聊天会话
    """创建新的聊天会话"""
    # 获取当前用户,如果有错误直接返回
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 获取请求体中的 JSON 数据,若无返回空字典
    data = request.get_json() or {}
    # 获取会话标题
    title = data.get('title')

    # 调用服务创建会话,传入当前用户ID、知识库ID与标题
    session_obj = session_service.create_session(
        user_id=current_user['id'],
        title=title
    )
    # 返回成功响应及会话对象
    return success_response(session_obj)


# 路由装饰器,定义 GET 方法获取单个会话详情的接口(带 session_id)
@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['GET'])
@api_login_required
@handle_api_error
def api_get_session(session_id):
    # 接口描述:获取会话详情和消息
    """获取会话详情和消息"""
    # 获取当前用户,如有错误直接返回
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 根据 session_id 获取会话对象,校验所属当前用户
    session_obj = session_service.get_session_by_id(session_id, current_user['id'])
    # 如果没有找到会话,返回 404 错误
    if not session_obj:
        return error_response("Session not found", 404)

    # 获取该会话下的所有消息
    messages = session_service.get_messages(session_id, current_user['id'])

    # 返回会话详情及消息列表
    return success_response({
        'session': session_obj,
        'messages': messages
    })


# 路由装饰器,定义 DELETE 方法删除单个会话接口
@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['DELETE'])
@api_login_required
@handle_api_error
def api_delete_session(session_id):
    # 接口描述:删除会话
    """删除会话"""
    # 获取当前用户,如有错误直接返回
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 调用服务删除会话,校验归属当前用户
    success = session_service.delete_session(session_id, current_user['id'])
    # 若删除成功,返回成功响应,否则返回 404
    if success:
        return success_response(None, "Session deleted")
    else:
        return error_response("Session not found", 404)


# 路由装饰器,定义 DELETE 方法清空所有会话的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['DELETE'])
@api_login_required
@handle_api_error
def api_delete_all_sessions():
    # 接口描述:清空所有会话
    """清空所有会话"""
    # 获取当前用户,如果有错误直接返回
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 调用服务删除所有属于当前用户的会话,返回删除数量
    count = session_service.delete_all_sessions(current_user['id'])
    # 返回成功响应及被删除会话数
    return success_response({'deleted_count': count}, f"Deleted {count} sessions")

6.2. chat.html #

app/templates/chat.html

{% extends "base.html" %}

{% block title %}智能问答 - RAG Lite{% endblock %}

{% block extra_css %}
<style>
    .chat-container {
        height: calc(100vh - 200px);
        display: flex;
        gap: 1rem;
    }
    .chat-sidebar {
        width: 280px;
        background: #fff;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        display: flex;
        flex-direction: column;
    }
    .chat-main {
        flex: 1;
        background: #fff;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        display: flex;
        flex-direction: column;
    }
    .session-list {
        flex: 1;
        overflow-y: auto;
        padding: 0.5rem;
    }
    .session-item {
        padding: 0.75rem;
        margin-bottom: 0.5rem;
        border-radius: 0.5rem;
        cursor: pointer;
        transition: background-color 0.2s;
        position: relative;
    }
    .session-item:hover {
        background-color: #f8f9fa;
    }
    .session-item.active {
        background-color: #e3f2fd;
        border-left: 3px solid #0d6efd;
    }
    .session-item .session-title {
        font-weight: 500;
        margin-bottom: 0.25rem;
        overflow: hidden;
        text-overflow: ellipsis;
        white-space: nowrap;
    }
    .session-item .session-time {
        font-size: 0.75rem;
        color: #6c757d;
    }
    .session-item .session-delete {
        position: absolute;
        top: 0.5rem;
        right: 0.5rem;
        opacity: 0;
        transition: opacity 0.2s;
    }
    .session-item:hover .session-delete {
        opacity: 1;
    }
    .chat-messages {
        flex: 1;
        overflow-y: auto;
        padding: 1rem;
        scroll-behavior: smooth;
    }
    .chat-message {
        padding: 1rem;
        margin-bottom: 1rem;
        border-radius: 0.5rem;
    }
    .chat-question {
        background-color: #e3f2fd;
    }
    .chat-answer {
        background-color: #f5f5f5;
    }
    .chat-input-area {
        padding: 1rem;
        border-top: 1px solid #dee2e6;
    }
    .empty-state {
        display: flex;
        flex-direction: column;
        align-items: center;
        justify-content: center;
        height: 100%;
        color: #6c757d;
    }
</style>
{% endblock %}

{% block content %}
<div class="chat-container">
    <!-- 左侧:会话管理 -->
    <div class="chat-sidebar">
        <div class="p-3 border-bottom">
            <div class="d-flex justify-content-between align-items-center mb-3">
                <h6 class="mb-0"><i class="bi bi-chat-left-text"></i> 聊天会话</h6>
                <button class="btn btn-sm btn-primary" onclick="createNewSession()">
                    <i class="bi bi-plus"></i> 新建
                </button>
            </div>
            <button class="btn btn-sm btn-outline-danger w-100" onclick="clearAllSessions()">
                <i class="bi bi-trash"></i> 清空所有
            </button>
        </div>
        <div class="session-list" id="sessionList">
            <div class="text-center text-muted py-5">
                <i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
                <p class="mt-2 small">暂无会话</p>
            </div>
        </div>
    </div>

    <!-- 右侧:对话页面 -->
    <div class="chat-main">
        <div class="p-3 border-bottom">
            <div class="d-flex justify-content-between align-items-center">
                <h6 class="mb-0"><i class="bi bi-chat-dots"></i> 对话</h6>
+               <select class="form-select form-select-sm" id="kbSelect" style="width: 200px;" onchange="onKbChange()">
+                   <option value="">-- 选择知识库 --</option>
+                   {% for kb in knowledgebases %}
+                   <option value="{{ kb.id }}">{{ kb.name }}</option>
+                   {% endfor %}
+               </select>
            </div>
        </div>

        <div class="chat-messages" id="chatMessages">
            <div class="empty-state">
                <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
                <p class="mt-2">开始提问吧!</p>
            </div>
        </div>

        <div class="chat-input-area">
            <form id="chatForm" onsubmit="askQuestion(event)">
                <div class="mb-2">
                    <textarea class="form-control" id="questionInput" rows="2" 
                              placeholder="输入您的问题..." required></textarea>
                </div>
                <div class="d-flex justify-content-end">
                    <button type="submit" class="btn btn-primary" id="submitBtn">
                        <i class="bi bi-send"></i> 发送
                    </button>
                </div>
            </form>
        </div>
    </div>
</div>
{% endblock %}

{% block extra_js %}
<script>
// 当前会话的ID,初始为null,表示暂未选择任何会话
let currentSessionId = null;
// 会话列表,初始为空数组,用于存储所有的会话对象
let sessions = [];
+// 当前知识库的ID,初始为null,表示暂未选择任何知识库
+let currentKbId = null;
// 为字符串进行HTML转义,防止XSS攻击
function escapeHtml(text) {
    // 创建一个div元素作为容器
    const div = document.createElement('div');
    // 将待转义文本设置为div的textContent(自动完成转义)
    div.textContent = text;
    // 返回转义后的HTML内容
    return div.innerHTML;
}

// 滚动消息框到底部
function scrollToBottom() {
    // 获取聊天消息区域元素
    const chatMessages = document.getElementById('chatMessages');
    // 设置滚动条位置到最底部
    chatMessages.scrollTop = chatMessages.scrollHeight;
}
// 定义一个用于渲染Markdown文本为HTML的函数
function renderMarkdown(text) {
    // 判断marked库是否已加载且能使用解析方法
    if (typeof marked !== 'undefined' && marked.parse) {
        try {
            // 尝试用marked库将Markdown文本解析成HTML
            return marked.parse(text);
        } catch (e) {
            // 如果解析出错,则进行HTML转义并换行
            return escapeHtml(text).replace(/\n/g, '<br>');
        }
    }
    // 如果marked库不可用,直接做HTML转义并处理换行
    return escapeHtml(text).replace(/\n/g, '<br>');
}
// 将Markdown内容渲染到指定元素(支持降级为普通文本)
function renderMarkdownToElement(element, text) {
    // 若未提供内容,显示思考中图标
    if (!text) {
        element.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
        return;
    }
    // 优先判断marked库是否可用(渲染markdown)
    if (typeof marked !== 'undefined' && marked.parse) {
        try {
            // 使用marked进行markdown转html
            element.innerHTML = marked.parse(text);
        } catch (e) {
            // 渲染失败则退化为转义+换行
            element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
        }
    } else {
        // 没有marked库则直接转义+换行显示
        element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
    }
}

// 主函数:处理用户提交问题事件
async function askQuestion(event) {
    // 阻止表单默认提交行为(防止页面刷新)
    event.preventDefault();
    // 获取输入框的用户问题并去除首尾空白
    const question = document.getElementById('questionInput').value.trim();
    // 获取消息显示区域元素
    const chatMessages = document.getElementById('chatMessages');
    // 若问题为空则直接返回
    if (!question) return;
     // 如果没有会话,创建新会话
     if (!currentSessionId) {
        await createNewSession();
    }
    // 检查并移除初始空白提示(如有)
    if (chatMessages.querySelector('.empty-state')) {
        chatMessages.innerHTML = '';
    }
    // 创建用于展示问题的div元素
    const questionDiv = document.createElement('div');
    // 加上样式: 用户问题
    questionDiv.className = 'chat-message chat-question';
    // 构建用户气泡内容含图标、文本和时间
    questionDiv.innerHTML = `
        <div class="d-flex justify-content-between align-items-start">
            <div class="flex-grow-1">
                <strong><i class="bi bi-person-circle"></i> 问题:</strong>
                <div class="mt-1">${escapeHtml(question)}</div>
            </div>
            <small class="text-muted">${new Date().toLocaleTimeString()}</small>
        </div>
    `;
    // 显示到对话窗口
    chatMessages.appendChild(questionDiv);

    // 创建用于显示答案的div元素
    const answerDiv = document.createElement('div');
    // 答案样式
    answerDiv.className = 'chat-message chat-answer';
    // 动态生成唯一的答案内容div id(用于唯一标记)
    const answerContentId = 'answerContent_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
    // 创建存放答案内容的div
    const answerContent = document.createElement('div');
    // 为markdown渲染容器设置样式和id
    answerContent.className = 'mt-2 markdown-content';
    answerContent.id = answerContentId;
    // 答案div显示机器人图标和时间
    answerDiv.innerHTML = `
        <div class="d-flex justify-content-between align-items-start">
            <div class="flex-grow-1">
                <strong><i class="bi bi-robot"></i> 答案:</strong>
            </div>
            <small class="text-muted">${new Date().toLocaleTimeString()}</small>
        </div>
    `;
    // 将答案内容div插入到机器人气泡div的内容区
    const flexGrowDiv = answerDiv.querySelector('.flex-grow-1');
    flexGrowDiv.appendChild(answerContent);
    // 插入答案div到消息区域
    chatMessages.appendChild(answerDiv);

    // 变量记录完整的回答内容
    let fullAnswer = '';
    // 标记渲染任务是否挂起(防止重复)
    let pendingUpdate = false;
    // 记录定时器id(去抖动用)
    let updateTimer = null;

    // 清空输入框内容
    document.getElementById('questionInput').value = '';
    // 滚动到底
    scrollToBottom();

    // 定义scheduleRender:将markdown渲染插入到队列合适时机执行
    function scheduleRender() {
        // 若当前无待渲染任务才安排渲染
        if (!pendingUpdate) {
            pendingUpdate = true;
            // 下一帧渲染
            requestAnimationFrame(() => {
                // 将答案作为markdown渲染进dom
                renderMarkdownToElement(answerContent, fullAnswer);
                // 清理pending标识
                pendingUpdate = false;
                // 渲染后滚动到底部
                scrollToBottom();
            });
        }
    }

    try {
        // 组装API接口地址
        const url = `/api/v1/knowledgebases/chat`;
        // 请求后端,发起流式POST请求
        const response = await fetch(url, {
            method: 'POST',
            headers: {'Content-Type': 'application/json'},
            body: JSON.stringify({
                question: question,
                session_id: currentSessionId,
                stream: true
            })
        });
        // 非200响应时,抛出异常
        if (!response.ok) throw new Error('请求失败');
        // 使用ReadableStream reader获取response流
        const reader = response.body.getReader();
        // 用TextDecoder解码二进制数据
        const decoder = new TextDecoder();
        // 数据缓冲区(字符串)
        let buffer = '';
        // 初始展示“思考中...”提示
        answerContent.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
        // 重新加载会话列表以更新标题
        await loadSessions();
        // 不断循环读取服务端推送流内容
        while (true) {
            // 逐步读取一段数据
            const { done, value } = await reader.read();
            // 若读完则结束循环
            if (done) break;
            // 本块新数据解码成字符串,并追加到buffer
            buffer += decoder.decode(value, { stream: true });
            // 按行切分(多行分别处理)
            const lines = buffer.split('\n');
            // 最后一行一般为数据残留,不处理,下次拼
            buffer = lines.pop() || '';
            // 处理每一行数据
            for (const line of lines) {
                // 筛选有效SSE协议"data: "数据包
                if (line.startsWith('data: ')) {
                    // 去掉前缀,留下json内容
                    const data = line.slice(6);
                    // [DONE]信号作为流结束,仅跳过
                    if (data === '[DONE]') continue;
                    try {
                        // 解析JSON数据
                        const chunk = JSON.parse(data);
                        // 类型为流起始,清空内容
                        if (chunk.type === 'start') {
                            fullAnswer = '';
                            answerContent.innerHTML = '';
                        // 类型为内容,追加文本并安排去抖渲染
                        } else if (chunk.type === 'content') {
                            fullAnswer += chunk.content;
                            scheduleRender();
                        // 类型为done(流传输结束),直接最终渲染所有答案内容
                        } else if (chunk.type === 'done') {
                            renderMarkdownToElement(answerContent, fullAnswer);
                        // 错误类型,alert显示报错内容
                        } else if (chunk.type === 'error') {
                            answerContent.innerHTML = `<div class="alert alert-danger">${chunk.content}</div>`;
                        }
                    } catch (e) {
                        // JSON解析失败时报console
                        console.error('解析流数据失败:', e);
                    }
                }
            }

        }
    } catch (error) {
        // 通信异常、Fetch错误等: 显示错误气泡
        answerContent.innerHTML = `<div class="alert alert-danger"><strong>错误:</strong> ${error.message}</div>`;
    }
    // 结束后确保界面滚动到底
    scrollToBottom();
}
// 清空所有会话
// 异步函数,用于清空所有会话
async function clearAllSessions() {
    // 弹窗确认操作,若取消则直接返回
    if (!confirm('确定要清空所有会话吗?此操作不可恢复!')) return;

    try {
        // 发送DELETE请求到服务器,删除所有会话
        const response = await fetch('/api/v1/knowledgebases/sessions', {
            method: 'DELETE'
        });
        // 获取服务器返回的JSON结果
        const result = await response.json();
        // 如果接口返回成功
        if (result.code === 200) {
            // 当前会话ID置空
            currentSessionId = null;
            // 清空聊天消息内容
            clearChatMessages();
            // 重新加载会话列表
            await loadSessions();
        }
    } catch (error) {
        // 捕获异常并弹窗提示失败原因
        alert('清空会话失败: ' + error.message);
    }
}

// 格式化时间
// 用于格式化时间字符串为友好显示
function formatTime(timeStr) {
    // 如果无有效时间,直接返回空字符串
    if (!timeStr) return '';
    // 将字符串转换为Date对象
    const date = new Date(timeStr);
    // 获取当前时间
    const now = new Date();
    // 计算时间差(毫秒)
    const diff = now - date;
    // 换算为分钟数
    const minutes = Math.floor(diff / 60000);
    // 换算为小时数
    const hours = Math.floor(diff / 3600000);
    // 换算为天数
    const days = Math.floor(diff / 86400000);
    // 1分钟内显示“刚刚”
    if (minutes < 1) return '刚刚';
    // 1小时内显示“x分钟前”
    if (minutes < 60) return `${minutes}分钟前`;
    // 24小时内显示“x小时前”
    if (hours < 24) return `${hours}小时前`;
    // 7天内显示“x天前”
    if (days < 7) return `${days}天前`;
    // 超过7天显示具体日期
    return date.toLocaleDateString();
}

// 渲染消息
function renderMessages(messages) {
    const chatMessages = document.getElementById('chatMessages');

    if (messages.length === 0) {
        chatMessages.innerHTML = `
            <div class="empty-state">
                <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
                <p class="mt-2">开始提问吧!</p>
            </div>
        `;
        return;
    }

    chatMessages.innerHTML = messages.map(msg => {
        if (msg.role === 'user') {
            return `
                <div class="chat-message chat-question">
                    <div class="d-flex justify-content-between align-items-start">
                        <div class="flex-grow-1">
                            <strong><i class="bi bi-person-circle"></i> 问题:</strong>
                            <div class="mt-1">${escapeHtml(msg.content)}</div>
                        </div>
                        <small class="text-muted">${formatTime(msg.created_at)}</small>
                    </div>
                </div>
            `;
        } else {
            return `
                <div class="chat-message chat-answer">
                    <div class="d-flex justify-content-between align-items-start">
                        <div class="flex-grow-1">
                            <strong><i class="bi bi-robot"></i> 答案:</strong>
                            <div class="mt-2 markdown-content">${renderMarkdown(msg.content)}</div>
                        </div>
                        <small class="text-muted">${formatTime(msg.created_at)}</small>
                    </div>
                </div>
            `;
        }
    }).join('');

    scrollToBottom();
}
// 加载会话
// 定义一个异步函数 loadSession,用于加载指定 sessionId 的会话
async function loadSession(sessionId) {
    // 异常捕获,防止请求或处理过程抛错
    try {
        // 发送 GET 请求,获取指定会话的详情数据
        const response = await fetch(`/api/v1/knowledgebases/sessions/${sessionId}`);
        // 将返回的响应数据解析为 JSON 格式
        const result = await response.json();

        // 如果接口请求成功,即 code 等于 200
        if (result.code === 200) {
            // 设置当前会话 ID 为传入的 sessionId
            currentSessionId = sessionId;
            // 获取会话详情数据
            const session = result.data.session;
            // 获取该会话包含的消息列表,如果没有则为 []
            const messages = result.data.messages || [];
            // 调用方法将消息渲染到页面
            renderMessages(messages);
            // 重新加载全部会话列表,刷新左侧会话栏
            await loadSessions();
        }
    } catch (error) {
        // 捕获异常并弹窗提示加载失败的原因
        alert('加载会话失败: ' + error.message);
    }
}
// 渲染会话列表
// 将sessions数组渲染到界面左侧会话列表
function renderSessions() {
    // 获取会话列表的DOM节点
    const sessionList = document.getElementById('sessionList');
    // 如果没有任何会话,显示为空状态
    if (sessions.length === 0) {
        sessionList.innerHTML = `
            <div class="text-center text-muted py-5">
                <i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
                <p class="mt-2 small">暂无会话</p>
            </div>
        `;
        return;
    }
    // 遍历sessions并渲染为会话项
    sessionList.innerHTML = sessions.map(session => `
        <div class="session-item ${session.id === currentSessionId ? 'active' : ''}" 
             onclick="loadSession('${session.id}')">
            <button class="btn btn-sm btn-link text-danger p-0 session-delete" 
                    onclick="event.stopPropagation(); deleteSession('${session.id}')">
                <i class="bi bi-x-lg"></i>
            </button>
            <div class="session-title">${session.title || '新对话'}</div>
            <div class="session-time">${formatTime(session.updated_at)}</div>
        </div>
    `).join('');
}

// 清空聊天消息
// 将聊天内容区重置为初始状态
function clearChatMessages() {
    document.getElementById('chatMessages').innerHTML = `
        <div class="empty-state">
            <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
            <p class="mt-2">开始提问吧!</p>
        </div>
    `;
}

// 加载会话列表
// 异步加载会话列表并渲染到界面
async function loadSessions() {
    try {
        // 发送GET请求获取会话列表
        const response = await fetch('/api/v1/knowledgebases/sessions');
        // 解析JSON结果
        const result = await response.json();
        // 如果接口返回成功
        if (result.code === 200) {
            // 保存会话数组
            sessions = result.data.items || [];
            // 渲染会话列表
            renderSessions();
        }
    } catch (error) {
        // 捕捉异常并在控制台打印
        console.error('加载会话列表失败:', error);
    }
}

// 创建新会话
// 异步请求服务端创建新会话
async function createNewSession() {
    try {
        // 发送POST请求创建新会话
        const response = await fetch('/api/v1/knowledgebases/sessions', {
            method: 'POST',
            headers: {'Content-Type': 'application/json'},
            body: JSON.stringify({})
        });
        // 解析返回结果
        const result = await response.json();
        // 如果创建成功
        if (result.code === 200) {
            // 切换到新会话
            currentSessionId = result.data.id;
            // 重新加载会话列表
            await loadSessions();
            // 清空聊天内容
            clearChatMessages();
        }
    } catch (error) {
        // 捕获异常并弹窗提示
        alert('创建会话失败: ' + error.message);
    }
}
// 监听 DOMContentLoaded 事件,确保页面元素加载完成后执行
document.addEventListener('DOMContentLoaded', function() {
    // 加载会话列表
    loadSessions();
    // 判断 marked.js 是否已经加载
    if (typeof marked !== 'undefined') {
        // 配置 marked.js 的渲染选项
        marked.setOptions({
            // 启用软换行
            breaks: true,
            // 启用 Github Flavored Markdown
            gfm: true,
            // 禁用标题自动生成 id
            headerIds: false,
            // 禁用混淆处理
            mangle: false
        });
    }
});
+// 知识库选择变化处理函数
+async function onKbChange() {
+   // 获取当前下拉选择的知识库ID
+   currentKbId = document.getElementById('kbSelect').value;;

+   // 如果已存在会话,则切换知识库时重新新建一个会话
+   if (currentSessionId) {
+       await createNewSession();
+   }

+   // 启用发送按钮(普通聊天和知识库聊天均可提问)
+   document.getElementById('submitBtn').disabled = false;
+}
</script>
{% endblock %}

7.知识库对话 #

7.1. rag_service.py #

app/services/rag_service.py

"""
RAG 服务
"""
# 导入日志模块
import logging
# 导入 LangChain 的对话提示模板模块
from langchain_core.prompts import ChatPromptTemplate
# 导入自定义 LLM 工厂
from app.utils.llm_factory import LLMFactory
# 导入设置服务
from app.services.settings_service import settings_service

# 设置日志对象
logger = logging.getLogger(__name__)

# 定义 RAGService 类
class RAGService:
    """RAG 服务"""

    # 初始化函数
    def __init__(self):
        """
        初始化服务

        Args:
            settings: 设置字典,如果为 None 则从数据库读取
        """
        # 从设置服务中获取配置信息
        self.settings = settings_service.get()
        # 定义默认系统消息提示词
        default_rag_system_prompt = "你是一个专业的AI助手。请基于文档内容回答问题。"
        # 定义默认查询提示词,包含 context 和 question 占位符
        default_rag_query_prompt = """文档内容:
        {context}

        问题:{question}

        请基于文档内容回答问题。如果文档中没有相关信息,请明确说明。"""

        # 从设置中获取自定义系统消息提示词
        rag_system_prompt_text = self.settings.get('rag_system_prompt')
        # 如果没有设置,使用默认系统提示词
        if not rag_system_prompt_text:
            rag_system_prompt_text = default_rag_system_prompt

        # 从设置中获取自定义查询提示词
        rag_query_prompt_text = self.settings.get('rag_query_prompt')
        # 如果没有设置,使用默认查询提示词
        if not rag_query_prompt_text:
            rag_query_prompt_text = default_rag_query_prompt

        # 构建 RAG 的提示模板,包含系统消息和用户查询部分
        self.rag_prompt = ChatPromptTemplate.from_messages([
            ("system", rag_system_prompt_text),
            ("human", rag_query_prompt_text)
        ])

    # 定义流式问答接口
    def ask_stream(self, kb_id: str, question: str):
        """
        流式问答接口

        Args:
            kb_id: 知识库ID
            question: 问题

        Yields:
            流式数据块
        """
        # 创建带流式输出能力的 LLM 实例
        llm = LLMFactory.create_llm(self.settings)

        # 文档过滤后的结果,暂时为空列表
        filtered_docs  = []
        # 发送流式开始信号
        yield {
            "type": "start",
            "content": ""
        }

        # 构造用于传递给 LLM 的上下文字符串,将所有文档整合为字符串
        context = "\n\n".join([
            f"文档 {i+1} ({doc.metadata.get('doc_name', '未知')}):\n{doc.page_content}"
            for i, doc in enumerate(filtered_docs)
        ])

        # 创建 Rag Prompt 到 LLM 的处理链
        chain = self.rag_prompt | llm

        # 初始化完整答案的字符串
        full_answer = ""
        # 逐块流式生成答案
        for chunk in chain.stream({"context": context, "question": question}):
            # 获取当前输出块内容
            content = chunk.content
            # 如果有内容则累加并 yield 输出内容块
            if content:
                full_answer += content
                yield {
                    "type": "content",
                    "content": content
                }
        # 所有内容输出结束后,发送完成信号和相关元数据
        yield {
            "type": "done",
            "content": "",
            "metadata": {
                'kb_id': kb_id,
                'question': question,
                'retrieved_chunks': len(filtered_docs)
            }
        }    

# 实例化 rag_service,供外部调用
rag_service = RAGService()        

7.2. chat.py #

app/blueprints/chat.py

# 聊天相关路由(视图 + API)
"""
聊天相关路由(视图 + API)
"""

# 导入 Flask 的 Blueprint 和模板渲染函数
from flask import Blueprint, render_template, request, stream_with_context, Response
import json
# 导入日志模块
import logging
# 导入登录保护装饰器和获取当前用户辅助方法
from app.utils.auth import login_required, api_login_required,get_current_user
# 导入知识库服务
from app.services.knowledgebase_service import kb_service
# 导入自定义工具函数:成功响应、错误响应、获取分页参数、获取当前用户或错误、异常处理装饰器、检查所有权
from app.blueprints.utils import (
    success_response, error_response,
+   get_current_user_or_error, handle_api_error, get_pagination_params,check_ownership
)
from app.services.chat_service import chat_service
from app.services.chat_session_service import session_service

# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)

# 创建名为 'chat' 的蓝图对象
bp = Blueprint('chat', __name__)

# 注册 /chat 路由,访问该路由需要先登录
@bp.route('/chat')
@login_required
def chat_view():
    # 智能问答页面视图函数
    """智能问答页面"""
    current_user = get_current_user()
    # 获取所有知识库(通常用户不会有太多知识库,不需要分页)
    result = kb_service.list(user_id=current_user['id'], page=1, page_size=1000)
    # 渲染 chat.html 模板并传递空知识库列表
    return render_template('chat.html', knowledgebases=result['items'])

# 注册 API 路由,处理聊天接口 POST 请求
@bp.route('/api/v1/knowledgebases/chat', methods=['POST'])
@api_login_required
@handle_api_error
def api_chat():
    # 普通聊天接口(不支持知识库,支持流式输出)
    """普通聊天接口(不支持知识库,支持流式输出)"""
    # 获取当前用户和错误信息
    current_user, err = get_current_user_or_error()
    # 如果有错误,直接返回错误响应
    if err:
        return err

    # 从请求体获取 JSON 数据
    data = request.get_json()
    # 如果数据为空或不存在 'question' 字段,返回错误
    if not data or 'question' not in data:
        return error_response("question is required", 400)

    # 去除问题文本首尾空格
    question = data['question'].strip()
    # 如果问题内容为空,返回错误
    if not question:
        return error_response("question cannot be empty", 400)
    session_id = data.get('session_id')  # 会话ID(可以为空,表示普通聊天)
    # 获取 max_tokens 参数,默认 1000
    max_tokens = int(data.get('max_tokens', 1000))
    # 限制最大和最小值在 1~10000 之间
    max_tokens = max(1, min(max_tokens, 10000))  # 限制在 1-10000 之间
    # 从请求数据中获取'stream'字段,默认为True,表示启用流式输出
    stream = data.get('stream', True)  # 默认启用流式输出

    # 初始化历史消息为None
    history = None
    # 如果请求中带有session_id,说明有现有会话
    if session_id:
        # 根据session_id和当前用户ID获取历史消息列表
        history_messages = session_service.get_messages(session_id, current_user['id'])
        # 将历史消息转换为对话格式,仅保留最近10条
        history = [
            {'role': msg.get('role'), 'content': msg.get('content')}
            for msg in history_messages[-10:]  # 只取最近10条
        ]

    # 如果请求中没有session_id,说明是新对话,需要新建会话
    if not session_id:
        # 创建新会话,kb_id设为None表示普通聊天
        chat_session = session_service.create_session(
            user_id=current_user['id']
        )
        # 使用新创建会话的ID作为本次会话ID
        session_id = chat_session['id']

    # 将用户的问题消息保存到当前会话中
    session_service.add_message(session_id, 'user', question)

    # 声明用于流式输出的生成器
    @stream_with_context
    def generate():
        try:
            # 用于缓存完整答案内容
            full_answer = ''
            # 调用服务进行流式对话
            for chunk in chat_service.chat_stream(
                question=question,
                temperature=None,  # 使用设置中的值
                max_tokens=max_tokens,
                history=history
            ):
                # 如果是内容块,则拼接内容到 full_answer
                if chunk.get('type') == 'content':
                    full_answer += chunk.get('content', '')
                # 以 SSE 协议格式输出数据
                yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
            # 输出对话完成信号
            yield "data: [DONE]\n\n"
            # 保存助手回复
            if full_answer:
                session_service.add_message(session_id, 'assistant', full_answer)
        except Exception as e:
            # 发生异常记录日志
            logger.error(f"流式输出时出错: {e}")
            # 构造错误数据块
            error_chunk = {
                "type": "error",
                "content": str(e)
            }
            # 输出错误数据块
            yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"

    # 创建 Response 对象,设置必要的 SSE 响应头部
    response = Response(
        generate(),
        mimetype='text/event-stream',
        headers={
            'Cache-Control': 'no-cache',
            'Connection': 'keep-alive',
            'X-Accel-Buffering': 'no',
            'Content-Type': 'text/event-stream; charset=utf-8'
        }
    )
    # 返回响应
    return response
# 路由装饰器,定义 GET 方法获取会话列表的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['GET'])
# API 登录校验装饰器,确保用户已登录
@api_login_required
# 错误处理装饰器,统一处理接口异常
@handle_api_error
def api_list_sessions():
    # 接口描述:获取当前用户的会话列表
    """获取当前用户的会话列表"""
    # 获取当前用户,如有错误直接返回错误响应
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 获取分页参数(页码和每页数量),最大单页1000
    page, page_size = get_pagination_params(max_page_size=1000)
    # 调用会话服务获取当前用户的会话列表
    result = session_service.list_sessions(current_user['id'], page=page, page_size=page_size)
    # 以统一成功响应格式返回会话列表
    return success_response(result)    


# 路由装饰器,定义 POST 方法创建会话的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['POST'])
@api_login_required
@handle_api_error
def api_create_session():
    # 接口描述:创建新的聊天会话
    """创建新的聊天会话"""
    # 获取当前用户,如果有错误直接返回
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 获取请求体中的 JSON 数据,若无返回空字典
    data = request.get_json() or {}
    # 获取会话标题
    title = data.get('title')

    # 调用服务创建会话,传入当前用户ID、知识库ID与标题
    session_obj = session_service.create_session(
        user_id=current_user['id'],
        title=title
    )
    # 返回成功响应及会话对象
    return success_response(session_obj)


# 路由装饰器,定义 GET 方法获取单个会话详情的接口(带 session_id)
@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['GET'])
@api_login_required
@handle_api_error
def api_get_session(session_id):
    # 接口描述:获取会话详情和消息
    """获取会话详情和消息"""
    # 获取当前用户,如有错误直接返回
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 根据 session_id 获取会话对象,校验所属当前用户
    session_obj = session_service.get_session_by_id(session_id, current_user['id'])
    # 如果没有找到会话,返回 404 错误
    if not session_obj:
        return error_response("Session not found", 404)

    # 获取该会话下的所有消息
    messages = session_service.get_messages(session_id, current_user['id'])

    # 返回会话详情及消息列表
    return success_response({
        'session': session_obj,
        'messages': messages
    })


# 路由装饰器,定义 DELETE 方法删除单个会话接口
@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['DELETE'])
@api_login_required
@handle_api_error
def api_delete_session(session_id):
    # 接口描述:删除会话
    """删除会话"""
    # 获取当前用户,如有错误直接返回
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 调用服务删除会话,校验归属当前用户
    success = session_service.delete_session(session_id, current_user['id'])
    # 若删除成功,返回成功响应,否则返回 404
    if success:
        return success_response(None, "Session deleted")
    else:
        return error_response("Session not found", 404)


# 路由装饰器,定义 DELETE 方法清空所有会话的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['DELETE'])
@api_login_required
@handle_api_error
def api_delete_all_sessions():
    # 接口描述:清空所有会话
    """清空所有会话"""
    # 获取当前用户,如果有错误直接返回
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 调用服务删除所有属于当前用户的会话,返回删除数量
    count = session_service.delete_all_sessions(current_user['id'])
    # 返回成功响应及被删除会话数
    return success_response({'deleted_count': count}, f"Deleted {count} sessions")

# 路由装饰器,指定POST方法用于知识库问答接口
+@bp.route('/api/v1/knowledgebases/<kb_id>/chat', methods=['POST'])
# 装饰器:需要API登录
+@api_login_required
# 装饰器:统一处理API错误
+@handle_api_error
+def api_ask(kb_id):
    # 知识库问答接口(支持流式输出)
+   """知识库问答接口(支持流式输出)"""
    # 获取当前用户和错误信息
+   current_user, err = get_current_user_or_error()
    # 如果获取用户出错,直接返回错误
+   if err:
+       return err

    # 获取指定id的知识库
+   kb = kb_service.get_by_id(kb_id)
    # 检查当前用户是否有权限访问该知识库
+   has_permission, err = check_ownership(kb['user_id'], current_user['id'], "knowledgebase")
    # 如果没有权限,直接返回错误
+   if not has_permission:
+       return err

    # 获取请求中的JSON数据
+   data = request.get_json()

    # 获取并去除问题字符串首尾空白
+   question = data['question'].strip()

    # 从请求数据获取session_id,如果没有则为None
+   session_id = data.get('session_id')  # 会话ID
    # 获取最大token数,默认为1000
+   max_tokens = int(data.get('max_tokens', 1000))
    # 限制max_tokens在1到10000之间
+   max_tokens = max(1, min(max_tokens, 10000))  # 限制在 1-10000 之间

    # 如果没有提供session_id,则为用户和知识库创建一个新会话
+   if not session_id:
+       chat_session = session_service.create_session(
+           user_id=current_user['id'],
+           kb_id=kb_id
+       )
        # 获取新会话的会话ID
+       session_id = chat_session['id']

    # 保存用户输入的问题到消息列表
+   session_service.add_message(session_id, 'user', question)

    # 内部函数:生成流式响应内容
+   @stream_with_context
+   def generate():
+       try:
            # 初始化完整回复内容
+           full_answer = ''
            # 初始化引用信息
+           sources = None

            # 迭代chat_service.ask_stream的每个数据块
+           for chunk in chat_service.ask_stream(
+               kb_id=kb_id,
+               question=question
+           ):
                # 如果块类型为内容,则将内容追加到full_answer
+               if chunk.get('type') == 'content':
+                   full_answer += chunk.get('content', '')
                # 如果块类型为done,则获取sources
+               elif chunk.get('type') == 'done':
+                   sources = chunk.get('sources')

                # 以SSE格式输出该块内容
+               yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"

            # 所有内容输出后发送结束标志
+           yield "data: [DONE]\n\n"

            # 如果有回复内容,则保存机器人助手的回复和引用
+           if full_answer:
+               session_service.add_message(session_id, 'assistant', full_answer)
+       except Exception as e:
            # 如果流式输出出错,在日志中记录错误信息
+           logger.error(f"流式输出时出错: {e}")
            # 构造错误信息块
+           error_chunk = {
+               "type": "error",
+               "content": str(e)
+           }
            # 以SSE格式输出错误信息
+           yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"

    # 构造SSE(服务端事件)响应对象,携带合适的头部信息
+   response = Response(
+       generate(),
+       mimetype='text/event-stream',
+       headers={
+           'Cache-Control': 'no-cache',
+           'Connection': 'keep-alive',
+           'X-Accel-Buffering': 'no',
+           'Content-Type': 'text/event-stream; charset=utf-8'
+       }
+   )
    # 返回响应对象
+   return response

7.3. chat_service.py #

app/services/chat_service.py

"""
问答服务
支持普通聊天和知识库聊天(RAG)
"""
# 导入日志模块
import logging
# 导入可选类型和迭代器类型注解
from typing import Optional, Iterator
# 导入 LLM 工厂,用于创建大语言模型实例
from app.utils.llm_factory import LLMFactory
# 导入 LangChain 的对话模板
from langchain_core.prompts import ChatPromptTemplate
# 导入设置服务,用于获取当前系统设置
from app.services.settings_service import settings_service
# 导入RAG服务
+from app.services.rag_service import rag_service
# 初始化日志记录器
logger = logging.getLogger(__name__)

# 定义问答服务类
class ChatService:
    # 类的初始化方法
    def __init__(self):
        """初始化问答服务"""
        # 获取并保存当前的系统设置
        self.settings = settings_service.get()
    """问答服务(支持普通聊天和RAG)"""
    # 定义流式普通聊天方法,不使用知识库
    def chat_stream(self, question: str, temperature: Optional[float] = None,
                   max_tokens: int = 1000, history: Optional[list] = None) -> Iterator[dict]:
        """
        流式普通聊天接口(不使用知识库)

        Args:
            question: 问题
            temperature: LLM 温度参数(如果为 None,则从设置中读取)
            max_tokens: 最大生成 token 数
            history: 历史对话记录(可选)

        Yields:
            流式数据块
        """
        # 如果没有指定温度,则从设置中获取(默认为 0.7),并限制在 0-2 之间
        if temperature is None:
            temperature = float(self.settings.get('llm_temperature', '0.7'))
            temperature = max(0.0, min(temperature, 2.0))  # 限制在 0-2 之间

        # 获取用于普通聊天的系统提示词
        chat_prompt_text = self.settings.get('chat_system_prompt')
        # 如果系统提示词不存在,则使用默认的提示词
        if not chat_prompt_text:
            chat_prompt_text = '你是一个专业的AI助手。请友好、准确地回答用户的问题。'

        # 创建支持流式输出的 LLM 实例
        llm = LLMFactory.create_llm(self.settings, temperature=temperature, max_tokens=max_tokens, streaming=True)

        # 构造单轮对话消息格式,包含 system 提示和用户问题
        messages = [
                ("system", chat_prompt_text),
                ("human", question)
        ]

        # 从消息创建对话提示模板
        prompt = ChatPromptTemplate.from_messages(messages)
        # 组装 prompt 和 llm,形成链式调用
        chain = prompt | llm

        # 发送流式开头信号
        yield {
            "type": "start",
            "content": ""
        }

        # 初始化完整答案内容
        full_answer = ""
        try:
            # 遍历模型生成的每一段内容
            for chunk in chain.stream({}):
                # 如果 chunk 有内容,提取内容并累加到full_answer
                if hasattr(chunk, 'content') and chunk.content:
                    content = chunk.content
                    full_answer += content
                    # 输出内容块
                    yield {
                        "type": "content",
                        "content": content
                    }
        # 捕获生成过程中的异常,记录日志并产出错误类型的数据块
        except Exception as e:
            logger.error(f"流式生成时出错: {e}")
            yield {
                "type": "error",
                "content": f"生成答案时出错: {str(e)}"
            }
            return

        # 发送流式结束信号,附带元数据(此处无知识库相关内容)
        yield {
            "type": "done",
            "content": "",
            "sources": [],
            "metadata": {
                'question': question,
                'retrieved_chunks': 0,
                'used_chunks': 0
            }
        }
+   def ask_stream(self, kb_id: str, question: str) -> Iterator[dict]:
+       """
+       流式知识库问答接口(使用 RAG)

+       Args:
+           kb_id: 知识库ID
+           question: 问题

+       Yields:
+           流式数据块
+       """
+       return rag_service.ask_stream(
+           kb_id=kb_id,
+           question=question
+       )    

# 创建全局单例 chat_service 实例
chat_service=ChatService()

7.4. chat.html #

app/templates/chat.html

{% extends "base.html" %}

{% block title %}智能问答 - RAG Lite{% endblock %}

{% block extra_css %}
<style>
    .chat-container {
        height: calc(100vh - 200px);
        display: flex;
        gap: 1rem;
    }
    .chat-sidebar {
        width: 280px;
        background: #fff;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        display: flex;
        flex-direction: column;
    }
    .chat-main {
        flex: 1;
        background: #fff;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        display: flex;
        flex-direction: column;
    }
    .session-list {
        flex: 1;
        overflow-y: auto;
        padding: 0.5rem;
    }
    .session-item {
        padding: 0.75rem;
        margin-bottom: 0.5rem;
        border-radius: 0.5rem;
        cursor: pointer;
        transition: background-color 0.2s;
        position: relative;
    }
    .session-item:hover {
        background-color: #f8f9fa;
    }
    .session-item.active {
        background-color: #e3f2fd;
        border-left: 3px solid #0d6efd;
    }
    .session-item .session-title {
        font-weight: 500;
        margin-bottom: 0.25rem;
        overflow: hidden;
        text-overflow: ellipsis;
        white-space: nowrap;
    }
    .session-item .session-time {
        font-size: 0.75rem;
        color: #6c757d;
    }
    .session-item .session-delete {
        position: absolute;
        top: 0.5rem;
        right: 0.5rem;
        opacity: 0;
        transition: opacity 0.2s;
    }
    .session-item:hover .session-delete {
        opacity: 1;
    }
    .chat-messages {
        flex: 1;
        overflow-y: auto;
        padding: 1rem;
        scroll-behavior: smooth;
    }
    .chat-message {
        padding: 1rem;
        margin-bottom: 1rem;
        border-radius: 0.5rem;
    }
    .chat-question {
        background-color: #e3f2fd;
    }
    .chat-answer {
        background-color: #f5f5f5;
    }
    .chat-input-area {
        padding: 1rem;
        border-top: 1px solid #dee2e6;
    }
    .empty-state {
        display: flex;
        flex-direction: column;
        align-items: center;
        justify-content: center;
        height: 100%;
        color: #6c757d;
    }
</style>
{% endblock %}

{% block content %}
<div class="chat-container">
    <!-- 左侧:会话管理 -->
    <div class="chat-sidebar">
        <div class="p-3 border-bottom">
            <div class="d-flex justify-content-between align-items-center mb-3">
                <h6 class="mb-0"><i class="bi bi-chat-left-text"></i> 聊天会话</h6>
                <button class="btn btn-sm btn-primary" onclick="createNewSession()">
                    <i class="bi bi-plus"></i> 新建
                </button>
            </div>
            <button class="btn btn-sm btn-outline-danger w-100" onclick="clearAllSessions()">
                <i class="bi bi-trash"></i> 清空所有
            </button>
        </div>
        <div class="session-list" id="sessionList">
            <div class="text-center text-muted py-5">
                <i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
                <p class="mt-2 small">暂无会话</p>
            </div>
        </div>
    </div>

    <!-- 右侧:对话页面 -->
    <div class="chat-main">
        <div class="p-3 border-bottom">
            <div class="d-flex justify-content-between align-items-center">
                <h6 class="mb-0"><i class="bi bi-chat-dots"></i> 对话</h6>
                <select class="form-select form-select-sm" id="kbSelect" style="width: 200px;" onchange="onKbChange()">
                    <option value="">-- 选择知识库 --</option>
                    {% for kb in knowledgebases %}
                    <option value="{{ kb.id }}">{{ kb.name }}</option>
                    {% endfor %}
                </select>
            </div>
        </div>

        <div class="chat-messages" id="chatMessages">
            <div class="empty-state">
                <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
                <p class="mt-2">开始提问吧!</p>
            </div>
        </div>

        <div class="chat-input-area">
            <form id="chatForm" onsubmit="askQuestion(event)">
                <div class="mb-2">
                    <textarea class="form-control" id="questionInput" rows="2" 
                              placeholder="输入您的问题..." required></textarea>
                </div>
                <div class="d-flex justify-content-end">
                    <button type="submit" class="btn btn-primary" id="submitBtn">
                        <i class="bi bi-send"></i> 发送
                    </button>
                </div>
            </form>
        </div>
    </div>
</div>
{% endblock %}

{% block extra_js %}
<script>
// 当前会话的ID,初始为null,表示暂未选择任何会话
let currentSessionId = null;
// 会话列表,初始为空数组,用于存储所有的会话对象
let sessions = [];
// 当前知识库的ID,初始为null,表示暂未选择任何知识库
let currentKbId = null;
// 为字符串进行HTML转义,防止XSS攻击
function escapeHtml(text) {
    // 创建一个div元素作为容器
    const div = document.createElement('div');
    // 将待转义文本设置为div的textContent(自动完成转义)
    div.textContent = text;
    // 返回转义后的HTML内容
    return div.innerHTML;
}

// 滚动消息框到底部
function scrollToBottom() {
    // 获取聊天消息区域元素
    const chatMessages = document.getElementById('chatMessages');
    // 设置滚动条位置到最底部
    chatMessages.scrollTop = chatMessages.scrollHeight;
}
// 定义一个用于渲染Markdown文本为HTML的函数
function renderMarkdown(text) {
    // 判断marked库是否已加载且能使用解析方法
    if (typeof marked !== 'undefined' && marked.parse) {
        try {
            // 尝试用marked库将Markdown文本解析成HTML
            return marked.parse(text);
        } catch (e) {
            // 如果解析出错,则进行HTML转义并换行
            return escapeHtml(text).replace(/\n/g, '<br>');
        }
    }
    // 如果marked库不可用,直接做HTML转义并处理换行
    return escapeHtml(text).replace(/\n/g, '<br>');
}
// 将Markdown内容渲染到指定元素(支持降级为普通文本)
function renderMarkdownToElement(element, text) {
    // 若未提供内容,显示思考中图标
    if (!text) {
        element.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
        return;
    }
    // 优先判断marked库是否可用(渲染markdown)
    if (typeof marked !== 'undefined' && marked.parse) {
        try {
            // 使用marked进行markdown转html
            element.innerHTML = marked.parse(text);
        } catch (e) {
            // 渲染失败则退化为转义+换行
            element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
        }
    } else {
        // 没有marked库则直接转义+换行显示
        element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
    }
}

// 主函数:处理用户提交问题事件
async function askQuestion(event) {
    // 阻止表单默认提交行为(防止页面刷新)
    event.preventDefault();
    // 获取输入框的用户问题并去除首尾空白
    const question = document.getElementById('questionInput').value.trim();
    // 获取消息显示区域元素
    const chatMessages = document.getElementById('chatMessages');
    // 若问题为空则直接返回
    if (!question) return;
     // 如果没有会话,创建新会话
     if (!currentSessionId) {
        await createNewSession();
    }
    // 检查并移除初始空白提示(如有)
    if (chatMessages.querySelector('.empty-state')) {
        chatMessages.innerHTML = '';
    }
    // 创建用于展示问题的div元素
    const questionDiv = document.createElement('div');
    // 加上样式: 用户问题
    questionDiv.className = 'chat-message chat-question';
    // 构建用户气泡内容含图标、文本和时间
    questionDiv.innerHTML = `
        <div class="d-flex justify-content-between align-items-start">
            <div class="flex-grow-1">
                <strong><i class="bi bi-person-circle"></i> 问题:</strong>
                <div class="mt-1">${escapeHtml(question)}</div>
            </div>
            <small class="text-muted">${new Date().toLocaleTimeString()}</small>
        </div>
    `;
    // 显示到对话窗口
    chatMessages.appendChild(questionDiv);

    // 创建用于显示答案的div元素
    const answerDiv = document.createElement('div');
    // 答案样式
    answerDiv.className = 'chat-message chat-answer';
    // 动态生成唯一的答案内容div id(用于唯一标记)
    const answerContentId = 'answerContent_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
    // 创建存放答案内容的div
    const answerContent = document.createElement('div');
    // 为markdown渲染容器设置样式和id
    answerContent.className = 'mt-2 markdown-content';
    answerContent.id = answerContentId;
    // 答案div显示机器人图标和时间
    answerDiv.innerHTML = `
        <div class="d-flex justify-content-between align-items-start">
            <div class="flex-grow-1">
                <strong><i class="bi bi-robot"></i> 答案:</strong>
            </div>
            <small class="text-muted">${new Date().toLocaleTimeString()}</small>
        </div>
    `;
    // 将答案内容div插入到机器人气泡div的内容区
    const flexGrowDiv = answerDiv.querySelector('.flex-grow-1');
    flexGrowDiv.appendChild(answerContent);
    // 插入答案div到消息区域
    chatMessages.appendChild(answerDiv);

    // 变量记录完整的回答内容
    let fullAnswer = '';
    // 标记渲染任务是否挂起(防止重复)
    let pendingUpdate = false;
    // 记录定时器id(去抖动用)
    let updateTimer = null;

    // 清空输入框内容
    document.getElementById('questionInput').value = '';
    // 滚动到底
    scrollToBottom();

    // 定义scheduleRender:将markdown渲染插入到队列合适时机执行
    function scheduleRender() {
        // 若当前无待渲染任务才安排渲染
        if (!pendingUpdate) {
            pendingUpdate = true;
            // 下一帧渲染
            requestAnimationFrame(() => {
                // 将答案作为markdown渲染进dom
                renderMarkdownToElement(answerContent, fullAnswer);
                // 清理pending标识
                pendingUpdate = false;
                // 渲染后滚动到底部
                scrollToBottom();
            });
        }
    }


    try {
        // 组装API接口地址
+       const url = currentKbId 
+           ? `/api/v1/knowledgebases/${currentKbId}/chat`
+           : `/api/v1/knowledgebases/chat`;
        // 请求后端,发起流式POST请求
        const response = await fetch(url, {
            method: 'POST',
            headers: {'Content-Type': 'application/json'},
            body: JSON.stringify({
                question: question,
                session_id: currentSessionId,
                stream: true
            })
        });
        // 非200响应时,抛出异常
        if (!response.ok) throw new Error('请求失败');
        // 使用ReadableStream reader获取response流
        const reader = response.body.getReader();
        // 用TextDecoder解码二进制数据
        const decoder = new TextDecoder();
        // 数据缓冲区(字符串)
        let buffer = '';
        // 初始展示“思考中...”提示
        answerContent.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
        // 重新加载会话列表以更新标题
        await loadSessions();
        // 不断循环读取服务端推送流内容
        while (true) {
            // 逐步读取一段数据
            const { done, value } = await reader.read();
            // 若读完则结束循环
            if (done) break;
            // 本块新数据解码成字符串,并追加到buffer
            buffer += decoder.decode(value, { stream: true });
            // 按行切分(多行分别处理)
            const lines = buffer.split('\n');
            // 最后一行一般为数据残留,不处理,下次拼
            buffer = lines.pop() || '';
            // 处理每一行数据
            for (const line of lines) {
                // 筛选有效SSE协议"data: "数据包
                if (line.startsWith('data: ')) {
                    // 去掉前缀,留下json内容
                    const data = line.slice(6);
                    // [DONE]信号作为流结束,仅跳过
                    if (data === '[DONE]') continue;
                    try {
                        // 解析JSON数据
                        const chunk = JSON.parse(data);
                        // 类型为流起始,清空内容
                        if (chunk.type === 'start') {
                            fullAnswer = '';
                            answerContent.innerHTML = '';
                        // 类型为内容,追加文本并安排去抖渲染
                        } else if (chunk.type === 'content') {
                            fullAnswer += chunk.content;
                            scheduleRender();
                        // 类型为done(流传输结束),直接最终渲染所有答案内容
                        } else if (chunk.type === 'done') {
                            renderMarkdownToElement(answerContent, fullAnswer);
                        // 错误类型,alert显示报错内容
                        } else if (chunk.type === 'error') {
                            answerContent.innerHTML = `<div class="alert alert-danger">${chunk.content}</div>`;
                        }
                    } catch (e) {
                        // JSON解析失败时报console
                        console.error('解析流数据失败:', e);
                    }
                }
            }

        }
    } catch (error) {
        // 通信异常、Fetch错误等: 显示错误气泡
        answerContent.innerHTML = `<div class="alert alert-danger"><strong>错误:</strong> ${error.message}</div>`;
    }
    // 结束后确保界面滚动到底
    scrollToBottom();
}
// 清空所有会话
// 异步函数,用于清空所有会话
async function clearAllSessions() {
    // 弹窗确认操作,若取消则直接返回
    if (!confirm('确定要清空所有会话吗?此操作不可恢复!')) return;

    try {
        // 发送DELETE请求到服务器,删除所有会话
        const response = await fetch('/api/v1/knowledgebases/sessions', {
            method: 'DELETE'
        });
        // 获取服务器返回的JSON结果
        const result = await response.json();
        // 如果接口返回成功
        if (result.code === 200) {
            // 当前会话ID置空
            currentSessionId = null;
            // 清空聊天消息内容
            clearChatMessages();
            // 重新加载会话列表
            await loadSessions();
        }
    } catch (error) {
        // 捕获异常并弹窗提示失败原因
        alert('清空会话失败: ' + error.message);
    }
}

// 格式化时间
// 用于格式化时间字符串为友好显示
function formatTime(timeStr) {
    // 如果无有效时间,直接返回空字符串
    if (!timeStr) return '';
    // 将字符串转换为Date对象
    const date = new Date(timeStr);
    // 获取当前时间
    const now = new Date();
    // 计算时间差(毫秒)
    const diff = now - date;
    // 换算为分钟数
    const minutes = Math.floor(diff / 60000);
    // 换算为小时数
    const hours = Math.floor(diff / 3600000);
    // 换算为天数
    const days = Math.floor(diff / 86400000);
    // 1分钟内显示“刚刚”
    if (minutes < 1) return '刚刚';
    // 1小时内显示“x分钟前”
    if (minutes < 60) return `${minutes}分钟前`;
    // 24小时内显示“x小时前”
    if (hours < 24) return `${hours}小时前`;
    // 7天内显示“x天前”
    if (days < 7) return `${days}天前`;
    // 超过7天显示具体日期
    return date.toLocaleDateString();
}

// 渲染消息
function renderMessages(messages) {
    const chatMessages = document.getElementById('chatMessages');

    if (messages.length === 0) {
        chatMessages.innerHTML = `
            <div class="empty-state">
                <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
                <p class="mt-2">开始提问吧!</p>
            </div>
        `;
        return;
    }

    chatMessages.innerHTML = messages.map(msg => {
        if (msg.role === 'user') {
            return `
                <div class="chat-message chat-question">
                    <div class="d-flex justify-content-between align-items-start">
                        <div class="flex-grow-1">
                            <strong><i class="bi bi-person-circle"></i> 问题:</strong>
                            <div class="mt-1">${escapeHtml(msg.content)}</div>
                        </div>
                        <small class="text-muted">${formatTime(msg.created_at)}</small>
                    </div>
                </div>
            `;
        } else {
            return `
                <div class="chat-message chat-answer">
                    <div class="d-flex justify-content-between align-items-start">
                        <div class="flex-grow-1">
                            <strong><i class="bi bi-robot"></i> 答案:</strong>
                            <div class="mt-2 markdown-content">${renderMarkdown(msg.content)}</div>
                        </div>
                        <small class="text-muted">${formatTime(msg.created_at)}</small>
                    </div>
                </div>
            `;
        }
    }).join('');

    scrollToBottom();
}
// 加载会话
// 定义一个异步函数 loadSession,用于加载指定 sessionId 的会话
async function loadSession(sessionId) {
    // 异常捕获,防止请求或处理过程抛错
    try {
        // 发送 GET 请求,获取指定会话的详情数据
        const response = await fetch(`/api/v1/knowledgebases/sessions/${sessionId}`);
        // 将返回的响应数据解析为 JSON 格式
        const result = await response.json();

        // 如果接口请求成功,即 code 等于 200
        if (result.code === 200) {
            // 设置当前会话 ID 为传入的 sessionId
            currentSessionId = sessionId;
            // 获取会话详情数据
            const session = result.data.session;
            // 获取该会话包含的消息列表,如果没有则为 []
            const messages = result.data.messages || [];
            // 调用方法将消息渲染到页面
            renderMessages(messages);
            // 重新加载全部会话列表,刷新左侧会话栏
            await loadSessions();
        }
    } catch (error) {
        // 捕获异常并弹窗提示加载失败的原因
        alert('加载会话失败: ' + error.message);
    }
}
// 渲染会话列表
// 将sessions数组渲染到界面左侧会话列表
function renderSessions() {
    // 获取会话列表的DOM节点
    const sessionList = document.getElementById('sessionList');
    // 如果没有任何会话,显示为空状态
    if (sessions.length === 0) {
        sessionList.innerHTML = `
            <div class="text-center text-muted py-5">
                <i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
                <p class="mt-2 small">暂无会话</p>
            </div>
        `;
        return;
    }
    // 遍历sessions并渲染为会话项
    sessionList.innerHTML = sessions.map(session => `
        <div class="session-item ${session.id === currentSessionId ? 'active' : ''}" 
             onclick="loadSession('${session.id}')">
            <button class="btn btn-sm btn-link text-danger p-0 session-delete" 
                    onclick="event.stopPropagation(); deleteSession('${session.id}')">
                <i class="bi bi-x-lg"></i>
            </button>
            <div class="session-title">${session.title || '新对话'}</div>
            <div class="session-time">${formatTime(session.updated_at)}</div>
        </div>
    `).join('');
}

// 清空聊天消息
// 将聊天内容区重置为初始状态
function clearChatMessages() {
    document.getElementById('chatMessages').innerHTML = `
        <div class="empty-state">
            <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
            <p class="mt-2">开始提问吧!</p>
        </div>
    `;
}

// 加载会话列表
// 异步加载会话列表并渲染到界面
async function loadSessions() {
    try {
        // 发送GET请求获取会话列表
        const response = await fetch('/api/v1/knowledgebases/sessions');
        // 解析JSON结果
        const result = await response.json();
        // 如果接口返回成功
        if (result.code === 200) {
            // 保存会话数组
            sessions = result.data.items || [];
            // 渲染会话列表
            renderSessions();
        }
    } catch (error) {
        // 捕捉异常并在控制台打印
        console.error('加载会话列表失败:', error);
    }
}

// 创建新会话
// 异步请求服务端创建新会话
async function createNewSession() {
    try {
        // 发送POST请求创建新会话
        const response = await fetch('/api/v1/knowledgebases/sessions', {
            method: 'POST',
            headers: {'Content-Type': 'application/json'},
            body: JSON.stringify({})
        });
        // 解析返回结果
        const result = await response.json();
        // 如果创建成功
        if (result.code === 200) {
            // 切换到新会话
            currentSessionId = result.data.id;
            // 重新加载会话列表
            await loadSessions();
            // 清空聊天内容
            clearChatMessages();
        }
    } catch (error) {
        // 捕获异常并弹窗提示
        alert('创建会话失败: ' + error.message);
    }
}
// 监听 DOMContentLoaded 事件,确保页面元素加载完成后执行
document.addEventListener('DOMContentLoaded', function() {
    // 加载会话列表
    loadSessions();
    // 判断 marked.js 是否已经加载
    if (typeof marked !== 'undefined') {
        // 配置 marked.js 的渲染选项
        marked.setOptions({
            // 启用软换行
            breaks: true,
            // 启用 Github Flavored Markdown
            gfm: true,
            // 禁用标题自动生成 id
            headerIds: false,
            // 禁用混淆处理
            mangle: false
        });
    }
});
// 知识库选择变化处理函数
async function onKbChange() {
    // 更新全局当前知识库ID变量
    currentKbId = document.getElementById('kbSelect').value;;

    // 如果已存在会话,则切换知识库时重新新建一个会话
    if (currentSessionId) {
        await createNewSession();
    }

    // 启用发送按钮(普通聊天和知识库聊天均可提问)
    document.getElementById('submitBtn').disabled = false;
}
</script>
{% endblock %}

8.向量检索 #

8.1. retrieval_service.py #

app/services/retrieval_service.py

"""
检索服务
支持向量检索、全文检索和混合检索
支持重排序功能
"""

# 导入日志模块
import logging
# 导入类型注解
from typing import List, Optional, Tuple
# 导入Document对象
from langchain_core.documents import Document
# 导入设置信息服务
from app.services.settings_service import settings_service
# 导入向量数据库服务工厂方法
from app.services.vectordb.factory import get_vector_db_service
# 创建日志记录器
logger = logging.getLogger(__name__)

# 检索服务类
class RetrievalService:
    """检索服务"""

    # 初始化方法
    def __init__(self):
        """
        初始化检索服务

        Args:
            settings: 设置字典,如果为 None 则从数据库读取
        """
        # 从设置信息服务获取设置
        self.settings = settings_service.get()

    # 向量检索函数
    def vector_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
        """
        向量检索

        Args:
            collection_name: 集合名称
            query: 查询文本

        Returns:
            (Document, similarity_score) 列表,按相似度降序排列
        """
        # 获取向量数据库实例
        vectordb = get_vector_db_service()
        # 获取或创建指定集合的向量存储对象
        vectorstore = vectordb.get_or_create_collection(collection_name)

        # 从设置获取top_k,若不存在则默认为5
        top_k = int(self.settings.get('top_k', '5'))

        # 判断向量阈值是否已定义,若未定义则获取设定值或默认为0.2
        vector_threshold = float(self.settings.get('vector_threshold', '0.2'))
        # 限定向量阈值在0到1之间
        vector_threshold = max(0.0, min(vector_threshold, 1.0))
        # 以相似度得分方式检索,返回结果(扩大top_k数目以便后续过滤)
        results = vectorstore.similarity_search_with_score(
            query=query,
            k=top_k*3
        )
        # 初始化文档以及得分的列表
        docs_with_scores = []
        # 遍历检索结果,将得分归一化并加入元数据
        for doc, score in results:
            # 计算归一化向量得分
            vector_score = 1.0 / (1.0 + float(score))
            # 存储向量得分到文档元数据
            doc.metadata['vector_score'] = vector_score
            # 标注检索类型为向量
            doc.metadata['retrieval_type'] = 'vector'
            # 加入列表
            docs_with_scores.append((doc, vector_score))

        # 按照相似度分数从高到低排序
        docs_with_scores.sort(key=lambda x: x[1], reverse=True)

        # 根据阈值过滤掉低于阈值的文档
        filtered_docs = [(doc, score) for doc, score in docs_with_scores 
                       if score >= vector_threshold]

        # 仅保留top_k个文档用于返回
        docs = [doc for doc, _ in filtered_docs[:top_k]]

        # 日志打印检索到的文档个数
        logger.info(f"向量搜索: 检索到 {len(docs)} 个文档")
        # 返回结果
        return docs    

# 实例化检索服务,供外部调用
retrieval_service = RetrievalService()

8.2. rag_service.py #

app/services/rag_service.py

"""
RAG 服务
"""
# 导入日志模块
import logging
# 导入 LangChain 的对话提示模板模块
from langchain_core.prompts import ChatPromptTemplate
# 导入自定义 LLM 工厂
from app.utils.llm_factory import LLMFactory
# 导入设置服务
from app.services.settings_service import settings_service
# 导入类型提示
+from typing import List
# 导入 LangChain 的 Document 类型
+from langchain_core.documents import Document
# 导入检索服务
+from app.services.retrieval_service import retrieval_service
# 设置日志对象
logger = logging.getLogger(__name__)

# 定义 RAGService 类
class RAGService:
    """RAG 服务"""

    # 初始化函数
    def __init__(self):
        """
        初始化服务

        Args:
            settings: 设置字典,如果为 None 则从数据库读取
        """
        # 从设置服务中获取配置信息
        self.settings = settings_service.get()
        # 定义默认系统消息提示词
        default_rag_system_prompt = "你是一个专业的AI助手。请基于文档内容回答问题。"
        # 定义默认查询提示词,包含 context 和 question 占位符
        default_rag_query_prompt = """文档内容:
        {context}

        问题:{question}

        请基于文档内容回答问题。如果文档中没有相关信息,请明确说明。"""

        # 从设置中获取自定义系统消息提示词
        rag_system_prompt_text = self.settings.get('rag_system_prompt')
        # 如果没有设置,使用默认系统提示词
        if not rag_system_prompt_text:
            rag_system_prompt_text = default_rag_system_prompt

        # 从设置中获取自定义查询提示词
        rag_query_prompt_text = self.settings.get('rag_query_prompt')
        # 如果没有设置,使用默认查询提示词
        if not rag_query_prompt_text:
            rag_query_prompt_text = default_rag_query_prompt

        # 构建 RAG 的提示模板,包含系统消息和用户查询部分
        self.rag_prompt = ChatPromptTemplate.from_messages([
            ("system", rag_system_prompt_text),
            ("human", rag_query_prompt_text)
        ])

    # 定义流式问答接口
    def ask_stream(self, kb_id: str, question: str):
        """
        流式问答接口

        Args:
            kb_id: 知识库ID
            question: 问题

        Yields:
            流式数据块
        """
        # 创建带流式输出能力的 LLM 实例
        llm = LLMFactory.create_llm(self.settings)

        # 文档过滤后的结果
+       filtered_docs = self._retrieve_documents(kb_id, question)
        # 发送流式开始信号
        yield {
            "type": "start",
            "content": ""
        }

        # 构造用于传递给 LLM 的上下文字符串,将所有文档整合为字符串
        context = "\n\n".join([
            f"文档 {i+1} ({doc.metadata.get('doc_name', '未知')}):\n{doc.page_content}"
            for i, doc in enumerate(filtered_docs)
        ])

        # 创建 Rag Prompt 到 LLM 的处理链
        chain = self.rag_prompt | llm

        # 初始化完整答案的字符串
        full_answer = ""
        # 逐块流式生成答案
        for chunk in chain.stream({"context": context, "question": question}):
            # 获取当前输出块内容
            content = chunk.content
            # 如果有内容则累加并 yield 输出内容块
            if content:
                full_answer += content
                yield {
                    "type": "content",
                    "content": content
                }
        # 所有内容输出结束后,发送完成信号和相关元数据
        yield {
            "type": "done",
            "content": "",
            "metadata": {
                'kb_id': kb_id,
                'question': question,
                'retrieved_chunks': len(filtered_docs)
            }
        }    

+   def _retrieve_documents(self, kb_id: str, question: str) -> List[Document]:
+       """
+       检索文档(公共方法)

+       Args:
+           kb_id: 知识库ID
+           question: 查询文本


+       Returns:
+           文档列表
+       """
+       collection_name = f"kb_{kb_id}"
+       retrieval_mode = self.settings.get('retrieval_mode', 'vector')

+       if retrieval_mode == 'vector':
+           docs = retrieval_service.vector_search(
+               collection_name=collection_name,
+               query=question
+           )
+       else:
+           logger.warning(f"未知的检索模式: {retrieval_mode}, 使用向量检索")
+           docs = retrieval_service.vector_search(
+               collection_name=collection_name,
+               query=question
+           )

+       logger.info(f"使用 {retrieval_mode} 模式检索到 {len(docs)} 个文档")
+       return docs   

# 实例化 rag_service,供外部调用
rag_service = RAGService()        

9.关键字检索 #

9.1. rag_service.py #

app/services/rag_service.py

"""
RAG 服务
"""
# 导入日志模块
import logging
# 导入 LangChain 的对话提示模板模块
from langchain_core.prompts import ChatPromptTemplate
# 导入自定义 LLM 工厂
from app.utils.llm_factory import LLMFactory
# 导入设置服务
from app.services.settings_service import settings_service
# 导入类型提示
from typing import List
# 导入 LangChain 的 Document 类型
from langchain_core.documents import Document
# 导入检索服务
from app.services.retrieval_service import retrieval_service
# 设置日志对象
logger = logging.getLogger(__name__)

# 定义 RAGService 类
class RAGService:
    """RAG 服务"""

    # 初始化函数
    def __init__(self):
        """
        初始化服务

        Args:
            settings: 设置字典,如果为 None 则从数据库读取
        """
        # 从设置服务中获取配置信息
        self.settings = settings_service.get()
        # 定义默认系统消息提示词
        default_rag_system_prompt = "你是一个专业的AI助手。请基于文档内容回答问题。"
        # 定义默认查询提示词,包含 context 和 question 占位符
        default_rag_query_prompt = """文档内容:
        {context}

        问题:{question}

        请基于文档内容回答问题。如果文档中没有相关信息,请明确说明。"""

        # 从设置中获取自定义系统消息提示词
        rag_system_prompt_text = self.settings.get('rag_system_prompt')
        # 如果没有设置,使用默认系统提示词
        if not rag_system_prompt_text:
            rag_system_prompt_text = default_rag_system_prompt

        # 从设置中获取自定义查询提示词
        rag_query_prompt_text = self.settings.get('rag_query_prompt')
        # 如果没有设置,使用默认查询提示词
        if not rag_query_prompt_text:
            rag_query_prompt_text = default_rag_query_prompt

        # 构建 RAG 的提示模板,包含系统消息和用户查询部分
        self.rag_prompt = ChatPromptTemplate.from_messages([
            ("system", rag_system_prompt_text),
            ("human", rag_query_prompt_text)
        ])

    # 定义流式问答接口
    def ask_stream(self, kb_id: str, question: str):
        """
        流式问答接口

        Args:
            kb_id: 知识库ID
            question: 问题

        Yields:
            流式数据块
        """
        # 创建带流式输出能力的 LLM 实例
        llm = LLMFactory.create_llm(self.settings)

        # 文档过滤后的结果
        filtered_docs = self._retrieve_documents(kb_id, question)
        # 发送流式开始信号
        yield {
            "type": "start",
            "content": ""
        }

        # 构造用于传递给 LLM 的上下文字符串,将所有文档整合为字符串
        context = "\n\n".join([
            f"文档 {i+1} ({doc.metadata.get('doc_name', '未知')}):\n{doc.page_content}"
            for i, doc in enumerate(filtered_docs)
        ])

        # 创建 Rag Prompt 到 LLM 的处理链
        chain = self.rag_prompt | llm

        # 初始化完整答案的字符串
        full_answer = ""
        # 逐块流式生成答案
        for chunk in chain.stream({"context": context, "question": question}):
            # 获取当前输出块内容
            content = chunk.content
            # 如果有内容则累加并 yield 输出内容块
            if content:
                full_answer += content
                yield {
                    "type": "content",
                    "content": content
                }
        # 所有内容输出结束后,发送完成信号和相关元数据
        yield {
            "type": "done",
            "content": "",
            "metadata": {
                'kb_id': kb_id,
                'question': question,
                'retrieved_chunks': len(filtered_docs)
            }
        }    

    def _retrieve_documents(self, kb_id: str, question: str) -> List[Document]:
        """
        检索文档(公共方法)

        Args:
            kb_id: 知识库ID
            question: 查询文本


        Returns:
            文档列表
        """
        collection_name = f"kb_{kb_id}"
        retrieval_mode = self.settings.get('retrieval_mode', 'vector')

        if retrieval_mode == 'vector':
            docs = retrieval_service.vector_search(
                collection_name=collection_name,
                query=question
            )
+       elif retrieval_mode == 'keyword':
+           docs = retrieval_service.keyword_search(
+               collection_name=collection_name,
+               query=question
+           )    
        else:
            logger.warning(f"未知的检索模式: {retrieval_mode}, 使用向量检索")
            docs = retrieval_service.vector_search(
                collection_name=collection_name,
                query=question
            )

        logger.info(f"使用 {retrieval_mode} 模式检索到 {len(docs)} 个文档")
        return docs   

# 实例化 rag_service,供外部调用
rag_service = RAGService()        

9.2. retrieval_service.py #

app/services/retrieval_service.py

"""
检索服务
支持向量检索、全文检索和混合检索
支持重排序功能
"""

# 导入日志模块
import logging
# 导入类型注解
from typing import List, Optional, Tuple
# 导入Document对象
from langchain_core.documents import Document
# 导入BM25Okapi
+from rank_bm25 import BM25Okapi
# 导入jieba
+import jieba
# 导入numpy
+import numpy as np
# 导入设置信息服务
from app.services.settings_service import settings_service
# 导入向量数据库服务工厂方法
from app.services.vectordb.factory import get_vector_db_service
# 创建日志记录器
logger = logging.getLogger(__name__)

# 检索服务类
class RetrievalService:
    """检索服务"""

    # 初始化方法
    def __init__(self):
        """
        初始化检索服务

        Args:
            settings: 设置字典,如果为 None 则从数据库读取
        """
        # 从设置信息服务获取设置
        self.settings = settings_service.get()

    # 向量检索函数
    def vector_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
        """
        向量检索

        Args:
            collection_name: 集合名称
            query: 查询文本

        Returns:
            (Document, similarity_score) 列表,按相似度降序排列
        """
        # 获取向量数据库实例
        vectordb = get_vector_db_service()
        # 获取或创建指定集合的向量存储对象
        vectorstore = vectordb.get_or_create_collection(collection_name)

        # 从设置获取top_k,若不存在则默认为5
        top_k = int(self.settings.get('top_k', '5'))

        # 判断向量阈值是否已定义,若未定义则获取设定值或默认为0.2
        vector_threshold = float(self.settings.get('vector_threshold', '0.2'))
        # 限定向量阈值在0到1之间
        vector_threshold = max(0.0, min(vector_threshold, 1.0))
        # 以相似度得分方式检索,返回结果(扩大top_k数目以便后续过滤)
        results = vectorstore.similarity_search_with_score(
            query=query,
            k=top_k*3
        )
        # 初始化文档以及得分的列表
        docs_with_scores = []
        # 遍历检索结果,将得分归一化并加入元数据
        for doc, score in results:
            # 计算归一化向量得分
            vector_score = 1.0 / (1.0 + float(score))
            # 存储向量得分到文档元数据
            doc.metadata['vector_score'] = vector_score
            # 标注检索类型为向量
            doc.metadata['retrieval_type'] = 'vector'
            # 加入列表
            docs_with_scores.append((doc, vector_score))

        # 按照相似度分数从高到低排序
        docs_with_scores.sort(key=lambda x: x[1], reverse=True)

        # 根据阈值过滤掉低于阈值的文档
        filtered_docs = [(doc, score) for doc, score in docs_with_scores 
                       if score >= vector_threshold]

        # 仅保留top_k个文档用于返回
        docs = [doc for doc, _ in filtered_docs[:top_k]]

        # 日志打印检索到的文档个数
        logger.info(f"向量搜索: 检索到 {len(docs)} 个文档")
        # 返回结果
+       return docs  

+   def _tokenize_chinese(self, text: str) -> List[str]:
+       """
+       中文分词(使用 jieba)

+       Args:
+           text: 输入文本

+       Returns:
+           分词后的词列表
+       """
        # 使用 jieba 分词
+       words = jieba.lcut(text)
        # 去除停用词和单字
+       stopwords = set(['的', '了', '在', '是', '和', '有', '与', '对', '等', '为', '也', '就', '都', '要', '可以', '会', '能', '而', '及', '与', '或'])
+       tokens = [word.strip() for word in words if len(word.strip()) > 1 and word.strip() not in stopwords]
+       return tokens
    # 定义关键字检索方法,使用 BM25 算法进行匹配
+   def keyword_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
+       """
+       全文检索(使用 BM25 算法进行关键词匹配)

+       Args:
+           collection_name: 集合名称
+           query: 查询文本

+       Returns:
+           (Document, keyword_score) 列表,按匹配分数降序排列
+       """
+       try:
            # 获取向量数据库服务实例
+           vectordb = get_vector_db_service()
            # 获取或创建指定集合的向量存储对象
+           vectorstore = vectordb.get_or_create_collection(collection_name)

            # 初始化用于存储所有文档的列表
+           all_docs = []
            # 从底层集合获取所有内容
+           results = vectorstore._collection.get()

            # 如果存在有效的检索结果且包含 'ids'
+           if results and 'ids' in results:
                # 遍历所有文档 id
+               for i, _ in enumerate(results['ids']):
                    # 判断 'documents' 键存在且索引不越界
+                   if 'documents' in results and i < len(results['documents']):
                        # 构建 Document 对象,提取内容和元数据
+                       doc = Document(
+                           page_content=results['documents'][i],
+                           metadata=results.get('metadatas', [{}])[i] if 'metadatas' in results else {}
+                       )
                        # 添加到所有文档列表
+                       all_docs.append(doc)

            # 提取所有文档的文本内容
+           documents = [doc.page_content for doc in all_docs]
            # 对每个文档进行分词处理
+           tokenized_docs = [self._tokenize_chinese(doc) for doc in documents]

            # 构建 BM25 索引
+           bm25 = BM25Okapi(tokenized_docs)

            # 对查询语句进行中文分词
+           query_tokens = self._tokenize_chinese(query)

            # 获取每个文档与查询的 BM25 分数
+           scores = bm25.get_scores(query_tokens)

            # 计算分数的最大值,用于归一化分数到 [0, 1] 范围
+           max_score = float(np.max(scores)) if len(scores) > 0 and np.max(scores) > 0 else 1.0
            # 归一化 BM25 分数
+           normalized_scores = scores / max_score if max_score > 0 else scores

            # 获取关键字分数阈值,默认0.5
+           keyword_threshold = float(self.settings.get('keyword_threshold', '0.5'))
            # 限定关键字阈值在0到1之间
+           keyword_threshold = max(0.0, min(keyword_threshold, 1.0))
            # 获取返回文档数量 top_k,默认5
+           top_k = int(self.settings.get('top_k', '5'))
            # 取分数最高的 top_k*3 个索引,便于后续过滤
+           top_indices = np.argsort(normalized_scores)[::-1][:top_k * 3]

            # 初始化结果列表
+           docs_with_scores = []
            # 遍历候选索引
+           for idx in top_indices:
                # 取出每个文档归一化后的分数
+               normalized_score = float(normalized_scores[idx])
                # 确保分数在 [0, 1] 之间
+               normalized_score = max(0.0, min(1.0, normalized_score))
                # 仅保留分数高于阈值的文档
+               if normalized_score >= keyword_threshold:
                    # 取出对应文档
+                   doc = all_docs[idx]
                    # 记录关键词得分到元数据
+                   doc.metadata['keyword_score'] = normalized_score
                    # 标记检索类型
+                   doc.metadata['retrieval_type'] = 'keyword'
                    # 添加到结果列表
+                   docs_with_scores.append((doc, normalized_score))

            # 按得分降序排序
+           docs_with_scores.sort(key=lambda x: x[1], reverse=True)

            # 截取分数最高的前 top_k 个文档
+           docs = [doc for doc, _ in docs_with_scores[:top_k]]

            # 记录 BM25 检索到的文档数量日志
+           logger.info(f"BM25 关键词搜索: 检索到 {len(docs)} 个文档")
            # 返回检索结果
+           return docs
+       except Exception as e:
            # 捕获异常并记录日志
+           logger.error(f"关键词搜索时出错: {e}")
            # 继续抛出异常
+           raise      

# 实例化检索服务,供外部调用
retrieval_service = RetrievalService()

10.混合检索 #

10.1. rag_service.py #

app/services/rag_service.py

"""
RAG 服务
"""
# 导入日志模块
import logging
# 导入 LangChain 的对话提示模板模块
from langchain_core.prompts import ChatPromptTemplate
# 导入自定义 LLM 工厂
from app.utils.llm_factory import LLMFactory
# 导入设置服务
from app.services.settings_service import settings_service
# 导入类型提示
from typing import List
# 导入 LangChain 的 Document 类型
from langchain_core.documents import Document
# 导入检索服务
from app.services.retrieval_service import retrieval_service
# 设置日志对象
logger = logging.getLogger(__name__)

# 定义 RAGService 类
class RAGService:
    """RAG 服务"""

    # 初始化函数
    def __init__(self):
        """
        初始化服务

        Args:
            settings: 设置字典,如果为 None 则从数据库读取
        """
        # 从设置服务中获取配置信息
        self.settings = settings_service.get()
        # 定义默认系统消息提示词
        default_rag_system_prompt = "你是一个专业的AI助手。请基于文档内容回答问题。"
        # 定义默认查询提示词,包含 context 和 question 占位符
        default_rag_query_prompt = """文档内容:
        {context}

        问题:{question}

        请基于文档内容回答问题。如果文档中没有相关信息,请明确说明。"""

        # 从设置中获取自定义系统消息提示词
        rag_system_prompt_text = self.settings.get('rag_system_prompt')
        # 如果没有设置,使用默认系统提示词
        if not rag_system_prompt_text:
            rag_system_prompt_text = default_rag_system_prompt

        # 从设置中获取自定义查询提示词
        rag_query_prompt_text = self.settings.get('rag_query_prompt')
        # 如果没有设置,使用默认查询提示词
        if not rag_query_prompt_text:
            rag_query_prompt_text = default_rag_query_prompt

        # 构建 RAG 的提示模板,包含系统消息和用户查询部分
        self.rag_prompt = ChatPromptTemplate.from_messages([
            ("system", rag_system_prompt_text),
            ("human", rag_query_prompt_text)
        ])

    # 定义流式问答接口
    def ask_stream(self, kb_id: str, question: str):
        """
        流式问答接口

        Args:
            kb_id: 知识库ID
            question: 问题

        Yields:
            流式数据块
        """
        # 创建带流式输出能力的 LLM 实例
        llm = LLMFactory.create_llm(self.settings)

        # 文档过滤后的结果
        filtered_docs = self._retrieve_documents(kb_id, question)
        # 发送流式开始信号
        yield {
            "type": "start",
            "content": ""
        }

        # 构造用于传递给 LLM 的上下文字符串,将所有文档整合为字符串
        context = "\n\n".join([
            f"文档 {i+1} ({doc.metadata.get('doc_name', '未知')}):\n{doc.page_content}"
            for i, doc in enumerate(filtered_docs)
        ])

        # 创建 Rag Prompt 到 LLM 的处理链
        chain = self.rag_prompt | llm

        # 初始化完整答案的字符串
        full_answer = ""
        # 逐块流式生成答案
        for chunk in chain.stream({"context": context, "question": question}):
            # 获取当前输出块内容
            content = chunk.content
            # 如果有内容则累加并 yield 输出内容块
            if content:
                full_answer += content
                yield {
                    "type": "content",
                    "content": content
                }
        # 所有内容输出结束后,发送完成信号和相关元数据
        yield {
            "type": "done",
            "content": "",
            "metadata": {
                'kb_id': kb_id,
                'question': question,
                'retrieved_chunks': len(filtered_docs)
            }
        }    

    def _retrieve_documents(self, kb_id: str, question: str) -> List[Document]:
        """
        检索文档(公共方法)

        Args:
            kb_id: 知识库ID
            question: 查询文本


        Returns:
            文档列表
        """
        collection_name = f"kb_{kb_id}"
        retrieval_mode = self.settings.get('retrieval_mode', 'vector')

        if retrieval_mode == 'vector':
            docs = retrieval_service.vector_search(
                collection_name=collection_name,
                query=question
            )
        elif retrieval_mode == 'keyword':
            docs = retrieval_service.keyword_search(
                collection_name=collection_name,
                query=question
+           )   
+       elif retrieval_mode == 'hybrid':
+           docs = retrieval_service.hybrid_search(
+               collection_name=collection_name,
+               query=question
+           )     
        else:
            logger.warning(f"未知的检索模式: {retrieval_mode}, 使用向量检索")
            docs = retrieval_service.vector_search(
                collection_name=collection_name,
                query=question
            )

        logger.info(f"使用 {retrieval_mode} 模式检索到 {len(docs)} 个文档")
        return docs   

# 实例化 rag_service,供外部调用
rag_service = RAGService()        

10.2. retrieval_service.py #

app/services/retrieval_service.py

"""
检索服务
支持向量检索、全文检索和混合检索
支持重排序功能
"""

# 导入日志模块
import logging
# 导入类型注解
from typing import List, Optional, Tuple
# 导入Document对象
from langchain_core.documents import Document
# 导入BM25Okapi
from rank_bm25 import BM25Okapi
# 导入jieba
import jieba
# 导入numpy
import numpy as np
# 导入设置信息服务
from app.services.settings_service import settings_service
# 导入向量数据库服务工厂方法
from app.services.vectordb.factory import get_vector_db_service
# 创建日志记录器
logger = logging.getLogger(__name__)

# 检索服务类
class RetrievalService:
    """检索服务"""

    # 初始化方法
    def __init__(self):
        """
        初始化检索服务

        Args:
            settings: 设置字典,如果为 None 则从数据库读取
        """
        # 从设置信息服务获取设置
        self.settings = settings_service.get()

    # 向量检索函数
    def vector_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
        """
        向量检索

        Args:
            collection_name: 集合名称
            query: 查询文本

        Returns:
            (Document, similarity_score) 列表,按相似度降序排列
        """
        # 获取向量数据库实例
        vectordb = get_vector_db_service()
        # 获取或创建指定集合的向量存储对象
        vectorstore = vectordb.get_or_create_collection(collection_name)

        # 从设置获取top_k,若不存在则默认为5
        top_k = int(self.settings.get('top_k', '5'))

        # 判断向量阈值是否已定义,若未定义则获取设定值或默认为0.2
        vector_threshold = float(self.settings.get('vector_threshold', '0.2'))
        # 限定向量阈值在0到1之间
        vector_threshold = max(0.0, min(vector_threshold, 1.0))
        # 以相似度得分方式检索,返回结果(扩大top_k数目以便后续过滤)
        results = vectorstore.similarity_search_with_score(
            query=query,
            k=top_k*3
        )
        # 初始化文档以及得分的列表
        docs_with_scores = []
        # 遍历检索结果,将得分归一化并加入元数据
        for doc, score in results:
            # 计算归一化向量得分
            vector_score = 1.0 / (1.0 + float(score))
            # 存储向量得分到文档元数据
            doc.metadata['vector_score'] = vector_score
            # 标注检索类型为向量
            doc.metadata['retrieval_type'] = 'vector'
            # 加入列表
            docs_with_scores.append((doc, vector_score))

        # 按照相似度分数从高到低排序
        docs_with_scores.sort(key=lambda x: x[1], reverse=True)

        # 根据阈值过滤掉低于阈值的文档
        filtered_docs = [(doc, score) for doc, score in docs_with_scores 
                       if score >= vector_threshold]

        # 仅保留top_k个文档用于返回
        docs = [doc for doc, _ in filtered_docs[:top_k]]

        # 日志打印检索到的文档个数
        logger.info(f"向量搜索: 检索到 {len(docs)} 个文档")
        # 返回结果
        return docs  

    def _tokenize_chinese(self, text: str) -> List[str]:
        """
        中文分词(使用 jieba)

        Args:
            text: 输入文本

        Returns:
            分词后的词列表
        """
        # 使用 jieba 分词
        words = jieba.lcut(text)
        # 去除停用词和单字
        stopwords = set(['的', '了', '在', '是', '和', '有', '与', '对', '等', '为', '也', '就', '都', '要', '可以', '会', '能', '而', '及', '与', '或'])
        tokens = [word.strip() for word in words if len(word.strip()) > 1 and word.strip() not in stopwords]
        return tokens
    # 定义关键字检索方法,使用 BM25 算法进行匹配
    def keyword_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
        """
        全文检索(使用 BM25 算法进行关键词匹配)

        Args:
            collection_name: 集合名称
            query: 查询文本

        Returns:
            (Document, keyword_score) 列表,按匹配分数降序排列
        """
        try:
            # 获取向量数据库服务实例
            vectordb = get_vector_db_service()
            # 获取或创建指定集合的向量存储对象
            vectorstore = vectordb.get_or_create_collection(collection_name)

            # 初始化用于存储所有文档的列表
            all_docs = []
            # 从底层集合获取所有内容
            results = vectorstore._collection.get()

            # 如果存在有效的检索结果且包含 'ids'
            if results and 'ids' in results:
                # 遍历所有文档 id
                for i, _ in enumerate(results['ids']):
                    # 判断 'documents' 键存在且索引不越界
                    if 'documents' in results and i < len(results['documents']):
                        # 构建 Document 对象,提取内容和元数据
                        doc = Document(
                            page_content=results['documents'][i],
                            metadata=results.get('metadatas', [{}])[i] if 'metadatas' in results else {}
                        )
                        # 添加到所有文档列表
                        all_docs.append(doc)

            # 提取所有文档的文本内容
            documents = [doc.page_content for doc in all_docs]
            # 对每个文档进行分词处理
            tokenized_docs = [self._tokenize_chinese(doc) for doc in documents]

            # 构建 BM25 索引
            bm25 = BM25Okapi(tokenized_docs)

            # 对查询语句进行中文分词
            query_tokens = self._tokenize_chinese(query)

            # 获取每个文档与查询的 BM25 分数
            scores = bm25.get_scores(query_tokens)

            # 计算分数的最大值,用于归一化分数到 [0, 1] 范围
            max_score = float(np.max(scores)) if len(scores) > 0 and np.max(scores) > 0 else 1.0
            # 归一化 BM25 分数
            normalized_scores = scores / max_score if max_score > 0 else scores

            # 获取关键字分数阈值,默认0.5
            keyword_threshold = float(self.settings.get('keyword_threshold', '0.5'))
            # 限定关键字阈值在0到1之间
            keyword_threshold = max(0.0, min(keyword_threshold, 1.0))
            # 获取返回文档数量 top_k,默认5
            top_k = int(self.settings.get('top_k', '5'))
            # 取分数最高的 top_k*3 个索引,便于后续过滤
            top_indices = np.argsort(normalized_scores)[::-1][:top_k * 3]

            # 初始化结果列表
            docs_with_scores = []
            # 遍历候选索引
            for idx in top_indices:
                # 取出每个文档归一化后的分数
                normalized_score = float(normalized_scores[idx])
                # 确保分数在 [0, 1] 之间
                normalized_score = max(0.0, min(1.0, normalized_score))
                # 仅保留分数高于阈值的文档
                if normalized_score >= keyword_threshold:
                    # 取出对应文档
                    doc = all_docs[idx]
                    # 记录关键词得分到元数据
                    doc.metadata['keyword_score'] = normalized_score
                    # 标记检索类型
                    doc.metadata['retrieval_type'] = 'keyword'
                    # 添加到结果列表
                    docs_with_scores.append((doc, normalized_score))

            # 按得分降序排序
            docs_with_scores.sort(key=lambda x: x[1], reverse=True)

            # 截取分数最高的前 top_k 个文档
            docs = [doc for doc, _ in docs_with_scores[:top_k]]

            # 记录 BM25 检索到的文档数量日志
            logger.info(f"BM25 关键词搜索: 检索到 {len(docs)} 个文档")
            # 返回检索结果
            return docs
        except Exception as e:
            # 捕获异常并记录日志
            logger.error(f"关键词搜索时出错: {e}")
            # 继续抛出异常
            raise      

    # 定义混合检索方法,使用RRF融合向量检索和全文检索
+   def hybrid_search(self, collection_name: str, query: str, rrf_k: int = 60) -> List[Tuple[Document, float]]:
+       """
+       混合检索(使用 Reciprocal Rank Fusion (RRF) 融合向量检索和全文检索)

+       Args:
+           collection_name: 集合名称
+           query: 查询文本
+           rrf_k: RRF 常数,默认 60

+       Returns:
+           (Document, combined_score) 列表,按综合分数降序排列
+       """
+       try:
            # 调用向量检索方法,得到向量检索结果
+           vector_results = self.vector_search(
+               collection_name=collection_name,
+               query=query,
+           )

            # 调用关键词检索方法,得到关键词检索结果
+           keyword_results = self.keyword_search(
+               collection_name=collection_name,
+               query=query
+           )
            # 创建字典用于存储文档及其排名信息
+           doc_ranks = {}

            # 遍历向量检索结果,记录排名及分数
+           for rank, doc in enumerate(vector_results, start=1):
                # 获取文档ID
+               doc_id = doc.metadata.get('id')
                # 若该文档ID不在字典中,进行初始化
+               if doc_id not in doc_ranks:
+                   doc_ranks[doc_id] = {'doc': doc}
                # 记录向量排名
+               doc_ranks[doc_id]['vector_rank'] = rank
                # 记录向量分数
+               doc_ranks[doc_id]['vector_score'] = doc.metadata.get('vector_score', 0.0)

            # 遍历关键词检索结果,记录排名及分数
+           for rank, doc in enumerate(keyword_results, start=1):
                # 获取文档ID
+               doc_id = doc.metadata.get('id')
                # 若该文档ID不在字典中,进行初始化
+               if doc_id not in doc_ranks:
+                   doc_ranks[doc_id] = {'doc': doc}
                # 记录关键词排名
+               doc_ranks[doc_id]['keyword_rank'] = rank
                # 记录关键词分数
+               doc_ranks[doc_id]['keyword_score'] = doc.metadata.get('keyword_score', 0.0)

            # 从设置中读取向量权重,默认为0.3
+           vector_weight = float(self.settings.get('vector_weight', 0.3))
            # 关键词权重等于1减去向量权重
+           keyword_weight = 1.0 - vector_weight
            # 初始化混合结果列表
+           combined_results = []

            # 遍历所有文档,计算RRF融合分数
+           for doc_id, ranks_info in doc_ranks.items():
                # 获取向量排名(不存在默认0)
+               vector_rank = ranks_info.get('vector_rank', 0)
                # 获取关键词排名(不存在默认0)
+               keyword_rank = ranks_info.get('keyword_rank', 0)

                # 初始化RRF分数
+               rrf_score = 0.0
                # 如果有向量排名,则累加向量RRF贡献
+               rrf_score += vector_weight / (rrf_k + vector_rank)
                # 如果有关键词排名,则累加关键词RRF贡献
+               rrf_score += keyword_weight / (rrf_k + keyword_rank)

                # 将RRF分数写入字典
+               doc_ranks[doc_id]['rrf_score'] = rrf_score

            # 组装所有文档及其排名信息
+           combined_results = [(doc_id, rankInfo) for doc_id, rankInfo in doc_ranks.items()]
            # 按RRF分数从高到低排序
+           combined_results.sort(key=lambda x: x[1].get('rrf_score', 0), reverse=True)
            # 获取返回文档数量top_k,默认5
+           top_k = int(self.settings.get('top_k', 5))

            # 遍历前top_k个文档,设置其元数据
+           docs = []
+           for doc_id, rankInfo in combined_results[:top_k]:
+               doc = rankInfo['doc']
                # 更新向量分数到元数据
+               doc.metadata['vector_score'] = rankInfo.get('vector_score', 0.0)
                # 更新关键词分数到元数据
+               doc.metadata['keyword_score'] = rankInfo.get('keyword_score', 0.0)
                # 更新RRF分数到元数据
+               doc.metadata['rrf_score'] = rankInfo.get('rrf_score', 0.0)
                # 标明检索类型为混合
+               doc.metadata['retrieval_type'] = 'hybrid'
                # 添加到最终结果列表
+               docs.append(doc)
            # 记录日志,输出检索到的文档数量
+           logger.info(f"混合搜索 (RRF): 检索到 {len(docs)} 个文档")
            # 返回最终文档列表
+           return docs

+       except Exception as e:
            # 捕获异常并打印日志
+           logger.error(f"混合搜索时出错: {e}")
            # 继续抛出异常
+           raise    

# 实例化检索服务,供外部调用
retrieval_service = RetrievalService()

11.重排序 #

11.1. rerank_factory.py #

app/utils/rerank_factory.py

"""
Rerank 模型工厂
使用固定的 CrossEncoder 模型
"""
# 导入日志库
import logging
# 导入类型提示
from typing import Optional, List, Tuple
# 导入 Document 类
from langchain_core.documents import Document
# 导入 CrossEncoder 类
from sentence_transformers import CrossEncoder

# 创建日志记录器
logger = logging.getLogger(__name__)

# 定义 Rerank 工厂类
class RerankFactory:
    """Rerank 模型工厂"""

    # 指定用于重排序的 CrossEncoder 模型名称
    MODEL_NAME = 'cross-encoder/ms-marco-MiniLM-L-6-v2'

    @staticmethod
    def create_reranker(settings: Optional[dict] = None):
        """
        创建 Rerank 模型(使用固定的 CrossEncoder 模型)

        Args:
            settings: 忽略(未使用)

        Returns:
            Reranker 对象(总是返回,因为总是启用)
        """
        # 返回 LocalReranker 实例
        return LocalReranker()

# 定义重排序基类
class BaseReranker:
    """Rerank 基类"""

    # 定义 rerank 接口,子类需要实现
    def rerank(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Tuple[Document, float]]:
        """
        对文档进行重排序

        Args:
            query: 查询文本
            documents: 文档列表
            top_k: 返回前N个结果,如果为None则返回所有

        Returns:
            (Document, score) 列表,按分数降序排列
        """
        # 抛出未实现异常
        raise NotImplementedError

# 本地 CrossEncoder 重排序实现类
class LocalReranker(BaseReranker):
    """本地 Rerank 模型(使用 CrossEncoder)"""

    def __init__(self):
        try:
            # 实例化 CrossEncoder
            self.reranker = CrossEncoder(RerankFactory.MODEL_NAME)
            # 记录 CrossEncoder 初始化信息
            logger.info(f"已创建 CrossEncoder 重排序器: {RerankFactory.MODEL_NAME}")
        except ImportError:
            # 没有安装依赖时抛出异常
            raise ImportError("Please install sentence-transformers: pip install sentence-transformers")
        except Exception as e:
            # 其他初始化异常时记录日志并抛出异常
            logger.error(f"初始化 CrossEncoder 重排序器时出错: {e}")
            raise

    def rerank(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Tuple[Document, float]]:
        """使用 CrossEncoder 进行重排序"""
        # 如果文档为空,直接返回空列表
        if not documents:
            return []

        # 如果没有指定 top_k,则默认为文档总数
        top_k = top_k or len(documents)

        try:
            # 构造输入对,每个文档和 query 组成一组
            pairs = [[query, doc.page_content] for doc in documents]

            # 用 CrossEncoder 模型计算每个对的相关性分数
            scores = self.reranker.predict(pairs)

            # 将分数转换为列表
            scores = list(scores)
            # 将分数转为 float 类型
            scores_float = [float(score) for score in scores]
            # 计算分数中的最小值
            min_score = min(scores_float) if scores_float else 0.0
            # 计算分数中的最大值
            max_score = max(scores_float) if scores_float else 1.0

            # 对分数做归一化(如果最大最小相等,则全为0)
            normalized_scores = [(score - min_score) / (max_score - min_score) if max_score > min_score else 0.0
                                 for score in scores_float]
            # 构造 (Document, score) 的元组,并限制 score 在 [0, 1] 范围
            doc_scores = [(doc, max(0.0, min(1.0, score)))
                          for doc, score in zip(documents, normalized_scores)]

            # 按归一化后分数降序排序
            doc_scores.sort(key=lambda x: x[1], reverse=True)

            # 记录重排序日志
            logger.info(f"CrossEncoder 重排序: 已重排序 {len(doc_scores)} 个文档")
            # 返回分数最高的 top_k 个文档
            return doc_scores[:top_k]
        except Exception as e:
            # 发生异常时记录日志,并返回默认分数
            logger.error(f"CrossEncoder 重排序时出错: {e}")
            import traceback
            logger.error(traceback.format_exc())
            # 所有文档均赋分 0.5 返回
            return [(doc, 0.5) for doc in documents[:top_k]]

11.2. retrieval_service.py #

app/services/retrieval_service.py

"""
检索服务
支持向量检索、全文检索和混合检索
支持重排序功能
"""

# 导入日志模块
import logging
# 导入类型注解
+from typing import List, Tuple
# 导入Document对象
from langchain_core.documents import Document
# 导入BM25Okapi
from rank_bm25 import BM25Okapi
# 导入jieba
import jieba
# 导入numpy
import numpy as np
# 导入设置信息服务
from app.services.settings_service import settings_service
# 导入向量数据库服务工厂方法
from app.services.vectordb.factory import get_vector_db_service
# 导入重排序工厂
+from app.utils.rerank_factory import RerankFactory
# 创建日志记录器
logger = logging.getLogger(__name__)

# 检索服务类
class RetrievalService:
    """检索服务"""

    # 初始化方法
    def __init__(self):
        """
        初始化检索服务

        Args:
            settings: 设置字典,如果为 None 则从数据库读取
        """
        # 从设置信息服务获取设置
        self.settings = settings_service.get()
        # 初始化reranker
+       self.reranker = RerankFactory.create_reranker(self.settings)

    # 向量检索函数
    def vector_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
        """
        向量检索

        Args:
            collection_name: 集合名称
            query: 查询文本

        Returns:
            (Document, similarity_score) 列表,按相似度降序排列
        """
        # 获取向量数据库实例
        vectordb = get_vector_db_service()
        # 获取或创建指定集合的向量存储对象
        vectorstore = vectordb.get_or_create_collection(collection_name)

        # 从设置获取top_k,若不存在则默认为5
        top_k = int(self.settings.get('top_k', '5'))

        # 判断向量阈值是否已定义,若未定义则获取设定值或默认为0.2
        vector_threshold = float(self.settings.get('vector_threshold', '0.2'))
        # 限定向量阈值在0到1之间
        vector_threshold = max(0.0, min(vector_threshold, 1.0))
        # 以相似度得分方式检索,返回结果(扩大top_k数目以便后续过滤)
        results = vectorstore.similarity_search_with_score(
            query=query,
            k=top_k*3
        )
        # 初始化文档以及得分的列表
        docs_with_scores = []
        # 遍历检索结果,将得分归一化并加入元数据
        for doc, score in results:
            # 计算归一化向量得分
            vector_score = 1.0 / (1.0 + float(score))
            # 存储向量得分到文档元数据
            doc.metadata['vector_score'] = vector_score
            # 标注检索类型为向量
            doc.metadata['retrieval_type'] = 'vector'
            # 加入列表
            docs_with_scores.append((doc, vector_score))

        # 按照相似度分数从高到低排序
        docs_with_scores.sort(key=lambda x: x[1], reverse=True)

        # 根据阈值过滤掉低于阈值的文档
        filtered_docs = [(doc, score) for doc, score in docs_with_scores 
                       if score >= vector_threshold]

        # 仅保留top_k个文档用于返回
        docs = [doc for doc, _ in filtered_docs[:top_k]]

        # 应用重排序
+       if self.reranker:
+           docs = self._apply_rerank(query, docs, top_k)

        # 日志打印检索到的文档个数
        logger.info(f"向量搜索: 检索到 {len(docs)} 个文档")
        # 返回结果
        return docs  

    def _tokenize_chinese(self, text: str) -> List[str]:
        """
        中文分词(使用 jieba)

        Args:
            text: 输入文本

        Returns:
            分词后的词列表
        """
        # 使用 jieba 分词
        words = jieba.lcut(text)
        # 去除停用词和单字
        stopwords = set(['的', '了', '在', '是', '和', '有', '与', '对', '等', '为', '也', '就', '都', '要', '可以', '会', '能', '而', '及', '与', '或'])
        tokens = [word.strip() for word in words if len(word.strip()) > 1 and word.strip() not in stopwords]
        return tokens
    # 定义关键字检索方法,使用 BM25 算法进行匹配
    def keyword_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
        """
        全文检索(使用 BM25 算法进行关键词匹配)

        Args:
            collection_name: 集合名称
            query: 查询文本

        Returns:
            (Document, keyword_score) 列表,按匹配分数降序排列
        """
        try:
            # 获取向量数据库服务实例
            vectordb = get_vector_db_service()
            # 获取或创建指定集合的向量存储对象
            vectorstore = vectordb.get_or_create_collection(collection_name)

            # 初始化用于存储所有文档的列表
            all_docs = []
            # 从底层集合获取所有内容
            results = vectorstore._collection.get()

            # 如果存在有效的检索结果且包含 'ids'
            if results and 'ids' in results:
                # 遍历所有文档 id
                for i, _ in enumerate(results['ids']):
                    # 判断 'documents' 键存在且索引不越界
                    if 'documents' in results and i < len(results['documents']):
                        # 构建 Document 对象,提取内容和元数据
                        doc = Document(
                            page_content=results['documents'][i],
                            metadata=results.get('metadatas', [{}])[i] if 'metadatas' in results else {}
                        )
                        # 添加到所有文档列表
                        all_docs.append(doc)

            # 提取所有文档的文本内容
            documents = [doc.page_content for doc in all_docs]
            # 对每个文档进行分词处理
            tokenized_docs = [self._tokenize_chinese(doc) for doc in documents]

            # 构建 BM25 索引
            bm25 = BM25Okapi(tokenized_docs)

            # 对查询语句进行中文分词
            query_tokens = self._tokenize_chinese(query)

            # 获取每个文档与查询的 BM25 分数
            scores = bm25.get_scores(query_tokens)

            # 计算分数的最大值,用于归一化分数到 [0, 1] 范围
            max_score = float(np.max(scores)) if len(scores) > 0 and np.max(scores) > 0 else 1.0
            # 归一化 BM25 分数
            normalized_scores = scores / max_score if max_score > 0 else scores

            # 获取关键字分数阈值,默认0.5
            keyword_threshold = float(self.settings.get('keyword_threshold', '0.5'))
            # 限定关键字阈值在0到1之间
            keyword_threshold = max(0.0, min(keyword_threshold, 1.0))
            # 获取返回文档数量 top_k,默认5
            top_k = int(self.settings.get('top_k', '5'))
            # 取分数最高的 top_k*3 个索引,便于后续过滤
            top_indices = np.argsort(normalized_scores)[::-1][:top_k * 3]

            # 初始化结果列表
            docs_with_scores = []
            # 遍历候选索引
            for idx in top_indices:
                # 取出每个文档归一化后的分数
                normalized_score = float(normalized_scores[idx])
                # 确保分数在 [0, 1] 之间
                normalized_score = max(0.0, min(1.0, normalized_score))
                # 仅保留分数高于阈值的文档
                if normalized_score >= keyword_threshold:
                    # 取出对应文档
                    doc = all_docs[idx]
                    # 记录关键词得分到元数据
                    doc.metadata['keyword_score'] = normalized_score
                    # 标记检索类型
                    doc.metadata['retrieval_type'] = 'keyword'
                    # 添加到结果列表
                    docs_with_scores.append((doc, normalized_score))

            # 按得分降序排序
            docs_with_scores.sort(key=lambda x: x[1], reverse=True)

            # 截取分数最高的前 top_k 个文档
            docs = [doc for doc, _ in docs_with_scores[:top_k]]

            # 应用重排序
+           if self.reranker:
+               docs = self._apply_rerank(query, docs, top_k)

            # 记录 BM25 检索到的文档数量日志
            logger.info(f"BM25 关键词搜索: 检索到 {len(docs)} 个文档")
            # 返回检索结果
            return docs
        except Exception as e:
            # 捕获异常并记录日志
            logger.error(f"关键词搜索时出错: {e}")
            # 继续抛出异常
            raise      

    # 定义混合检索方法,使用RRF融合向量检索和全文检索
    def hybrid_search(self, collection_name: str, query: str, rrf_k: int = 60) -> List[Tuple[Document, float]]:
        """
        混合检索(使用 Reciprocal Rank Fusion (RRF) 融合向量检索和全文检索)

        Args:
            collection_name: 集合名称
            query: 查询文本
            rrf_k: RRF 常数,默认 60

        Returns:
            (Document, combined_score) 列表,按综合分数降序排列
        """
        try:
            # 调用向量检索方法,得到向量检索结果
            vector_results = self.vector_search(
                collection_name=collection_name,
                query=query,
            )

            # 调用关键词检索方法,得到关键词检索结果
            keyword_results = self.keyword_search(
                collection_name=collection_name,
                query=query
            )
            # 创建字典用于存储文档及其排名信息
            doc_ranks = {}

            # 遍历向量检索结果,记录排名及分数
            for rank, doc in enumerate(vector_results, start=1):
                # 获取文档ID
                doc_id = doc.metadata.get('id')
                # 若该文档ID不在字典中,进行初始化
                if doc_id not in doc_ranks:
                    doc_ranks[doc_id] = {'doc': doc}
                # 记录向量排名
                doc_ranks[doc_id]['vector_rank'] = rank
                # 记录向量分数
                doc_ranks[doc_id]['vector_score'] = doc.metadata.get('vector_score', 0.0)

            # 遍历关键词检索结果,记录排名及分数
            for rank, doc in enumerate(keyword_results, start=1):
                # 获取文档ID
                doc_id = doc.metadata.get('id')
                # 若该文档ID不在字典中,进行初始化
                if doc_id not in doc_ranks:
                    doc_ranks[doc_id] = {'doc': doc}
                # 记录关键词排名
                doc_ranks[doc_id]['keyword_rank'] = rank
                # 记录关键词分数
                doc_ranks[doc_id]['keyword_score'] = doc.metadata.get('keyword_score', 0.0)

            # 从设置中读取向量权重,默认为0.3
            vector_weight = float(self.settings.get('vector_weight', 0.3))
            # 关键词权重等于1减去向量权重
            keyword_weight = 1.0 - vector_weight
            # 初始化混合结果列表
            combined_results = []

            # 遍历所有文档,计算RRF融合分数
            for doc_id, ranks_info in doc_ranks.items():
                # 获取向量排名(不存在默认0)
                vector_rank = ranks_info.get('vector_rank', 0)
                # 获取关键词排名(不存在默认0)
                keyword_rank = ranks_info.get('keyword_rank', 0)

                # 初始化RRF分数
                rrf_score = 0.0
                # 如果有向量排名,则累加向量RRF贡献
                rrf_score += vector_weight / (rrf_k + vector_rank)
                # 如果有关键词排名,则累加关键词RRF贡献
                rrf_score += keyword_weight / (rrf_k + keyword_rank)

                # 将RRF分数写入字典
                doc_ranks[doc_id]['rrf_score'] = rrf_score

            # 组装所有文档及其排名信息
            combined_results = [(doc_id, rankInfo) for doc_id, rankInfo in doc_ranks.items()]
            # 按RRF分数从高到低排序
            combined_results.sort(key=lambda x: x[1].get('rrf_score', 0), reverse=True)
            # 获取返回文档数量top_k,默认5
            top_k = int(self.settings.get('top_k', 5))

            # 遍历前top_k个文档,设置其元数据
            docs = []
            for doc_id, rankInfo in combined_results[:top_k]:
                doc = rankInfo['doc']
                # 更新向量分数到元数据
                doc.metadata['vector_score'] = rankInfo.get('vector_score', 0.0)
                # 更新关键词分数到元数据
                doc.metadata['keyword_score'] = rankInfo.get('keyword_score', 0.0)
                # 更新RRF分数到元数据
                doc.metadata['rrf_score'] = rankInfo.get('rrf_score', 0.0)
                # 标明检索类型为混合
                doc.metadata['retrieval_type'] = 'hybrid'
                # 添加到最终结果列表
                docs.append(doc)
            # 应用重排序
+           if self.reranker:
+               docs = self._apply_rerank(query, docs, top_k)    
            # 记录日志,输出检索到的文档数量
            logger.info(f"混合搜索 (RRF): 检索到 {len(docs)} 个文档")
            # 返回最终文档列表
            return docs

        except Exception as e:
            # 捕获异常并打印日志
            logger.error(f"混合搜索时出错: {e}")
            # 继续抛出异常
            raise    
      # 定义应用重排序的方法,参数包括查询query,带有分数的文档列表,返回前top_k结果
+   def _apply_rerank(self, query: str, docs: List[Document], top_k: int) -> List[Document]:
+       """
+       应用重排序

+       Args:
+           query: 查询文本
+           docs: Document 列表
+           top_k: 返回前k个结果

+       Returns:
+           重排序后的 (Document, score) 列表
+       """
        # 如果没有重排序器或者输入为空,直接返回原始结果
+       if not self.reranker or not docs:
+           return docs

+       try:
            # 使用重排序器对文档进行重排序,返回带有新分数的文档列表
+           reranked = self.reranker.rerank(query, docs, top_k=top_k)

            # 遍历重排序后的结果,更新文档元数据
+           for doc, rerank_score in reranked:
                # 新增 rerank_score 到元数据
+               doc.metadata['rerank_score'] = rerank_score
                # 标记检索类型为原有类型+rerank,若没有原有类型则标记为unknown
+               doc.metadata['retrieval_type'] = doc.metadata.get('retrieval_type', 'unknown')

            # 打印重排序成功的日志
+           logger.info(f"已应用重排序: {len(reranked)} 个文档已重排序")
            # 返回带有重排序分数的文档列表
+           return [doc for doc, _ in reranked]
+       except Exception as e:
            # 如果重排序过程中发生错误,打印日志,并返回原始结果
+           logger.error(f"应用重排序时出错: {e}")
            # 如果rerank失败,返回原始结果
+           return docs    

# 实例化检索服务,供外部调用
retrieval_service = RetrievalService()

12.显示来源 #

12.1. chat.py #

app/blueprints/chat.py

# 聊天相关路由(视图 + API)
"""
聊天相关路由(视图 + API)
"""

# 导入 Flask 的 Blueprint 和模板渲染函数
from flask import Blueprint, render_template, request, stream_with_context, Response
import json
# 导入日志模块
import logging
# 导入登录保护装饰器和获取当前用户辅助方法
from app.utils.auth import login_required, api_login_required,get_current_user
# 导入知识库服务
from app.services.knowledgebase_service import kb_service
# 导入自定义工具函数:成功响应、错误响应、获取分页参数、获取当前用户或错误、异常处理装饰器、检查所有权
from app.blueprints.utils import (
    success_response, error_response,
    get_current_user_or_error, handle_api_error, get_pagination_params,check_ownership
)
from app.services.chat_service import chat_service
from app.services.chat_session_service import session_service

# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)

# 创建名为 'chat' 的蓝图对象
bp = Blueprint('chat', __name__)

# 注册 /chat 路由,访问该路由需要先登录
@bp.route('/chat')
@login_required
def chat_view():
    # 智能问答页面视图函数
    """智能问答页面"""
    current_user = get_current_user()
    # 获取所有知识库(通常用户不会有太多知识库,不需要分页)
    result = kb_service.list(user_id=current_user['id'], page=1, page_size=1000)
    # 渲染 chat.html 模板并传递空知识库列表
    return render_template('chat.html', knowledgebases=result['items'])

# 注册 API 路由,处理聊天接口 POST 请求
@bp.route('/api/v1/knowledgebases/chat', methods=['POST'])
@api_login_required
@handle_api_error
def api_chat():
    # 普通聊天接口(不支持知识库,支持流式输出)
    """普通聊天接口(不支持知识库,支持流式输出)"""
    # 获取当前用户和错误信息
    current_user, err = get_current_user_or_error()
    # 如果有错误,直接返回错误响应
    if err:
        return err

    # 从请求体获取 JSON 数据
    data = request.get_json()
    # 如果数据为空或不存在 'question' 字段,返回错误
    if not data or 'question' not in data:
        return error_response("question is required", 400)

    # 去除问题文本首尾空格
    question = data['question'].strip()
    # 如果问题内容为空,返回错误
    if not question:
        return error_response("question cannot be empty", 400)
    session_id = data.get('session_id')  # 会话ID(可以为空,表示普通聊天)
    # 获取 max_tokens 参数,默认 1000
    max_tokens = int(data.get('max_tokens', 1000))
    # 限制最大和最小值在 1~10000 之间
    max_tokens = max(1, min(max_tokens, 10000))  # 限制在 1-10000 之间
    # 从请求数据中获取'stream'字段,默认为True,表示启用流式输出
    stream = data.get('stream', True)  # 默认启用流式输出

    # 初始化历史消息为None
    history = None
    # 如果请求中带有session_id,说明有现有会话
    if session_id:
        # 根据session_id和当前用户ID获取历史消息列表
        history_messages = session_service.get_messages(session_id, current_user['id'])
        # 将历史消息转换为对话格式,仅保留最近10条
        history = [
            {'role': msg.get('role'), 'content': msg.get('content')}
            for msg in history_messages[-10:]  # 只取最近10条
        ]

    # 如果请求中没有session_id,说明是新对话,需要新建会话
    if not session_id:
        # 创建新会话,kb_id设为None表示普通聊天
        chat_session = session_service.create_session(
            user_id=current_user['id']
        )
        # 使用新创建会话的ID作为本次会话ID
        session_id = chat_session['id']

    # 将用户的问题消息保存到当前会话中
    session_service.add_message(session_id, 'user', question)

    # 声明用于流式输出的生成器
    @stream_with_context
    def generate():
        try:
            # 用于缓存完整答案内容
            full_answer = ''
            # 调用服务进行流式对话
            for chunk in chat_service.chat_stream(
                question=question,
                temperature=None,  # 使用设置中的值
                max_tokens=max_tokens,
                history=history
            ):
                # 如果是内容块,则拼接内容到 full_answer
                if chunk.get('type') == 'content':
                    full_answer += chunk.get('content', '')
                # 以 SSE 协议格式输出数据
                yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
            # 输出对话完成信号
            yield "data: [DONE]\n\n"
            # 保存助手回复
            if full_answer:
                session_service.add_message(session_id, 'assistant', full_answer)
        except Exception as e:
            # 发生异常记录日志
            logger.error(f"流式输出时出错: {e}")
            # 构造错误数据块
            error_chunk = {
                "type": "error",
                "content": str(e)
            }
            # 输出错误数据块
            yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"

    # 创建 Response 对象,设置必要的 SSE 响应头部
    response = Response(
        generate(),
        mimetype='text/event-stream',
        headers={
            'Cache-Control': 'no-cache',
            'Connection': 'keep-alive',
            'X-Accel-Buffering': 'no',
            'Content-Type': 'text/event-stream; charset=utf-8'
        }
    )
    # 返回响应
    return response
# 路由装饰器,定义 GET 方法获取会话列表的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['GET'])
# API 登录校验装饰器,确保用户已登录
@api_login_required
# 错误处理装饰器,统一处理接口异常
@handle_api_error
def api_list_sessions():
    # 接口描述:获取当前用户的会话列表
    """获取当前用户的会话列表"""
    # 获取当前用户,如有错误直接返回错误响应
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 获取分页参数(页码和每页数量),最大单页1000
    page, page_size = get_pagination_params(max_page_size=1000)
    # 调用会话服务获取当前用户的会话列表
    result = session_service.list_sessions(current_user['id'], page=page, page_size=page_size)
    # 以统一成功响应格式返回会话列表
    return success_response(result)    


# 路由装饰器,定义 POST 方法创建会话的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['POST'])
@api_login_required
@handle_api_error
def api_create_session():
    # 接口描述:创建新的聊天会话
    """创建新的聊天会话"""
    # 获取当前用户,如果有错误直接返回
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 获取请求体中的 JSON 数据,若无返回空字典
    data = request.get_json() or {}
    # 获取会话标题
    title = data.get('title')

    # 调用服务创建会话,传入当前用户ID、知识库ID与标题
    session_obj = session_service.create_session(
        user_id=current_user['id'],
        title=title
    )
    # 返回成功响应及会话对象
    return success_response(session_obj)


# 路由装饰器,定义 GET 方法获取单个会话详情的接口(带 session_id)
@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['GET'])
@api_login_required
@handle_api_error
def api_get_session(session_id):
    # 接口描述:获取会话详情和消息
    """获取会话详情和消息"""
    # 获取当前用户,如有错误直接返回
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 根据 session_id 获取会话对象,校验所属当前用户
    session_obj = session_service.get_session_by_id(session_id, current_user['id'])
    # 如果没有找到会话,返回 404 错误
    if not session_obj:
        return error_response("Session not found", 404)

    # 获取该会话下的所有消息
    messages = session_service.get_messages(session_id, current_user['id'])

    # 返回会话详情及消息列表
    return success_response({
        'session': session_obj,
        'messages': messages
    })


# 路由装饰器,定义 DELETE 方法删除单个会话接口
@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['DELETE'])
@api_login_required
@handle_api_error
def api_delete_session(session_id):
    # 接口描述:删除会话
    """删除会话"""
    # 获取当前用户,如有错误直接返回
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 调用服务删除会话,校验归属当前用户
    success = session_service.delete_session(session_id, current_user['id'])
    # 若删除成功,返回成功响应,否则返回 404
    if success:
        return success_response(None, "Session deleted")
    else:
        return error_response("Session not found", 404)


# 路由装饰器,定义 DELETE 方法清空所有会话的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['DELETE'])
@api_login_required
@handle_api_error
def api_delete_all_sessions():
    # 接口描述:清空所有会话
    """清空所有会话"""
    # 获取当前用户,如果有错误直接返回
    current_user, err = get_current_user_or_error()
    if err:
        return err

    # 调用服务删除所有属于当前用户的会话,返回删除数量
    count = session_service.delete_all_sessions(current_user['id'])
    # 返回成功响应及被删除会话数
    return success_response({'deleted_count': count}, f"Deleted {count} sessions")

# 路由装饰器,指定POST方法用于知识库问答接口
@bp.route('/api/v1/knowledgebases/<kb_id>/chat', methods=['POST'])
# 装饰器:需要API登录
@api_login_required
# 装饰器:统一处理API错误
@handle_api_error
def api_ask(kb_id):
    # 知识库问答接口(支持流式输出)
    """知识库问答接口(支持流式输出)"""
    # 获取当前用户和错误信息
    current_user, err = get_current_user_or_error()
    # 如果获取用户出错,直接返回错误
    if err:
        return err

    # 获取指定id的知识库
    kb = kb_service.get_by_id(kb_id)
    # 检查当前用户是否有权限访问该知识库
    has_permission, err = check_ownership(kb['user_id'], current_user['id'], "knowledgebase")
    # 如果没有权限,直接返回错误
    if not has_permission:
        return err

    # 获取请求中的JSON数据
    data = request.get_json()

    # 获取并去除问题字符串首尾空白
    question = data['question'].strip()

    # 从请求数据获取session_id,如果没有则为None
    session_id = data.get('session_id')  # 会话ID
    # 获取最大token数,默认为1000
    max_tokens = int(data.get('max_tokens', 1000))
    # 限制max_tokens在1到10000之间
    max_tokens = max(1, min(max_tokens, 10000))  # 限制在 1-10000 之间

    # 如果没有提供session_id,则为用户和知识库创建一个新会话
    if not session_id:
        chat_session = session_service.create_session(
            user_id=current_user['id'],
            kb_id=kb_id
        )
        # 获取新会话的会话ID
        session_id = chat_session['id']

    # 保存用户输入的问题到消息列表
    session_service.add_message(session_id, 'user', question)

    # 内部函数:生成流式响应内容
    @stream_with_context
    def generate():
        try:
            # 初始化完整回复内容
            full_answer = ''
            # 初始化引用信息
            sources = None

            # 迭代chat_service.ask_stream的每个数据块
            for chunk in chat_service.ask_stream(
                kb_id=kb_id,
                question=question
            ):
                # 如果块类型为内容,则将内容追加到full_answer
                if chunk.get('type') == 'content':
                    full_answer += chunk.get('content', '')
                # 如果块类型为done,则获取sources
                elif chunk.get('type') == 'done':
                    sources = chunk.get('sources')

                # 以SSE格式输出该块内容
                yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"

            # 所有内容输出后发送结束标志
            yield "data: [DONE]\n\n"

            # 如果有回复内容,则保存机器人助手的回复和引用
            if full_answer:
+               session_service.add_message(session_id, 'assistant', full_answer, sources)
        except Exception as e:
            # 如果流式输出出错,在日志中记录错误信息
            logger.error(f"流式输出时出错: {e}")
            # 构造错误信息块
            error_chunk = {
                "type": "error",
                "content": str(e)
            }
            # 以SSE格式输出错误信息
            yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"

    # 构造SSE(服务端事件)响应对象,携带合适的头部信息
    response = Response(
        generate(),
        mimetype='text/event-stream',
        headers={
            'Cache-Control': 'no-cache',
            'Connection': 'keep-alive',
            'X-Accel-Buffering': 'no',
            'Content-Type': 'text/event-stream; charset=utf-8'
        }
    )
    # 返回响应对象
    return response

12.2. chat_session_service.py #

app/services/chat_session_service.py

# 导入倒序排序工具
from sqlalchemy import desc
# 导入日期时间
from datetime import datetime
# 导入聊天会话ORM模型
from app.models.chat_session import ChatSession
# 导入聊天消息ORM模型
from app.models.chat_message import ChatMessage
# 导入基础服务类
from app.services.base_service import BaseService

# 聊天会话服务类,继承自基础服务
class ChatSessionService(BaseService[ChatSession]):
    # 聊天会话服务说明文档
    """聊天会话服务"""

    # 创建新的聊天会话
    def create_session(self, user_id: str, kb_id: str = None, title: str = None) -> dict:
        """
        创建新的聊天会话

        Args:
            user_id: 用户ID
            kb_id: 知识库ID(可选)
            title: 会话标题(可选,如果不提供则使用默认标题)

        Returns:
            会话信息字典
        """
        # 启动数据库事务
        with self.transaction() as session:
            # 如果没有传标题则用默认标题
            if not title:
                title = "新对话"
            # 构造会话对象
            chat_session = ChatSession(
                user_id=user_id,
                title=title
            )
            # 新会话入库
            session.add(chat_session)
            # 刷新以拿到自增ID
            session.flush()
            # 刷新会话对象,便于获取ID等字段
            session.refresh(chat_session)
            # 记录日志
            self.logger.info(f"已创建聊天会话: {chat_session.id}, 用户: {user_id}")
            # 返回会话字典格式
            return chat_session.to_dict()

    # 根据ID获取会话
    def get_session_by_id(self, session_id: str, user_id: str = None) -> dict:
        """
        根据ID获取会话

        Args:
            session_id: 会话ID
            user_id: 用户ID(可选,用于验证权限)

        Returns:
            会话信息字典,如果不存在或无权访问则返回 None
        """
        # 打开数据库只读session
        with self.session() as session:
            # 查询指定ID的会话
            query = session.query(ChatSession).filter_by(id=session_id)
            # 如果提供了user_id则额外限定归属
            if user_id:
                query = query.filter_by(user_id=user_id)
            # 拿到第一个会话记录
            chat_session = query.first()
            # 有则返回字典信息,没有返回None
            if chat_session:
                return chat_session.to_dict()
            return None

    # 获取用户的所有会话列表(分页)
    def list_sessions(self, user_id: str, page: int = 1, page_size: int = 100) -> dict:
        """
        获取用户的会话列表

        Args:
            user_id: 用户ID
            page: 页码
            page_size: 每页数量

        Returns:
            包含总数和会话列表的字典
        """
        # 打开数据库只读session
        with self.session() as session:
            # 查询当前用户的所有会话
            query = session.query(ChatSession).filter_by(user_id=user_id)
            # 用基类的分页方法返回结构化内容,按更新时间倒序
            return self.paginate_query(query, page=page, page_size=page_size,
                                      order_by=desc(ChatSession.updated_at))

    # 会话标题修改
    def update_session_title(self, session_id: str, user_id: str, title: str) -> dict:
        """
        更新会话标题

        Args:
            session_id: 会话ID
            user_id: 用户ID(用于验证权限)
            title: 新标题

        Returns:
            更新后的会话信息字典
        """
        # 启动数据库事务
        with self.transaction() as session:
            # 查询拥有该会话的用户
            chat_session = session.query(ChatSession).filter_by(id=session_id, user_id=user_id).first()
            # 会话不存在则抛异常
            if not chat_session:
                raise ValueError("Session not found or access denied")
            # 更新标题
            chat_session.title = title
            # 刷新会话对象(update不会提交,refresh刷新对象)
            session.refresh(chat_session)
            # 返回最新数据
            return chat_session.to_dict()

    # 删除指定会话
    def delete_session(self, session_id: str, user_id: str) -> bool:
        """
        删除会话(级联删除消息)

        Args:
            session_id: 会话ID
            user_id: 用户ID(用于验证权限)

        Returns:
            是否删除成功
        """
        # 开启数据库事务
        with self.transaction() as session:
            # 查询该用户的指定会话
            chat_session = session.query(ChatSession).filter_by(id=session_id, user_id=user_id).first()
            # 未找到会话则返回False
            if not chat_session:
                return False
            # 删除会话(DB应有级联消息)
            session.delete(chat_session)
            # 记录日志
            self.logger.info(f"已删除聊天会话: {session_id}")
            # 返回True表示删除成功
            return True

    # 删除当前用户的所有会话
    def delete_all_sessions(self, user_id: str) -> int:
        """
        删除用户的所有会话

        Args:
            user_id: 用户ID

        Returns:
            删除的会话数量
        """
        # 开启数据库事务
        with self.transaction() as session:
            # 批量删除本用户所有会话
            count = session.query(ChatSession).filter_by(user_id=user_id).delete()
            # 记录日志
            self.logger.info(f"已删除用户 {user_id} 的 {count} 个聊天会话")
            # 返回删除数量
            return count

    # 添加消息到会话
+   def add_message(self, session_id: str, role: str, content: str, sources: list = None) -> dict:
        """
        添加消息到会话

        Args:
            session_id: 会话ID
            role: 角色('user' 或 'assistant')
            content: 消息内容
            sources: 引用来源列表(可选)

        Returns:
            消息信息字典
        """
        # 开启数据库事务
        with self.transaction() as session:
            # 构造消息对象
            message = ChatMessage(
                session_id=session_id,
                role=role,
+               content=content,
+               sources=sources
            )
            # 添加消息到数据库
            session.add(message)
            # 查询会话对象,用于更新时间/自动生成标题
            chat_session = session.query(ChatSession).filter_by(id=session_id).first()
            # 如果存在会话对象
            if chat_session:
                # 更新会话更新时间
                chat_session.updated_at = datetime.now()
                # 如果是用户发的第一条消息,并且还没标题,则用内容自动命名
                if role == 'user' and (not chat_session.title or chat_session.title == "新对话"):
                    # 会话标题截取前30字符,超长加省略号
                    title = content[:30] + ('...' if len(content) > 30 else '')
                    chat_session.title = title
            # 刷新确保message有ID
            session.flush()
            # 刷新消息对象
            session.refresh(message)
            # 返回消息字典
            return message.to_dict()

    # 获取会话的全部消息
    def get_messages(self, session_id: str, user_id: str = None) -> list:
        """
        获取会话的所有消息

        Args:
            session_id: 会话ID
            user_id: 用户ID(可选,用于验证权限)

        Returns:
            消息列表
        """
        # 打开只读session
        with self.session() as session:
            # 如指定user_id,须先验证此会话是否属于该用户,不属则不给查
            if user_id:
                chat_session = session.query(ChatSession).filter_by(id=session_id, user_id=user_id).first()
                # 如果会话不存在则返回空列表
                if not chat_session:
                    return []
            # 查询该会话下所有消息,按创建时间升序排序
            messages = session.query(ChatMessage).filter_by(session_id=session_id).order_by(ChatMessage.created_at).all()
            # 返回所有消息的字典列表
            return [m.to_dict() for m in messages]

# 单例: 聊天会话服务对象
session_service = ChatSessionService()

12.3. rag_service.py #

app/services/rag_service.py

"""
RAG 服务
"""
# 导入日志模块
import logging
# 导入 LangChain 的对话提示模板模块
from langchain_core.prompts import ChatPromptTemplate
# 导入自定义 LLM 工厂
from app.utils.llm_factory import LLMFactory
# 导入设置服务
from app.services.settings_service import settings_service
# 导入类型提示
from typing import List
# 导入 LangChain 的 Document 类型
from langchain_core.documents import Document
# 导入检索服务
from app.services.retrieval_service import retrieval_service
# 设置日志对象
logger = logging.getLogger(__name__)

# 定义 RAGService 类
class RAGService:
    """RAG 服务"""

    # 初始化函数
    def __init__(self):
        """
        初始化服务

        Args:
            settings: 设置字典,如果为 None 则从数据库读取
        """
        # 从设置服务中获取配置信息
        self.settings = settings_service.get()
        # 定义默认系统消息提示词
        default_rag_system_prompt = "你是一个专业的AI助手。请基于文档内容回答问题。"
        # 定义默认查询提示词,包含 context 和 question 占位符
        default_rag_query_prompt = """文档内容:
        {context}

        问题:{question}

        请基于文档内容回答问题。如果文档中没有相关信息,请明确说明。"""

        # 从设置中获取自定义系统消息提示词
        rag_system_prompt_text = self.settings.get('rag_system_prompt')
        # 如果没有设置,使用默认系统提示词
        if not rag_system_prompt_text:
            rag_system_prompt_text = default_rag_system_prompt

        # 从设置中获取自定义查询提示词
        rag_query_prompt_text = self.settings.get('rag_query_prompt')
        # 如果没有设置,使用默认查询提示词
        if not rag_query_prompt_text:
            rag_query_prompt_text = default_rag_query_prompt

        # 构建 RAG 的提示模板,包含系统消息和用户查询部分
        self.rag_prompt = ChatPromptTemplate.from_messages([
            ("system", rag_system_prompt_text),
            ("human", rag_query_prompt_text)
        ])

    # 定义流式问答接口
    def ask_stream(self, kb_id: str, question: str):
        """
        流式问答接口

        Args:
            kb_id: 知识库ID
            question: 问题

        Yields:
            流式数据块
        """
        # 创建带流式输出能力的 LLM 实例
        llm = LLMFactory.create_llm(self.settings)

        # 文档过滤后的结果
        filtered_docs = self._retrieve_documents(kb_id, question)
        # 发送流式开始信号
        yield {
            "type": "start",
            "content": ""
        }

        # 构造用于传递给 LLM 的上下文字符串,将所有文档整合为字符串
        context = "\n\n".join([
            f"文档 {i+1} ({doc.metadata.get('doc_name', '未知')}):\n{doc.page_content}"
            for i, doc in enumerate(filtered_docs)
        ])

        # 创建 Rag Prompt 到 LLM 的处理链
        chain = self.rag_prompt | llm

        # 初始化完整答案的字符串
        full_answer = ""
        # 逐块流式生成答案
        for chunk in chain.stream({"context": context, "question": question}):
            # 获取当前输出块内容
            content = chunk.content
            # 如果有内容则累加并 yield 输出内容块
            if content:
                full_answer += content
                yield {
                    "type": "content",
                    "content": content
                }
+       sources = self._extract_citations(filtered_docs)        
        # 所有内容输出结束后,发送完成信号和相关元数据
        yield {
            "type": "done",
            "content": "",
+           "sources": sources,
            "metadata": {
                'kb_id': kb_id,
                'question': question,
+               'retrieved_chunks': len(filtered_docs),
+               'used_chunks': len(sources)
            }
        }    
+   def _extract_citations(self, docs: List[Document]):
+       """
+       提取引用信息

+       Args:
+           docs: 检索到的 Document 列表(应该已经按相似度排序)

+       Returns:
+           引用列表(按相似度从高到低排序)
+       """
+       sources = []
+       for doc in docs:
+           metadata = doc.metadata

            # 获取各种分数
+           vector_score = metadata.get('vector_score')
+           keyword_score = metadata.get('keyword_score')
+           rrf_score = metadata.get('rrf_score')
+           rerank_score = metadata.get('rerank_score')
+           doc_name = metadata.get('doc_name', '未知文档')

+           sources.append({
+               'chunk_id': metadata.get('id') ,
+               'doc_id': metadata.get('doc_id', ''),
+               'doc_name': doc_name,
+               'content': doc.page_content,
+               'vector_score': round(float(vector_score), 4) if vector_score is not None else None,
+               'keyword_score': round(float(keyword_score), 4) if keyword_score is not None else None,
+               'rrf_score': round(float(rrf_score), 4) if rrf_score is not None else None,
+               'rerank_score': round(float(rerank_score), 4) if rerank_score is not None else None,
+               'retrieval_type': metadata.get('retrieval_type', 'unknown')
+           })

+       return sources
    def _retrieve_documents(self, kb_id: str, question: str) -> List[Document]:
        """
        检索文档(公共方法)

        Args:
            kb_id: 知识库ID
            question: 查询文本


        Returns:
            文档列表
        """
        collection_name = f"kb_{kb_id}"
        retrieval_mode = self.settings.get('retrieval_mode', 'vector')

        if retrieval_mode == 'vector':
            docs = retrieval_service.vector_search(
                collection_name=collection_name,
                query=question
            )
        elif retrieval_mode == 'keyword':
            docs = retrieval_service.keyword_search(
                collection_name=collection_name,
                query=question
            )   
        elif retrieval_mode == 'hybrid':
            docs = retrieval_service.hybrid_search(
                collection_name=collection_name,
                query=question
            )     
        else:
            logger.warning(f"未知的检索模式: {retrieval_mode}, 使用向量检索")
            docs = retrieval_service.vector_search(
                collection_name=collection_name,
                query=question
            )

        logger.info(f"使用 {retrieval_mode} 模式检索到 {len(docs)} 个文档")
        return docs   

# 实例化 rag_service,供外部调用
rag_service = RAGService()        

12.4. chat.html #

app/templates/chat.html

{% extends "base.html" %}

{% block title %}智能问答 - RAG Lite{% endblock %}

{% block extra_css %}
<style>
    .chat-container {
        height: calc(100vh - 200px);
        display: flex;
        gap: 1rem;
    }
    .chat-sidebar {
        width: 280px;
        background: #fff;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        display: flex;
        flex-direction: column;
    }
    .chat-main {
        flex: 1;
        background: #fff;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        display: flex;
        flex-direction: column;
    }
    .session-list {
        flex: 1;
        overflow-y: auto;
        padding: 0.5rem;
    }
    .session-item {
        padding: 0.75rem;
        margin-bottom: 0.5rem;
        border-radius: 0.5rem;
        cursor: pointer;
        transition: background-color 0.2s;
        position: relative;
    }
    .session-item:hover {
        background-color: #f8f9fa;
    }
    .session-item.active {
        background-color: #e3f2fd;
        border-left: 3px solid #0d6efd;
    }
    .session-item .session-title {
        font-weight: 500;
        margin-bottom: 0.25rem;
        overflow: hidden;
        text-overflow: ellipsis;
        white-space: nowrap;
    }
    .session-item .session-time {
        font-size: 0.75rem;
        color: #6c757d;
    }
    .session-item .session-delete {
        position: absolute;
        top: 0.5rem;
        right: 0.5rem;
        opacity: 0;
        transition: opacity 0.2s;
    }
    .session-item:hover .session-delete {
        opacity: 1;
    }
    .chat-messages {
        flex: 1;
        overflow-y: auto;
        padding: 1rem;
        scroll-behavior: smooth;
    }
    .chat-message {
        padding: 1rem;
        margin-bottom: 1rem;
        border-radius: 0.5rem;
    }
    .chat-question {
        background-color: #e3f2fd;
    }
    .chat-answer {
        background-color: #f5f5f5;
    }
    .chat-input-area {
        padding: 1rem;
        border-top: 1px solid #dee2e6;
    }
    .empty-state {
        display: flex;
        flex-direction: column;
        align-items: center;
        justify-content: center;
        height: 100%;
        color: #6c757d;
    }
</style>
{% endblock %}

{% block content %}
<div class="chat-container">
    <!-- 左侧:会话管理 -->
    <div class="chat-sidebar">
        <div class="p-3 border-bottom">
            <div class="d-flex justify-content-between align-items-center mb-3">
                <h6 class="mb-0"><i class="bi bi-chat-left-text"></i> 聊天会话</h6>
                <button class="btn btn-sm btn-primary" onclick="createNewSession()">
                    <i class="bi bi-plus"></i> 新建
                </button>
            </div>
            <button class="btn btn-sm btn-outline-danger w-100" onclick="clearAllSessions()">
                <i class="bi bi-trash"></i> 清空所有
            </button>
        </div>
        <div class="session-list" id="sessionList">
            <div class="text-center text-muted py-5">
                <i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
                <p class="mt-2 small">暂无会话</p>
            </div>
        </div>
    </div>

    <!-- 右侧:对话页面 -->
    <div class="chat-main">
        <div class="p-3 border-bottom">
            <div class="d-flex justify-content-between align-items-center">
                <h6 class="mb-0"><i class="bi bi-chat-dots"></i> 对话</h6>
                <select class="form-select form-select-sm" id="kbSelect" style="width: 200px;" onchange="onKbChange()">
                    <option value="">-- 选择知识库 --</option>
                    {% for kb in knowledgebases %}
                    <option value="{{ kb.id }}">{{ kb.name }}</option>
                    {% endfor %}
                </select>
            </div>
        </div>

        <div class="chat-messages" id="chatMessages">
            <div class="empty-state">
                <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
                <p class="mt-2">开始提问吧!</p>
            </div>
        </div>

        <div class="chat-input-area">
            <form id="chatForm" onsubmit="askQuestion(event)">
                <div class="mb-2">
                    <textarea class="form-control" id="questionInput" rows="2" 
                              placeholder="输入您的问题..." required></textarea>
                </div>
                <div class="d-flex justify-content-end">
                    <button type="submit" class="btn btn-primary" id="submitBtn">
                        <i class="bi bi-send"></i> 发送
                    </button>
                </div>
            </form>
        </div>
    </div>
</div>
{% endblock %}

{% block extra_js %}
<script>
// 当前会话的ID,初始为null,表示暂未选择任何会话
let currentSessionId = null;
// 会话列表,初始为空数组,用于存储所有的会话对象
let sessions = [];
// 当前知识库的ID,初始为null,表示暂未选择任何知识库
let currentKbId = null;
// 为字符串进行HTML转义,防止XSS攻击
function escapeHtml(text) {
    // 创建一个div元素作为容器
    const div = document.createElement('div');
    // 将待转义文本设置为div的textContent(自动完成转义)
    div.textContent = text;
    // 返回转义后的HTML内容
    return div.innerHTML;
}

// 滚动消息框到底部
function scrollToBottom() {
    // 获取聊天消息区域元素
    const chatMessages = document.getElementById('chatMessages');
    // 设置滚动条位置到最底部
    chatMessages.scrollTop = chatMessages.scrollHeight;
}
// 定义一个用于渲染Markdown文本为HTML的函数
function renderMarkdown(text) {
    // 判断marked库是否已加载且能使用解析方法
    if (typeof marked !== 'undefined' && marked.parse) {
        try {
            // 尝试用marked库将Markdown文本解析成HTML
            return marked.parse(text);
        } catch (e) {
            // 如果解析出错,则进行HTML转义并换行
            return escapeHtml(text).replace(/\n/g, '<br>');
        }
    }
    // 如果marked库不可用,直接做HTML转义并处理换行
    return escapeHtml(text).replace(/\n/g, '<br>');
}
// 将Markdown内容渲染到指定元素(支持降级为普通文本)
function renderMarkdownToElement(element, text) {
    // 若未提供内容,显示思考中图标
    if (!text) {
        element.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
        return;
    }
    // 优先判断marked库是否可用(渲染markdown)
    if (typeof marked !== 'undefined' && marked.parse) {
        try {
            // 使用marked进行markdown转html
            element.innerHTML = marked.parse(text);
        } catch (e) {
            // 渲染失败则退化为转义+换行
            element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
        }
    } else {
        // 没有marked库则直接转义+换行显示
        element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
    }
}

// 主函数:处理用户提交问题事件
async function askQuestion(event) {
    // 阻止表单默认提交行为(防止页面刷新)
    event.preventDefault();
    // 获取输入框的用户问题并去除首尾空白
    const question = document.getElementById('questionInput').value.trim();
    // 获取消息显示区域元素
    const chatMessages = document.getElementById('chatMessages');
    // 若问题为空则直接返回
    if (!question) return;
     // 如果没有会话,创建新会话
     if (!currentSessionId) {
        await createNewSession();
    }
    // 检查并移除初始空白提示(如有)
    if (chatMessages.querySelector('.empty-state')) {
        chatMessages.innerHTML = '';
    }
    // 创建用于展示问题的div元素
    const questionDiv = document.createElement('div');
    // 加上样式: 用户问题
    questionDiv.className = 'chat-message chat-question';
    // 构建用户气泡内容含图标、文本和时间
    questionDiv.innerHTML = `
        <div class="d-flex justify-content-between align-items-start">
            <div class="flex-grow-1">
                <strong><i class="bi bi-person-circle"></i> 问题:</strong>
                <div class="mt-1">${escapeHtml(question)}</div>
            </div>
            <small class="text-muted">${new Date().toLocaleTimeString()}</small>
        </div>
    `;
    // 显示到对话窗口
    chatMessages.appendChild(questionDiv);

    // 创建用于显示答案的div元素
    const answerDiv = document.createElement('div');
    // 答案样式
    answerDiv.className = 'chat-message chat-answer';
    // 动态生成唯一的答案内容div id(用于唯一标记)
    const answerContentId = 'answerContent_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
    // 创建存放答案内容的div
    const answerContent = document.createElement('div');
    // 为markdown渲染容器设置样式和id
    answerContent.className = 'mt-2 markdown-content';
    answerContent.id = answerContentId;
    // 答案div显示机器人图标和时间
    answerDiv.innerHTML = `
        <div class="d-flex justify-content-between align-items-start">
            <div class="flex-grow-1">
                <strong><i class="bi bi-robot"></i> 答案:</strong>
            </div>
            <small class="text-muted">${new Date().toLocaleTimeString()}</small>
        </div>
    `;
    // 将答案内容div插入到机器人气泡div的内容区
    const flexGrowDiv = answerDiv.querySelector('.flex-grow-1');
    flexGrowDiv.appendChild(answerContent);
    // 插入答案div到消息区域
    chatMessages.appendChild(answerDiv);

    // 变量记录完整的回答内容
    let fullAnswer = '';
    // 标记渲染任务是否挂起(防止重复)
    let pendingUpdate = false;
    // 记录定时器id(去抖动用)
    let updateTimer = null;

    // 清空输入框内容
    document.getElementById('questionInput').value = '';
    // 滚动到底
    scrollToBottom();

    // 定义scheduleRender:将markdown渲染插入到队列合适时机执行
    function scheduleRender() {
        // 若当前无待渲染任务才安排渲染
        if (!pendingUpdate) {
            pendingUpdate = true;
            // 下一帧渲染
            requestAnimationFrame(() => {
                // 将答案作为markdown渲染进dom
                renderMarkdownToElement(answerContent, fullAnswer);
                // 清理pending标识
                pendingUpdate = false;
                // 渲染后滚动到底部
                scrollToBottom();
            });
        }
    }

    try {
        // 组装API接口地址
        const url = currentKbId 
            ? `/api/v1/knowledgebases/${currentKbId}/chat`
            : `/api/v1/knowledgebases/chat`;
        // 请求后端,发起流式POST请求
        const response = await fetch(url, {
            method: 'POST',
            headers: {'Content-Type': 'application/json'},
            body: JSON.stringify({
                question: question,
                session_id: currentSessionId,
                stream: true
            })
        });
        // 非200响应时,抛出异常
        if (!response.ok) throw new Error('请求失败');
        // 使用ReadableStream reader获取response流
        const reader = response.body.getReader();
        // 用TextDecoder解码二进制数据
        const decoder = new TextDecoder();
        // 数据缓冲区(字符串)
        let buffer = '';
        // 初始展示“思考中...”提示
        answerContent.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
        // 重新加载会话列表以更新标题
        await loadSessions();
        // 不断循环读取服务端推送流内容
        while (true) {
            // 逐步读取一段数据
            const { done, value } = await reader.read();
            // 若读完则结束循环
            if (done) break;
            // 本块新数据解码成字符串,并追加到buffer
            buffer += decoder.decode(value, { stream: true });
            // 按行切分(多行分别处理)
            const lines = buffer.split('\n');
            // 最后一行一般为数据残留,不处理,下次拼
            buffer = lines.pop() || '';
            // 处理每一行数据
            for (const line of lines) {
                // 筛选有效SSE协议"data: "数据包
                if (line.startsWith('data: ')) {
                    // 去掉前缀,留下json内容
                    const data = line.slice(6);
                    // [DONE]信号作为流结束,仅跳过
                    if (data === '[DONE]') continue;
                    try {
                        // 解析JSON数据
                        const chunk = JSON.parse(data);
                        // 类型为流起始,清空内容
                        if (chunk.type === 'start') {
                            fullAnswer = '';
                            answerContent.innerHTML = '';
                        // 类型为内容,追加文本并安排去抖渲染
                        } else if (chunk.type === 'content') {
                            fullAnswer += chunk.content;
                            scheduleRender();
                        // 类型为done(流传输结束),直接最终渲染所有答案内容
                        } else if (chunk.type === 'done') {
                            renderMarkdownToElement(answerContent, fullAnswer);
+                           renderSources(chunk.sources, chunk.metadata, answerDiv, currentKbId);
                        // 错误类型,alert显示报错内容
                        } else if (chunk.type === 'error') {
                            answerContent.innerHTML = `<div class="alert alert-danger">${chunk.content}</div>`;
                        }
                    } catch (e) {
                        // JSON解析失败时报console
                        console.error('解析流数据失败:', e);
                    }
                }
            }

        }
    } catch (error) {
        // 通信异常、Fetch错误等: 显示错误气泡
        answerContent.innerHTML = `<div class="alert alert-danger"><strong>错误:</strong> ${error.message}</div>`;
    }
    // 结束后确保界面滚动到底
    scrollToBottom();
}
+// 渲染引用来源到 DOM 元素
+function renderSources(sources, metadata, answerDiv, currentKbId) {
+   const html = generateSourcesHtml(sources, metadata, currentKbId);
+   if (html) {
+       const tempDiv = document.createElement('div');
+       tempDiv.innerHTML = html;
+       while (tempDiv.firstChild) {
+           answerDiv.appendChild(tempDiv.firstChild);
+       }
+   }
+}
+// 生成引用来源的 HTML 字符串
+function generateSourcesHtml(sources, metadata, currentKbId) {
+   // 如果 sources 存在、且是数组且长度大于0
+   if (sources && Array.isArray(sources) && sources.length > 0) {
+       // 对每个 source 逐个生成 html
+       const sourcesHtml = sources.map((source, idx) => {
+           // 初始化分数数组
+           const scores = [];

+           // 定义辅助函数: 安全转换分数字段
+           const getScore = (score) => {
+               // 如果为空则返回 null
+               if (score === null || score === undefined) return null;
+               // 字符串转换为浮点数,否则直接返回
+               const num = typeof score === 'string' ? parseFloat(score) : score;
+               // 非数值返回 null,否则返回数字
+               return isNaN(num) ? null : num;
+           };

+           // 获取检索类型,默认为 unknown
+           const retrievalType = source.retrieval_type || 'unknown';
+           // 是否为 vector 检索
+           const isVector = retrievalType.includes('vector');
+           // 是否为 keyword 检索
+           const isKeyword = retrievalType.includes('keyword');
+           // 是否为 hybrid 混合检索
+           const isHybrid = retrievalType.includes('hybrid');

+           // 如果是仅向量检索并且不是混合
+           if (isVector && !isHybrid) {
+               // 获取向量分数
+               const vectorScore = getScore(source.vector_score);
+               // 有效且大于等于0则添加 badge
+               if (vectorScore !== null && vectorScore >= 0) {
+                   scores.push(`<span class="badge bg-info me-1">向量相似度: ${vectorScore.toFixed(4)}</span>`);
+               }
+           }

+           // 如果是仅关键词检索并且不是混合
+           if (isKeyword && !isHybrid) {
+               // 获取关键词分数
+               const keywordScore = getScore(source.keyword_score);
+               // 有效且大于等于0则添加 badge
+               if (keywordScore !== null && keywordScore >= 0) {
+                   scores.push(`<span class="badge bg-success me-1">关键词相似度: ${keywordScore.toFixed(4)}</span>`);
+               }
+           }

+           // 如果是混合检索
+           if (isHybrid) {
+               // 获取向量分数
+               const vectorScore = getScore(source.vector_score);
+               // 有效且大于等于0则添加 badge
+               if (vectorScore !== null && vectorScore >= 0) {
+                   scores.push(`<span class="badge bg-info me-1">向量相似度: ${vectorScore.toFixed(4)}</span>`);
+               }
+               // 获取关键词分数
+               const keywordScore = getScore(source.keyword_score);
+               // 有效且大于等于0则添加 badge
+               if (keywordScore !== null && keywordScore >= 0) {
+                   scores.push(`<span class="badge bg-success me-1">关键词相似度: ${keywordScore.toFixed(4)}</span>`);
+               }
+               // 获取混合 RRF 分数
+               const rrfScore = getScore(source.rrf_score);
+               // 有效且大于等于0则添加 badge
+               if (rrfScore !== null && rrfScore >= 0) {
+                   scores.push(`<span class="badge bg-primary me-1">RRF 分数: ${rrfScore.toFixed(4)}</span>`);
+               }
+           }

+           // 获取重排序分数,并显示
+           const rerankScore = getScore(source.rerank_score);
+           // 有效且大于等于0则添加 badge
+           if (rerankScore !== null && rerankScore >= 0) {
+               scores.push(`<span class="badge bg-danger me-1">重排序分数: ${rerankScore.toFixed(4)}</span>`);
+           }

+           // 返回单个引用块的 html 字符串
+           return `
+               <div class="source-item mb-3 p-2 border rounded">
+                   <div class="fw-bold mb-1">
+                       <i class="bi bi-file-earmark"></i> ${idx + 1}. ${source.doc_name || '未知文档'}
+                   </div>
+                   <div class="mb-2">
+                       ${scores.length > 0 ? scores.join('') : '<span class="text-muted small">暂无分数信息</span>'}
+                   </div>
+                   <div class="small text-muted">
+                       <strong>分块内容:</strong>
+                       <div class="mt-1 p-2 bg-white rounded" style="max-height: 150px; overflow-y: auto; font-size: 0.9em;">
+                           ${escapeHtml(source.content || '')}
+                       </div>
+                   </div>
+               </div>
+           `;
+       // 连接所有的引用来源 HTML
+       }).join('');

+       // 返回所有引用来源的 html 块
+       return `
+           <div class="mt-3 p-3 bg-light rounded">
+               <strong><i class="bi bi-link-45deg"></i> 引用来源 (${sources.length} 个):</strong>
+               <div class="mt-2">
+                   ${sourcesHtml}
+               </div>
+           </div>
+       `;
+   // 如果传入了知识库ID但是没有 sources
+   } else if (currentKbId) {
+       // 从元数据获取检索块数,没有则为0
+       const retrievedChunks = (metadata && metadata.retrieved_chunks) || 0;

+       // 定义提示信息
+       let message = '';
+       // 若未检索到任何文档,提示排查信息
+       if (retrievedChunks === 0) {
+           message = '未检索到任何文档。请检查:1) 知识库中是否有文档;2) 文档是否已处理完成;3) 检索阈值是否设置过高。';
+       // 如果检索到了但没引用,提示相关性低
+       } else {
+           message = '未找到相关引用来源。可能是检索到的文档与问题相关性较低。';
+       }

+       // 返回警告气泡 html
+       return `
+           <div class="mt-3 p-2 bg-warning bg-opacity-10 rounded">
+               <small class="text-warning">
+                   <i class="bi bi-exclamation-triangle"></i> ${message}
+               </small>
+           </div>
+       `;
+   }
+   // 默认返回空字符串
+   return '';
+}
// 清空所有会话
// 异步函数,用于清空所有会话
async function clearAllSessions() {
    // 弹窗确认操作,若取消则直接返回
    if (!confirm('确定要清空所有会话吗?此操作不可恢复!')) return;

    try {
        // 发送DELETE请求到服务器,删除所有会话
        const response = await fetch('/api/v1/knowledgebases/sessions', {
            method: 'DELETE'
        });
        // 获取服务器返回的JSON结果
        const result = await response.json();
        // 如果接口返回成功
        if (result.code === 200) {
            // 当前会话ID置空
            currentSessionId = null;
            // 清空聊天消息内容
            clearChatMessages();
            // 重新加载会话列表
            await loadSessions();
        }
    } catch (error) {
        // 捕获异常并弹窗提示失败原因
        alert('清空会话失败: ' + error.message);
    }
}

// 格式化时间
// 用于格式化时间字符串为友好显示
function formatTime(timeStr) {
    // 如果无有效时间,直接返回空字符串
    if (!timeStr) return '';
    // 将字符串转换为Date对象
    const date = new Date(timeStr);
    // 获取当前时间
    const now = new Date();
    // 计算时间差(毫秒)
    const diff = now - date;
    // 换算为分钟数
    const minutes = Math.floor(diff / 60000);
    // 换算为小时数
    const hours = Math.floor(diff / 3600000);
    // 换算为天数
    const days = Math.floor(diff / 86400000);
    // 1分钟内显示“刚刚”
    if (minutes < 1) return '刚刚';
    // 1小时内显示“x分钟前”
    if (minutes < 60) return `${minutes}分钟前`;
    // 24小时内显示“x小时前”
    if (hours < 24) return `${hours}小时前`;
    // 7天内显示“x天前”
    if (days < 7) return `${days}天前`;
    // 超过7天显示具体日期
    return date.toLocaleDateString();
}

// 渲染消息
function renderMessages(messages) {
    const chatMessages = document.getElementById('chatMessages');

    if (messages.length === 0) {
        chatMessages.innerHTML = `
            <div class="empty-state">
                <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
                <p class="mt-2">开始提问吧!</p>
            </div>
        `;
        return;
    }

    chatMessages.innerHTML = messages.map(msg => {
        if (msg.role === 'user') {
            return `
                <div class="chat-message chat-question">
                    <div class="d-flex justify-content-between align-items-start">
                        <div class="flex-grow-1">
                            <strong><i class="bi bi-person-circle"></i> 问题:</strong>
                            <div class="mt-1">${escapeHtml(msg.content)}</div>
                        </div>
                        <small class="text-muted">${formatTime(msg.created_at)}</small>
                    </div>
                </div>
            `;
        } else {
+           const sourcesHtml = generateSourcesHtml(msg.sources, msg.metadata, null);
            return `
                <div class="chat-message chat-answer">
                    <div class="d-flex justify-content-between align-items-start">
                        <div class="flex-grow-1">
                            <strong><i class="bi bi-robot"></i> 答案:</strong>
                            <div class="mt-2 markdown-content">${renderMarkdown(msg.content)}</div>
+                           ${sourcesHtml}
                        </div>
                        <small class="text-muted">${formatTime(msg.created_at)}</small>
                    </div>
                </div>
            `;
        }
    }).join('');

    scrollToBottom();
}
// 加载会话
// 定义一个异步函数 loadSession,用于加载指定 sessionId 的会话
async function loadSession(sessionId) {
    // 异常捕获,防止请求或处理过程抛错
    try {
        // 发送 GET 请求,获取指定会话的详情数据
        const response = await fetch(`/api/v1/knowledgebases/sessions/${sessionId}`);
        // 将返回的响应数据解析为 JSON 格式
        const result = await response.json();

        // 如果接口请求成功,即 code 等于 200
        if (result.code === 200) {
            // 设置当前会话 ID 为传入的 sessionId
            currentSessionId = sessionId;
            // 获取会话详情数据
            const session = result.data.session;
            // 获取该会话包含的消息列表,如果没有则为 []
            const messages = result.data.messages || [];
            // 调用方法将消息渲染到页面
            renderMessages(messages);
            // 重新加载全部会话列表,刷新左侧会话栏
            await loadSessions();
        }
    } catch (error) {
        // 捕获异常并弹窗提示加载失败的原因
        alert('加载会话失败: ' + error.message);
    }
}
// 渲染会话列表
// 将sessions数组渲染到界面左侧会话列表
function renderSessions() {
    // 获取会话列表的DOM节点
    const sessionList = document.getElementById('sessionList');
    // 如果没有任何会话,显示为空状态
    if (sessions.length === 0) {
        sessionList.innerHTML = `
            <div class="text-center text-muted py-5">
                <i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
                <p class="mt-2 small">暂无会话</p>
            </div>
        `;
        return;
    }
    // 遍历sessions并渲染为会话项
    sessionList.innerHTML = sessions.map(session => `
        <div class="session-item ${session.id === currentSessionId ? 'active' : ''}" 
             onclick="loadSession('${session.id}')">
            <button class="btn btn-sm btn-link text-danger p-0 session-delete" 
                    onclick="event.stopPropagation(); deleteSession('${session.id}')">
                <i class="bi bi-x-lg"></i>
            </button>
            <div class="session-title">${session.title || '新对话'}</div>
            <div class="session-time">${formatTime(session.updated_at)}</div>
        </div>
    `).join('');
}

// 清空聊天消息
// 将聊天内容区重置为初始状态
function clearChatMessages() {
    document.getElementById('chatMessages').innerHTML = `
        <div class="empty-state">
            <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
            <p class="mt-2">开始提问吧!</p>
        </div>
    `;
}

// 加载会话列表
// 异步加载会话列表并渲染到界面
async function loadSessions() {
    try {
        // 发送GET请求获取会话列表
        const response = await fetch('/api/v1/knowledgebases/sessions');
        // 解析JSON结果
        const result = await response.json();
        // 如果接口返回成功
        if (result.code === 200) {
            // 保存会话数组
            sessions = result.data.items || [];
            // 渲染会话列表
            renderSessions();
        }
    } catch (error) {
        // 捕捉异常并在控制台打印
        console.error('加载会话列表失败:', error);
    }
}

// 创建新会话
// 异步请求服务端创建新会话
async function createNewSession() {
    try {
        // 发送POST请求创建新会话
        const response = await fetch('/api/v1/knowledgebases/sessions', {
            method: 'POST',
            headers: {'Content-Type': 'application/json'},
            body: JSON.stringify({})
        });
        // 解析返回结果
        const result = await response.json();
        // 如果创建成功
        if (result.code === 200) {
            // 切换到新会话
            currentSessionId = result.data.id;
            // 重新加载会话列表
            await loadSessions();
            // 清空聊天内容
            clearChatMessages();
        }
    } catch (error) {
        // 捕获异常并弹窗提示
        alert('创建会话失败: ' + error.message);
    }
}
// 监听 DOMContentLoaded 事件,确保页面元素加载完成后执行
document.addEventListener('DOMContentLoaded', function() {
    // 加载会话列表
    loadSessions();
    // 判断 marked.js 是否已经加载
    if (typeof marked !== 'undefined') {
        // 配置 marked.js 的渲染选项
        marked.setOptions({
            // 启用软换行
            breaks: true,
            // 启用 Github Flavored Markdown
            gfm: true,
            // 禁用标题自动生成 id
            headerIds: false,
            // 禁用混淆处理
            mangle: false
        });
    }
});
// 知识库选择变化处理函数
async function onKbChange() {
    // 更新全局当前知识库ID变量
    currentKbId = document.getElementById('kbSelect').value;;

    // 如果已存在会话,则切换知识库时重新新建一个会话
    if (currentSessionId) {
        await createNewSession();
    }

    // 启用发送按钮(普通聊天和知识库聊天均可提问)
    document.getElementById('submitBtn').disabled = false;
}
</script>
{% endblock %}
← 上一节 20.文档管理 下一节 22.API文档 →

访问验证

请输入访问令牌

Token不正确,请重新输入