1.本章目标 #
本章将介绍聊天模块的实现目标与功能,包括用户与AI的对话流程、消息存储与检索、会话管理机制以及与知识库等模块的基础集成方式。通过本章的学习,读者能够理解聊天功能的架构设计思路、关键模块结构、以及如何扩展和自定义聊天体验,为构建智能对话应用奠定基础。
2.目录结构 #
// 项目根目录
rag-lite/
// 应用主目录
├── app/
// 蓝图视图目录(用于组织不同功能的视图)
│ ├── blueprints/
// 蓝图包初始化
│ │ ├── __init__.py
// 用户认证相关视图
│ │ ├── auth.py
// 聊天相关视图
│ │ ├── chat.py
// 文档管理相关视图
│ │ ├── document.py
// 知识库相关视图
│ │ ├── knowledgebase.py
// 系统设置相关视图
│ │ ├── settings.py
// 工具函数蓝图
│ │ └── utils.py
// ORM数据模型目录
│ ├── models/
// 模型包初始化
│ │ ├── __init__.py
// 基础模型定义
│ │ ├── base.py
// 聊天消息模型
│ │ ├── chat_message.py
// 聊天会话模型
│ │ ├── chat_session.py
// 文档模型
│ │ ├── document.py
// 知识库模型
│ │ ├── knowledgebase.py
// 系统设置模型
│ │ ├── settings.py
// 用户模型
│ │ └── user.py
// 服务层目录
│ ├── services/
// 存储相关服务子目录
│ │ ├── storage/
// 存储包初始化
│ │ │ ├── __init__.py
// 存储基类
│ │ │ ├── base.py
// 存储工厂类
│ │ │ ├── factory.py
// 本地存储实现
│ │ │ ├── local_storage.py
// minio存储实现
│ │ │ └── minio_storage.py
// 向量数据库服务子目录
│ │ ├── vectordb/
// 向量数据库包初始化
│ │ │ ├── __init__.py
// 向量数据库基类
│ │ │ ├── base.py
// Chroma向量数据库实现
│ │ │ ├── chroma.py
// 向量数据库工厂
│ │ │ ├── factory.py
// Milvus向量数据库实现
│ │ │ └── milvus.py
// 基础服务类
│ │ ├── base_service.py
// 聊天服务
│ │ ├── chat_service.py
// 聊天会话管理服务
│ │ ├── chat_session_service.py
// 文档服务
│ │ ├── document_service.py
// 知识库服务
│ │ ├── knowledgebase_service.py
// 文档解析服务
│ │ ├── parser_service.py
// RAG服务
│ │ ├── rag_service.py
// 检索服务
│ │ ├── retrieval_service.py
// 系统设置服务
│ │ ├── settings_service.py
// 存储服务(入口)
│ │ ├── storage_service.py
// 用户服务
│ │ ├── user_service.py
// 向量化服务
│ │ └── vector_service.py
// 静态资源目录
│ ├── static/
// 前端模板文件目录
│ ├── templates/
// 基础模板
│ │ ├── base.html
// 聊天页面模板
│ │ ├── chat.html
// 首页模板
│ │ ├── home.html
// 知识库详情页面模板
│ │ ├── kb_detail.html
// 知识库列表页面模板
│ │ ├── kb_list.html
// 登录页面模板
│ │ ├── login.html
// 注册页面模板
│ │ ├── register.html
// 设置页面模板
│ │ └── settings.html
// 工具函数模块目录
│ ├── utils/
// 认证相关工具
│ │ ├── auth.py
// 数据库工具
│ │ ├── db.py
// 文档加载工具
│ │ ├── document_loader.py
// 嵌入模型工厂
│ │ ├── embedding_factory.py
// LLM大模型工厂
│ │ ├── llm_factory.py
// 日志工具
│ │ ├── logger.py
// 模型配置工具
│ │ ├── models_config.py
// 重排序(rerank)工厂
│ │ ├── rerank_factory.py
// 文本分割工具
│ │ └── text_splitter.py
// 应用初始化文件
│ ├── __init__.py
// 应用配置文件
│ └── config.py
// Chroma数据库目录
├── chroma_db/
// 日志文件目录
├── logs/
// RAG Lite 主日志文件
│ └── rag_lite.log
// 存储文件目录
├── storages/
// 持久化卷目录
├── volumes/
// Milvus向量数据库的数据目录
│ ├── milvus/
// Minio对象存储的数据目录
│ └── minio/
// Docker Compose服务编排配置文件
├── docker-compose.yml
// 应用启动主程序
├── main.py
// Python项目配置文件
├── pyproject.toml3.聊天页面 #
3.1. chat.py #
app/blueprints/chat.py
# 聊天相关路由(视图 + API)
"""
聊天相关路由(视图 + API)
"""
# 导入 Flask 的 Blueprint 和模板渲染函数
from flask import Blueprint, render_template
# 导入知识库服务,用于后续业务逻辑
from app.services.knowledgebase_service import kb_service
# 导入登录保护装饰器和获取当前用户辅助方法
from app.utils.auth import login_required, get_current_user
# 导入日志模块
import logging
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 创建名为 'chat' 的蓝图对象
bp = Blueprint('chat', __name__)
# 注册 /chat 路由,访问该路由需要先登录
@bp.route('/chat')
@login_required
def chat_view():
# 智能问答页面视图函数
"""智能问答页面"""
# 渲染 chat.html 模板并传递空知识库列表
return render_template('chat.html', knowledgebases=[])
3.2. chat.html #
app/templates/chat.html
{% extends "base.html" %}
{% block title %}智能问答 - RAG Lite{% endblock %}
{% block extra_css %}
<style>
.chat-container {
height: calc(100vh - 200px);
display: flex;
gap: 1rem;
}
.chat-sidebar {
width: 280px;
background: #fff;
border-radius: 0.5rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
display: flex;
flex-direction: column;
}
.chat-main {
flex: 1;
background: #fff;
border-radius: 0.5rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
display: flex;
flex-direction: column;
}
.session-list {
flex: 1;
overflow-y: auto;
padding: 0.5rem;
}
.session-item {
padding: 0.75rem;
margin-bottom: 0.5rem;
border-radius: 0.5rem;
cursor: pointer;
transition: background-color 0.2s;
position: relative;
}
.session-item:hover {
background-color: #f8f9fa;
}
.session-item.active {
background-color: #e3f2fd;
border-left: 3px solid #0d6efd;
}
.session-item .session-title {
font-weight: 500;
margin-bottom: 0.25rem;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.session-item .session-time {
font-size: 0.75rem;
color: #6c757d;
}
.session-item .session-delete {
position: absolute;
top: 0.5rem;
right: 0.5rem;
opacity: 0;
transition: opacity 0.2s;
}
.session-item:hover .session-delete {
opacity: 1;
}
.chat-messages {
flex: 1;
overflow-y: auto;
padding: 1rem;
scroll-behavior: smooth;
}
.chat-message {
padding: 1rem;
margin-bottom: 1rem;
border-radius: 0.5rem;
}
.chat-question {
background-color: #e3f2fd;
}
.chat-answer {
background-color: #f5f5f5;
}
.chat-input-area {
padding: 1rem;
border-top: 1px solid #dee2e6;
}
.empty-state {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 100%;
color: #6c757d;
}
</style>
{% endblock %}
{% block content %}
<div class="chat-container">
<!-- 左侧:会话管理 -->
<div class="chat-sidebar">
<div class="p-3 border-bottom">
<div class="d-flex justify-content-between align-items-center mb-3">
<h6 class="mb-0"><i class="bi bi-chat-left-text"></i> 聊天会话</h6>
<button class="btn btn-sm btn-primary" onclick="createNewSession()">
<i class="bi bi-plus"></i> 新建
</button>
</div>
<button class="btn btn-sm btn-outline-danger w-100" onclick="clearAllSessions()">
<i class="bi bi-trash"></i> 清空所有
</button>
</div>
<div class="session-list" id="sessionList">
<div class="text-center text-muted py-5">
<i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
<p class="mt-2 small">暂无会话</p>
</div>
</div>
</div>
<!-- 右侧:对话页面 -->
<div class="chat-main">
<div class="p-3 border-bottom">
<div class="d-flex justify-content-between align-items-center">
<h6 class="mb-0"><i class="bi bi-chat-dots"></i> 对话</h6>
<select class="form-select form-select-sm" id="kbSelect" style="width: 200px;" onchange="onKbChange()">
<option value="">-- 选择知识库 --</option>
{% for kb in knowledgebases %}
<option value="{{ kb.id }}">{{ kb.name }}</option>
{% endfor %}
</select>
</div>
</div>
<div class="chat-messages" id="chatMessages">
<div class="empty-state">
<i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
<p class="mt-2">开始提问吧!</p>
</div>
</div>
<div class="chat-input-area">
<form id="chatForm" onsubmit="askQuestion(event)">
<div class="mb-2">
<textarea class="form-control" id="questionInput" rows="2"
placeholder="输入您的问题..." required></textarea>
</div>
<div class="d-flex justify-content-end">
<button type="submit" class="btn btn-primary" id="submitBtn">
<i class="bi bi-send"></i> 发送
</button>
</div>
</form>
</div>
</div>
</div>
{% endblock %}
{% block extra_js %}
<script>
</script>
{% endblock %}
3.3. init.py #
app/init.py
# RAG Lite 应用模块说明
"""
RAG Lite Application
"""
# 导入操作系统相关模块
import os
# 从 Flask 包导入 Flask 应用对象
from flask import Flask
# 导入 Flask 跨域资源共享支持
from flask_cors import CORS
# 导入应用配置类
from app.config import Config
# 导入日志工具,用于获取日志记录器
from app.utils.logger import get_logger
# 导入数据库初始化函数
from app.utils.db import init_db
# 导入蓝图模块
+from app.blueprints import auth,knowledgebase,settings,document,chat
# 导入获取当前用户信息函数
from app.utils.auth import get_current_user
# 定义创建 Flask 应用的工厂函数
def create_app(config_class=Config):
# 获取日志记录器,名称为当前模块名
logger = get_logger(__name__)
# 尝试初始化数据库
try:
# 输出日志,表示即将初始化数据库
logger.info("初始化数据库...")
# 执行数据库初始化函数
init_db()
# 输出日志,表示数据库初始化成功
logger.info("数据库初始化成功")
# 捕获任意异常
except Exception as e:
# 输出警告日志,提示数据库初始化失败,并输出异常信息
logger.warning(f"数据库初始化失败: {e}")
# 输出警告日志,提示检查数据库是否已存在,并建议手动创建数据表
logger.warning("请确认数据库已存在,或手动创建数据表")
# 创建 Flask 应用对象,并指定模板和静态文件目录
base_dir = os.path.abspath(os.path.dirname(__file__))
# 创建 Flask 应用对象,并指定模板和静态文件目录
app = Flask(
__name__,
# 指定模板文件目录
template_folder=os.path.join(base_dir, 'templates'),
# 指定静态文件目录
static_folder=os.path.join(base_dir, 'static')
)
# 从给定配置类加载配置信息到应用
app.config.from_object(config_class)
# 启用跨域请求支持
CORS(app)
# 记录应用创建日志信息
logger.info("Flask 应用已创建")
# 注册上下文处理器,使 current_user 在所有模板中可用
@app.context_processor
def inject_user():
# 返回当前用户信息字典
# 使用 get_current_user 获取当前用户信息,并将其添加到上下文字典中
# 这样在模板中可以直接使用 current_user 变量
return dict(current_user=get_current_user())
# 注册蓝图
app.register_blueprint(auth.bp)
# 注册知识库蓝图
app.register_blueprint(knowledgebase.bp)
# 注册设置蓝图
app.register_blueprint(settings.bp)
# 注册文档蓝图
app.register_blueprint(document.bp)
# 注册聊天蓝图``
+ app.register_blueprint(chat.bp)
# 定义首页路由
@app.route('/')
def index():
return "Hello, World!"
# 返回已配置的 Flask 应用对象
return app
3.4. init.py #
app/blueprints/init.py
# 定义此文件的描述信息:蓝图模块
"""
蓝图模块
"""
# 从 app.blueprints 包中分别导入 auth、settings、document、chat 模块
+from app.blueprints import auth, settings, document, chat
# 设置 __all__,用于声明对外可导出的模块名称列表
+__all__ = ['auth', 'settings', 'document', 'chat']4.聊天对话 #
4.1. chat_service.py #
app/services/chat_service.py
"""
问答服务
支持普通聊天和知识库聊天(RAG)
"""
# 导入日志模块
import logging
# 导入可选类型和迭代器类型注解
from typing import Optional, Iterator
# 导入 LLM 工厂,用于创建大语言模型实例
from app.utils.llm_factory import LLMFactory
# 导入 LangChain 的对话模板
from langchain_core.prompts import ChatPromptTemplate
# 导入设置服务,用于获取当前系统设置
from app.services.settings_service import settings_service
# 初始化日志记录器
logger = logging.getLogger(__name__)
# 定义问答服务类
class ChatService:
# 类的初始化方法
def __init__(self):
"""初始化问答服务"""
# 获取并保存当前的系统设置
self.settings = settings_service.get()
"""问答服务(支持普通聊天和RAG)"""
# 定义流式普通聊天方法,不使用知识库
def chat_stream(self, question: str, temperature: Optional[float] = None,
max_tokens: int = 1000, history: Optional[list] = None) -> Iterator[dict]:
"""
流式普通聊天接口(不使用知识库)
Args:
question: 问题
temperature: LLM 温度参数(如果为 None,则从设置中读取)
max_tokens: 最大生成 token 数
history: 历史对话记录(可选)
Yields:
流式数据块
"""
# 如果没有指定温度,则从设置中获取(默认为 0.7),并限制在 0-2 之间
if temperature is None:
temperature = float(self.settings.get('llm_temperature', '0.7'))
temperature = max(0.0, min(temperature, 2.0)) # 限制在 0-2 之间
# 获取用于普通聊天的系统提示词
chat_prompt_text = self.settings.get('chat_system_prompt')
# 如果系统提示词不存在,则使用默认的提示词
if not chat_prompt_text:
chat_prompt_text = '你是一个专业的AI助手。请友好、准确地回答用户的问题。'
# 创建支持流式输出的 LLM 实例
llm = LLMFactory.create_llm(self.settings, temperature=temperature, max_tokens=max_tokens, streaming=True)
# 构造单轮对话消息格式,包含 system 提示和用户问题
messages = [
("system", chat_prompt_text),
("human", question)
]
# 从消息创建对话提示模板
prompt = ChatPromptTemplate.from_messages(messages)
# 组装 prompt 和 llm,形成链式调用
chain = prompt | llm
# 发送流式开头信号
yield {
"type": "start",
"content": ""
}
# 初始化完整答案内容
full_answer = ""
try:
# 遍历模型生成的每一段内容
for chunk in chain.stream({}):
# 如果 chunk 有内容,提取内容并累加到full_answer
if hasattr(chunk, 'content') and chunk.content:
content = chunk.content
full_answer += content
# 输出内容块
yield {
"type": "content",
"content": content
}
# 捕获生成过程中的异常,记录日志并产出错误类型的数据块
except Exception as e:
logger.error(f"流式生成时出错: {e}")
yield {
"type": "error",
"content": f"生成答案时出错: {str(e)}"
}
return
# 发送流式结束信号,附带元数据(此处无知识库相关内容)
yield {
"type": "done",
"content": "",
"sources": [],
"metadata": {
'question': question,
'retrieved_chunks': 0,
'used_chunks': 0
}
}
# 创建全局单例 chat_service 实例
chat_service=ChatService()4.2. llm_factory.py #
app/utils/llm_factory.py
# LLM 模型工厂
# 根据设置动态创建 LLM 模型,支持扩展
"""
LLM 模型工厂
根据设置动态创建 LLM 模型,支持扩展
"""
# 导入日志模块
import logging
# 导入类型注解:可选、字典、可调用、任意类型
from typing import Optional, Dict, Callable, Any
# 导入设置服务,用于获取当前设置
from app.services.settings_service import settings_service
# 导入配置类
from app.config import Config
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 定义 LLM 工厂类
class LLMFactory:
# LLM 模型工厂(支持扩展)
"""LLM 模型工厂(支持扩展)"""
# 注册的LLM提供者,用于存储各Provider的构造函数
_providers: Dict[str, Callable] = {}
# 注册新的LLM提供者方法
@classmethod
def register_provider(cls, provider_name: str, provider_func: Callable):
# 注册新的LLM提供者
"""
注册新的LLM提供者
Args:
provider_name: 提供者名称
provider_func: 创建LLM的函数,签名应为:
func(settings: dict, temperature: float, max_tokens: int, streaming: bool) -> LLM
"""
# 将提供者函数存入_providers字典,键为小写名称
cls._providers[provider_name.lower()] = provider_func
# 日志打印已注册信息
logger.info(f"已注册 LLM 提供商: {provider_name}")
# 创建LLM实例方法
@classmethod
def create_llm(cls, settings: Optional[dict] = None, temperature: float = 0.7,
max_tokens: int = 1000, streaming: bool = False):
# 根据设置创建 LLM 模型
"""
根据设置创建 LLM 模型
Args:
settings: 设置字典,如果为 None 则从数据库读取
temperature: 温度参数
max_tokens: 最大 token 数
streaming: 是否启用流式输出
Returns:
LLM 对象
"""
# 如果未传入settings,从setting_service获取全局设置
if settings is None:
settings = settings_service.get()
# 获取llm_provider的名称(默认为deepseek)
provider = settings.get('llm_provider', 'deepseek').lower()
# 优先检查是否用户注册的 Provider
if provider in cls._providers:
# 使用自定义注册的Provider创建llm对象
return cls._providers[provider](settings, temperature, max_tokens, streaming)
# 若非自定义Provider,则按内置Provider处理
if provider == 'deepseek':
# 创建 DeepSeek LLM
return cls._create_deepseek(settings, temperature, max_tokens, streaming)
elif provider == 'openai':
# 创建 OpenAI LLM
return cls._create_openai(settings, temperature, max_tokens, streaming)
elif provider == 'ollama':
# 创建 Ollama LLM
return cls._create_ollama(settings, temperature, max_tokens, streaming)
else:
# 不支持的Provider,抛出错误
raise ValueError(
f"Unsupported LLM provider: {provider}. "
f"Available providers: {list(cls._providers.keys()) + ['deepseek', 'openai', 'ollama']}"
)
# DeepSeek LLM 创建方法
@classmethod
def _create_deepseek(cls, settings: dict, temperature: float, max_tokens: int, streaming: bool):
# 创建 DeepSeek LLM
"""创建 DeepSeek LLM"""
# 导入 DeepSeek LLM 的类
from langchain_deepseek import ChatDeepSeek
# 获取模型名,优先用settings里的值,否则用默认配置
model_name = settings.get('llm_model_name') or Config.DEEPSEEK_CHAT_MODEL
# 获取API Key,优先用settings里的值,否则用默认配置
api_key = settings.get('llm_api_key') or Config.DEEPSEEK_API_KEY
# 获取Base URL,优先用settings里的值,否则用默认配置
base_url = settings.get('llm_base_url') or Config.DEEPSEEK_BASE_URL
# 实例化 DeepSeek LLM 对象
llm = ChatDeepSeek(
model=model_name,
api_key=api_key,
base_url=base_url,
temperature=temperature,
max_tokens=max_tokens,
streaming=streaming
)
# 日志打印创建成功的信息
logger.info(f"已创建 DeepSeek LLM: {model_name}")
# 返回 DeepSeek LLM 实例
return llm
# OpenAI LLM 创建方法
@classmethod
def _create_openai(cls, settings: dict, temperature: float, max_tokens: int, streaming: bool):
# 创建 OpenAI LLM
"""创建 OpenAI LLM"""
# 导入OpenAI LLM的类
from langchain_openai import ChatOpenAI
# 从settings中获取API Key
api_key = settings.get('llm_api_key')
# 如果未设置API Key则报错
if not api_key:
raise ValueError("OpenAI API key is required")
# 获取模型名称,优先用用户配置,否则用默认gpt-4o
model_name = settings.get('llm_model_name') or 'gpt-4o'
# 实例化OpenAI LLM对象
llm = ChatOpenAI(
model=model_name,
api_key=api_key,
temperature=temperature,
max_tokens=max_tokens,
streaming=streaming
)
# 日志打印创建成功的信息
logger.info(f"已创建 OpenAI LLM: {model_name}")
# 返回OpenAI LLM实例
return llm
# Ollama LLM 创建方法
@classmethod
def _create_ollama(cls, settings: dict, temperature: float, max_tokens: int, streaming: bool):
# 创建 Ollama LLM
"""创建 Ollama LLM"""
# 导入Ollama LLM相关类
from langchain_community.chat_models import ChatOllama
# 获取基础URL(优先用setting,否则用默认本地地址)
base_url = settings.get('llm_base_url') or 'http://localhost:11434'
# 获取模型名,优先用用户配置
model_name = settings.get('llm_model_name') or 'llama2'
# 实例化Ollama LLM对象
llm = ChatOllama(
model=model_name,
base_url=base_url,
temperature=temperature,
num_predict=max_tokens
)
# 日志打印Ollama创建信息
logger.info(f"已创建 Ollama LLM: {model_name}, 地址: {base_url}")
# 返回 Ollama LLM 实例
return llm
# 注册内置提供者(可选,用于统一管理)
def _register_builtin_providers():
# 注册内置提供者
"""注册内置提供者"""
# 注册deepseek provider
LLMFactory.register_provider('deepseek', LLMFactory._create_deepseek)
# 注册openai provider
LLMFactory.register_provider('openai', LLMFactory._create_openai)
# 注册ollama provider
LLMFactory.register_provider('ollama', LLMFactory._create_ollama)
# 自动注册内置提供者
_register_builtin_providers()
4.3. chat.py #
app/blueprints/chat.py
# 聊天相关路由(视图 + API)
"""
聊天相关路由(视图 + API)
"""
# 导入 Flask 的 Blueprint 和模板渲染函数
+from flask import Blueprint, render_template, request, stream_with_context, Response
+import json
# 导入日志模块
+import logging
# 导入知识库服务,用于后续业务逻辑
from app.services.knowledgebase_service import kb_service
# 导入登录保护装饰器和获取当前用户辅助方法
+from app.utils.auth import login_required, get_current_user, api_login_required
# 导入自定义工具函数:成功响应、错误响应、获取分页参数、获取当前用户或错误、异常处理装饰器、检查所有权
+from app.blueprints.utils import (
+ success_response, error_response,
+ get_current_user_or_error, handle_api_error
+)
+from app.services.chat_service import chat_service
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 创建名为 'chat' 的蓝图对象
bp = Blueprint('chat', __name__)
# 注册 /chat 路由,访问该路由需要先登录
@bp.route('/chat')
@login_required
def chat_view():
# 智能问答页面视图函数
"""智能问答页面"""
# 渲染 chat.html 模板并传递空知识库列表
return render_template('chat.html', knowledgebases=[])
# 注册 API 路由,处理聊天接口 POST 请求
+@bp.route('/api/v1/knowledgebases/chat', methods=['POST'])
+@api_login_required
+@handle_api_error
+def api_chat():
# 普通聊天接口(不支持知识库,支持流式输出)
+ """普通聊天接口(不支持知识库,支持流式输出)"""
# 获取当前用户和错误信息
+ current_user, err = get_current_user_or_error()
# 如果有错误,直接返回错误响应
+ if err:
+ return err
# 从请求体获取 JSON 数据
+ data = request.get_json()
# 如果数据为空或不存在 'question' 字段,返回错误
+ if not data or 'question' not in data:
+ return error_response("question is required", 400)
# 去除问题文本首尾空格
+ question = data['question'].strip()
# 如果问题内容为空,返回错误
+ if not question:
+ return error_response("question cannot be empty", 400)
# 获取 max_tokens 参数,默认 1000
+ max_tokens = int(data.get('max_tokens', 1000))
# 限制最大和最小值在 1~10000 之间
+ max_tokens = max(1, min(max_tokens, 10000)) # 限制在 1-10000 之间
# 声明用于流式输出的生成器
+ @stream_with_context
+ def generate():
+ try:
# 用于缓存完整答案内容
+ full_answer = ''
# 调用服务进行流式对话
+ for chunk in chat_service.chat_stream(
+ question=question,
+ temperature=None, # 使用设置中的值
+ max_tokens=max_tokens
+ ):
# 如果是内容块,则拼接内容到 full_answer
+ if chunk.get('type') == 'content':
+ full_answer += chunk.get('content', '')
# 以 SSE 协议格式输出数据
+ yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
# 输出对话完成信号
+ yield "data: [DONE]\n\n"
+ except Exception as e:
# 发生异常记录日志
+ logger.error(f"流式输出时出错: {e}")
# 构造错误数据块
+ error_chunk = {
+ "type": "error",
+ "content": str(e)
+ }
# 输出错误数据块
+ yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"
# 创建 Response 对象,设置必要的 SSE 响应头部
+ response = Response(
+ generate(),
+ mimetype='text/event-stream',
+ headers={
+ 'Cache-Control': 'no-cache',
+ 'Connection': 'keep-alive',
+ 'X-Accel-Buffering': 'no',
+ 'Content-Type': 'text/event-stream; charset=utf-8'
+ }
+ )
# 返回响应
+ return response4.4. chat.html #
app/templates/chat.html
{% extends "base.html" %}
{% block title %}智能问答 - RAG Lite{% endblock %}
{% block extra_css %}
<style>
.chat-container {
height: calc(100vh - 200px);
display: flex;
gap: 1rem;
}
.chat-sidebar {
width: 280px;
background: #fff;
border-radius: 0.5rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
display: flex;
flex-direction: column;
}
.chat-main {
flex: 1;
background: #fff;
border-radius: 0.5rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
display: flex;
flex-direction: column;
}
.session-list {
flex: 1;
overflow-y: auto;
padding: 0.5rem;
}
.session-item {
padding: 0.75rem;
margin-bottom: 0.5rem;
border-radius: 0.5rem;
cursor: pointer;
transition: background-color 0.2s;
position: relative;
}
.session-item:hover {
background-color: #f8f9fa;
}
.session-item.active {
background-color: #e3f2fd;
border-left: 3px solid #0d6efd;
}
.session-item .session-title {
font-weight: 500;
margin-bottom: 0.25rem;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.session-item .session-time {
font-size: 0.75rem;
color: #6c757d;
}
.session-item .session-delete {
position: absolute;
top: 0.5rem;
right: 0.5rem;
opacity: 0;
transition: opacity 0.2s;
}
.session-item:hover .session-delete {
opacity: 1;
}
.chat-messages {
flex: 1;
overflow-y: auto;
padding: 1rem;
scroll-behavior: smooth;
}
.chat-message {
padding: 1rem;
margin-bottom: 1rem;
border-radius: 0.5rem;
}
.chat-question {
background-color: #e3f2fd;
}
.chat-answer {
background-color: #f5f5f5;
}
.chat-input-area {
padding: 1rem;
border-top: 1px solid #dee2e6;
}
.empty-state {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 100%;
color: #6c757d;
}
</style>
{% endblock %}
{% block content %}
<div class="chat-container">
<!-- 左侧:会话管理 -->
<div class="chat-sidebar">
<div class="p-3 border-bottom">
<div class="d-flex justify-content-between align-items-center mb-3">
<h6 class="mb-0"><i class="bi bi-chat-left-text"></i> 聊天会话</h6>
<button class="btn btn-sm btn-primary" onclick="createNewSession()">
<i class="bi bi-plus"></i> 新建
</button>
</div>
<button class="btn btn-sm btn-outline-danger w-100" onclick="clearAllSessions()">
<i class="bi bi-trash"></i> 清空所有
</button>
</div>
<div class="session-list" id="sessionList">
<div class="text-center text-muted py-5">
<i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
<p class="mt-2 small">暂无会话</p>
</div>
</div>
</div>
<!-- 右侧:对话页面 -->
<div class="chat-main">
<div class="p-3 border-bottom">
<div class="d-flex justify-content-between align-items-center">
<h6 class="mb-0"><i class="bi bi-chat-dots"></i> 对话</h6>
</div>
</div>
<div class="chat-messages" id="chatMessages">
<div class="empty-state">
<i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
<p class="mt-2">开始提问吧!</p>
</div>
</div>
<div class="chat-input-area">
<form id="chatForm" onsubmit="askQuestion(event)">
<div class="mb-2">
<textarea class="form-control" id="questionInput" rows="2"
placeholder="输入您的问题..." required></textarea>
</div>
<div class="d-flex justify-content-end">
<button type="submit" class="btn btn-primary" id="submitBtn">
<i class="bi bi-send"></i> 发送
</button>
</div>
</form>
</div>
</div>
</div>
{% endblock %}
{% block extra_js %}
<script>
+// 为字符串进行HTML转义,防止XSS攻击
+function escapeHtml(text) {
+ // 创建一个div元素作为容器
+ const div = document.createElement('div');
+ // 将待转义文本设置为div的textContent(自动完成转义)
+ div.textContent = text;
+ // 返回转义后的HTML内容
+ return div.innerHTML;
+}
+// 滚动消息框到底部
+function scrollToBottom() {
+ // 获取聊天消息区域元素
+ const chatMessages = document.getElementById('chatMessages');
+ // 设置滚动条位置到最底部
+ chatMessages.scrollTop = chatMessages.scrollHeight;
+}
+// 将Markdown内容渲染到指定元素(支持降级为普通文本)
+function renderMarkdownToElement(element, text) {
+ // 若未提供内容,显示思考中图标
+ if (!text) {
+ element.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
+ return;
+ }
+ // 优先判断marked库是否可用(渲染markdown)
+ if (typeof marked !== 'undefined' && marked.parse) {
+ try {
+ // 使用marked进行markdown转html
+ element.innerHTML = marked.parse(text);
+ } catch (e) {
+ // 渲染失败则退化为转义+换行
+ element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
+ }
+ } else {
+ // 没有marked库则直接转义+换行显示
+ element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
+ }
+}
+// 主函数:处理用户提交问题事件
+async function askQuestion(event) {
+ // 阻止表单默认提交行为(防止页面刷新)
+ event.preventDefault();
+ // 获取输入框的用户问题并去除首尾空白
+ const question = document.getElementById('questionInput').value.trim();
+ // 获取消息显示区域元素
+ const chatMessages = document.getElementById('chatMessages');
+ // 若问题为空则直接返回
+ if (!question) return;
+ // 检查并移除初始空白提示(如有)
+ if (chatMessages.querySelector('.empty-state')) {
+ chatMessages.innerHTML = '';
+ }
+ // 创建用于展示问题的div元素
+ const questionDiv = document.createElement('div');
+ // 加上样式: 用户问题
+ questionDiv.className = 'chat-message chat-question';
+ // 构建用户气泡内容含图标、文本和时间
+ questionDiv.innerHTML = `
+ <div class="d-flex justify-content-between align-items-start">
+ <div class="flex-grow-1">
+ <strong><i class="bi bi-person-circle"></i> 问题:</strong>
+ <div class="mt-1">${escapeHtml(question)}</div>
+ </div>
+ <small class="text-muted">${new Date().toLocaleTimeString()}</small>
+ </div>
+ `;
+ // 显示到对话窗口
+ chatMessages.appendChild(questionDiv);
+ // 创建用于显示答案的div元素
+ const answerDiv = document.createElement('div');
+ // 答案样式
+ answerDiv.className = 'chat-message chat-answer';
+ // 动态生成唯一的答案内容div id(用于唯一标记)
+ const answerContentId = 'answerContent_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
+ // 创建存放答案内容的div
+ const answerContent = document.createElement('div');
+ // 为markdown渲染容器设置样式和id
+ answerContent.className = 'mt-2 markdown-content';
+ answerContent.id = answerContentId;
+ // 答案div显示机器人图标和时间
+ answerDiv.innerHTML = `
+ <div class="d-flex justify-content-between align-items-start">
+ <div class="flex-grow-1">
+ <strong><i class="bi bi-robot"></i> 答案:</strong>
+ </div>
+ <small class="text-muted">${new Date().toLocaleTimeString()}</small>
+ </div>
+ `;
+ // 将答案内容div插入到机器人气泡div的内容区
+ const flexGrowDiv = answerDiv.querySelector('.flex-grow-1');
+ flexGrowDiv.appendChild(answerContent);
+ // 插入答案div到消息区域
+ chatMessages.appendChild(answerDiv);
+ // 变量记录完整的回答内容
+ let fullAnswer = '';
+ // 标记渲染任务是否挂起(防止重复)
+ let pendingUpdate = false;
+ // 记录定时器id(去抖动用)
+ let updateTimer = null;
+ // 清空输入框内容
+ document.getElementById('questionInput').value = '';
+ // 滚动到底
+ scrollToBottom();
+ // 定义scheduleRender:将markdown渲染插入到队列合适时机执行
+ function scheduleRender() {
+ // 若当前无待渲染任务才安排渲染
+ if (!pendingUpdate) {
+ pendingUpdate = true;
+ // 下一帧渲染
+ requestAnimationFrame(() => {
+ // 将答案作为markdown渲染进dom
+ renderMarkdownToElement(answerContent, fullAnswer);
+ // 清理pending标识
+ pendingUpdate = false;
+ // 渲染后滚动到底部
+ scrollToBottom();
+ });
+ }
+ }
+ try {
+ // 组装API接口地址
+ const url = `/api/v1/knowledgebases/chat`;
+ // 请求后端,发起流式POST请求
+ const response = await fetch(url, {
+ method: 'POST',
+ headers: {'Content-Type': 'application/json'},
+ body: JSON.stringify({
+ question: question,
+ stream: true
+ })
+ });
+ // 非200响应时,抛出异常
+ if (!response.ok) throw new Error('请求失败');
+ // 使用ReadableStream reader获取response流
+ const reader = response.body.getReader();
+ // 用TextDecoder解码二进制数据
+ const decoder = new TextDecoder();
+ // 数据缓冲区(字符串)
+ let buffer = '';
+ // 初始展示“思考中...”提示
+ answerContent.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
+ // 不断循环读取服务端推送流内容
+ while (true) {
+ // 逐步读取一段数据
+ const { done, value } = await reader.read();
+ // 若读完则结束循环
+ if (done) break;
+ // 本块新数据解码成字符串,并追加到buffer
+ buffer += decoder.decode(value, { stream: true });
+ // 按行切分(多行分别处理)
+ const lines = buffer.split('\n');
+ // 最后一行一般为数据残留,不处理,下次拼
+ buffer = lines.pop() || '';
+ // 处理每一行数据
+ for (const line of lines) {
+ // 筛选有效SSE协议"data: "数据包
+ if (line.startsWith('data: ')) {
+ // 去掉前缀,留下json内容
+ const data = line.slice(6);
+ // [DONE]信号作为流结束,仅跳过
+ if (data === '[DONE]') continue;
+ try {
+ // 解析JSON数据
+ const chunk = JSON.parse(data);
+ // 类型为流起始,清空内容
+ if (chunk.type === 'start') {
+ fullAnswer = '';
+ answerContent.innerHTML = '';
+ // 类型为内容,追加文本并安排去抖渲染
+ } else if (chunk.type === 'content') {
+ fullAnswer += chunk.content;
+ scheduleRender();
+ // 类型为done(流传输结束),直接最终渲染所有答案内容
+ } else if (chunk.type === 'done') {
+ renderMarkdownToElement(answerContent, fullAnswer);
+ // 错误类型,alert显示报错内容
+ } else if (chunk.type === 'error') {
+ answerContent.innerHTML = `<div class="alert alert-danger">${chunk.content}</div>`;
+ }
+ } catch (e) {
+ // JSON解析失败时报console
+ console.error('解析流数据失败:', e);
+ }
+ }
+ }
+ }
+ } catch (error) {
+ // 通信异常、Fetch错误等: 显示错误气泡
+ answerContent.innerHTML = `<div class="alert alert-danger"><strong>错误:</strong> ${error.message}</div>`;
+ }
+ // 结束后确保界面滚动到底
+ scrollToBottom();
+}
</script>
{% endblock %}
5.会话管理 #
5.1. chat_session_service.py #
app/services/chat_session_service.py
# 导入倒序排序工具
from sqlalchemy import desc
# 导入日期时间
from datetime import datetime
# 导入聊天会话ORM模型
from app.models.chat_session import ChatSession
# 导入聊天消息ORM模型
from app.models.chat_message import ChatMessage
# 导入基础服务类
from app.services.base_service import BaseService
# 聊天会话服务类,继承自基础服务
class ChatSessionService(BaseService[ChatSession]):
# 聊天会话服务说明文档
"""聊天会话服务"""
# 创建新的聊天会话
def create_session(self, user_id: str, kb_id: str = None, title: str = None) -> dict:
"""
创建新的聊天会话
Args:
user_id: 用户ID
kb_id: 知识库ID(可选)
title: 会话标题(可选,如果不提供则使用默认标题)
Returns:
会话信息字典
"""
# 启动数据库事务
with self.transaction() as session:
# 如果没有传标题则用默认标题
if not title:
title = "新对话"
# 构造会话对象
chat_session = ChatSession(
user_id=user_id,
title=title
)
# 新会话入库
session.add(chat_session)
# 刷新以拿到自增ID
session.flush()
# 刷新会话对象,便于获取ID等字段
session.refresh(chat_session)
# 记录日志
self.logger.info(f"已创建聊天会话: {chat_session.id}, 用户: {user_id}")
# 返回会话字典格式
return chat_session.to_dict()
# 根据ID获取会话
def get_session_by_id(self, session_id: str, user_id: str = None) -> dict:
"""
根据ID获取会话
Args:
session_id: 会话ID
user_id: 用户ID(可选,用于验证权限)
Returns:
会话信息字典,如果不存在或无权访问则返回 None
"""
# 打开数据库只读session
with self.session() as session:
# 查询指定ID的会话
query = session.query(ChatSession).filter_by(id=session_id)
# 如果提供了user_id则额外限定归属
if user_id:
query = query.filter_by(user_id=user_id)
# 拿到第一个会话记录
chat_session = query.first()
# 有则返回字典信息,没有返回None
if chat_session:
return chat_session.to_dict()
return None
# 获取用户的所有会话列表(分页)
def list_sessions(self, user_id: str, page: int = 1, page_size: int = 100) -> dict:
"""
获取用户的会话列表
Args:
user_id: 用户ID
page: 页码
page_size: 每页数量
Returns:
包含总数和会话列表的字典
"""
# 打开数据库只读session
with self.session() as session:
# 查询当前用户的所有会话
query = session.query(ChatSession).filter_by(user_id=user_id)
# 用基类的分页方法返回结构化内容,按更新时间倒序
return self.paginate_query(query, page=page, page_size=page_size,
order_by=desc(ChatSession.updated_at))
# 会话标题修改
def update_session_title(self, session_id: str, user_id: str, title: str) -> dict:
"""
更新会话标题
Args:
session_id: 会话ID
user_id: 用户ID(用于验证权限)
title: 新标题
Returns:
更新后的会话信息字典
"""
# 启动数据库事务
with self.transaction() as session:
# 查询拥有该会话的用户
chat_session = session.query(ChatSession).filter_by(id=session_id, user_id=user_id).first()
# 会话不存在则抛异常
if not chat_session:
raise ValueError("Session not found or access denied")
# 更新标题
chat_session.title = title
# 刷新会话对象(update不会提交,refresh刷新对象)
session.refresh(chat_session)
# 返回最新数据
return chat_session.to_dict()
# 删除指定会话
def delete_session(self, session_id: str, user_id: str) -> bool:
"""
删除会话(级联删除消息)
Args:
session_id: 会话ID
user_id: 用户ID(用于验证权限)
Returns:
是否删除成功
"""
# 开启数据库事务
with self.transaction() as session:
# 查询该用户的指定会话
chat_session = session.query(ChatSession).filter_by(id=session_id, user_id=user_id).first()
# 未找到会话则返回False
if not chat_session:
return False
# 删除会话(DB应有级联消息)
session.delete(chat_session)
# 记录日志
self.logger.info(f"已删除聊天会话: {session_id}")
# 返回True表示删除成功
return True
# 删除当前用户的所有会话
def delete_all_sessions(self, user_id: str) -> int:
"""
删除用户的所有会话
Args:
user_id: 用户ID
Returns:
删除的会话数量
"""
# 开启数据库事务
with self.transaction() as session:
# 批量删除本用户所有会话
count = session.query(ChatSession).filter_by(user_id=user_id).delete()
# 记录日志
self.logger.info(f"已删除用户 {user_id} 的 {count} 个聊天会话")
# 返回删除数量
return count
# 添加消息到会话
def add_message(self, session_id: str, role: str, content: str) -> dict:
"""
添加消息到会话
Args:
session_id: 会话ID
role: 角色('user' 或 'assistant')
content: 消息内容
sources: 引用来源列表(可选)
Returns:
消息信息字典
"""
# 开启数据库事务
with self.transaction() as session:
# 构造消息对象
message = ChatMessage(
session_id=session_id,
role=role,
content=content
)
# 添加消息到数据库
session.add(message)
# 查询会话对象,用于更新时间/自动生成标题
chat_session = session.query(ChatSession).filter_by(id=session_id).first()
# 如果存在会话对象
if chat_session:
# 更新会话更新时间
chat_session.updated_at = datetime.now()
# 如果是用户发的第一条消息,并且还没标题,则用内容自动命名
if role == 'user' and (not chat_session.title or chat_session.title == "新对话"):
# 会话标题截取前30字符,超长加省略号
title = content[:30] + ('...' if len(content) > 30 else '')
chat_session.title = title
# 刷新确保message有ID
session.flush()
# 刷新消息对象
session.refresh(message)
# 返回消息字典
return message.to_dict()
# 获取会话的全部消息
def get_messages(self, session_id: str, user_id: str = None) -> list:
"""
获取会话的所有消息
Args:
session_id: 会话ID
user_id: 用户ID(可选,用于验证权限)
Returns:
消息列表
"""
# 打开只读session
with self.session() as session:
# 如指定user_id,须先验证此会话是否属于该用户,不属则不给查
if user_id:
chat_session = session.query(ChatSession).filter_by(id=session_id, user_id=user_id).first()
# 如果会话不存在则返回空列表
if not chat_session:
return []
# 查询该会话下所有消息,按创建时间升序排序
messages = session.query(ChatMessage).filter_by(session_id=session_id).order_by(ChatMessage.created_at).all()
# 返回所有消息的字典列表
return [m.to_dict() for m in messages]
# 单例: 聊天会话服务对象
session_service = ChatSessionService()
5.2. chat.py #
app/blueprints/chat.py
# 聊天相关路由(视图 + API)
"""
聊天相关路由(视图 + API)
"""
# 导入 Flask 的 Blueprint 和模板渲染函数
from flask import Blueprint, render_template, request, stream_with_context, Response
import json
# 导入日志模块
import logging
# 导入登录保护装饰器和获取当前用户辅助方法
+from app.utils.auth import login_required, api_login_required
# 导入自定义工具函数:成功响应、错误响应、获取分页参数、获取当前用户或错误、异常处理装饰器、检查所有权
from app.blueprints.utils import (
success_response, error_response,
+ get_current_user_or_error, handle_api_error, get_pagination_params
)
from app.services.chat_service import chat_service
+from app.services.chat_session_service import session_service
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 创建名为 'chat' 的蓝图对象
bp = Blueprint('chat', __name__)
# 注册 /chat 路由,访问该路由需要先登录
@bp.route('/chat')
@login_required
def chat_view():
# 智能问答页面视图函数
"""智能问答页面"""
# 渲染 chat.html 模板并传递空知识库列表
return render_template('chat.html', knowledgebases=[])
# 注册 API 路由,处理聊天接口 POST 请求
@bp.route('/api/v1/knowledgebases/chat', methods=['POST'])
@api_login_required
@handle_api_error
def api_chat():
# 普通聊天接口(不支持知识库,支持流式输出)
"""普通聊天接口(不支持知识库,支持流式输出)"""
# 获取当前用户和错误信息
current_user, err = get_current_user_or_error()
# 如果有错误,直接返回错误响应
if err:
return err
# 从请求体获取 JSON 数据
data = request.get_json()
# 如果数据为空或不存在 'question' 字段,返回错误
if not data or 'question' not in data:
return error_response("question is required", 400)
# 去除问题文本首尾空格
question = data['question'].strip()
# 如果问题内容为空,返回错误
if not question:
return error_response("question cannot be empty", 400)
+ session_id = data.get('session_id') # 会话ID(可以为空,表示普通聊天)
# 获取 max_tokens 参数,默认 1000
max_tokens = int(data.get('max_tokens', 1000))
# 限制最大和最小值在 1~10000 之间
max_tokens = max(1, min(max_tokens, 10000)) # 限制在 1-10000 之间
# 从请求数据中获取'stream'字段,默认为True,表示启用流式输出
+ stream = data.get('stream', True) # 默认启用流式输出
# 初始化历史消息为None
+ history = None
# 如果请求中带有session_id,说明有现有会话
+ if session_id:
# 根据session_id和当前用户ID获取历史消息列表
+ history_messages = session_service.get_messages(session_id, current_user['id'])
# 将历史消息转换为对话格式,仅保留最近10条
+ history = [
+ {'role': msg.get('role'), 'content': msg.get('content')}
+ for msg in history_messages[-10:] # 只取最近10条
+ ]
# 如果请求中没有session_id,说明是新对话,需要新建会话
+ if not session_id:
# 创建新会话,kb_id设为None表示普通聊天
+ chat_session = session_service.create_session(
+ user_id=current_user['id']
+ )
# 使用新创建会话的ID作为本次会话ID
+ session_id = chat_session['id']
# 将用户的问题消息保存到当前会话中
+ session_service.add_message(session_id, 'user', question)
# 声明用于流式输出的生成器
@stream_with_context
def generate():
try:
# 用于缓存完整答案内容
full_answer = ''
# 调用服务进行流式对话
for chunk in chat_service.chat_stream(
question=question,
temperature=None, # 使用设置中的值
+ max_tokens=max_tokens,
+ history=history
):
# 如果是内容块,则拼接内容到 full_answer
if chunk.get('type') == 'content':
full_answer += chunk.get('content', '')
# 以 SSE 协议格式输出数据
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
# 输出对话完成信号
yield "data: [DONE]\n\n"
# 保存助手回复
+ if full_answer:
+ session_service.add_message(session_id, 'assistant', full_answer)
except Exception as e:
# 发生异常记录日志
logger.error(f"流式输出时出错: {e}")
# 构造错误数据块
error_chunk = {
"type": "error",
"content": str(e)
}
# 输出错误数据块
yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"
# 创建 Response 对象,设置必要的 SSE 响应头部
response = Response(
generate(),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Accel-Buffering': 'no',
'Content-Type': 'text/event-stream; charset=utf-8'
}
)
# 返回响应
return response
# 路由装饰器,定义 GET 方法获取会话列表的接口
+@bp.route('/api/v1/knowledgebases/sessions', methods=['GET'])
# API 登录校验装饰器,确保用户已登录
+@api_login_required
# 错误处理装饰器,统一处理接口异常
+@handle_api_error
+def api_list_sessions():
# 接口描述:获取当前用户的会话列表
+ """获取当前用户的会话列表"""
# 获取当前用户,如有错误直接返回错误响应
+ current_user, err = get_current_user_or_error()
+ if err:
+ return err
# 获取分页参数(页码和每页数量),最大单页1000
+ page, page_size = get_pagination_params(max_page_size=1000)
# 调用会话服务获取当前用户的会话列表
+ result = session_service.list_sessions(current_user['id'], page=page, page_size=page_size)
# 以统一成功响应格式返回会话列表
+ return success_response(result)
# 路由装饰器,定义 POST 方法创建会话的接口
+@bp.route('/api/v1/knowledgebases/sessions', methods=['POST'])
+@api_login_required
+@handle_api_error
+def api_create_session():
# 接口描述:创建新的聊天会话
+ """创建新的聊天会话"""
# 获取当前用户,如果有错误直接返回
+ current_user, err = get_current_user_or_error()
+ if err:
+ return err
# 获取请求体中的 JSON 数据,若无返回空字典
+ data = request.get_json() or {}
# 获取会话标题
+ title = data.get('title')
# 调用服务创建会话,传入当前用户ID、知识库ID与标题
+ session_obj = session_service.create_session(
+ user_id=current_user['id'],
+ title=title
+ )
# 返回成功响应及会话对象
+ return success_response(session_obj)
# 路由装饰器,定义 GET 方法获取单个会话详情的接口(带 session_id)
+@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['GET'])
+@api_login_required
+@handle_api_error
+def api_get_session(session_id):
# 接口描述:获取会话详情和消息
+ """获取会话详情和消息"""
# 获取当前用户,如有错误直接返回
+ current_user, err = get_current_user_or_error()
+ if err:
+ return err
# 根据 session_id 获取会话对象,校验所属当前用户
+ session_obj = session_service.get_session_by_id(session_id, current_user['id'])
# 如果没有找到会话,返回 404 错误
+ if not session_obj:
+ return error_response("Session not found", 404)
# 获取该会话下的所有消息
+ messages = session_service.get_messages(session_id, current_user['id'])
# 返回会话详情及消息列表
+ return success_response({
+ 'session': session_obj,
+ 'messages': messages
+ })
# 路由装饰器,定义 DELETE 方法删除单个会话接口
+@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['DELETE'])
+@api_login_required
+@handle_api_error
+def api_delete_session(session_id):
# 接口描述:删除会话
+ """删除会话"""
# 获取当前用户,如有错误直接返回
+ current_user, err = get_current_user_or_error()
+ if err:
+ return err
# 调用服务删除会话,校验归属当前用户
+ success = session_service.delete_session(session_id, current_user['id'])
# 若删除成功,返回成功响应,否则返回 404
+ if success:
+ return success_response(None, "Session deleted")
+ else:
+ return error_response("Session not found", 404)
# 路由装饰器,定义 DELETE 方法清空所有会话的接口
+@bp.route('/api/v1/knowledgebases/sessions', methods=['DELETE'])
+@api_login_required
+@handle_api_error
+def api_delete_all_sessions():
# 接口描述:清空所有会话
+ """清空所有会话"""
# 获取当前用户,如果有错误直接返回
+ current_user, err = get_current_user_or_error()
+ if err:
+ return err
# 调用服务删除所有属于当前用户的会话,返回删除数量
+ count = session_service.delete_all_sessions(current_user['id'])
# 返回成功响应及被删除会话数
+ return success_response({'deleted_count': count}, f"Deleted {count} sessions")
5.3. chat.html #
app/templates/chat.html
{% extends "base.html" %}
{% block title %}智能问答 - RAG Lite{% endblock %}
{% block extra_css %}
<style>
.chat-container {
height: calc(100vh - 200px);
display: flex;
gap: 1rem;
}
.chat-sidebar {
width: 280px;
background: #fff;
border-radius: 0.5rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
display: flex;
flex-direction: column;
}
.chat-main {
flex: 1;
background: #fff;
border-radius: 0.5rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
display: flex;
flex-direction: column;
}
.session-list {
flex: 1;
overflow-y: auto;
padding: 0.5rem;
}
.session-item {
padding: 0.75rem;
margin-bottom: 0.5rem;
border-radius: 0.5rem;
cursor: pointer;
transition: background-color 0.2s;
position: relative;
}
.session-item:hover {
background-color: #f8f9fa;
}
.session-item.active {
background-color: #e3f2fd;
border-left: 3px solid #0d6efd;
}
.session-item .session-title {
font-weight: 500;
margin-bottom: 0.25rem;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.session-item .session-time {
font-size: 0.75rem;
color: #6c757d;
}
.session-item .session-delete {
position: absolute;
top: 0.5rem;
right: 0.5rem;
opacity: 0;
transition: opacity 0.2s;
}
.session-item:hover .session-delete {
opacity: 1;
}
.chat-messages {
flex: 1;
overflow-y: auto;
padding: 1rem;
scroll-behavior: smooth;
}
.chat-message {
padding: 1rem;
margin-bottom: 1rem;
border-radius: 0.5rem;
}
.chat-question {
background-color: #e3f2fd;
}
.chat-answer {
background-color: #f5f5f5;
}
.chat-input-area {
padding: 1rem;
border-top: 1px solid #dee2e6;
}
.empty-state {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 100%;
color: #6c757d;
}
</style>
{% endblock %}
{% block content %}
<div class="chat-container">
<!-- 左侧:会话管理 -->
<div class="chat-sidebar">
<div class="p-3 border-bottom">
<div class="d-flex justify-content-between align-items-center mb-3">
<h6 class="mb-0"><i class="bi bi-chat-left-text"></i> 聊天会话</h6>
<button class="btn btn-sm btn-primary" onclick="createNewSession()">
<i class="bi bi-plus"></i> 新建
</button>
</div>
<button class="btn btn-sm btn-outline-danger w-100" onclick="clearAllSessions()">
<i class="bi bi-trash"></i> 清空所有
</button>
</div>
<div class="session-list" id="sessionList">
<div class="text-center text-muted py-5">
<i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
<p class="mt-2 small">暂无会话</p>
</div>
</div>
</div>
<!-- 右侧:对话页面 -->
<div class="chat-main">
<div class="p-3 border-bottom">
<div class="d-flex justify-content-between align-items-center">
<h6 class="mb-0"><i class="bi bi-chat-dots"></i> 对话</h6>
</div>
</div>
<div class="chat-messages" id="chatMessages">
<div class="empty-state">
<i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
<p class="mt-2">开始提问吧!</p>
</div>
</div>
<div class="chat-input-area">
<form id="chatForm" onsubmit="askQuestion(event)">
<div class="mb-2">
<textarea class="form-control" id="questionInput" rows="2"
placeholder="输入您的问题..." required></textarea>
</div>
<div class="d-flex justify-content-end">
<button type="submit" class="btn btn-primary" id="submitBtn">
<i class="bi bi-send"></i> 发送
</button>
</div>
</form>
</div>
</div>
</div>
{% endblock %}
{% block extra_js %}
<script>
+// 当前会话的ID,初始为null,表示暂未选择任何会话
+let currentSessionId = null;
+// 会话列表,初始为空数组,用于存储所有的会话对象
+let sessions = [];
// 为字符串进行HTML转义,防止XSS攻击
function escapeHtml(text) {
// 创建一个div元素作为容器
const div = document.createElement('div');
// 将待转义文本设置为div的textContent(自动完成转义)
div.textContent = text;
// 返回转义后的HTML内容
return div.innerHTML;
}
// 滚动消息框到底部
function scrollToBottom() {
// 获取聊天消息区域元素
const chatMessages = document.getElementById('chatMessages');
// 设置滚动条位置到最底部
chatMessages.scrollTop = chatMessages.scrollHeight;
}
+// 定义一个用于渲染Markdown文本为HTML的函数
+function renderMarkdown(text) {
+ // 判断marked库是否已加载且能使用解析方法
+ if (typeof marked !== 'undefined' && marked.parse) {
+ try {
+ // 尝试用marked库将Markdown文本解析成HTML
+ return marked.parse(text);
+ } catch (e) {
+ // 如果解析出错,则进行HTML转义并换行
+ return escapeHtml(text).replace(/\n/g, '<br>');
+ }
+ }
+ // 如果marked库不可用,直接做HTML转义并处理换行
+ return escapeHtml(text).replace(/\n/g, '<br>');
+}
// 将Markdown内容渲染到指定元素(支持降级为普通文本)
function renderMarkdownToElement(element, text) {
// 若未提供内容,显示思考中图标
if (!text) {
element.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
return;
}
// 优先判断marked库是否可用(渲染markdown)
if (typeof marked !== 'undefined' && marked.parse) {
try {
// 使用marked进行markdown转html
element.innerHTML = marked.parse(text);
} catch (e) {
// 渲染失败则退化为转义+换行
element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
}
} else {
// 没有marked库则直接转义+换行显示
element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
}
}
// 主函数:处理用户提交问题事件
async function askQuestion(event) {
// 阻止表单默认提交行为(防止页面刷新)
event.preventDefault();
// 获取输入框的用户问题并去除首尾空白
const question = document.getElementById('questionInput').value.trim();
// 获取消息显示区域元素
const chatMessages = document.getElementById('chatMessages');
// 若问题为空则直接返回
if (!question) return;
+ // 如果没有会话,创建新会话
+ if (!currentSessionId) {
+ await createNewSession();
+ }
// 检查并移除初始空白提示(如有)
if (chatMessages.querySelector('.empty-state')) {
chatMessages.innerHTML = '';
}
// 创建用于展示问题的div元素
const questionDiv = document.createElement('div');
// 加上样式: 用户问题
questionDiv.className = 'chat-message chat-question';
// 构建用户气泡内容含图标、文本和时间
questionDiv.innerHTML = `
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-person-circle"></i> 问题:</strong>
<div class="mt-1">${escapeHtml(question)}</div>
</div>
<small class="text-muted">${new Date().toLocaleTimeString()}</small>
</div>
`;
// 显示到对话窗口
chatMessages.appendChild(questionDiv);
// 创建用于显示答案的div元素
const answerDiv = document.createElement('div');
// 答案样式
answerDiv.className = 'chat-message chat-answer';
// 动态生成唯一的答案内容div id(用于唯一标记)
const answerContentId = 'answerContent_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
// 创建存放答案内容的div
const answerContent = document.createElement('div');
// 为markdown渲染容器设置样式和id
answerContent.className = 'mt-2 markdown-content';
answerContent.id = answerContentId;
// 答案div显示机器人图标和时间
answerDiv.innerHTML = `
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-robot"></i> 答案:</strong>
</div>
<small class="text-muted">${new Date().toLocaleTimeString()}</small>
</div>
`;
// 将答案内容div插入到机器人气泡div的内容区
const flexGrowDiv = answerDiv.querySelector('.flex-grow-1');
flexGrowDiv.appendChild(answerContent);
// 插入答案div到消息区域
chatMessages.appendChild(answerDiv);
// 变量记录完整的回答内容
let fullAnswer = '';
// 标记渲染任务是否挂起(防止重复)
let pendingUpdate = false;
// 记录定时器id(去抖动用)
let updateTimer = null;
// 清空输入框内容
document.getElementById('questionInput').value = '';
// 滚动到底
scrollToBottom();
// 定义scheduleRender:将markdown渲染插入到队列合适时机执行
function scheduleRender() {
// 若当前无待渲染任务才安排渲染
if (!pendingUpdate) {
pendingUpdate = true;
// 下一帧渲染
requestAnimationFrame(() => {
// 将答案作为markdown渲染进dom
renderMarkdownToElement(answerContent, fullAnswer);
// 清理pending标识
pendingUpdate = false;
// 渲染后滚动到底部
scrollToBottom();
});
}
}
try {
// 组装API接口地址
const url = `/api/v1/knowledgebases/chat`;
// 请求后端,发起流式POST请求
const response = await fetch(url, {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({
question: question,
+ session_id: currentSessionId,
stream: true
})
});
// 非200响应时,抛出异常
if (!response.ok) throw new Error('请求失败');
// 使用ReadableStream reader获取response流
const reader = response.body.getReader();
// 用TextDecoder解码二进制数据
const decoder = new TextDecoder();
// 数据缓冲区(字符串)
let buffer = '';
// 初始展示“思考中...”提示
answerContent.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
+ // 重新加载会话列表以更新标题
+ await loadSessions();
// 不断循环读取服务端推送流内容
while (true) {
// 逐步读取一段数据
const { done, value } = await reader.read();
// 若读完则结束循环
if (done) break;
// 本块新数据解码成字符串,并追加到buffer
buffer += decoder.decode(value, { stream: true });
// 按行切分(多行分别处理)
const lines = buffer.split('\n');
// 最后一行一般为数据残留,不处理,下次拼
buffer = lines.pop() || '';
// 处理每一行数据
for (const line of lines) {
// 筛选有效SSE协议"data: "数据包
if (line.startsWith('data: ')) {
// 去掉前缀,留下json内容
const data = line.slice(6);
// [DONE]信号作为流结束,仅跳过
if (data === '[DONE]') continue;
try {
// 解析JSON数据
const chunk = JSON.parse(data);
// 类型为流起始,清空内容
if (chunk.type === 'start') {
fullAnswer = '';
answerContent.innerHTML = '';
// 类型为内容,追加文本并安排去抖渲染
} else if (chunk.type === 'content') {
fullAnswer += chunk.content;
scheduleRender();
// 类型为done(流传输结束),直接最终渲染所有答案内容
} else if (chunk.type === 'done') {
renderMarkdownToElement(answerContent, fullAnswer);
// 错误类型,alert显示报错内容
} else if (chunk.type === 'error') {
answerContent.innerHTML = `<div class="alert alert-danger">${chunk.content}</div>`;
}
} catch (e) {
// JSON解析失败时报console
console.error('解析流数据失败:', e);
}
}
}
}
} catch (error) {
// 通信异常、Fetch错误等: 显示错误气泡
answerContent.innerHTML = `<div class="alert alert-danger"><strong>错误:</strong> ${error.message}</div>`;
}
// 结束后确保界面滚动到底
scrollToBottom();
}
+// 清空所有会话
+// 异步函数,用于清空所有会话
+async function clearAllSessions() {
+ // 弹窗确认操作,若取消则直接返回
+ if (!confirm('确定要清空所有会话吗?此操作不可恢复!')) return;
+ try {
+ // 发送DELETE请求到服务器,删除所有会话
+ const response = await fetch('/api/v1/knowledgebases/sessions', {
+ method: 'DELETE'
+ });
+ // 获取服务器返回的JSON结果
+ const result = await response.json();
+ // 如果接口返回成功
+ if (result.code === 200) {
+ // 当前会话ID置空
+ currentSessionId = null;
+ // 清空聊天消息内容
+ clearChatMessages();
+ // 重新加载会话列表
+ await loadSessions();
+ }
+ } catch (error) {
+ // 捕获异常并弹窗提示失败原因
+ alert('清空会话失败: ' + error.message);
+ }
+}
+ // 删除会话
+ // 异步函数,用于删除指定的会话
+ async function deleteSession(sessionId) {
+ // 弹窗确认操作,若取消则直接返回
+ if (!confirm('确定要删除这个会话吗?此操作不可恢复!')) return;
+ try {
+ // 发送DELETE请求到服务器,删除指定会话
+ const response = await fetch(`/api/v1/knowledgebases/sessions/${sessionId}`, {
+ method: 'DELETE'
+ });
+ // 获取服务器返回的JSON结果
+ const result = await response.json();
+ // 如果接口返回成功
+ if (result.code === 200) {
+ // 如果删除的是当前会话,清空当前会话ID和聊天消息
+ if (currentSessionId === sessionId) {
+ currentSessionId = null;
+ clearChatMessages();
+ }
+ // 重新加载会话列表
+ await loadSessions();
+ } else {
+ // 如果删除失败,显示错误信息
+ alert('删除会话失败: ' + (result.message || '未知错误'));
+ }
+ } catch (error) {
+ // 捕获异常并弹窗提示失败原因
+ alert('删除会话失败: ' + error.message);
+ }
+ }
+// 格式化时间
+// 用于格式化时间字符串为友好显示
+function formatTime(timeStr) {
+ // 如果无有效时间,直接返回空字符串
+ if (!timeStr) return '';
+ // 将字符串转换为Date对象
+ const date = new Date(timeStr);
+ // 获取当前时间
+ const now = new Date();
+ // 计算时间差(毫秒)
+ const diff = now - date;
+ // 换算为分钟数
+ const minutes = Math.floor(diff / 60000);
+ // 换算为小时数
+ const hours = Math.floor(diff / 3600000);
+ // 换算为天数
+ const days = Math.floor(diff / 86400000);
+ // 1分钟内显示“刚刚”
+ if (minutes < 1) return '刚刚';
+ // 1小时内显示“x分钟前”
+ if (minutes < 60) return `${minutes}分钟前`;
+ // 24小时内显示“x小时前”
+ if (hours < 24) return `${hours}小时前`;
+ // 7天内显示“x天前”
+ if (days < 7) return `${days}天前`;
+ // 超过7天显示具体日期
+ return date.toLocaleDateString();
+}
+// 渲染消息
+function renderMessages(messages) {
+ const chatMessages = document.getElementById('chatMessages');
+ if (messages.length === 0) {
+ chatMessages.innerHTML = `
+ <div class="empty-state">
+ <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
+ <p class="mt-2">开始提问吧!</p>
+ </div>
+ `;
+ return;
+ }
+ chatMessages.innerHTML = messages.map(msg => {
+ if (msg.role === 'user') {
+ return `
+ <div class="chat-message chat-question">
+ <div class="d-flex justify-content-between align-items-start">
+ <div class="flex-grow-1">
+ <strong><i class="bi bi-person-circle"></i> 问题:</strong>
+ <div class="mt-1">${escapeHtml(msg.content)}</div>
+ </div>
+ <small class="text-muted">${formatTime(msg.created_at)}</small>
+ </div>
+ </div>
+ `;
+ } else {
+ return `
+ <div class="chat-message chat-answer">
+ <div class="d-flex justify-content-between align-items-start">
+ <div class="flex-grow-1">
+ <strong><i class="bi bi-robot"></i> 答案:</strong>
+ <div class="mt-2 markdown-content">${renderMarkdown(msg.content)}</div>
+ </div>
+ <small class="text-muted">${formatTime(msg.created_at)}</small>
+ </div>
+ </div>
+ `;
+ }
+ }).join('');
+ scrollToBottom();
+}
+// 加载会话
+// 定义一个异步函数 loadSession,用于加载指定 sessionId 的会话
+async function loadSession(sessionId) {
+ // 异常捕获,防止请求或处理过程抛错
+ try {
+ // 发送 GET 请求,获取指定会话的详情数据
+ const response = await fetch(`/api/v1/knowledgebases/sessions/${sessionId}`);
+ // 将返回的响应数据解析为 JSON 格式
+ const result = await response.json();
+ // 如果接口请求成功,即 code 等于 200
+ if (result.code === 200) {
+ // 设置当前会话 ID 为传入的 sessionId
+ currentSessionId = sessionId;
+ // 获取会话详情数据
+ const session = result.data.session;
+ // 获取该会话包含的消息列表,如果没有则为 []
+ const messages = result.data.messages || [];
+ // 调用方法将消息渲染到页面
+ renderMessages(messages);
+ // 重新加载全部会话列表,刷新左侧会话栏
+ await loadSessions();
+ }
+ } catch (error) {
+ // 捕获异常并弹窗提示加载失败的原因
+ alert('加载会话失败: ' + error.message);
+ }
+}
+// 渲染会话列表
+// 将sessions数组渲染到界面左侧会话列表
+function renderSessions() {
+ // 获取会话列表的DOM节点
+ const sessionList = document.getElementById('sessionList');
+ // 如果没有任何会话,显示为空状态
+ if (sessions.length === 0) {
+ sessionList.innerHTML = `
+ <div class="text-center text-muted py-5">
+ <i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
+ <p class="mt-2 small">暂无会话</p>
+ </div>
+ `;
+ return;
+ }
+ // 遍历sessions并渲染为会话项
+ sessionList.innerHTML = sessions.map(session => `
+ <div class="session-item ${session.id === currentSessionId ? 'active' : ''}"
+ onclick="loadSession('${session.id}')">
+ <button class="btn btn-sm btn-link text-danger p-0 session-delete"
+ onclick="event.stopPropagation(); deleteSession('${session.id}')">
+ <i class="bi bi-x-lg"></i>
+ </button>
+ <div class="session-title">${session.title || '新对话'}</div>
+ <div class="session-time">${formatTime(session.updated_at)}</div>
+ </div>
+ `).join('');
+}
+// 清空聊天消息
+// 将聊天内容区重置为初始状态
+function clearChatMessages() {
+ document.getElementById('chatMessages').innerHTML = `
+ <div class="empty-state">
+ <i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
+ <p class="mt-2">开始提问吧!</p>
+ </div>
+ `;
+}
+// 加载会话列表
+// 异步加载会话列表并渲染到界面
+async function loadSessions() {
+ try {
+ // 发送GET请求获取会话列表
+ const response = await fetch('/api/v1/knowledgebases/sessions');
+ // 解析JSON结果
+ const result = await response.json();
+ // 如果接口返回成功
+ if (result.code === 200) {
+ // 保存会话数组
+ sessions = result.data.items || [];
+ // 渲染会话列表
+ renderSessions();
+ }
+ } catch (error) {
+ // 捕捉异常并在控制台打印
+ console.error('加载会话列表失败:', error);
+ }
+}
+// 创建新会话
+// 异步请求服务端创建新会话
+async function createNewSession() {
+ try {
+ // 发送POST请求创建新会话
+ const response = await fetch('/api/v1/knowledgebases/sessions', {
+ method: 'POST',
+ headers: {'Content-Type': 'application/json'},
+ body: JSON.stringify({})
+ });
+ // 解析返回结果
+ const result = await response.json();
+ // 如果创建成功
+ if (result.code === 200) {
+ // 切换到新会话
+ currentSessionId = result.data.id;
+ // 重新加载会话列表
+ await loadSessions();
+ // 清空聊天内容
+ clearChatMessages();
+ }
+ } catch (error) {
+ // 捕获异常并弹窗提示
+ alert('创建会话失败: ' + error.message);
+ }
+}
+// 监听 DOMContentLoaded 事件,确保页面元素加载完成后执行
+document.addEventListener('DOMContentLoaded', function() {
+ // 加载会话列表
+ loadSessions();
+ // 判断 marked.js 是否已经加载
+ if (typeof marked !== 'undefined') {
+ // 配置 marked.js 的渲染选项
+ marked.setOptions({
+ // 启用软换行
+ breaks: true,
+ // 启用 Github Flavored Markdown
+ gfm: true,
+ // 禁用标题自动生成 id
+ headerIds: false,
+ // 禁用混淆处理
+ mangle: false
+ });
+ }
+});
</script>
{% endblock %}
6.选择知识库 #
6.1. chat.py #
app/blueprints/chat.py
# 聊天相关路由(视图 + API)
"""
聊天相关路由(视图 + API)
"""
# 导入 Flask 的 Blueprint 和模板渲染函数
from flask import Blueprint, render_template, request, stream_with_context, Response
import json
# 导入日志模块
import logging
# 导入登录保护装饰器和获取当前用户辅助方法
+from app.utils.auth import login_required, api_login_required,get_current_user
# 导入知识库服务
+from app.services.knowledgebase_service import kb_service
# 导入自定义工具函数:成功响应、错误响应、获取分页参数、获取当前用户或错误、异常处理装饰器、检查所有权
from app.blueprints.utils import (
success_response, error_response,
get_current_user_or_error, handle_api_error, get_pagination_params
)
from app.services.chat_service import chat_service
from app.services.chat_session_service import session_service
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 创建名为 'chat' 的蓝图对象
bp = Blueprint('chat', __name__)
# 注册 /chat 路由,访问该路由需要先登录
@bp.route('/chat')
@login_required
def chat_view():
# 智能问答页面视图函数
"""智能问答页面"""
+ current_user = get_current_user()
# 获取所有知识库(通常用户不会有太多知识库,不需要分页)
+ result = kb_service.list(user_id=current_user['id'], page=1, page_size=1000)
# 渲染 chat.html 模板并传递空知识库列表
+ return render_template('chat.html', knowledgebases=result['items'])
# 注册 API 路由,处理聊天接口 POST 请求
@bp.route('/api/v1/knowledgebases/chat', methods=['POST'])
@api_login_required
@handle_api_error
def api_chat():
# 普通聊天接口(不支持知识库,支持流式输出)
"""普通聊天接口(不支持知识库,支持流式输出)"""
# 获取当前用户和错误信息
current_user, err = get_current_user_or_error()
# 如果有错误,直接返回错误响应
if err:
return err
# 从请求体获取 JSON 数据
data = request.get_json()
# 如果数据为空或不存在 'question' 字段,返回错误
if not data or 'question' not in data:
return error_response("question is required", 400)
# 去除问题文本首尾空格
question = data['question'].strip()
# 如果问题内容为空,返回错误
if not question:
return error_response("question cannot be empty", 400)
session_id = data.get('session_id') # 会话ID(可以为空,表示普通聊天)
# 获取 max_tokens 参数,默认 1000
max_tokens = int(data.get('max_tokens', 1000))
# 限制最大和最小值在 1~10000 之间
max_tokens = max(1, min(max_tokens, 10000)) # 限制在 1-10000 之间
# 从请求数据中获取'stream'字段,默认为True,表示启用流式输出
stream = data.get('stream', True) # 默认启用流式输出
# 初始化历史消息为None
history = None
# 如果请求中带有session_id,说明有现有会话
if session_id:
# 根据session_id和当前用户ID获取历史消息列表
history_messages = session_service.get_messages(session_id, current_user['id'])
# 将历史消息转换为对话格式,仅保留最近10条
history = [
{'role': msg.get('role'), 'content': msg.get('content')}
for msg in history_messages[-10:] # 只取最近10条
]
# 如果请求中没有session_id,说明是新对话,需要新建会话
if not session_id:
# 创建新会话,kb_id设为None表示普通聊天
chat_session = session_service.create_session(
user_id=current_user['id']
)
# 使用新创建会话的ID作为本次会话ID
session_id = chat_session['id']
# 将用户的问题消息保存到当前会话中
session_service.add_message(session_id, 'user', question)
# 声明用于流式输出的生成器
@stream_with_context
def generate():
try:
# 用于缓存完整答案内容
full_answer = ''
# 调用服务进行流式对话
for chunk in chat_service.chat_stream(
question=question,
temperature=None, # 使用设置中的值
max_tokens=max_tokens,
history=history
):
# 如果是内容块,则拼接内容到 full_answer
if chunk.get('type') == 'content':
full_answer += chunk.get('content', '')
# 以 SSE 协议格式输出数据
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
# 输出对话完成信号
yield "data: [DONE]\n\n"
# 保存助手回复
if full_answer:
session_service.add_message(session_id, 'assistant', full_answer)
except Exception as e:
# 发生异常记录日志
logger.error(f"流式输出时出错: {e}")
# 构造错误数据块
error_chunk = {
"type": "error",
"content": str(e)
}
# 输出错误数据块
yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"
# 创建 Response 对象,设置必要的 SSE 响应头部
response = Response(
generate(),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Accel-Buffering': 'no',
'Content-Type': 'text/event-stream; charset=utf-8'
}
)
# 返回响应
return response
# 路由装饰器,定义 GET 方法获取会话列表的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['GET'])
# API 登录校验装饰器,确保用户已登录
@api_login_required
# 错误处理装饰器,统一处理接口异常
@handle_api_error
def api_list_sessions():
# 接口描述:获取当前用户的会话列表
"""获取当前用户的会话列表"""
# 获取当前用户,如有错误直接返回错误响应
current_user, err = get_current_user_or_error()
if err:
return err
# 获取分页参数(页码和每页数量),最大单页1000
page, page_size = get_pagination_params(max_page_size=1000)
# 调用会话服务获取当前用户的会话列表
result = session_service.list_sessions(current_user['id'], page=page, page_size=page_size)
# 以统一成功响应格式返回会话列表
return success_response(result)
# 路由装饰器,定义 POST 方法创建会话的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['POST'])
@api_login_required
@handle_api_error
def api_create_session():
# 接口描述:创建新的聊天会话
"""创建新的聊天会话"""
# 获取当前用户,如果有错误直接返回
current_user, err = get_current_user_or_error()
if err:
return err
# 获取请求体中的 JSON 数据,若无返回空字典
data = request.get_json() or {}
# 获取会话标题
title = data.get('title')
# 调用服务创建会话,传入当前用户ID、知识库ID与标题
session_obj = session_service.create_session(
user_id=current_user['id'],
title=title
)
# 返回成功响应及会话对象
return success_response(session_obj)
# 路由装饰器,定义 GET 方法获取单个会话详情的接口(带 session_id)
@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['GET'])
@api_login_required
@handle_api_error
def api_get_session(session_id):
# 接口描述:获取会话详情和消息
"""获取会话详情和消息"""
# 获取当前用户,如有错误直接返回
current_user, err = get_current_user_or_error()
if err:
return err
# 根据 session_id 获取会话对象,校验所属当前用户
session_obj = session_service.get_session_by_id(session_id, current_user['id'])
# 如果没有找到会话,返回 404 错误
if not session_obj:
return error_response("Session not found", 404)
# 获取该会话下的所有消息
messages = session_service.get_messages(session_id, current_user['id'])
# 返回会话详情及消息列表
return success_response({
'session': session_obj,
'messages': messages
})
# 路由装饰器,定义 DELETE 方法删除单个会话接口
@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['DELETE'])
@api_login_required
@handle_api_error
def api_delete_session(session_id):
# 接口描述:删除会话
"""删除会话"""
# 获取当前用户,如有错误直接返回
current_user, err = get_current_user_or_error()
if err:
return err
# 调用服务删除会话,校验归属当前用户
success = session_service.delete_session(session_id, current_user['id'])
# 若删除成功,返回成功响应,否则返回 404
if success:
return success_response(None, "Session deleted")
else:
return error_response("Session not found", 404)
# 路由装饰器,定义 DELETE 方法清空所有会话的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['DELETE'])
@api_login_required
@handle_api_error
def api_delete_all_sessions():
# 接口描述:清空所有会话
"""清空所有会话"""
# 获取当前用户,如果有错误直接返回
current_user, err = get_current_user_or_error()
if err:
return err
# 调用服务删除所有属于当前用户的会话,返回删除数量
count = session_service.delete_all_sessions(current_user['id'])
# 返回成功响应及被删除会话数
return success_response({'deleted_count': count}, f"Deleted {count} sessions")
6.2. chat.html #
app/templates/chat.html
{% extends "base.html" %}
{% block title %}智能问答 - RAG Lite{% endblock %}
{% block extra_css %}
<style>
.chat-container {
height: calc(100vh - 200px);
display: flex;
gap: 1rem;
}
.chat-sidebar {
width: 280px;
background: #fff;
border-radius: 0.5rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
display: flex;
flex-direction: column;
}
.chat-main {
flex: 1;
background: #fff;
border-radius: 0.5rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
display: flex;
flex-direction: column;
}
.session-list {
flex: 1;
overflow-y: auto;
padding: 0.5rem;
}
.session-item {
padding: 0.75rem;
margin-bottom: 0.5rem;
border-radius: 0.5rem;
cursor: pointer;
transition: background-color 0.2s;
position: relative;
}
.session-item:hover {
background-color: #f8f9fa;
}
.session-item.active {
background-color: #e3f2fd;
border-left: 3px solid #0d6efd;
}
.session-item .session-title {
font-weight: 500;
margin-bottom: 0.25rem;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.session-item .session-time {
font-size: 0.75rem;
color: #6c757d;
}
.session-item .session-delete {
position: absolute;
top: 0.5rem;
right: 0.5rem;
opacity: 0;
transition: opacity 0.2s;
}
.session-item:hover .session-delete {
opacity: 1;
}
.chat-messages {
flex: 1;
overflow-y: auto;
padding: 1rem;
scroll-behavior: smooth;
}
.chat-message {
padding: 1rem;
margin-bottom: 1rem;
border-radius: 0.5rem;
}
.chat-question {
background-color: #e3f2fd;
}
.chat-answer {
background-color: #f5f5f5;
}
.chat-input-area {
padding: 1rem;
border-top: 1px solid #dee2e6;
}
.empty-state {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 100%;
color: #6c757d;
}
</style>
{% endblock %}
{% block content %}
<div class="chat-container">
<!-- 左侧:会话管理 -->
<div class="chat-sidebar">
<div class="p-3 border-bottom">
<div class="d-flex justify-content-between align-items-center mb-3">
<h6 class="mb-0"><i class="bi bi-chat-left-text"></i> 聊天会话</h6>
<button class="btn btn-sm btn-primary" onclick="createNewSession()">
<i class="bi bi-plus"></i> 新建
</button>
</div>
<button class="btn btn-sm btn-outline-danger w-100" onclick="clearAllSessions()">
<i class="bi bi-trash"></i> 清空所有
</button>
</div>
<div class="session-list" id="sessionList">
<div class="text-center text-muted py-5">
<i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
<p class="mt-2 small">暂无会话</p>
</div>
</div>
</div>
<!-- 右侧:对话页面 -->
<div class="chat-main">
<div class="p-3 border-bottom">
<div class="d-flex justify-content-between align-items-center">
<h6 class="mb-0"><i class="bi bi-chat-dots"></i> 对话</h6>
+ <select class="form-select form-select-sm" id="kbSelect" style="width: 200px;" onchange="onKbChange()">
+ <option value="">-- 选择知识库 --</option>
+ {% for kb in knowledgebases %}
+ <option value="{{ kb.id }}">{{ kb.name }}</option>
+ {% endfor %}
+ </select>
</div>
</div>
<div class="chat-messages" id="chatMessages">
<div class="empty-state">
<i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
<p class="mt-2">开始提问吧!</p>
</div>
</div>
<div class="chat-input-area">
<form id="chatForm" onsubmit="askQuestion(event)">
<div class="mb-2">
<textarea class="form-control" id="questionInput" rows="2"
placeholder="输入您的问题..." required></textarea>
</div>
<div class="d-flex justify-content-end">
<button type="submit" class="btn btn-primary" id="submitBtn">
<i class="bi bi-send"></i> 发送
</button>
</div>
</form>
</div>
</div>
</div>
{% endblock %}
{% block extra_js %}
<script>
// 当前会话的ID,初始为null,表示暂未选择任何会话
let currentSessionId = null;
// 会话列表,初始为空数组,用于存储所有的会话对象
let sessions = [];
+// 当前知识库的ID,初始为null,表示暂未选择任何知识库
+let currentKbId = null;
// 为字符串进行HTML转义,防止XSS攻击
function escapeHtml(text) {
// 创建一个div元素作为容器
const div = document.createElement('div');
// 将待转义文本设置为div的textContent(自动完成转义)
div.textContent = text;
// 返回转义后的HTML内容
return div.innerHTML;
}
// 滚动消息框到底部
function scrollToBottom() {
// 获取聊天消息区域元素
const chatMessages = document.getElementById('chatMessages');
// 设置滚动条位置到最底部
chatMessages.scrollTop = chatMessages.scrollHeight;
}
// 定义一个用于渲染Markdown文本为HTML的函数
function renderMarkdown(text) {
// 判断marked库是否已加载且能使用解析方法
if (typeof marked !== 'undefined' && marked.parse) {
try {
// 尝试用marked库将Markdown文本解析成HTML
return marked.parse(text);
} catch (e) {
// 如果解析出错,则进行HTML转义并换行
return escapeHtml(text).replace(/\n/g, '<br>');
}
}
// 如果marked库不可用,直接做HTML转义并处理换行
return escapeHtml(text).replace(/\n/g, '<br>');
}
// 将Markdown内容渲染到指定元素(支持降级为普通文本)
function renderMarkdownToElement(element, text) {
// 若未提供内容,显示思考中图标
if (!text) {
element.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
return;
}
// 优先判断marked库是否可用(渲染markdown)
if (typeof marked !== 'undefined' && marked.parse) {
try {
// 使用marked进行markdown转html
element.innerHTML = marked.parse(text);
} catch (e) {
// 渲染失败则退化为转义+换行
element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
}
} else {
// 没有marked库则直接转义+换行显示
element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
}
}
// 主函数:处理用户提交问题事件
async function askQuestion(event) {
// 阻止表单默认提交行为(防止页面刷新)
event.preventDefault();
// 获取输入框的用户问题并去除首尾空白
const question = document.getElementById('questionInput').value.trim();
// 获取消息显示区域元素
const chatMessages = document.getElementById('chatMessages');
// 若问题为空则直接返回
if (!question) return;
// 如果没有会话,创建新会话
if (!currentSessionId) {
await createNewSession();
}
// 检查并移除初始空白提示(如有)
if (chatMessages.querySelector('.empty-state')) {
chatMessages.innerHTML = '';
}
// 创建用于展示问题的div元素
const questionDiv = document.createElement('div');
// 加上样式: 用户问题
questionDiv.className = 'chat-message chat-question';
// 构建用户气泡内容含图标、文本和时间
questionDiv.innerHTML = `
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-person-circle"></i> 问题:</strong>
<div class="mt-1">${escapeHtml(question)}</div>
</div>
<small class="text-muted">${new Date().toLocaleTimeString()}</small>
</div>
`;
// 显示到对话窗口
chatMessages.appendChild(questionDiv);
// 创建用于显示答案的div元素
const answerDiv = document.createElement('div');
// 答案样式
answerDiv.className = 'chat-message chat-answer';
// 动态生成唯一的答案内容div id(用于唯一标记)
const answerContentId = 'answerContent_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
// 创建存放答案内容的div
const answerContent = document.createElement('div');
// 为markdown渲染容器设置样式和id
answerContent.className = 'mt-2 markdown-content';
answerContent.id = answerContentId;
// 答案div显示机器人图标和时间
answerDiv.innerHTML = `
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-robot"></i> 答案:</strong>
</div>
<small class="text-muted">${new Date().toLocaleTimeString()}</small>
</div>
`;
// 将答案内容div插入到机器人气泡div的内容区
const flexGrowDiv = answerDiv.querySelector('.flex-grow-1');
flexGrowDiv.appendChild(answerContent);
// 插入答案div到消息区域
chatMessages.appendChild(answerDiv);
// 变量记录完整的回答内容
let fullAnswer = '';
// 标记渲染任务是否挂起(防止重复)
let pendingUpdate = false;
// 记录定时器id(去抖动用)
let updateTimer = null;
// 清空输入框内容
document.getElementById('questionInput').value = '';
// 滚动到底
scrollToBottom();
// 定义scheduleRender:将markdown渲染插入到队列合适时机执行
function scheduleRender() {
// 若当前无待渲染任务才安排渲染
if (!pendingUpdate) {
pendingUpdate = true;
// 下一帧渲染
requestAnimationFrame(() => {
// 将答案作为markdown渲染进dom
renderMarkdownToElement(answerContent, fullAnswer);
// 清理pending标识
pendingUpdate = false;
// 渲染后滚动到底部
scrollToBottom();
});
}
}
try {
// 组装API接口地址
const url = `/api/v1/knowledgebases/chat`;
// 请求后端,发起流式POST请求
const response = await fetch(url, {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({
question: question,
session_id: currentSessionId,
stream: true
})
});
// 非200响应时,抛出异常
if (!response.ok) throw new Error('请求失败');
// 使用ReadableStream reader获取response流
const reader = response.body.getReader();
// 用TextDecoder解码二进制数据
const decoder = new TextDecoder();
// 数据缓冲区(字符串)
let buffer = '';
// 初始展示“思考中...”提示
answerContent.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
// 重新加载会话列表以更新标题
await loadSessions();
// 不断循环读取服务端推送流内容
while (true) {
// 逐步读取一段数据
const { done, value } = await reader.read();
// 若读完则结束循环
if (done) break;
// 本块新数据解码成字符串,并追加到buffer
buffer += decoder.decode(value, { stream: true });
// 按行切分(多行分别处理)
const lines = buffer.split('\n');
// 最后一行一般为数据残留,不处理,下次拼
buffer = lines.pop() || '';
// 处理每一行数据
for (const line of lines) {
// 筛选有效SSE协议"data: "数据包
if (line.startsWith('data: ')) {
// 去掉前缀,留下json内容
const data = line.slice(6);
// [DONE]信号作为流结束,仅跳过
if (data === '[DONE]') continue;
try {
// 解析JSON数据
const chunk = JSON.parse(data);
// 类型为流起始,清空内容
if (chunk.type === 'start') {
fullAnswer = '';
answerContent.innerHTML = '';
// 类型为内容,追加文本并安排去抖渲染
} else if (chunk.type === 'content') {
fullAnswer += chunk.content;
scheduleRender();
// 类型为done(流传输结束),直接最终渲染所有答案内容
} else if (chunk.type === 'done') {
renderMarkdownToElement(answerContent, fullAnswer);
// 错误类型,alert显示报错内容
} else if (chunk.type === 'error') {
answerContent.innerHTML = `<div class="alert alert-danger">${chunk.content}</div>`;
}
} catch (e) {
// JSON解析失败时报console
console.error('解析流数据失败:', e);
}
}
}
}
} catch (error) {
// 通信异常、Fetch错误等: 显示错误气泡
answerContent.innerHTML = `<div class="alert alert-danger"><strong>错误:</strong> ${error.message}</div>`;
}
// 结束后确保界面滚动到底
scrollToBottom();
}
// 清空所有会话
// 异步函数,用于清空所有会话
async function clearAllSessions() {
// 弹窗确认操作,若取消则直接返回
if (!confirm('确定要清空所有会话吗?此操作不可恢复!')) return;
try {
// 发送DELETE请求到服务器,删除所有会话
const response = await fetch('/api/v1/knowledgebases/sessions', {
method: 'DELETE'
});
// 获取服务器返回的JSON结果
const result = await response.json();
// 如果接口返回成功
if (result.code === 200) {
// 当前会话ID置空
currentSessionId = null;
// 清空聊天消息内容
clearChatMessages();
// 重新加载会话列表
await loadSessions();
}
} catch (error) {
// 捕获异常并弹窗提示失败原因
alert('清空会话失败: ' + error.message);
}
}
// 格式化时间
// 用于格式化时间字符串为友好显示
function formatTime(timeStr) {
// 如果无有效时间,直接返回空字符串
if (!timeStr) return '';
// 将字符串转换为Date对象
const date = new Date(timeStr);
// 获取当前时间
const now = new Date();
// 计算时间差(毫秒)
const diff = now - date;
// 换算为分钟数
const minutes = Math.floor(diff / 60000);
// 换算为小时数
const hours = Math.floor(diff / 3600000);
// 换算为天数
const days = Math.floor(diff / 86400000);
// 1分钟内显示“刚刚”
if (minutes < 1) return '刚刚';
// 1小时内显示“x分钟前”
if (minutes < 60) return `${minutes}分钟前`;
// 24小时内显示“x小时前”
if (hours < 24) return `${hours}小时前`;
// 7天内显示“x天前”
if (days < 7) return `${days}天前`;
// 超过7天显示具体日期
return date.toLocaleDateString();
}
// 渲染消息
function renderMessages(messages) {
const chatMessages = document.getElementById('chatMessages');
if (messages.length === 0) {
chatMessages.innerHTML = `
<div class="empty-state">
<i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
<p class="mt-2">开始提问吧!</p>
</div>
`;
return;
}
chatMessages.innerHTML = messages.map(msg => {
if (msg.role === 'user') {
return `
<div class="chat-message chat-question">
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-person-circle"></i> 问题:</strong>
<div class="mt-1">${escapeHtml(msg.content)}</div>
</div>
<small class="text-muted">${formatTime(msg.created_at)}</small>
</div>
</div>
`;
} else {
return `
<div class="chat-message chat-answer">
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-robot"></i> 答案:</strong>
<div class="mt-2 markdown-content">${renderMarkdown(msg.content)}</div>
</div>
<small class="text-muted">${formatTime(msg.created_at)}</small>
</div>
</div>
`;
}
}).join('');
scrollToBottom();
}
// 加载会话
// 定义一个异步函数 loadSession,用于加载指定 sessionId 的会话
async function loadSession(sessionId) {
// 异常捕获,防止请求或处理过程抛错
try {
// 发送 GET 请求,获取指定会话的详情数据
const response = await fetch(`/api/v1/knowledgebases/sessions/${sessionId}`);
// 将返回的响应数据解析为 JSON 格式
const result = await response.json();
// 如果接口请求成功,即 code 等于 200
if (result.code === 200) {
// 设置当前会话 ID 为传入的 sessionId
currentSessionId = sessionId;
// 获取会话详情数据
const session = result.data.session;
// 获取该会话包含的消息列表,如果没有则为 []
const messages = result.data.messages || [];
// 调用方法将消息渲染到页面
renderMessages(messages);
// 重新加载全部会话列表,刷新左侧会话栏
await loadSessions();
}
} catch (error) {
// 捕获异常并弹窗提示加载失败的原因
alert('加载会话失败: ' + error.message);
}
}
// 渲染会话列表
// 将sessions数组渲染到界面左侧会话列表
function renderSessions() {
// 获取会话列表的DOM节点
const sessionList = document.getElementById('sessionList');
// 如果没有任何会话,显示为空状态
if (sessions.length === 0) {
sessionList.innerHTML = `
<div class="text-center text-muted py-5">
<i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
<p class="mt-2 small">暂无会话</p>
</div>
`;
return;
}
// 遍历sessions并渲染为会话项
sessionList.innerHTML = sessions.map(session => `
<div class="session-item ${session.id === currentSessionId ? 'active' : ''}"
onclick="loadSession('${session.id}')">
<button class="btn btn-sm btn-link text-danger p-0 session-delete"
onclick="event.stopPropagation(); deleteSession('${session.id}')">
<i class="bi bi-x-lg"></i>
</button>
<div class="session-title">${session.title || '新对话'}</div>
<div class="session-time">${formatTime(session.updated_at)}</div>
</div>
`).join('');
}
// 清空聊天消息
// 将聊天内容区重置为初始状态
function clearChatMessages() {
document.getElementById('chatMessages').innerHTML = `
<div class="empty-state">
<i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
<p class="mt-2">开始提问吧!</p>
</div>
`;
}
// 加载会话列表
// 异步加载会话列表并渲染到界面
async function loadSessions() {
try {
// 发送GET请求获取会话列表
const response = await fetch('/api/v1/knowledgebases/sessions');
// 解析JSON结果
const result = await response.json();
// 如果接口返回成功
if (result.code === 200) {
// 保存会话数组
sessions = result.data.items || [];
// 渲染会话列表
renderSessions();
}
} catch (error) {
// 捕捉异常并在控制台打印
console.error('加载会话列表失败:', error);
}
}
// 创建新会话
// 异步请求服务端创建新会话
async function createNewSession() {
try {
// 发送POST请求创建新会话
const response = await fetch('/api/v1/knowledgebases/sessions', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({})
});
// 解析返回结果
const result = await response.json();
// 如果创建成功
if (result.code === 200) {
// 切换到新会话
currentSessionId = result.data.id;
// 重新加载会话列表
await loadSessions();
// 清空聊天内容
clearChatMessages();
}
} catch (error) {
// 捕获异常并弹窗提示
alert('创建会话失败: ' + error.message);
}
}
// 监听 DOMContentLoaded 事件,确保页面元素加载完成后执行
document.addEventListener('DOMContentLoaded', function() {
// 加载会话列表
loadSessions();
// 判断 marked.js 是否已经加载
if (typeof marked !== 'undefined') {
// 配置 marked.js 的渲染选项
marked.setOptions({
// 启用软换行
breaks: true,
// 启用 Github Flavored Markdown
gfm: true,
// 禁用标题自动生成 id
headerIds: false,
// 禁用混淆处理
mangle: false
});
}
});
+// 知识库选择变化处理函数
+async function onKbChange() {
+ // 获取当前下拉选择的知识库ID
+ currentKbId = document.getElementById('kbSelect').value;;
+ // 如果已存在会话,则切换知识库时重新新建一个会话
+ if (currentSessionId) {
+ await createNewSession();
+ }
+ // 启用发送按钮(普通聊天和知识库聊天均可提问)
+ document.getElementById('submitBtn').disabled = false;
+}
</script>
{% endblock %}
7.知识库对话 #
7.1. rag_service.py #
app/services/rag_service.py
"""
RAG 服务
"""
# 导入日志模块
import logging
# 导入 LangChain 的对话提示模板模块
from langchain_core.prompts import ChatPromptTemplate
# 导入自定义 LLM 工厂
from app.utils.llm_factory import LLMFactory
# 导入设置服务
from app.services.settings_service import settings_service
# 设置日志对象
logger = logging.getLogger(__name__)
# 定义 RAGService 类
class RAGService:
"""RAG 服务"""
# 初始化函数
def __init__(self):
"""
初始化服务
Args:
settings: 设置字典,如果为 None 则从数据库读取
"""
# 从设置服务中获取配置信息
self.settings = settings_service.get()
# 定义默认系统消息提示词
default_rag_system_prompt = "你是一个专业的AI助手。请基于文档内容回答问题。"
# 定义默认查询提示词,包含 context 和 question 占位符
default_rag_query_prompt = """文档内容:
{context}
问题:{question}
请基于文档内容回答问题。如果文档中没有相关信息,请明确说明。"""
# 从设置中获取自定义系统消息提示词
rag_system_prompt_text = self.settings.get('rag_system_prompt')
# 如果没有设置,使用默认系统提示词
if not rag_system_prompt_text:
rag_system_prompt_text = default_rag_system_prompt
# 从设置中获取自定义查询提示词
rag_query_prompt_text = self.settings.get('rag_query_prompt')
# 如果没有设置,使用默认查询提示词
if not rag_query_prompt_text:
rag_query_prompt_text = default_rag_query_prompt
# 构建 RAG 的提示模板,包含系统消息和用户查询部分
self.rag_prompt = ChatPromptTemplate.from_messages([
("system", rag_system_prompt_text),
("human", rag_query_prompt_text)
])
# 定义流式问答接口
def ask_stream(self, kb_id: str, question: str):
"""
流式问答接口
Args:
kb_id: 知识库ID
question: 问题
Yields:
流式数据块
"""
# 创建带流式输出能力的 LLM 实例
llm = LLMFactory.create_llm(self.settings)
# 文档过滤后的结果,暂时为空列表
filtered_docs = []
# 发送流式开始信号
yield {
"type": "start",
"content": ""
}
# 构造用于传递给 LLM 的上下文字符串,将所有文档整合为字符串
context = "\n\n".join([
f"文档 {i+1} ({doc.metadata.get('doc_name', '未知')}):\n{doc.page_content}"
for i, doc in enumerate(filtered_docs)
])
# 创建 Rag Prompt 到 LLM 的处理链
chain = self.rag_prompt | llm
# 初始化完整答案的字符串
full_answer = ""
# 逐块流式生成答案
for chunk in chain.stream({"context": context, "question": question}):
# 获取当前输出块内容
content = chunk.content
# 如果有内容则累加并 yield 输出内容块
if content:
full_answer += content
yield {
"type": "content",
"content": content
}
# 所有内容输出结束后,发送完成信号和相关元数据
yield {
"type": "done",
"content": "",
"metadata": {
'kb_id': kb_id,
'question': question,
'retrieved_chunks': len(filtered_docs)
}
}
# 实例化 rag_service,供外部调用
rag_service = RAGService() 7.2. chat.py #
app/blueprints/chat.py
# 聊天相关路由(视图 + API)
"""
聊天相关路由(视图 + API)
"""
# 导入 Flask 的 Blueprint 和模板渲染函数
from flask import Blueprint, render_template, request, stream_with_context, Response
import json
# 导入日志模块
import logging
# 导入登录保护装饰器和获取当前用户辅助方法
from app.utils.auth import login_required, api_login_required,get_current_user
# 导入知识库服务
from app.services.knowledgebase_service import kb_service
# 导入自定义工具函数:成功响应、错误响应、获取分页参数、获取当前用户或错误、异常处理装饰器、检查所有权
from app.blueprints.utils import (
success_response, error_response,
+ get_current_user_or_error, handle_api_error, get_pagination_params,check_ownership
)
from app.services.chat_service import chat_service
from app.services.chat_session_service import session_service
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 创建名为 'chat' 的蓝图对象
bp = Blueprint('chat', __name__)
# 注册 /chat 路由,访问该路由需要先登录
@bp.route('/chat')
@login_required
def chat_view():
# 智能问答页面视图函数
"""智能问答页面"""
current_user = get_current_user()
# 获取所有知识库(通常用户不会有太多知识库,不需要分页)
result = kb_service.list(user_id=current_user['id'], page=1, page_size=1000)
# 渲染 chat.html 模板并传递空知识库列表
return render_template('chat.html', knowledgebases=result['items'])
# 注册 API 路由,处理聊天接口 POST 请求
@bp.route('/api/v1/knowledgebases/chat', methods=['POST'])
@api_login_required
@handle_api_error
def api_chat():
# 普通聊天接口(不支持知识库,支持流式输出)
"""普通聊天接口(不支持知识库,支持流式输出)"""
# 获取当前用户和错误信息
current_user, err = get_current_user_or_error()
# 如果有错误,直接返回错误响应
if err:
return err
# 从请求体获取 JSON 数据
data = request.get_json()
# 如果数据为空或不存在 'question' 字段,返回错误
if not data or 'question' not in data:
return error_response("question is required", 400)
# 去除问题文本首尾空格
question = data['question'].strip()
# 如果问题内容为空,返回错误
if not question:
return error_response("question cannot be empty", 400)
session_id = data.get('session_id') # 会话ID(可以为空,表示普通聊天)
# 获取 max_tokens 参数,默认 1000
max_tokens = int(data.get('max_tokens', 1000))
# 限制最大和最小值在 1~10000 之间
max_tokens = max(1, min(max_tokens, 10000)) # 限制在 1-10000 之间
# 从请求数据中获取'stream'字段,默认为True,表示启用流式输出
stream = data.get('stream', True) # 默认启用流式输出
# 初始化历史消息为None
history = None
# 如果请求中带有session_id,说明有现有会话
if session_id:
# 根据session_id和当前用户ID获取历史消息列表
history_messages = session_service.get_messages(session_id, current_user['id'])
# 将历史消息转换为对话格式,仅保留最近10条
history = [
{'role': msg.get('role'), 'content': msg.get('content')}
for msg in history_messages[-10:] # 只取最近10条
]
# 如果请求中没有session_id,说明是新对话,需要新建会话
if not session_id:
# 创建新会话,kb_id设为None表示普通聊天
chat_session = session_service.create_session(
user_id=current_user['id']
)
# 使用新创建会话的ID作为本次会话ID
session_id = chat_session['id']
# 将用户的问题消息保存到当前会话中
session_service.add_message(session_id, 'user', question)
# 声明用于流式输出的生成器
@stream_with_context
def generate():
try:
# 用于缓存完整答案内容
full_answer = ''
# 调用服务进行流式对话
for chunk in chat_service.chat_stream(
question=question,
temperature=None, # 使用设置中的值
max_tokens=max_tokens,
history=history
):
# 如果是内容块,则拼接内容到 full_answer
if chunk.get('type') == 'content':
full_answer += chunk.get('content', '')
# 以 SSE 协议格式输出数据
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
# 输出对话完成信号
yield "data: [DONE]\n\n"
# 保存助手回复
if full_answer:
session_service.add_message(session_id, 'assistant', full_answer)
except Exception as e:
# 发生异常记录日志
logger.error(f"流式输出时出错: {e}")
# 构造错误数据块
error_chunk = {
"type": "error",
"content": str(e)
}
# 输出错误数据块
yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"
# 创建 Response 对象,设置必要的 SSE 响应头部
response = Response(
generate(),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Accel-Buffering': 'no',
'Content-Type': 'text/event-stream; charset=utf-8'
}
)
# 返回响应
return response
# 路由装饰器,定义 GET 方法获取会话列表的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['GET'])
# API 登录校验装饰器,确保用户已登录
@api_login_required
# 错误处理装饰器,统一处理接口异常
@handle_api_error
def api_list_sessions():
# 接口描述:获取当前用户的会话列表
"""获取当前用户的会话列表"""
# 获取当前用户,如有错误直接返回错误响应
current_user, err = get_current_user_or_error()
if err:
return err
# 获取分页参数(页码和每页数量),最大单页1000
page, page_size = get_pagination_params(max_page_size=1000)
# 调用会话服务获取当前用户的会话列表
result = session_service.list_sessions(current_user['id'], page=page, page_size=page_size)
# 以统一成功响应格式返回会话列表
return success_response(result)
# 路由装饰器,定义 POST 方法创建会话的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['POST'])
@api_login_required
@handle_api_error
def api_create_session():
# 接口描述:创建新的聊天会话
"""创建新的聊天会话"""
# 获取当前用户,如果有错误直接返回
current_user, err = get_current_user_or_error()
if err:
return err
# 获取请求体中的 JSON 数据,若无返回空字典
data = request.get_json() or {}
# 获取会话标题
title = data.get('title')
# 调用服务创建会话,传入当前用户ID、知识库ID与标题
session_obj = session_service.create_session(
user_id=current_user['id'],
title=title
)
# 返回成功响应及会话对象
return success_response(session_obj)
# 路由装饰器,定义 GET 方法获取单个会话详情的接口(带 session_id)
@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['GET'])
@api_login_required
@handle_api_error
def api_get_session(session_id):
# 接口描述:获取会话详情和消息
"""获取会话详情和消息"""
# 获取当前用户,如有错误直接返回
current_user, err = get_current_user_or_error()
if err:
return err
# 根据 session_id 获取会话对象,校验所属当前用户
session_obj = session_service.get_session_by_id(session_id, current_user['id'])
# 如果没有找到会话,返回 404 错误
if not session_obj:
return error_response("Session not found", 404)
# 获取该会话下的所有消息
messages = session_service.get_messages(session_id, current_user['id'])
# 返回会话详情及消息列表
return success_response({
'session': session_obj,
'messages': messages
})
# 路由装饰器,定义 DELETE 方法删除单个会话接口
@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['DELETE'])
@api_login_required
@handle_api_error
def api_delete_session(session_id):
# 接口描述:删除会话
"""删除会话"""
# 获取当前用户,如有错误直接返回
current_user, err = get_current_user_or_error()
if err:
return err
# 调用服务删除会话,校验归属当前用户
success = session_service.delete_session(session_id, current_user['id'])
# 若删除成功,返回成功响应,否则返回 404
if success:
return success_response(None, "Session deleted")
else:
return error_response("Session not found", 404)
# 路由装饰器,定义 DELETE 方法清空所有会话的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['DELETE'])
@api_login_required
@handle_api_error
def api_delete_all_sessions():
# 接口描述:清空所有会话
"""清空所有会话"""
# 获取当前用户,如果有错误直接返回
current_user, err = get_current_user_or_error()
if err:
return err
# 调用服务删除所有属于当前用户的会话,返回删除数量
count = session_service.delete_all_sessions(current_user['id'])
# 返回成功响应及被删除会话数
return success_response({'deleted_count': count}, f"Deleted {count} sessions")
# 路由装饰器,指定POST方法用于知识库问答接口
+@bp.route('/api/v1/knowledgebases/<kb_id>/chat', methods=['POST'])
# 装饰器:需要API登录
+@api_login_required
# 装饰器:统一处理API错误
+@handle_api_error
+def api_ask(kb_id):
# 知识库问答接口(支持流式输出)
+ """知识库问答接口(支持流式输出)"""
# 获取当前用户和错误信息
+ current_user, err = get_current_user_or_error()
# 如果获取用户出错,直接返回错误
+ if err:
+ return err
# 获取指定id的知识库
+ kb = kb_service.get_by_id(kb_id)
# 检查当前用户是否有权限访问该知识库
+ has_permission, err = check_ownership(kb['user_id'], current_user['id'], "knowledgebase")
# 如果没有权限,直接返回错误
+ if not has_permission:
+ return err
# 获取请求中的JSON数据
+ data = request.get_json()
# 获取并去除问题字符串首尾空白
+ question = data['question'].strip()
# 从请求数据获取session_id,如果没有则为None
+ session_id = data.get('session_id') # 会话ID
# 获取最大token数,默认为1000
+ max_tokens = int(data.get('max_tokens', 1000))
# 限制max_tokens在1到10000之间
+ max_tokens = max(1, min(max_tokens, 10000)) # 限制在 1-10000 之间
# 如果没有提供session_id,则为用户和知识库创建一个新会话
+ if not session_id:
+ chat_session = session_service.create_session(
+ user_id=current_user['id'],
+ kb_id=kb_id
+ )
# 获取新会话的会话ID
+ session_id = chat_session['id']
# 保存用户输入的问题到消息列表
+ session_service.add_message(session_id, 'user', question)
# 内部函数:生成流式响应内容
+ @stream_with_context
+ def generate():
+ try:
# 初始化完整回复内容
+ full_answer = ''
# 初始化引用信息
+ sources = None
# 迭代chat_service.ask_stream的每个数据块
+ for chunk in chat_service.ask_stream(
+ kb_id=kb_id,
+ question=question
+ ):
# 如果块类型为内容,则将内容追加到full_answer
+ if chunk.get('type') == 'content':
+ full_answer += chunk.get('content', '')
# 如果块类型为done,则获取sources
+ elif chunk.get('type') == 'done':
+ sources = chunk.get('sources')
# 以SSE格式输出该块内容
+ yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
# 所有内容输出后发送结束标志
+ yield "data: [DONE]\n\n"
# 如果有回复内容,则保存机器人助手的回复和引用
+ if full_answer:
+ session_service.add_message(session_id, 'assistant', full_answer)
+ except Exception as e:
# 如果流式输出出错,在日志中记录错误信息
+ logger.error(f"流式输出时出错: {e}")
# 构造错误信息块
+ error_chunk = {
+ "type": "error",
+ "content": str(e)
+ }
# 以SSE格式输出错误信息
+ yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"
# 构造SSE(服务端事件)响应对象,携带合适的头部信息
+ response = Response(
+ generate(),
+ mimetype='text/event-stream',
+ headers={
+ 'Cache-Control': 'no-cache',
+ 'Connection': 'keep-alive',
+ 'X-Accel-Buffering': 'no',
+ 'Content-Type': 'text/event-stream; charset=utf-8'
+ }
+ )
# 返回响应对象
+ return response
7.3. chat_service.py #
app/services/chat_service.py
"""
问答服务
支持普通聊天和知识库聊天(RAG)
"""
# 导入日志模块
import logging
# 导入可选类型和迭代器类型注解
from typing import Optional, Iterator
# 导入 LLM 工厂,用于创建大语言模型实例
from app.utils.llm_factory import LLMFactory
# 导入 LangChain 的对话模板
from langchain_core.prompts import ChatPromptTemplate
# 导入设置服务,用于获取当前系统设置
from app.services.settings_service import settings_service
# 导入RAG服务
+from app.services.rag_service import rag_service
# 初始化日志记录器
logger = logging.getLogger(__name__)
# 定义问答服务类
class ChatService:
# 类的初始化方法
def __init__(self):
"""初始化问答服务"""
# 获取并保存当前的系统设置
self.settings = settings_service.get()
"""问答服务(支持普通聊天和RAG)"""
# 定义流式普通聊天方法,不使用知识库
def chat_stream(self, question: str, temperature: Optional[float] = None,
max_tokens: int = 1000, history: Optional[list] = None) -> Iterator[dict]:
"""
流式普通聊天接口(不使用知识库)
Args:
question: 问题
temperature: LLM 温度参数(如果为 None,则从设置中读取)
max_tokens: 最大生成 token 数
history: 历史对话记录(可选)
Yields:
流式数据块
"""
# 如果没有指定温度,则从设置中获取(默认为 0.7),并限制在 0-2 之间
if temperature is None:
temperature = float(self.settings.get('llm_temperature', '0.7'))
temperature = max(0.0, min(temperature, 2.0)) # 限制在 0-2 之间
# 获取用于普通聊天的系统提示词
chat_prompt_text = self.settings.get('chat_system_prompt')
# 如果系统提示词不存在,则使用默认的提示词
if not chat_prompt_text:
chat_prompt_text = '你是一个专业的AI助手。请友好、准确地回答用户的问题。'
# 创建支持流式输出的 LLM 实例
llm = LLMFactory.create_llm(self.settings, temperature=temperature, max_tokens=max_tokens, streaming=True)
# 构造单轮对话消息格式,包含 system 提示和用户问题
messages = [
("system", chat_prompt_text),
("human", question)
]
# 从消息创建对话提示模板
prompt = ChatPromptTemplate.from_messages(messages)
# 组装 prompt 和 llm,形成链式调用
chain = prompt | llm
# 发送流式开头信号
yield {
"type": "start",
"content": ""
}
# 初始化完整答案内容
full_answer = ""
try:
# 遍历模型生成的每一段内容
for chunk in chain.stream({}):
# 如果 chunk 有内容,提取内容并累加到full_answer
if hasattr(chunk, 'content') and chunk.content:
content = chunk.content
full_answer += content
# 输出内容块
yield {
"type": "content",
"content": content
}
# 捕获生成过程中的异常,记录日志并产出错误类型的数据块
except Exception as e:
logger.error(f"流式生成时出错: {e}")
yield {
"type": "error",
"content": f"生成答案时出错: {str(e)}"
}
return
# 发送流式结束信号,附带元数据(此处无知识库相关内容)
yield {
"type": "done",
"content": "",
"sources": [],
"metadata": {
'question': question,
'retrieved_chunks': 0,
'used_chunks': 0
}
}
+ def ask_stream(self, kb_id: str, question: str) -> Iterator[dict]:
+ """
+ 流式知识库问答接口(使用 RAG)
+ Args:
+ kb_id: 知识库ID
+ question: 问题
+ Yields:
+ 流式数据块
+ """
+ return rag_service.ask_stream(
+ kb_id=kb_id,
+ question=question
+ )
# 创建全局单例 chat_service 实例
chat_service=ChatService()7.4. chat.html #
app/templates/chat.html
{% extends "base.html" %}
{% block title %}智能问答 - RAG Lite{% endblock %}
{% block extra_css %}
<style>
.chat-container {
height: calc(100vh - 200px);
display: flex;
gap: 1rem;
}
.chat-sidebar {
width: 280px;
background: #fff;
border-radius: 0.5rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
display: flex;
flex-direction: column;
}
.chat-main {
flex: 1;
background: #fff;
border-radius: 0.5rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
display: flex;
flex-direction: column;
}
.session-list {
flex: 1;
overflow-y: auto;
padding: 0.5rem;
}
.session-item {
padding: 0.75rem;
margin-bottom: 0.5rem;
border-radius: 0.5rem;
cursor: pointer;
transition: background-color 0.2s;
position: relative;
}
.session-item:hover {
background-color: #f8f9fa;
}
.session-item.active {
background-color: #e3f2fd;
border-left: 3px solid #0d6efd;
}
.session-item .session-title {
font-weight: 500;
margin-bottom: 0.25rem;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.session-item .session-time {
font-size: 0.75rem;
color: #6c757d;
}
.session-item .session-delete {
position: absolute;
top: 0.5rem;
right: 0.5rem;
opacity: 0;
transition: opacity 0.2s;
}
.session-item:hover .session-delete {
opacity: 1;
}
.chat-messages {
flex: 1;
overflow-y: auto;
padding: 1rem;
scroll-behavior: smooth;
}
.chat-message {
padding: 1rem;
margin-bottom: 1rem;
border-radius: 0.5rem;
}
.chat-question {
background-color: #e3f2fd;
}
.chat-answer {
background-color: #f5f5f5;
}
.chat-input-area {
padding: 1rem;
border-top: 1px solid #dee2e6;
}
.empty-state {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 100%;
color: #6c757d;
}
</style>
{% endblock %}
{% block content %}
<div class="chat-container">
<!-- 左侧:会话管理 -->
<div class="chat-sidebar">
<div class="p-3 border-bottom">
<div class="d-flex justify-content-between align-items-center mb-3">
<h6 class="mb-0"><i class="bi bi-chat-left-text"></i> 聊天会话</h6>
<button class="btn btn-sm btn-primary" onclick="createNewSession()">
<i class="bi bi-plus"></i> 新建
</button>
</div>
<button class="btn btn-sm btn-outline-danger w-100" onclick="clearAllSessions()">
<i class="bi bi-trash"></i> 清空所有
</button>
</div>
<div class="session-list" id="sessionList">
<div class="text-center text-muted py-5">
<i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
<p class="mt-2 small">暂无会话</p>
</div>
</div>
</div>
<!-- 右侧:对话页面 -->
<div class="chat-main">
<div class="p-3 border-bottom">
<div class="d-flex justify-content-between align-items-center">
<h6 class="mb-0"><i class="bi bi-chat-dots"></i> 对话</h6>
<select class="form-select form-select-sm" id="kbSelect" style="width: 200px;" onchange="onKbChange()">
<option value="">-- 选择知识库 --</option>
{% for kb in knowledgebases %}
<option value="{{ kb.id }}">{{ kb.name }}</option>
{% endfor %}
</select>
</div>
</div>
<div class="chat-messages" id="chatMessages">
<div class="empty-state">
<i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
<p class="mt-2">开始提问吧!</p>
</div>
</div>
<div class="chat-input-area">
<form id="chatForm" onsubmit="askQuestion(event)">
<div class="mb-2">
<textarea class="form-control" id="questionInput" rows="2"
placeholder="输入您的问题..." required></textarea>
</div>
<div class="d-flex justify-content-end">
<button type="submit" class="btn btn-primary" id="submitBtn">
<i class="bi bi-send"></i> 发送
</button>
</div>
</form>
</div>
</div>
</div>
{% endblock %}
{% block extra_js %}
<script>
// 当前会话的ID,初始为null,表示暂未选择任何会话
let currentSessionId = null;
// 会话列表,初始为空数组,用于存储所有的会话对象
let sessions = [];
// 当前知识库的ID,初始为null,表示暂未选择任何知识库
let currentKbId = null;
// 为字符串进行HTML转义,防止XSS攻击
function escapeHtml(text) {
// 创建一个div元素作为容器
const div = document.createElement('div');
// 将待转义文本设置为div的textContent(自动完成转义)
div.textContent = text;
// 返回转义后的HTML内容
return div.innerHTML;
}
// 滚动消息框到底部
function scrollToBottom() {
// 获取聊天消息区域元素
const chatMessages = document.getElementById('chatMessages');
// 设置滚动条位置到最底部
chatMessages.scrollTop = chatMessages.scrollHeight;
}
// 定义一个用于渲染Markdown文本为HTML的函数
function renderMarkdown(text) {
// 判断marked库是否已加载且能使用解析方法
if (typeof marked !== 'undefined' && marked.parse) {
try {
// 尝试用marked库将Markdown文本解析成HTML
return marked.parse(text);
} catch (e) {
// 如果解析出错,则进行HTML转义并换行
return escapeHtml(text).replace(/\n/g, '<br>');
}
}
// 如果marked库不可用,直接做HTML转义并处理换行
return escapeHtml(text).replace(/\n/g, '<br>');
}
// 将Markdown内容渲染到指定元素(支持降级为普通文本)
function renderMarkdownToElement(element, text) {
// 若未提供内容,显示思考中图标
if (!text) {
element.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
return;
}
// 优先判断marked库是否可用(渲染markdown)
if (typeof marked !== 'undefined' && marked.parse) {
try {
// 使用marked进行markdown转html
element.innerHTML = marked.parse(text);
} catch (e) {
// 渲染失败则退化为转义+换行
element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
}
} else {
// 没有marked库则直接转义+换行显示
element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
}
}
// 主函数:处理用户提交问题事件
async function askQuestion(event) {
// 阻止表单默认提交行为(防止页面刷新)
event.preventDefault();
// 获取输入框的用户问题并去除首尾空白
const question = document.getElementById('questionInput').value.trim();
// 获取消息显示区域元素
const chatMessages = document.getElementById('chatMessages');
// 若问题为空则直接返回
if (!question) return;
// 如果没有会话,创建新会话
if (!currentSessionId) {
await createNewSession();
}
// 检查并移除初始空白提示(如有)
if (chatMessages.querySelector('.empty-state')) {
chatMessages.innerHTML = '';
}
// 创建用于展示问题的div元素
const questionDiv = document.createElement('div');
// 加上样式: 用户问题
questionDiv.className = 'chat-message chat-question';
// 构建用户气泡内容含图标、文本和时间
questionDiv.innerHTML = `
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-person-circle"></i> 问题:</strong>
<div class="mt-1">${escapeHtml(question)}</div>
</div>
<small class="text-muted">${new Date().toLocaleTimeString()}</small>
</div>
`;
// 显示到对话窗口
chatMessages.appendChild(questionDiv);
// 创建用于显示答案的div元素
const answerDiv = document.createElement('div');
// 答案样式
answerDiv.className = 'chat-message chat-answer';
// 动态生成唯一的答案内容div id(用于唯一标记)
const answerContentId = 'answerContent_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
// 创建存放答案内容的div
const answerContent = document.createElement('div');
// 为markdown渲染容器设置样式和id
answerContent.className = 'mt-2 markdown-content';
answerContent.id = answerContentId;
// 答案div显示机器人图标和时间
answerDiv.innerHTML = `
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-robot"></i> 答案:</strong>
</div>
<small class="text-muted">${new Date().toLocaleTimeString()}</small>
</div>
`;
// 将答案内容div插入到机器人气泡div的内容区
const flexGrowDiv = answerDiv.querySelector('.flex-grow-1');
flexGrowDiv.appendChild(answerContent);
// 插入答案div到消息区域
chatMessages.appendChild(answerDiv);
// 变量记录完整的回答内容
let fullAnswer = '';
// 标记渲染任务是否挂起(防止重复)
let pendingUpdate = false;
// 记录定时器id(去抖动用)
let updateTimer = null;
// 清空输入框内容
document.getElementById('questionInput').value = '';
// 滚动到底
scrollToBottom();
// 定义scheduleRender:将markdown渲染插入到队列合适时机执行
function scheduleRender() {
// 若当前无待渲染任务才安排渲染
if (!pendingUpdate) {
pendingUpdate = true;
// 下一帧渲染
requestAnimationFrame(() => {
// 将答案作为markdown渲染进dom
renderMarkdownToElement(answerContent, fullAnswer);
// 清理pending标识
pendingUpdate = false;
// 渲染后滚动到底部
scrollToBottom();
});
}
}
try {
// 组装API接口地址
+ const url = currentKbId
+ ? `/api/v1/knowledgebases/${currentKbId}/chat`
+ : `/api/v1/knowledgebases/chat`;
// 请求后端,发起流式POST请求
const response = await fetch(url, {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({
question: question,
session_id: currentSessionId,
stream: true
})
});
// 非200响应时,抛出异常
if (!response.ok) throw new Error('请求失败');
// 使用ReadableStream reader获取response流
const reader = response.body.getReader();
// 用TextDecoder解码二进制数据
const decoder = new TextDecoder();
// 数据缓冲区(字符串)
let buffer = '';
// 初始展示“思考中...”提示
answerContent.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
// 重新加载会话列表以更新标题
await loadSessions();
// 不断循环读取服务端推送流内容
while (true) {
// 逐步读取一段数据
const { done, value } = await reader.read();
// 若读完则结束循环
if (done) break;
// 本块新数据解码成字符串,并追加到buffer
buffer += decoder.decode(value, { stream: true });
// 按行切分(多行分别处理)
const lines = buffer.split('\n');
// 最后一行一般为数据残留,不处理,下次拼
buffer = lines.pop() || '';
// 处理每一行数据
for (const line of lines) {
// 筛选有效SSE协议"data: "数据包
if (line.startsWith('data: ')) {
// 去掉前缀,留下json内容
const data = line.slice(6);
// [DONE]信号作为流结束,仅跳过
if (data === '[DONE]') continue;
try {
// 解析JSON数据
const chunk = JSON.parse(data);
// 类型为流起始,清空内容
if (chunk.type === 'start') {
fullAnswer = '';
answerContent.innerHTML = '';
// 类型为内容,追加文本并安排去抖渲染
} else if (chunk.type === 'content') {
fullAnswer += chunk.content;
scheduleRender();
// 类型为done(流传输结束),直接最终渲染所有答案内容
} else if (chunk.type === 'done') {
renderMarkdownToElement(answerContent, fullAnswer);
// 错误类型,alert显示报错内容
} else if (chunk.type === 'error') {
answerContent.innerHTML = `<div class="alert alert-danger">${chunk.content}</div>`;
}
} catch (e) {
// JSON解析失败时报console
console.error('解析流数据失败:', e);
}
}
}
}
} catch (error) {
// 通信异常、Fetch错误等: 显示错误气泡
answerContent.innerHTML = `<div class="alert alert-danger"><strong>错误:</strong> ${error.message}</div>`;
}
// 结束后确保界面滚动到底
scrollToBottom();
}
// 清空所有会话
// 异步函数,用于清空所有会话
async function clearAllSessions() {
// 弹窗确认操作,若取消则直接返回
if (!confirm('确定要清空所有会话吗?此操作不可恢复!')) return;
try {
// 发送DELETE请求到服务器,删除所有会话
const response = await fetch('/api/v1/knowledgebases/sessions', {
method: 'DELETE'
});
// 获取服务器返回的JSON结果
const result = await response.json();
// 如果接口返回成功
if (result.code === 200) {
// 当前会话ID置空
currentSessionId = null;
// 清空聊天消息内容
clearChatMessages();
// 重新加载会话列表
await loadSessions();
}
} catch (error) {
// 捕获异常并弹窗提示失败原因
alert('清空会话失败: ' + error.message);
}
}
// 格式化时间
// 用于格式化时间字符串为友好显示
function formatTime(timeStr) {
// 如果无有效时间,直接返回空字符串
if (!timeStr) return '';
// 将字符串转换为Date对象
const date = new Date(timeStr);
// 获取当前时间
const now = new Date();
// 计算时间差(毫秒)
const diff = now - date;
// 换算为分钟数
const minutes = Math.floor(diff / 60000);
// 换算为小时数
const hours = Math.floor(diff / 3600000);
// 换算为天数
const days = Math.floor(diff / 86400000);
// 1分钟内显示“刚刚”
if (minutes < 1) return '刚刚';
// 1小时内显示“x分钟前”
if (minutes < 60) return `${minutes}分钟前`;
// 24小时内显示“x小时前”
if (hours < 24) return `${hours}小时前`;
// 7天内显示“x天前”
if (days < 7) return `${days}天前`;
// 超过7天显示具体日期
return date.toLocaleDateString();
}
// 渲染消息
function renderMessages(messages) {
const chatMessages = document.getElementById('chatMessages');
if (messages.length === 0) {
chatMessages.innerHTML = `
<div class="empty-state">
<i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
<p class="mt-2">开始提问吧!</p>
</div>
`;
return;
}
chatMessages.innerHTML = messages.map(msg => {
if (msg.role === 'user') {
return `
<div class="chat-message chat-question">
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-person-circle"></i> 问题:</strong>
<div class="mt-1">${escapeHtml(msg.content)}</div>
</div>
<small class="text-muted">${formatTime(msg.created_at)}</small>
</div>
</div>
`;
} else {
return `
<div class="chat-message chat-answer">
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-robot"></i> 答案:</strong>
<div class="mt-2 markdown-content">${renderMarkdown(msg.content)}</div>
</div>
<small class="text-muted">${formatTime(msg.created_at)}</small>
</div>
</div>
`;
}
}).join('');
scrollToBottom();
}
// 加载会话
// 定义一个异步函数 loadSession,用于加载指定 sessionId 的会话
async function loadSession(sessionId) {
// 异常捕获,防止请求或处理过程抛错
try {
// 发送 GET 请求,获取指定会话的详情数据
const response = await fetch(`/api/v1/knowledgebases/sessions/${sessionId}`);
// 将返回的响应数据解析为 JSON 格式
const result = await response.json();
// 如果接口请求成功,即 code 等于 200
if (result.code === 200) {
// 设置当前会话 ID 为传入的 sessionId
currentSessionId = sessionId;
// 获取会话详情数据
const session = result.data.session;
// 获取该会话包含的消息列表,如果没有则为 []
const messages = result.data.messages || [];
// 调用方法将消息渲染到页面
renderMessages(messages);
// 重新加载全部会话列表,刷新左侧会话栏
await loadSessions();
}
} catch (error) {
// 捕获异常并弹窗提示加载失败的原因
alert('加载会话失败: ' + error.message);
}
}
// 渲染会话列表
// 将sessions数组渲染到界面左侧会话列表
function renderSessions() {
// 获取会话列表的DOM节点
const sessionList = document.getElementById('sessionList');
// 如果没有任何会话,显示为空状态
if (sessions.length === 0) {
sessionList.innerHTML = `
<div class="text-center text-muted py-5">
<i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
<p class="mt-2 small">暂无会话</p>
</div>
`;
return;
}
// 遍历sessions并渲染为会话项
sessionList.innerHTML = sessions.map(session => `
<div class="session-item ${session.id === currentSessionId ? 'active' : ''}"
onclick="loadSession('${session.id}')">
<button class="btn btn-sm btn-link text-danger p-0 session-delete"
onclick="event.stopPropagation(); deleteSession('${session.id}')">
<i class="bi bi-x-lg"></i>
</button>
<div class="session-title">${session.title || '新对话'}</div>
<div class="session-time">${formatTime(session.updated_at)}</div>
</div>
`).join('');
}
// 清空聊天消息
// 将聊天内容区重置为初始状态
function clearChatMessages() {
document.getElementById('chatMessages').innerHTML = `
<div class="empty-state">
<i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
<p class="mt-2">开始提问吧!</p>
</div>
`;
}
// 加载会话列表
// 异步加载会话列表并渲染到界面
async function loadSessions() {
try {
// 发送GET请求获取会话列表
const response = await fetch('/api/v1/knowledgebases/sessions');
// 解析JSON结果
const result = await response.json();
// 如果接口返回成功
if (result.code === 200) {
// 保存会话数组
sessions = result.data.items || [];
// 渲染会话列表
renderSessions();
}
} catch (error) {
// 捕捉异常并在控制台打印
console.error('加载会话列表失败:', error);
}
}
// 创建新会话
// 异步请求服务端创建新会话
async function createNewSession() {
try {
// 发送POST请求创建新会话
const response = await fetch('/api/v1/knowledgebases/sessions', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({})
});
// 解析返回结果
const result = await response.json();
// 如果创建成功
if (result.code === 200) {
// 切换到新会话
currentSessionId = result.data.id;
// 重新加载会话列表
await loadSessions();
// 清空聊天内容
clearChatMessages();
}
} catch (error) {
// 捕获异常并弹窗提示
alert('创建会话失败: ' + error.message);
}
}
// 监听 DOMContentLoaded 事件,确保页面元素加载完成后执行
document.addEventListener('DOMContentLoaded', function() {
// 加载会话列表
loadSessions();
// 判断 marked.js 是否已经加载
if (typeof marked !== 'undefined') {
// 配置 marked.js 的渲染选项
marked.setOptions({
// 启用软换行
breaks: true,
// 启用 Github Flavored Markdown
gfm: true,
// 禁用标题自动生成 id
headerIds: false,
// 禁用混淆处理
mangle: false
});
}
});
// 知识库选择变化处理函数
async function onKbChange() {
// 更新全局当前知识库ID变量
currentKbId = document.getElementById('kbSelect').value;;
// 如果已存在会话,则切换知识库时重新新建一个会话
if (currentSessionId) {
await createNewSession();
}
// 启用发送按钮(普通聊天和知识库聊天均可提问)
document.getElementById('submitBtn').disabled = false;
}
</script>
{% endblock %}8.向量检索 #
8.1. retrieval_service.py #
app/services/retrieval_service.py
"""
检索服务
支持向量检索、全文检索和混合检索
支持重排序功能
"""
# 导入日志模块
import logging
# 导入类型注解
from typing import List, Optional, Tuple
# 导入Document对象
from langchain_core.documents import Document
# 导入设置信息服务
from app.services.settings_service import settings_service
# 导入向量数据库服务工厂方法
from app.services.vectordb.factory import get_vector_db_service
# 创建日志记录器
logger = logging.getLogger(__name__)
# 检索服务类
class RetrievalService:
"""检索服务"""
# 初始化方法
def __init__(self):
"""
初始化检索服务
Args:
settings: 设置字典,如果为 None 则从数据库读取
"""
# 从设置信息服务获取设置
self.settings = settings_service.get()
# 向量检索函数
def vector_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
"""
向量检索
Args:
collection_name: 集合名称
query: 查询文本
Returns:
(Document, similarity_score) 列表,按相似度降序排列
"""
# 获取向量数据库实例
vectordb = get_vector_db_service()
# 获取或创建指定集合的向量存储对象
vectorstore = vectordb.get_or_create_collection(collection_name)
# 从设置获取top_k,若不存在则默认为5
top_k = int(self.settings.get('top_k', '5'))
# 判断向量阈值是否已定义,若未定义则获取设定值或默认为0.2
vector_threshold = float(self.settings.get('vector_threshold', '0.2'))
# 限定向量阈值在0到1之间
vector_threshold = max(0.0, min(vector_threshold, 1.0))
# 以相似度得分方式检索,返回结果(扩大top_k数目以便后续过滤)
results = vectorstore.similarity_search_with_score(
query=query,
k=top_k*3
)
# 初始化文档以及得分的列表
docs_with_scores = []
# 遍历检索结果,将得分归一化并加入元数据
for doc, score in results:
# 计算归一化向量得分
vector_score = 1.0 / (1.0 + float(score))
# 存储向量得分到文档元数据
doc.metadata['vector_score'] = vector_score
# 标注检索类型为向量
doc.metadata['retrieval_type'] = 'vector'
# 加入列表
docs_with_scores.append((doc, vector_score))
# 按照相似度分数从高到低排序
docs_with_scores.sort(key=lambda x: x[1], reverse=True)
# 根据阈值过滤掉低于阈值的文档
filtered_docs = [(doc, score) for doc, score in docs_with_scores
if score >= vector_threshold]
# 仅保留top_k个文档用于返回
docs = [doc for doc, _ in filtered_docs[:top_k]]
# 日志打印检索到的文档个数
logger.info(f"向量搜索: 检索到 {len(docs)} 个文档")
# 返回结果
return docs
# 实例化检索服务,供外部调用
retrieval_service = RetrievalService()8.2. rag_service.py #
app/services/rag_service.py
"""
RAG 服务
"""
# 导入日志模块
import logging
# 导入 LangChain 的对话提示模板模块
from langchain_core.prompts import ChatPromptTemplate
# 导入自定义 LLM 工厂
from app.utils.llm_factory import LLMFactory
# 导入设置服务
from app.services.settings_service import settings_service
# 导入类型提示
+from typing import List
# 导入 LangChain 的 Document 类型
+from langchain_core.documents import Document
# 导入检索服务
+from app.services.retrieval_service import retrieval_service
# 设置日志对象
logger = logging.getLogger(__name__)
# 定义 RAGService 类
class RAGService:
"""RAG 服务"""
# 初始化函数
def __init__(self):
"""
初始化服务
Args:
settings: 设置字典,如果为 None 则从数据库读取
"""
# 从设置服务中获取配置信息
self.settings = settings_service.get()
# 定义默认系统消息提示词
default_rag_system_prompt = "你是一个专业的AI助手。请基于文档内容回答问题。"
# 定义默认查询提示词,包含 context 和 question 占位符
default_rag_query_prompt = """文档内容:
{context}
问题:{question}
请基于文档内容回答问题。如果文档中没有相关信息,请明确说明。"""
# 从设置中获取自定义系统消息提示词
rag_system_prompt_text = self.settings.get('rag_system_prompt')
# 如果没有设置,使用默认系统提示词
if not rag_system_prompt_text:
rag_system_prompt_text = default_rag_system_prompt
# 从设置中获取自定义查询提示词
rag_query_prompt_text = self.settings.get('rag_query_prompt')
# 如果没有设置,使用默认查询提示词
if not rag_query_prompt_text:
rag_query_prompt_text = default_rag_query_prompt
# 构建 RAG 的提示模板,包含系统消息和用户查询部分
self.rag_prompt = ChatPromptTemplate.from_messages([
("system", rag_system_prompt_text),
("human", rag_query_prompt_text)
])
# 定义流式问答接口
def ask_stream(self, kb_id: str, question: str):
"""
流式问答接口
Args:
kb_id: 知识库ID
question: 问题
Yields:
流式数据块
"""
# 创建带流式输出能力的 LLM 实例
llm = LLMFactory.create_llm(self.settings)
# 文档过滤后的结果
+ filtered_docs = self._retrieve_documents(kb_id, question)
# 发送流式开始信号
yield {
"type": "start",
"content": ""
}
# 构造用于传递给 LLM 的上下文字符串,将所有文档整合为字符串
context = "\n\n".join([
f"文档 {i+1} ({doc.metadata.get('doc_name', '未知')}):\n{doc.page_content}"
for i, doc in enumerate(filtered_docs)
])
# 创建 Rag Prompt 到 LLM 的处理链
chain = self.rag_prompt | llm
# 初始化完整答案的字符串
full_answer = ""
# 逐块流式生成答案
for chunk in chain.stream({"context": context, "question": question}):
# 获取当前输出块内容
content = chunk.content
# 如果有内容则累加并 yield 输出内容块
if content:
full_answer += content
yield {
"type": "content",
"content": content
}
# 所有内容输出结束后,发送完成信号和相关元数据
yield {
"type": "done",
"content": "",
"metadata": {
'kb_id': kb_id,
'question': question,
'retrieved_chunks': len(filtered_docs)
}
}
+ def _retrieve_documents(self, kb_id: str, question: str) -> List[Document]:
+ """
+ 检索文档(公共方法)
+ Args:
+ kb_id: 知识库ID
+ question: 查询文本
+ Returns:
+ 文档列表
+ """
+ collection_name = f"kb_{kb_id}"
+ retrieval_mode = self.settings.get('retrieval_mode', 'vector')
+ if retrieval_mode == 'vector':
+ docs = retrieval_service.vector_search(
+ collection_name=collection_name,
+ query=question
+ )
+ else:
+ logger.warning(f"未知的检索模式: {retrieval_mode}, 使用向量检索")
+ docs = retrieval_service.vector_search(
+ collection_name=collection_name,
+ query=question
+ )
+ logger.info(f"使用 {retrieval_mode} 模式检索到 {len(docs)} 个文档")
+ return docs
# 实例化 rag_service,供外部调用
rag_service = RAGService() 9.关键字检索 #
9.1. rag_service.py #
app/services/rag_service.py
"""
RAG 服务
"""
# 导入日志模块
import logging
# 导入 LangChain 的对话提示模板模块
from langchain_core.prompts import ChatPromptTemplate
# 导入自定义 LLM 工厂
from app.utils.llm_factory import LLMFactory
# 导入设置服务
from app.services.settings_service import settings_service
# 导入类型提示
from typing import List
# 导入 LangChain 的 Document 类型
from langchain_core.documents import Document
# 导入检索服务
from app.services.retrieval_service import retrieval_service
# 设置日志对象
logger = logging.getLogger(__name__)
# 定义 RAGService 类
class RAGService:
"""RAG 服务"""
# 初始化函数
def __init__(self):
"""
初始化服务
Args:
settings: 设置字典,如果为 None 则从数据库读取
"""
# 从设置服务中获取配置信息
self.settings = settings_service.get()
# 定义默认系统消息提示词
default_rag_system_prompt = "你是一个专业的AI助手。请基于文档内容回答问题。"
# 定义默认查询提示词,包含 context 和 question 占位符
default_rag_query_prompt = """文档内容:
{context}
问题:{question}
请基于文档内容回答问题。如果文档中没有相关信息,请明确说明。"""
# 从设置中获取自定义系统消息提示词
rag_system_prompt_text = self.settings.get('rag_system_prompt')
# 如果没有设置,使用默认系统提示词
if not rag_system_prompt_text:
rag_system_prompt_text = default_rag_system_prompt
# 从设置中获取自定义查询提示词
rag_query_prompt_text = self.settings.get('rag_query_prompt')
# 如果没有设置,使用默认查询提示词
if not rag_query_prompt_text:
rag_query_prompt_text = default_rag_query_prompt
# 构建 RAG 的提示模板,包含系统消息和用户查询部分
self.rag_prompt = ChatPromptTemplate.from_messages([
("system", rag_system_prompt_text),
("human", rag_query_prompt_text)
])
# 定义流式问答接口
def ask_stream(self, kb_id: str, question: str):
"""
流式问答接口
Args:
kb_id: 知识库ID
question: 问题
Yields:
流式数据块
"""
# 创建带流式输出能力的 LLM 实例
llm = LLMFactory.create_llm(self.settings)
# 文档过滤后的结果
filtered_docs = self._retrieve_documents(kb_id, question)
# 发送流式开始信号
yield {
"type": "start",
"content": ""
}
# 构造用于传递给 LLM 的上下文字符串,将所有文档整合为字符串
context = "\n\n".join([
f"文档 {i+1} ({doc.metadata.get('doc_name', '未知')}):\n{doc.page_content}"
for i, doc in enumerate(filtered_docs)
])
# 创建 Rag Prompt 到 LLM 的处理链
chain = self.rag_prompt | llm
# 初始化完整答案的字符串
full_answer = ""
# 逐块流式生成答案
for chunk in chain.stream({"context": context, "question": question}):
# 获取当前输出块内容
content = chunk.content
# 如果有内容则累加并 yield 输出内容块
if content:
full_answer += content
yield {
"type": "content",
"content": content
}
# 所有内容输出结束后,发送完成信号和相关元数据
yield {
"type": "done",
"content": "",
"metadata": {
'kb_id': kb_id,
'question': question,
'retrieved_chunks': len(filtered_docs)
}
}
def _retrieve_documents(self, kb_id: str, question: str) -> List[Document]:
"""
检索文档(公共方法)
Args:
kb_id: 知识库ID
question: 查询文本
Returns:
文档列表
"""
collection_name = f"kb_{kb_id}"
retrieval_mode = self.settings.get('retrieval_mode', 'vector')
if retrieval_mode == 'vector':
docs = retrieval_service.vector_search(
collection_name=collection_name,
query=question
)
+ elif retrieval_mode == 'keyword':
+ docs = retrieval_service.keyword_search(
+ collection_name=collection_name,
+ query=question
+ )
else:
logger.warning(f"未知的检索模式: {retrieval_mode}, 使用向量检索")
docs = retrieval_service.vector_search(
collection_name=collection_name,
query=question
)
logger.info(f"使用 {retrieval_mode} 模式检索到 {len(docs)} 个文档")
return docs
# 实例化 rag_service,供外部调用
rag_service = RAGService() 9.2. retrieval_service.py #
app/services/retrieval_service.py
"""
检索服务
支持向量检索、全文检索和混合检索
支持重排序功能
"""
# 导入日志模块
import logging
# 导入类型注解
from typing import List, Optional, Tuple
# 导入Document对象
from langchain_core.documents import Document
# 导入BM25Okapi
+from rank_bm25 import BM25Okapi
# 导入jieba
+import jieba
# 导入numpy
+import numpy as np
# 导入设置信息服务
from app.services.settings_service import settings_service
# 导入向量数据库服务工厂方法
from app.services.vectordb.factory import get_vector_db_service
# 创建日志记录器
logger = logging.getLogger(__name__)
# 检索服务类
class RetrievalService:
"""检索服务"""
# 初始化方法
def __init__(self):
"""
初始化检索服务
Args:
settings: 设置字典,如果为 None 则从数据库读取
"""
# 从设置信息服务获取设置
self.settings = settings_service.get()
# 向量检索函数
def vector_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
"""
向量检索
Args:
collection_name: 集合名称
query: 查询文本
Returns:
(Document, similarity_score) 列表,按相似度降序排列
"""
# 获取向量数据库实例
vectordb = get_vector_db_service()
# 获取或创建指定集合的向量存储对象
vectorstore = vectordb.get_or_create_collection(collection_name)
# 从设置获取top_k,若不存在则默认为5
top_k = int(self.settings.get('top_k', '5'))
# 判断向量阈值是否已定义,若未定义则获取设定值或默认为0.2
vector_threshold = float(self.settings.get('vector_threshold', '0.2'))
# 限定向量阈值在0到1之间
vector_threshold = max(0.0, min(vector_threshold, 1.0))
# 以相似度得分方式检索,返回结果(扩大top_k数目以便后续过滤)
results = vectorstore.similarity_search_with_score(
query=query,
k=top_k*3
)
# 初始化文档以及得分的列表
docs_with_scores = []
# 遍历检索结果,将得分归一化并加入元数据
for doc, score in results:
# 计算归一化向量得分
vector_score = 1.0 / (1.0 + float(score))
# 存储向量得分到文档元数据
doc.metadata['vector_score'] = vector_score
# 标注检索类型为向量
doc.metadata['retrieval_type'] = 'vector'
# 加入列表
docs_with_scores.append((doc, vector_score))
# 按照相似度分数从高到低排序
docs_with_scores.sort(key=lambda x: x[1], reverse=True)
# 根据阈值过滤掉低于阈值的文档
filtered_docs = [(doc, score) for doc, score in docs_with_scores
if score >= vector_threshold]
# 仅保留top_k个文档用于返回
docs = [doc for doc, _ in filtered_docs[:top_k]]
# 日志打印检索到的文档个数
logger.info(f"向量搜索: 检索到 {len(docs)} 个文档")
# 返回结果
+ return docs
+ def _tokenize_chinese(self, text: str) -> List[str]:
+ """
+ 中文分词(使用 jieba)
+ Args:
+ text: 输入文本
+ Returns:
+ 分词后的词列表
+ """
# 使用 jieba 分词
+ words = jieba.lcut(text)
# 去除停用词和单字
+ stopwords = set(['的', '了', '在', '是', '和', '有', '与', '对', '等', '为', '也', '就', '都', '要', '可以', '会', '能', '而', '及', '与', '或'])
+ tokens = [word.strip() for word in words if len(word.strip()) > 1 and word.strip() not in stopwords]
+ return tokens
# 定义关键字检索方法,使用 BM25 算法进行匹配
+ def keyword_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
+ """
+ 全文检索(使用 BM25 算法进行关键词匹配)
+ Args:
+ collection_name: 集合名称
+ query: 查询文本
+ Returns:
+ (Document, keyword_score) 列表,按匹配分数降序排列
+ """
+ try:
# 获取向量数据库服务实例
+ vectordb = get_vector_db_service()
# 获取或创建指定集合的向量存储对象
+ vectorstore = vectordb.get_or_create_collection(collection_name)
# 初始化用于存储所有文档的列表
+ all_docs = []
# 从底层集合获取所有内容
+ results = vectorstore._collection.get()
# 如果存在有效的检索结果且包含 'ids'
+ if results and 'ids' in results:
# 遍历所有文档 id
+ for i, _ in enumerate(results['ids']):
# 判断 'documents' 键存在且索引不越界
+ if 'documents' in results and i < len(results['documents']):
# 构建 Document 对象,提取内容和元数据
+ doc = Document(
+ page_content=results['documents'][i],
+ metadata=results.get('metadatas', [{}])[i] if 'metadatas' in results else {}
+ )
# 添加到所有文档列表
+ all_docs.append(doc)
# 提取所有文档的文本内容
+ documents = [doc.page_content for doc in all_docs]
# 对每个文档进行分词处理
+ tokenized_docs = [self._tokenize_chinese(doc) for doc in documents]
# 构建 BM25 索引
+ bm25 = BM25Okapi(tokenized_docs)
# 对查询语句进行中文分词
+ query_tokens = self._tokenize_chinese(query)
# 获取每个文档与查询的 BM25 分数
+ scores = bm25.get_scores(query_tokens)
# 计算分数的最大值,用于归一化分数到 [0, 1] 范围
+ max_score = float(np.max(scores)) if len(scores) > 0 and np.max(scores) > 0 else 1.0
# 归一化 BM25 分数
+ normalized_scores = scores / max_score if max_score > 0 else scores
# 获取关键字分数阈值,默认0.5
+ keyword_threshold = float(self.settings.get('keyword_threshold', '0.5'))
# 限定关键字阈值在0到1之间
+ keyword_threshold = max(0.0, min(keyword_threshold, 1.0))
# 获取返回文档数量 top_k,默认5
+ top_k = int(self.settings.get('top_k', '5'))
# 取分数最高的 top_k*3 个索引,便于后续过滤
+ top_indices = np.argsort(normalized_scores)[::-1][:top_k * 3]
# 初始化结果列表
+ docs_with_scores = []
# 遍历候选索引
+ for idx in top_indices:
# 取出每个文档归一化后的分数
+ normalized_score = float(normalized_scores[idx])
# 确保分数在 [0, 1] 之间
+ normalized_score = max(0.0, min(1.0, normalized_score))
# 仅保留分数高于阈值的文档
+ if normalized_score >= keyword_threshold:
# 取出对应文档
+ doc = all_docs[idx]
# 记录关键词得分到元数据
+ doc.metadata['keyword_score'] = normalized_score
# 标记检索类型
+ doc.metadata['retrieval_type'] = 'keyword'
# 添加到结果列表
+ docs_with_scores.append((doc, normalized_score))
# 按得分降序排序
+ docs_with_scores.sort(key=lambda x: x[1], reverse=True)
# 截取分数最高的前 top_k 个文档
+ docs = [doc for doc, _ in docs_with_scores[:top_k]]
# 记录 BM25 检索到的文档数量日志
+ logger.info(f"BM25 关键词搜索: 检索到 {len(docs)} 个文档")
# 返回检索结果
+ return docs
+ except Exception as e:
# 捕获异常并记录日志
+ logger.error(f"关键词搜索时出错: {e}")
# 继续抛出异常
+ raise
# 实例化检索服务,供外部调用
retrieval_service = RetrievalService()10.混合检索 #
10.1. rag_service.py #
app/services/rag_service.py
"""
RAG 服务
"""
# 导入日志模块
import logging
# 导入 LangChain 的对话提示模板模块
from langchain_core.prompts import ChatPromptTemplate
# 导入自定义 LLM 工厂
from app.utils.llm_factory import LLMFactory
# 导入设置服务
from app.services.settings_service import settings_service
# 导入类型提示
from typing import List
# 导入 LangChain 的 Document 类型
from langchain_core.documents import Document
# 导入检索服务
from app.services.retrieval_service import retrieval_service
# 设置日志对象
logger = logging.getLogger(__name__)
# 定义 RAGService 类
class RAGService:
"""RAG 服务"""
# 初始化函数
def __init__(self):
"""
初始化服务
Args:
settings: 设置字典,如果为 None 则从数据库读取
"""
# 从设置服务中获取配置信息
self.settings = settings_service.get()
# 定义默认系统消息提示词
default_rag_system_prompt = "你是一个专业的AI助手。请基于文档内容回答问题。"
# 定义默认查询提示词,包含 context 和 question 占位符
default_rag_query_prompt = """文档内容:
{context}
问题:{question}
请基于文档内容回答问题。如果文档中没有相关信息,请明确说明。"""
# 从设置中获取自定义系统消息提示词
rag_system_prompt_text = self.settings.get('rag_system_prompt')
# 如果没有设置,使用默认系统提示词
if not rag_system_prompt_text:
rag_system_prompt_text = default_rag_system_prompt
# 从设置中获取自定义查询提示词
rag_query_prompt_text = self.settings.get('rag_query_prompt')
# 如果没有设置,使用默认查询提示词
if not rag_query_prompt_text:
rag_query_prompt_text = default_rag_query_prompt
# 构建 RAG 的提示模板,包含系统消息和用户查询部分
self.rag_prompt = ChatPromptTemplate.from_messages([
("system", rag_system_prompt_text),
("human", rag_query_prompt_text)
])
# 定义流式问答接口
def ask_stream(self, kb_id: str, question: str):
"""
流式问答接口
Args:
kb_id: 知识库ID
question: 问题
Yields:
流式数据块
"""
# 创建带流式输出能力的 LLM 实例
llm = LLMFactory.create_llm(self.settings)
# 文档过滤后的结果
filtered_docs = self._retrieve_documents(kb_id, question)
# 发送流式开始信号
yield {
"type": "start",
"content": ""
}
# 构造用于传递给 LLM 的上下文字符串,将所有文档整合为字符串
context = "\n\n".join([
f"文档 {i+1} ({doc.metadata.get('doc_name', '未知')}):\n{doc.page_content}"
for i, doc in enumerate(filtered_docs)
])
# 创建 Rag Prompt 到 LLM 的处理链
chain = self.rag_prompt | llm
# 初始化完整答案的字符串
full_answer = ""
# 逐块流式生成答案
for chunk in chain.stream({"context": context, "question": question}):
# 获取当前输出块内容
content = chunk.content
# 如果有内容则累加并 yield 输出内容块
if content:
full_answer += content
yield {
"type": "content",
"content": content
}
# 所有内容输出结束后,发送完成信号和相关元数据
yield {
"type": "done",
"content": "",
"metadata": {
'kb_id': kb_id,
'question': question,
'retrieved_chunks': len(filtered_docs)
}
}
def _retrieve_documents(self, kb_id: str, question: str) -> List[Document]:
"""
检索文档(公共方法)
Args:
kb_id: 知识库ID
question: 查询文本
Returns:
文档列表
"""
collection_name = f"kb_{kb_id}"
retrieval_mode = self.settings.get('retrieval_mode', 'vector')
if retrieval_mode == 'vector':
docs = retrieval_service.vector_search(
collection_name=collection_name,
query=question
)
elif retrieval_mode == 'keyword':
docs = retrieval_service.keyword_search(
collection_name=collection_name,
query=question
+ )
+ elif retrieval_mode == 'hybrid':
+ docs = retrieval_service.hybrid_search(
+ collection_name=collection_name,
+ query=question
+ )
else:
logger.warning(f"未知的检索模式: {retrieval_mode}, 使用向量检索")
docs = retrieval_service.vector_search(
collection_name=collection_name,
query=question
)
logger.info(f"使用 {retrieval_mode} 模式检索到 {len(docs)} 个文档")
return docs
# 实例化 rag_service,供外部调用
rag_service = RAGService() 10.2. retrieval_service.py #
app/services/retrieval_service.py
"""
检索服务
支持向量检索、全文检索和混合检索
支持重排序功能
"""
# 导入日志模块
import logging
# 导入类型注解
from typing import List, Optional, Tuple
# 导入Document对象
from langchain_core.documents import Document
# 导入BM25Okapi
from rank_bm25 import BM25Okapi
# 导入jieba
import jieba
# 导入numpy
import numpy as np
# 导入设置信息服务
from app.services.settings_service import settings_service
# 导入向量数据库服务工厂方法
from app.services.vectordb.factory import get_vector_db_service
# 创建日志记录器
logger = logging.getLogger(__name__)
# 检索服务类
class RetrievalService:
"""检索服务"""
# 初始化方法
def __init__(self):
"""
初始化检索服务
Args:
settings: 设置字典,如果为 None 则从数据库读取
"""
# 从设置信息服务获取设置
self.settings = settings_service.get()
# 向量检索函数
def vector_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
"""
向量检索
Args:
collection_name: 集合名称
query: 查询文本
Returns:
(Document, similarity_score) 列表,按相似度降序排列
"""
# 获取向量数据库实例
vectordb = get_vector_db_service()
# 获取或创建指定集合的向量存储对象
vectorstore = vectordb.get_or_create_collection(collection_name)
# 从设置获取top_k,若不存在则默认为5
top_k = int(self.settings.get('top_k', '5'))
# 判断向量阈值是否已定义,若未定义则获取设定值或默认为0.2
vector_threshold = float(self.settings.get('vector_threshold', '0.2'))
# 限定向量阈值在0到1之间
vector_threshold = max(0.0, min(vector_threshold, 1.0))
# 以相似度得分方式检索,返回结果(扩大top_k数目以便后续过滤)
results = vectorstore.similarity_search_with_score(
query=query,
k=top_k*3
)
# 初始化文档以及得分的列表
docs_with_scores = []
# 遍历检索结果,将得分归一化并加入元数据
for doc, score in results:
# 计算归一化向量得分
vector_score = 1.0 / (1.0 + float(score))
# 存储向量得分到文档元数据
doc.metadata['vector_score'] = vector_score
# 标注检索类型为向量
doc.metadata['retrieval_type'] = 'vector'
# 加入列表
docs_with_scores.append((doc, vector_score))
# 按照相似度分数从高到低排序
docs_with_scores.sort(key=lambda x: x[1], reverse=True)
# 根据阈值过滤掉低于阈值的文档
filtered_docs = [(doc, score) for doc, score in docs_with_scores
if score >= vector_threshold]
# 仅保留top_k个文档用于返回
docs = [doc for doc, _ in filtered_docs[:top_k]]
# 日志打印检索到的文档个数
logger.info(f"向量搜索: 检索到 {len(docs)} 个文档")
# 返回结果
return docs
def _tokenize_chinese(self, text: str) -> List[str]:
"""
中文分词(使用 jieba)
Args:
text: 输入文本
Returns:
分词后的词列表
"""
# 使用 jieba 分词
words = jieba.lcut(text)
# 去除停用词和单字
stopwords = set(['的', '了', '在', '是', '和', '有', '与', '对', '等', '为', '也', '就', '都', '要', '可以', '会', '能', '而', '及', '与', '或'])
tokens = [word.strip() for word in words if len(word.strip()) > 1 and word.strip() not in stopwords]
return tokens
# 定义关键字检索方法,使用 BM25 算法进行匹配
def keyword_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
"""
全文检索(使用 BM25 算法进行关键词匹配)
Args:
collection_name: 集合名称
query: 查询文本
Returns:
(Document, keyword_score) 列表,按匹配分数降序排列
"""
try:
# 获取向量数据库服务实例
vectordb = get_vector_db_service()
# 获取或创建指定集合的向量存储对象
vectorstore = vectordb.get_or_create_collection(collection_name)
# 初始化用于存储所有文档的列表
all_docs = []
# 从底层集合获取所有内容
results = vectorstore._collection.get()
# 如果存在有效的检索结果且包含 'ids'
if results and 'ids' in results:
# 遍历所有文档 id
for i, _ in enumerate(results['ids']):
# 判断 'documents' 键存在且索引不越界
if 'documents' in results and i < len(results['documents']):
# 构建 Document 对象,提取内容和元数据
doc = Document(
page_content=results['documents'][i],
metadata=results.get('metadatas', [{}])[i] if 'metadatas' in results else {}
)
# 添加到所有文档列表
all_docs.append(doc)
# 提取所有文档的文本内容
documents = [doc.page_content for doc in all_docs]
# 对每个文档进行分词处理
tokenized_docs = [self._tokenize_chinese(doc) for doc in documents]
# 构建 BM25 索引
bm25 = BM25Okapi(tokenized_docs)
# 对查询语句进行中文分词
query_tokens = self._tokenize_chinese(query)
# 获取每个文档与查询的 BM25 分数
scores = bm25.get_scores(query_tokens)
# 计算分数的最大值,用于归一化分数到 [0, 1] 范围
max_score = float(np.max(scores)) if len(scores) > 0 and np.max(scores) > 0 else 1.0
# 归一化 BM25 分数
normalized_scores = scores / max_score if max_score > 0 else scores
# 获取关键字分数阈值,默认0.5
keyword_threshold = float(self.settings.get('keyword_threshold', '0.5'))
# 限定关键字阈值在0到1之间
keyword_threshold = max(0.0, min(keyword_threshold, 1.0))
# 获取返回文档数量 top_k,默认5
top_k = int(self.settings.get('top_k', '5'))
# 取分数最高的 top_k*3 个索引,便于后续过滤
top_indices = np.argsort(normalized_scores)[::-1][:top_k * 3]
# 初始化结果列表
docs_with_scores = []
# 遍历候选索引
for idx in top_indices:
# 取出每个文档归一化后的分数
normalized_score = float(normalized_scores[idx])
# 确保分数在 [0, 1] 之间
normalized_score = max(0.0, min(1.0, normalized_score))
# 仅保留分数高于阈值的文档
if normalized_score >= keyword_threshold:
# 取出对应文档
doc = all_docs[idx]
# 记录关键词得分到元数据
doc.metadata['keyword_score'] = normalized_score
# 标记检索类型
doc.metadata['retrieval_type'] = 'keyword'
# 添加到结果列表
docs_with_scores.append((doc, normalized_score))
# 按得分降序排序
docs_with_scores.sort(key=lambda x: x[1], reverse=True)
# 截取分数最高的前 top_k 个文档
docs = [doc for doc, _ in docs_with_scores[:top_k]]
# 记录 BM25 检索到的文档数量日志
logger.info(f"BM25 关键词搜索: 检索到 {len(docs)} 个文档")
# 返回检索结果
return docs
except Exception as e:
# 捕获异常并记录日志
logger.error(f"关键词搜索时出错: {e}")
# 继续抛出异常
raise
# 定义混合检索方法,使用RRF融合向量检索和全文检索
+ def hybrid_search(self, collection_name: str, query: str, rrf_k: int = 60) -> List[Tuple[Document, float]]:
+ """
+ 混合检索(使用 Reciprocal Rank Fusion (RRF) 融合向量检索和全文检索)
+ Args:
+ collection_name: 集合名称
+ query: 查询文本
+ rrf_k: RRF 常数,默认 60
+ Returns:
+ (Document, combined_score) 列表,按综合分数降序排列
+ """
+ try:
# 调用向量检索方法,得到向量检索结果
+ vector_results = self.vector_search(
+ collection_name=collection_name,
+ query=query,
+ )
# 调用关键词检索方法,得到关键词检索结果
+ keyword_results = self.keyword_search(
+ collection_name=collection_name,
+ query=query
+ )
# 创建字典用于存储文档及其排名信息
+ doc_ranks = {}
# 遍历向量检索结果,记录排名及分数
+ for rank, doc in enumerate(vector_results, start=1):
# 获取文档ID
+ doc_id = doc.metadata.get('id')
# 若该文档ID不在字典中,进行初始化
+ if doc_id not in doc_ranks:
+ doc_ranks[doc_id] = {'doc': doc}
# 记录向量排名
+ doc_ranks[doc_id]['vector_rank'] = rank
# 记录向量分数
+ doc_ranks[doc_id]['vector_score'] = doc.metadata.get('vector_score', 0.0)
# 遍历关键词检索结果,记录排名及分数
+ for rank, doc in enumerate(keyword_results, start=1):
# 获取文档ID
+ doc_id = doc.metadata.get('id')
# 若该文档ID不在字典中,进行初始化
+ if doc_id not in doc_ranks:
+ doc_ranks[doc_id] = {'doc': doc}
# 记录关键词排名
+ doc_ranks[doc_id]['keyword_rank'] = rank
# 记录关键词分数
+ doc_ranks[doc_id]['keyword_score'] = doc.metadata.get('keyword_score', 0.0)
# 从设置中读取向量权重,默认为0.3
+ vector_weight = float(self.settings.get('vector_weight', 0.3))
# 关键词权重等于1减去向量权重
+ keyword_weight = 1.0 - vector_weight
# 初始化混合结果列表
+ combined_results = []
# 遍历所有文档,计算RRF融合分数
+ for doc_id, ranks_info in doc_ranks.items():
# 获取向量排名(不存在默认0)
+ vector_rank = ranks_info.get('vector_rank', 0)
# 获取关键词排名(不存在默认0)
+ keyword_rank = ranks_info.get('keyword_rank', 0)
# 初始化RRF分数
+ rrf_score = 0.0
# 如果有向量排名,则累加向量RRF贡献
+ rrf_score += vector_weight / (rrf_k + vector_rank)
# 如果有关键词排名,则累加关键词RRF贡献
+ rrf_score += keyword_weight / (rrf_k + keyword_rank)
# 将RRF分数写入字典
+ doc_ranks[doc_id]['rrf_score'] = rrf_score
# 组装所有文档及其排名信息
+ combined_results = [(doc_id, rankInfo) for doc_id, rankInfo in doc_ranks.items()]
# 按RRF分数从高到低排序
+ combined_results.sort(key=lambda x: x[1].get('rrf_score', 0), reverse=True)
# 获取返回文档数量top_k,默认5
+ top_k = int(self.settings.get('top_k', 5))
# 遍历前top_k个文档,设置其元数据
+ docs = []
+ for doc_id, rankInfo in combined_results[:top_k]:
+ doc = rankInfo['doc']
# 更新向量分数到元数据
+ doc.metadata['vector_score'] = rankInfo.get('vector_score', 0.0)
# 更新关键词分数到元数据
+ doc.metadata['keyword_score'] = rankInfo.get('keyword_score', 0.0)
# 更新RRF分数到元数据
+ doc.metadata['rrf_score'] = rankInfo.get('rrf_score', 0.0)
# 标明检索类型为混合
+ doc.metadata['retrieval_type'] = 'hybrid'
# 添加到最终结果列表
+ docs.append(doc)
# 记录日志,输出检索到的文档数量
+ logger.info(f"混合搜索 (RRF): 检索到 {len(docs)} 个文档")
# 返回最终文档列表
+ return docs
+ except Exception as e:
# 捕获异常并打印日志
+ logger.error(f"混合搜索时出错: {e}")
# 继续抛出异常
+ raise
# 实例化检索服务,供外部调用
retrieval_service = RetrievalService()11.重排序 #
11.1. rerank_factory.py #
app/utils/rerank_factory.py
"""
Rerank 模型工厂
使用固定的 CrossEncoder 模型
"""
# 导入日志库
import logging
# 导入类型提示
from typing import Optional, List, Tuple
# 导入 Document 类
from langchain_core.documents import Document
# 导入 CrossEncoder 类
from sentence_transformers import CrossEncoder
# 创建日志记录器
logger = logging.getLogger(__name__)
# 定义 Rerank 工厂类
class RerankFactory:
"""Rerank 模型工厂"""
# 指定用于重排序的 CrossEncoder 模型名称
MODEL_NAME = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
@staticmethod
def create_reranker(settings: Optional[dict] = None):
"""
创建 Rerank 模型(使用固定的 CrossEncoder 模型)
Args:
settings: 忽略(未使用)
Returns:
Reranker 对象(总是返回,因为总是启用)
"""
# 返回 LocalReranker 实例
return LocalReranker()
# 定义重排序基类
class BaseReranker:
"""Rerank 基类"""
# 定义 rerank 接口,子类需要实现
def rerank(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Tuple[Document, float]]:
"""
对文档进行重排序
Args:
query: 查询文本
documents: 文档列表
top_k: 返回前N个结果,如果为None则返回所有
Returns:
(Document, score) 列表,按分数降序排列
"""
# 抛出未实现异常
raise NotImplementedError
# 本地 CrossEncoder 重排序实现类
class LocalReranker(BaseReranker):
"""本地 Rerank 模型(使用 CrossEncoder)"""
def __init__(self):
try:
# 实例化 CrossEncoder
self.reranker = CrossEncoder(RerankFactory.MODEL_NAME)
# 记录 CrossEncoder 初始化信息
logger.info(f"已创建 CrossEncoder 重排序器: {RerankFactory.MODEL_NAME}")
except ImportError:
# 没有安装依赖时抛出异常
raise ImportError("Please install sentence-transformers: pip install sentence-transformers")
except Exception as e:
# 其他初始化异常时记录日志并抛出异常
logger.error(f"初始化 CrossEncoder 重排序器时出错: {e}")
raise
def rerank(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Tuple[Document, float]]:
"""使用 CrossEncoder 进行重排序"""
# 如果文档为空,直接返回空列表
if not documents:
return []
# 如果没有指定 top_k,则默认为文档总数
top_k = top_k or len(documents)
try:
# 构造输入对,每个文档和 query 组成一组
pairs = [[query, doc.page_content] for doc in documents]
# 用 CrossEncoder 模型计算每个对的相关性分数
scores = self.reranker.predict(pairs)
# 将分数转换为列表
scores = list(scores)
# 将分数转为 float 类型
scores_float = [float(score) for score in scores]
# 计算分数中的最小值
min_score = min(scores_float) if scores_float else 0.0
# 计算分数中的最大值
max_score = max(scores_float) if scores_float else 1.0
# 对分数做归一化(如果最大最小相等,则全为0)
normalized_scores = [(score - min_score) / (max_score - min_score) if max_score > min_score else 0.0
for score in scores_float]
# 构造 (Document, score) 的元组,并限制 score 在 [0, 1] 范围
doc_scores = [(doc, max(0.0, min(1.0, score)))
for doc, score in zip(documents, normalized_scores)]
# 按归一化后分数降序排序
doc_scores.sort(key=lambda x: x[1], reverse=True)
# 记录重排序日志
logger.info(f"CrossEncoder 重排序: 已重排序 {len(doc_scores)} 个文档")
# 返回分数最高的 top_k 个文档
return doc_scores[:top_k]
except Exception as e:
# 发生异常时记录日志,并返回默认分数
logger.error(f"CrossEncoder 重排序时出错: {e}")
import traceback
logger.error(traceback.format_exc())
# 所有文档均赋分 0.5 返回
return [(doc, 0.5) for doc in documents[:top_k]]
11.2. retrieval_service.py #
app/services/retrieval_service.py
"""
检索服务
支持向量检索、全文检索和混合检索
支持重排序功能
"""
# 导入日志模块
import logging
# 导入类型注解
+from typing import List, Tuple
# 导入Document对象
from langchain_core.documents import Document
# 导入BM25Okapi
from rank_bm25 import BM25Okapi
# 导入jieba
import jieba
# 导入numpy
import numpy as np
# 导入设置信息服务
from app.services.settings_service import settings_service
# 导入向量数据库服务工厂方法
from app.services.vectordb.factory import get_vector_db_service
# 导入重排序工厂
+from app.utils.rerank_factory import RerankFactory
# 创建日志记录器
logger = logging.getLogger(__name__)
# 检索服务类
class RetrievalService:
"""检索服务"""
# 初始化方法
def __init__(self):
"""
初始化检索服务
Args:
settings: 设置字典,如果为 None 则从数据库读取
"""
# 从设置信息服务获取设置
self.settings = settings_service.get()
# 初始化reranker
+ self.reranker = RerankFactory.create_reranker(self.settings)
# 向量检索函数
def vector_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
"""
向量检索
Args:
collection_name: 集合名称
query: 查询文本
Returns:
(Document, similarity_score) 列表,按相似度降序排列
"""
# 获取向量数据库实例
vectordb = get_vector_db_service()
# 获取或创建指定集合的向量存储对象
vectorstore = vectordb.get_or_create_collection(collection_name)
# 从设置获取top_k,若不存在则默认为5
top_k = int(self.settings.get('top_k', '5'))
# 判断向量阈值是否已定义,若未定义则获取设定值或默认为0.2
vector_threshold = float(self.settings.get('vector_threshold', '0.2'))
# 限定向量阈值在0到1之间
vector_threshold = max(0.0, min(vector_threshold, 1.0))
# 以相似度得分方式检索,返回结果(扩大top_k数目以便后续过滤)
results = vectorstore.similarity_search_with_score(
query=query,
k=top_k*3
)
# 初始化文档以及得分的列表
docs_with_scores = []
# 遍历检索结果,将得分归一化并加入元数据
for doc, score in results:
# 计算归一化向量得分
vector_score = 1.0 / (1.0 + float(score))
# 存储向量得分到文档元数据
doc.metadata['vector_score'] = vector_score
# 标注检索类型为向量
doc.metadata['retrieval_type'] = 'vector'
# 加入列表
docs_with_scores.append((doc, vector_score))
# 按照相似度分数从高到低排序
docs_with_scores.sort(key=lambda x: x[1], reverse=True)
# 根据阈值过滤掉低于阈值的文档
filtered_docs = [(doc, score) for doc, score in docs_with_scores
if score >= vector_threshold]
# 仅保留top_k个文档用于返回
docs = [doc for doc, _ in filtered_docs[:top_k]]
# 应用重排序
+ if self.reranker:
+ docs = self._apply_rerank(query, docs, top_k)
# 日志打印检索到的文档个数
logger.info(f"向量搜索: 检索到 {len(docs)} 个文档")
# 返回结果
return docs
def _tokenize_chinese(self, text: str) -> List[str]:
"""
中文分词(使用 jieba)
Args:
text: 输入文本
Returns:
分词后的词列表
"""
# 使用 jieba 分词
words = jieba.lcut(text)
# 去除停用词和单字
stopwords = set(['的', '了', '在', '是', '和', '有', '与', '对', '等', '为', '也', '就', '都', '要', '可以', '会', '能', '而', '及', '与', '或'])
tokens = [word.strip() for word in words if len(word.strip()) > 1 and word.strip() not in stopwords]
return tokens
# 定义关键字检索方法,使用 BM25 算法进行匹配
def keyword_search(self, collection_name: str, query: str) -> List[Tuple[Document, float]]:
"""
全文检索(使用 BM25 算法进行关键词匹配)
Args:
collection_name: 集合名称
query: 查询文本
Returns:
(Document, keyword_score) 列表,按匹配分数降序排列
"""
try:
# 获取向量数据库服务实例
vectordb = get_vector_db_service()
# 获取或创建指定集合的向量存储对象
vectorstore = vectordb.get_or_create_collection(collection_name)
# 初始化用于存储所有文档的列表
all_docs = []
# 从底层集合获取所有内容
results = vectorstore._collection.get()
# 如果存在有效的检索结果且包含 'ids'
if results and 'ids' in results:
# 遍历所有文档 id
for i, _ in enumerate(results['ids']):
# 判断 'documents' 键存在且索引不越界
if 'documents' in results and i < len(results['documents']):
# 构建 Document 对象,提取内容和元数据
doc = Document(
page_content=results['documents'][i],
metadata=results.get('metadatas', [{}])[i] if 'metadatas' in results else {}
)
# 添加到所有文档列表
all_docs.append(doc)
# 提取所有文档的文本内容
documents = [doc.page_content for doc in all_docs]
# 对每个文档进行分词处理
tokenized_docs = [self._tokenize_chinese(doc) for doc in documents]
# 构建 BM25 索引
bm25 = BM25Okapi(tokenized_docs)
# 对查询语句进行中文分词
query_tokens = self._tokenize_chinese(query)
# 获取每个文档与查询的 BM25 分数
scores = bm25.get_scores(query_tokens)
# 计算分数的最大值,用于归一化分数到 [0, 1] 范围
max_score = float(np.max(scores)) if len(scores) > 0 and np.max(scores) > 0 else 1.0
# 归一化 BM25 分数
normalized_scores = scores / max_score if max_score > 0 else scores
# 获取关键字分数阈值,默认0.5
keyword_threshold = float(self.settings.get('keyword_threshold', '0.5'))
# 限定关键字阈值在0到1之间
keyword_threshold = max(0.0, min(keyword_threshold, 1.0))
# 获取返回文档数量 top_k,默认5
top_k = int(self.settings.get('top_k', '5'))
# 取分数最高的 top_k*3 个索引,便于后续过滤
top_indices = np.argsort(normalized_scores)[::-1][:top_k * 3]
# 初始化结果列表
docs_with_scores = []
# 遍历候选索引
for idx in top_indices:
# 取出每个文档归一化后的分数
normalized_score = float(normalized_scores[idx])
# 确保分数在 [0, 1] 之间
normalized_score = max(0.0, min(1.0, normalized_score))
# 仅保留分数高于阈值的文档
if normalized_score >= keyword_threshold:
# 取出对应文档
doc = all_docs[idx]
# 记录关键词得分到元数据
doc.metadata['keyword_score'] = normalized_score
# 标记检索类型
doc.metadata['retrieval_type'] = 'keyword'
# 添加到结果列表
docs_with_scores.append((doc, normalized_score))
# 按得分降序排序
docs_with_scores.sort(key=lambda x: x[1], reverse=True)
# 截取分数最高的前 top_k 个文档
docs = [doc for doc, _ in docs_with_scores[:top_k]]
# 应用重排序
+ if self.reranker:
+ docs = self._apply_rerank(query, docs, top_k)
# 记录 BM25 检索到的文档数量日志
logger.info(f"BM25 关键词搜索: 检索到 {len(docs)} 个文档")
# 返回检索结果
return docs
except Exception as e:
# 捕获异常并记录日志
logger.error(f"关键词搜索时出错: {e}")
# 继续抛出异常
raise
# 定义混合检索方法,使用RRF融合向量检索和全文检索
def hybrid_search(self, collection_name: str, query: str, rrf_k: int = 60) -> List[Tuple[Document, float]]:
"""
混合检索(使用 Reciprocal Rank Fusion (RRF) 融合向量检索和全文检索)
Args:
collection_name: 集合名称
query: 查询文本
rrf_k: RRF 常数,默认 60
Returns:
(Document, combined_score) 列表,按综合分数降序排列
"""
try:
# 调用向量检索方法,得到向量检索结果
vector_results = self.vector_search(
collection_name=collection_name,
query=query,
)
# 调用关键词检索方法,得到关键词检索结果
keyword_results = self.keyword_search(
collection_name=collection_name,
query=query
)
# 创建字典用于存储文档及其排名信息
doc_ranks = {}
# 遍历向量检索结果,记录排名及分数
for rank, doc in enumerate(vector_results, start=1):
# 获取文档ID
doc_id = doc.metadata.get('id')
# 若该文档ID不在字典中,进行初始化
if doc_id not in doc_ranks:
doc_ranks[doc_id] = {'doc': doc}
# 记录向量排名
doc_ranks[doc_id]['vector_rank'] = rank
# 记录向量分数
doc_ranks[doc_id]['vector_score'] = doc.metadata.get('vector_score', 0.0)
# 遍历关键词检索结果,记录排名及分数
for rank, doc in enumerate(keyword_results, start=1):
# 获取文档ID
doc_id = doc.metadata.get('id')
# 若该文档ID不在字典中,进行初始化
if doc_id not in doc_ranks:
doc_ranks[doc_id] = {'doc': doc}
# 记录关键词排名
doc_ranks[doc_id]['keyword_rank'] = rank
# 记录关键词分数
doc_ranks[doc_id]['keyword_score'] = doc.metadata.get('keyword_score', 0.0)
# 从设置中读取向量权重,默认为0.3
vector_weight = float(self.settings.get('vector_weight', 0.3))
# 关键词权重等于1减去向量权重
keyword_weight = 1.0 - vector_weight
# 初始化混合结果列表
combined_results = []
# 遍历所有文档,计算RRF融合分数
for doc_id, ranks_info in doc_ranks.items():
# 获取向量排名(不存在默认0)
vector_rank = ranks_info.get('vector_rank', 0)
# 获取关键词排名(不存在默认0)
keyword_rank = ranks_info.get('keyword_rank', 0)
# 初始化RRF分数
rrf_score = 0.0
# 如果有向量排名,则累加向量RRF贡献
rrf_score += vector_weight / (rrf_k + vector_rank)
# 如果有关键词排名,则累加关键词RRF贡献
rrf_score += keyword_weight / (rrf_k + keyword_rank)
# 将RRF分数写入字典
doc_ranks[doc_id]['rrf_score'] = rrf_score
# 组装所有文档及其排名信息
combined_results = [(doc_id, rankInfo) for doc_id, rankInfo in doc_ranks.items()]
# 按RRF分数从高到低排序
combined_results.sort(key=lambda x: x[1].get('rrf_score', 0), reverse=True)
# 获取返回文档数量top_k,默认5
top_k = int(self.settings.get('top_k', 5))
# 遍历前top_k个文档,设置其元数据
docs = []
for doc_id, rankInfo in combined_results[:top_k]:
doc = rankInfo['doc']
# 更新向量分数到元数据
doc.metadata['vector_score'] = rankInfo.get('vector_score', 0.0)
# 更新关键词分数到元数据
doc.metadata['keyword_score'] = rankInfo.get('keyword_score', 0.0)
# 更新RRF分数到元数据
doc.metadata['rrf_score'] = rankInfo.get('rrf_score', 0.0)
# 标明检索类型为混合
doc.metadata['retrieval_type'] = 'hybrid'
# 添加到最终结果列表
docs.append(doc)
# 应用重排序
+ if self.reranker:
+ docs = self._apply_rerank(query, docs, top_k)
# 记录日志,输出检索到的文档数量
logger.info(f"混合搜索 (RRF): 检索到 {len(docs)} 个文档")
# 返回最终文档列表
return docs
except Exception as e:
# 捕获异常并打印日志
logger.error(f"混合搜索时出错: {e}")
# 继续抛出异常
raise
# 定义应用重排序的方法,参数包括查询query,带有分数的文档列表,返回前top_k结果
+ def _apply_rerank(self, query: str, docs: List[Document], top_k: int) -> List[Document]:
+ """
+ 应用重排序
+ Args:
+ query: 查询文本
+ docs: Document 列表
+ top_k: 返回前k个结果
+ Returns:
+ 重排序后的 (Document, score) 列表
+ """
# 如果没有重排序器或者输入为空,直接返回原始结果
+ if not self.reranker or not docs:
+ return docs
+ try:
# 使用重排序器对文档进行重排序,返回带有新分数的文档列表
+ reranked = self.reranker.rerank(query, docs, top_k=top_k)
# 遍历重排序后的结果,更新文档元数据
+ for doc, rerank_score in reranked:
# 新增 rerank_score 到元数据
+ doc.metadata['rerank_score'] = rerank_score
# 标记检索类型为原有类型+rerank,若没有原有类型则标记为unknown
+ doc.metadata['retrieval_type'] = doc.metadata.get('retrieval_type', 'unknown')
# 打印重排序成功的日志
+ logger.info(f"已应用重排序: {len(reranked)} 个文档已重排序")
# 返回带有重排序分数的文档列表
+ return [doc for doc, _ in reranked]
+ except Exception as e:
# 如果重排序过程中发生错误,打印日志,并返回原始结果
+ logger.error(f"应用重排序时出错: {e}")
# 如果rerank失败,返回原始结果
+ return docs
# 实例化检索服务,供外部调用
retrieval_service = RetrievalService()12.显示来源 #
12.1. chat.py #
app/blueprints/chat.py
# 聊天相关路由(视图 + API)
"""
聊天相关路由(视图 + API)
"""
# 导入 Flask 的 Blueprint 和模板渲染函数
from flask import Blueprint, render_template, request, stream_with_context, Response
import json
# 导入日志模块
import logging
# 导入登录保护装饰器和获取当前用户辅助方法
from app.utils.auth import login_required, api_login_required,get_current_user
# 导入知识库服务
from app.services.knowledgebase_service import kb_service
# 导入自定义工具函数:成功响应、错误响应、获取分页参数、获取当前用户或错误、异常处理装饰器、检查所有权
from app.blueprints.utils import (
success_response, error_response,
get_current_user_or_error, handle_api_error, get_pagination_params,check_ownership
)
from app.services.chat_service import chat_service
from app.services.chat_session_service import session_service
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 创建名为 'chat' 的蓝图对象
bp = Blueprint('chat', __name__)
# 注册 /chat 路由,访问该路由需要先登录
@bp.route('/chat')
@login_required
def chat_view():
# 智能问答页面视图函数
"""智能问答页面"""
current_user = get_current_user()
# 获取所有知识库(通常用户不会有太多知识库,不需要分页)
result = kb_service.list(user_id=current_user['id'], page=1, page_size=1000)
# 渲染 chat.html 模板并传递空知识库列表
return render_template('chat.html', knowledgebases=result['items'])
# 注册 API 路由,处理聊天接口 POST 请求
@bp.route('/api/v1/knowledgebases/chat', methods=['POST'])
@api_login_required
@handle_api_error
def api_chat():
# 普通聊天接口(不支持知识库,支持流式输出)
"""普通聊天接口(不支持知识库,支持流式输出)"""
# 获取当前用户和错误信息
current_user, err = get_current_user_or_error()
# 如果有错误,直接返回错误响应
if err:
return err
# 从请求体获取 JSON 数据
data = request.get_json()
# 如果数据为空或不存在 'question' 字段,返回错误
if not data or 'question' not in data:
return error_response("question is required", 400)
# 去除问题文本首尾空格
question = data['question'].strip()
# 如果问题内容为空,返回错误
if not question:
return error_response("question cannot be empty", 400)
session_id = data.get('session_id') # 会话ID(可以为空,表示普通聊天)
# 获取 max_tokens 参数,默认 1000
max_tokens = int(data.get('max_tokens', 1000))
# 限制最大和最小值在 1~10000 之间
max_tokens = max(1, min(max_tokens, 10000)) # 限制在 1-10000 之间
# 从请求数据中获取'stream'字段,默认为True,表示启用流式输出
stream = data.get('stream', True) # 默认启用流式输出
# 初始化历史消息为None
history = None
# 如果请求中带有session_id,说明有现有会话
if session_id:
# 根据session_id和当前用户ID获取历史消息列表
history_messages = session_service.get_messages(session_id, current_user['id'])
# 将历史消息转换为对话格式,仅保留最近10条
history = [
{'role': msg.get('role'), 'content': msg.get('content')}
for msg in history_messages[-10:] # 只取最近10条
]
# 如果请求中没有session_id,说明是新对话,需要新建会话
if not session_id:
# 创建新会话,kb_id设为None表示普通聊天
chat_session = session_service.create_session(
user_id=current_user['id']
)
# 使用新创建会话的ID作为本次会话ID
session_id = chat_session['id']
# 将用户的问题消息保存到当前会话中
session_service.add_message(session_id, 'user', question)
# 声明用于流式输出的生成器
@stream_with_context
def generate():
try:
# 用于缓存完整答案内容
full_answer = ''
# 调用服务进行流式对话
for chunk in chat_service.chat_stream(
question=question,
temperature=None, # 使用设置中的值
max_tokens=max_tokens,
history=history
):
# 如果是内容块,则拼接内容到 full_answer
if chunk.get('type') == 'content':
full_answer += chunk.get('content', '')
# 以 SSE 协议格式输出数据
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
# 输出对话完成信号
yield "data: [DONE]\n\n"
# 保存助手回复
if full_answer:
session_service.add_message(session_id, 'assistant', full_answer)
except Exception as e:
# 发生异常记录日志
logger.error(f"流式输出时出错: {e}")
# 构造错误数据块
error_chunk = {
"type": "error",
"content": str(e)
}
# 输出错误数据块
yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"
# 创建 Response 对象,设置必要的 SSE 响应头部
response = Response(
generate(),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Accel-Buffering': 'no',
'Content-Type': 'text/event-stream; charset=utf-8'
}
)
# 返回响应
return response
# 路由装饰器,定义 GET 方法获取会话列表的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['GET'])
# API 登录校验装饰器,确保用户已登录
@api_login_required
# 错误处理装饰器,统一处理接口异常
@handle_api_error
def api_list_sessions():
# 接口描述:获取当前用户的会话列表
"""获取当前用户的会话列表"""
# 获取当前用户,如有错误直接返回错误响应
current_user, err = get_current_user_or_error()
if err:
return err
# 获取分页参数(页码和每页数量),最大单页1000
page, page_size = get_pagination_params(max_page_size=1000)
# 调用会话服务获取当前用户的会话列表
result = session_service.list_sessions(current_user['id'], page=page, page_size=page_size)
# 以统一成功响应格式返回会话列表
return success_response(result)
# 路由装饰器,定义 POST 方法创建会话的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['POST'])
@api_login_required
@handle_api_error
def api_create_session():
# 接口描述:创建新的聊天会话
"""创建新的聊天会话"""
# 获取当前用户,如果有错误直接返回
current_user, err = get_current_user_or_error()
if err:
return err
# 获取请求体中的 JSON 数据,若无返回空字典
data = request.get_json() or {}
# 获取会话标题
title = data.get('title')
# 调用服务创建会话,传入当前用户ID、知识库ID与标题
session_obj = session_service.create_session(
user_id=current_user['id'],
title=title
)
# 返回成功响应及会话对象
return success_response(session_obj)
# 路由装饰器,定义 GET 方法获取单个会话详情的接口(带 session_id)
@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['GET'])
@api_login_required
@handle_api_error
def api_get_session(session_id):
# 接口描述:获取会话详情和消息
"""获取会话详情和消息"""
# 获取当前用户,如有错误直接返回
current_user, err = get_current_user_or_error()
if err:
return err
# 根据 session_id 获取会话对象,校验所属当前用户
session_obj = session_service.get_session_by_id(session_id, current_user['id'])
# 如果没有找到会话,返回 404 错误
if not session_obj:
return error_response("Session not found", 404)
# 获取该会话下的所有消息
messages = session_service.get_messages(session_id, current_user['id'])
# 返回会话详情及消息列表
return success_response({
'session': session_obj,
'messages': messages
})
# 路由装饰器,定义 DELETE 方法删除单个会话接口
@bp.route('/api/v1/knowledgebases/sessions/<session_id>', methods=['DELETE'])
@api_login_required
@handle_api_error
def api_delete_session(session_id):
# 接口描述:删除会话
"""删除会话"""
# 获取当前用户,如有错误直接返回
current_user, err = get_current_user_or_error()
if err:
return err
# 调用服务删除会话,校验归属当前用户
success = session_service.delete_session(session_id, current_user['id'])
# 若删除成功,返回成功响应,否则返回 404
if success:
return success_response(None, "Session deleted")
else:
return error_response("Session not found", 404)
# 路由装饰器,定义 DELETE 方法清空所有会话的接口
@bp.route('/api/v1/knowledgebases/sessions', methods=['DELETE'])
@api_login_required
@handle_api_error
def api_delete_all_sessions():
# 接口描述:清空所有会话
"""清空所有会话"""
# 获取当前用户,如果有错误直接返回
current_user, err = get_current_user_or_error()
if err:
return err
# 调用服务删除所有属于当前用户的会话,返回删除数量
count = session_service.delete_all_sessions(current_user['id'])
# 返回成功响应及被删除会话数
return success_response({'deleted_count': count}, f"Deleted {count} sessions")
# 路由装饰器,指定POST方法用于知识库问答接口
@bp.route('/api/v1/knowledgebases/<kb_id>/chat', methods=['POST'])
# 装饰器:需要API登录
@api_login_required
# 装饰器:统一处理API错误
@handle_api_error
def api_ask(kb_id):
# 知识库问答接口(支持流式输出)
"""知识库问答接口(支持流式输出)"""
# 获取当前用户和错误信息
current_user, err = get_current_user_or_error()
# 如果获取用户出错,直接返回错误
if err:
return err
# 获取指定id的知识库
kb = kb_service.get_by_id(kb_id)
# 检查当前用户是否有权限访问该知识库
has_permission, err = check_ownership(kb['user_id'], current_user['id'], "knowledgebase")
# 如果没有权限,直接返回错误
if not has_permission:
return err
# 获取请求中的JSON数据
data = request.get_json()
# 获取并去除问题字符串首尾空白
question = data['question'].strip()
# 从请求数据获取session_id,如果没有则为None
session_id = data.get('session_id') # 会话ID
# 获取最大token数,默认为1000
max_tokens = int(data.get('max_tokens', 1000))
# 限制max_tokens在1到10000之间
max_tokens = max(1, min(max_tokens, 10000)) # 限制在 1-10000 之间
# 如果没有提供session_id,则为用户和知识库创建一个新会话
if not session_id:
chat_session = session_service.create_session(
user_id=current_user['id'],
kb_id=kb_id
)
# 获取新会话的会话ID
session_id = chat_session['id']
# 保存用户输入的问题到消息列表
session_service.add_message(session_id, 'user', question)
# 内部函数:生成流式响应内容
@stream_with_context
def generate():
try:
# 初始化完整回复内容
full_answer = ''
# 初始化引用信息
sources = None
# 迭代chat_service.ask_stream的每个数据块
for chunk in chat_service.ask_stream(
kb_id=kb_id,
question=question
):
# 如果块类型为内容,则将内容追加到full_answer
if chunk.get('type') == 'content':
full_answer += chunk.get('content', '')
# 如果块类型为done,则获取sources
elif chunk.get('type') == 'done':
sources = chunk.get('sources')
# 以SSE格式输出该块内容
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
# 所有内容输出后发送结束标志
yield "data: [DONE]\n\n"
# 如果有回复内容,则保存机器人助手的回复和引用
if full_answer:
+ session_service.add_message(session_id, 'assistant', full_answer, sources)
except Exception as e:
# 如果流式输出出错,在日志中记录错误信息
logger.error(f"流式输出时出错: {e}")
# 构造错误信息块
error_chunk = {
"type": "error",
"content": str(e)
}
# 以SSE格式输出错误信息
yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"
# 构造SSE(服务端事件)响应对象,携带合适的头部信息
response = Response(
generate(),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Accel-Buffering': 'no',
'Content-Type': 'text/event-stream; charset=utf-8'
}
)
# 返回响应对象
return response
12.2. chat_session_service.py #
app/services/chat_session_service.py
# 导入倒序排序工具
from sqlalchemy import desc
# 导入日期时间
from datetime import datetime
# 导入聊天会话ORM模型
from app.models.chat_session import ChatSession
# 导入聊天消息ORM模型
from app.models.chat_message import ChatMessage
# 导入基础服务类
from app.services.base_service import BaseService
# 聊天会话服务类,继承自基础服务
class ChatSessionService(BaseService[ChatSession]):
# 聊天会话服务说明文档
"""聊天会话服务"""
# 创建新的聊天会话
def create_session(self, user_id: str, kb_id: str = None, title: str = None) -> dict:
"""
创建新的聊天会话
Args:
user_id: 用户ID
kb_id: 知识库ID(可选)
title: 会话标题(可选,如果不提供则使用默认标题)
Returns:
会话信息字典
"""
# 启动数据库事务
with self.transaction() as session:
# 如果没有传标题则用默认标题
if not title:
title = "新对话"
# 构造会话对象
chat_session = ChatSession(
user_id=user_id,
title=title
)
# 新会话入库
session.add(chat_session)
# 刷新以拿到自增ID
session.flush()
# 刷新会话对象,便于获取ID等字段
session.refresh(chat_session)
# 记录日志
self.logger.info(f"已创建聊天会话: {chat_session.id}, 用户: {user_id}")
# 返回会话字典格式
return chat_session.to_dict()
# 根据ID获取会话
def get_session_by_id(self, session_id: str, user_id: str = None) -> dict:
"""
根据ID获取会话
Args:
session_id: 会话ID
user_id: 用户ID(可选,用于验证权限)
Returns:
会话信息字典,如果不存在或无权访问则返回 None
"""
# 打开数据库只读session
with self.session() as session:
# 查询指定ID的会话
query = session.query(ChatSession).filter_by(id=session_id)
# 如果提供了user_id则额外限定归属
if user_id:
query = query.filter_by(user_id=user_id)
# 拿到第一个会话记录
chat_session = query.first()
# 有则返回字典信息,没有返回None
if chat_session:
return chat_session.to_dict()
return None
# 获取用户的所有会话列表(分页)
def list_sessions(self, user_id: str, page: int = 1, page_size: int = 100) -> dict:
"""
获取用户的会话列表
Args:
user_id: 用户ID
page: 页码
page_size: 每页数量
Returns:
包含总数和会话列表的字典
"""
# 打开数据库只读session
with self.session() as session:
# 查询当前用户的所有会话
query = session.query(ChatSession).filter_by(user_id=user_id)
# 用基类的分页方法返回结构化内容,按更新时间倒序
return self.paginate_query(query, page=page, page_size=page_size,
order_by=desc(ChatSession.updated_at))
# 会话标题修改
def update_session_title(self, session_id: str, user_id: str, title: str) -> dict:
"""
更新会话标题
Args:
session_id: 会话ID
user_id: 用户ID(用于验证权限)
title: 新标题
Returns:
更新后的会话信息字典
"""
# 启动数据库事务
with self.transaction() as session:
# 查询拥有该会话的用户
chat_session = session.query(ChatSession).filter_by(id=session_id, user_id=user_id).first()
# 会话不存在则抛异常
if not chat_session:
raise ValueError("Session not found or access denied")
# 更新标题
chat_session.title = title
# 刷新会话对象(update不会提交,refresh刷新对象)
session.refresh(chat_session)
# 返回最新数据
return chat_session.to_dict()
# 删除指定会话
def delete_session(self, session_id: str, user_id: str) -> bool:
"""
删除会话(级联删除消息)
Args:
session_id: 会话ID
user_id: 用户ID(用于验证权限)
Returns:
是否删除成功
"""
# 开启数据库事务
with self.transaction() as session:
# 查询该用户的指定会话
chat_session = session.query(ChatSession).filter_by(id=session_id, user_id=user_id).first()
# 未找到会话则返回False
if not chat_session:
return False
# 删除会话(DB应有级联消息)
session.delete(chat_session)
# 记录日志
self.logger.info(f"已删除聊天会话: {session_id}")
# 返回True表示删除成功
return True
# 删除当前用户的所有会话
def delete_all_sessions(self, user_id: str) -> int:
"""
删除用户的所有会话
Args:
user_id: 用户ID
Returns:
删除的会话数量
"""
# 开启数据库事务
with self.transaction() as session:
# 批量删除本用户所有会话
count = session.query(ChatSession).filter_by(user_id=user_id).delete()
# 记录日志
self.logger.info(f"已删除用户 {user_id} 的 {count} 个聊天会话")
# 返回删除数量
return count
# 添加消息到会话
+ def add_message(self, session_id: str, role: str, content: str, sources: list = None) -> dict:
"""
添加消息到会话
Args:
session_id: 会话ID
role: 角色('user' 或 'assistant')
content: 消息内容
sources: 引用来源列表(可选)
Returns:
消息信息字典
"""
# 开启数据库事务
with self.transaction() as session:
# 构造消息对象
message = ChatMessage(
session_id=session_id,
role=role,
+ content=content,
+ sources=sources
)
# 添加消息到数据库
session.add(message)
# 查询会话对象,用于更新时间/自动生成标题
chat_session = session.query(ChatSession).filter_by(id=session_id).first()
# 如果存在会话对象
if chat_session:
# 更新会话更新时间
chat_session.updated_at = datetime.now()
# 如果是用户发的第一条消息,并且还没标题,则用内容自动命名
if role == 'user' and (not chat_session.title or chat_session.title == "新对话"):
# 会话标题截取前30字符,超长加省略号
title = content[:30] + ('...' if len(content) > 30 else '')
chat_session.title = title
# 刷新确保message有ID
session.flush()
# 刷新消息对象
session.refresh(message)
# 返回消息字典
return message.to_dict()
# 获取会话的全部消息
def get_messages(self, session_id: str, user_id: str = None) -> list:
"""
获取会话的所有消息
Args:
session_id: 会话ID
user_id: 用户ID(可选,用于验证权限)
Returns:
消息列表
"""
# 打开只读session
with self.session() as session:
# 如指定user_id,须先验证此会话是否属于该用户,不属则不给查
if user_id:
chat_session = session.query(ChatSession).filter_by(id=session_id, user_id=user_id).first()
# 如果会话不存在则返回空列表
if not chat_session:
return []
# 查询该会话下所有消息,按创建时间升序排序
messages = session.query(ChatMessage).filter_by(session_id=session_id).order_by(ChatMessage.created_at).all()
# 返回所有消息的字典列表
return [m.to_dict() for m in messages]
# 单例: 聊天会话服务对象
session_service = ChatSessionService()
12.3. rag_service.py #
app/services/rag_service.py
"""
RAG 服务
"""
# 导入日志模块
import logging
# 导入 LangChain 的对话提示模板模块
from langchain_core.prompts import ChatPromptTemplate
# 导入自定义 LLM 工厂
from app.utils.llm_factory import LLMFactory
# 导入设置服务
from app.services.settings_service import settings_service
# 导入类型提示
from typing import List
# 导入 LangChain 的 Document 类型
from langchain_core.documents import Document
# 导入检索服务
from app.services.retrieval_service import retrieval_service
# 设置日志对象
logger = logging.getLogger(__name__)
# 定义 RAGService 类
class RAGService:
"""RAG 服务"""
# 初始化函数
def __init__(self):
"""
初始化服务
Args:
settings: 设置字典,如果为 None 则从数据库读取
"""
# 从设置服务中获取配置信息
self.settings = settings_service.get()
# 定义默认系统消息提示词
default_rag_system_prompt = "你是一个专业的AI助手。请基于文档内容回答问题。"
# 定义默认查询提示词,包含 context 和 question 占位符
default_rag_query_prompt = """文档内容:
{context}
问题:{question}
请基于文档内容回答问题。如果文档中没有相关信息,请明确说明。"""
# 从设置中获取自定义系统消息提示词
rag_system_prompt_text = self.settings.get('rag_system_prompt')
# 如果没有设置,使用默认系统提示词
if not rag_system_prompt_text:
rag_system_prompt_text = default_rag_system_prompt
# 从设置中获取自定义查询提示词
rag_query_prompt_text = self.settings.get('rag_query_prompt')
# 如果没有设置,使用默认查询提示词
if not rag_query_prompt_text:
rag_query_prompt_text = default_rag_query_prompt
# 构建 RAG 的提示模板,包含系统消息和用户查询部分
self.rag_prompt = ChatPromptTemplate.from_messages([
("system", rag_system_prompt_text),
("human", rag_query_prompt_text)
])
# 定义流式问答接口
def ask_stream(self, kb_id: str, question: str):
"""
流式问答接口
Args:
kb_id: 知识库ID
question: 问题
Yields:
流式数据块
"""
# 创建带流式输出能力的 LLM 实例
llm = LLMFactory.create_llm(self.settings)
# 文档过滤后的结果
filtered_docs = self._retrieve_documents(kb_id, question)
# 发送流式开始信号
yield {
"type": "start",
"content": ""
}
# 构造用于传递给 LLM 的上下文字符串,将所有文档整合为字符串
context = "\n\n".join([
f"文档 {i+1} ({doc.metadata.get('doc_name', '未知')}):\n{doc.page_content}"
for i, doc in enumerate(filtered_docs)
])
# 创建 Rag Prompt 到 LLM 的处理链
chain = self.rag_prompt | llm
# 初始化完整答案的字符串
full_answer = ""
# 逐块流式生成答案
for chunk in chain.stream({"context": context, "question": question}):
# 获取当前输出块内容
content = chunk.content
# 如果有内容则累加并 yield 输出内容块
if content:
full_answer += content
yield {
"type": "content",
"content": content
}
+ sources = self._extract_citations(filtered_docs)
# 所有内容输出结束后,发送完成信号和相关元数据
yield {
"type": "done",
"content": "",
+ "sources": sources,
"metadata": {
'kb_id': kb_id,
'question': question,
+ 'retrieved_chunks': len(filtered_docs),
+ 'used_chunks': len(sources)
}
}
+ def _extract_citations(self, docs: List[Document]):
+ """
+ 提取引用信息
+ Args:
+ docs: 检索到的 Document 列表(应该已经按相似度排序)
+ Returns:
+ 引用列表(按相似度从高到低排序)
+ """
+ sources = []
+ for doc in docs:
+ metadata = doc.metadata
# 获取各种分数
+ vector_score = metadata.get('vector_score')
+ keyword_score = metadata.get('keyword_score')
+ rrf_score = metadata.get('rrf_score')
+ rerank_score = metadata.get('rerank_score')
+ doc_name = metadata.get('doc_name', '未知文档')
+ sources.append({
+ 'chunk_id': metadata.get('id') ,
+ 'doc_id': metadata.get('doc_id', ''),
+ 'doc_name': doc_name,
+ 'content': doc.page_content,
+ 'vector_score': round(float(vector_score), 4) if vector_score is not None else None,
+ 'keyword_score': round(float(keyword_score), 4) if keyword_score is not None else None,
+ 'rrf_score': round(float(rrf_score), 4) if rrf_score is not None else None,
+ 'rerank_score': round(float(rerank_score), 4) if rerank_score is not None else None,
+ 'retrieval_type': metadata.get('retrieval_type', 'unknown')
+ })
+ return sources
def _retrieve_documents(self, kb_id: str, question: str) -> List[Document]:
"""
检索文档(公共方法)
Args:
kb_id: 知识库ID
question: 查询文本
Returns:
文档列表
"""
collection_name = f"kb_{kb_id}"
retrieval_mode = self.settings.get('retrieval_mode', 'vector')
if retrieval_mode == 'vector':
docs = retrieval_service.vector_search(
collection_name=collection_name,
query=question
)
elif retrieval_mode == 'keyword':
docs = retrieval_service.keyword_search(
collection_name=collection_name,
query=question
)
elif retrieval_mode == 'hybrid':
docs = retrieval_service.hybrid_search(
collection_name=collection_name,
query=question
)
else:
logger.warning(f"未知的检索模式: {retrieval_mode}, 使用向量检索")
docs = retrieval_service.vector_search(
collection_name=collection_name,
query=question
)
logger.info(f"使用 {retrieval_mode} 模式检索到 {len(docs)} 个文档")
return docs
# 实例化 rag_service,供外部调用
rag_service = RAGService() 12.4. chat.html #
app/templates/chat.html
{% extends "base.html" %}
{% block title %}智能问答 - RAG Lite{% endblock %}
{% block extra_css %}
<style>
.chat-container {
height: calc(100vh - 200px);
display: flex;
gap: 1rem;
}
.chat-sidebar {
width: 280px;
background: #fff;
border-radius: 0.5rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
display: flex;
flex-direction: column;
}
.chat-main {
flex: 1;
background: #fff;
border-radius: 0.5rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
display: flex;
flex-direction: column;
}
.session-list {
flex: 1;
overflow-y: auto;
padding: 0.5rem;
}
.session-item {
padding: 0.75rem;
margin-bottom: 0.5rem;
border-radius: 0.5rem;
cursor: pointer;
transition: background-color 0.2s;
position: relative;
}
.session-item:hover {
background-color: #f8f9fa;
}
.session-item.active {
background-color: #e3f2fd;
border-left: 3px solid #0d6efd;
}
.session-item .session-title {
font-weight: 500;
margin-bottom: 0.25rem;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.session-item .session-time {
font-size: 0.75rem;
color: #6c757d;
}
.session-item .session-delete {
position: absolute;
top: 0.5rem;
right: 0.5rem;
opacity: 0;
transition: opacity 0.2s;
}
.session-item:hover .session-delete {
opacity: 1;
}
.chat-messages {
flex: 1;
overflow-y: auto;
padding: 1rem;
scroll-behavior: smooth;
}
.chat-message {
padding: 1rem;
margin-bottom: 1rem;
border-radius: 0.5rem;
}
.chat-question {
background-color: #e3f2fd;
}
.chat-answer {
background-color: #f5f5f5;
}
.chat-input-area {
padding: 1rem;
border-top: 1px solid #dee2e6;
}
.empty-state {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 100%;
color: #6c757d;
}
</style>
{% endblock %}
{% block content %}
<div class="chat-container">
<!-- 左侧:会话管理 -->
<div class="chat-sidebar">
<div class="p-3 border-bottom">
<div class="d-flex justify-content-between align-items-center mb-3">
<h6 class="mb-0"><i class="bi bi-chat-left-text"></i> 聊天会话</h6>
<button class="btn btn-sm btn-primary" onclick="createNewSession()">
<i class="bi bi-plus"></i> 新建
</button>
</div>
<button class="btn btn-sm btn-outline-danger w-100" onclick="clearAllSessions()">
<i class="bi bi-trash"></i> 清空所有
</button>
</div>
<div class="session-list" id="sessionList">
<div class="text-center text-muted py-5">
<i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
<p class="mt-2 small">暂无会话</p>
</div>
</div>
</div>
<!-- 右侧:对话页面 -->
<div class="chat-main">
<div class="p-3 border-bottom">
<div class="d-flex justify-content-between align-items-center">
<h6 class="mb-0"><i class="bi bi-chat-dots"></i> 对话</h6>
<select class="form-select form-select-sm" id="kbSelect" style="width: 200px;" onchange="onKbChange()">
<option value="">-- 选择知识库 --</option>
{% for kb in knowledgebases %}
<option value="{{ kb.id }}">{{ kb.name }}</option>
{% endfor %}
</select>
</div>
</div>
<div class="chat-messages" id="chatMessages">
<div class="empty-state">
<i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
<p class="mt-2">开始提问吧!</p>
</div>
</div>
<div class="chat-input-area">
<form id="chatForm" onsubmit="askQuestion(event)">
<div class="mb-2">
<textarea class="form-control" id="questionInput" rows="2"
placeholder="输入您的问题..." required></textarea>
</div>
<div class="d-flex justify-content-end">
<button type="submit" class="btn btn-primary" id="submitBtn">
<i class="bi bi-send"></i> 发送
</button>
</div>
</form>
</div>
</div>
</div>
{% endblock %}
{% block extra_js %}
<script>
// 当前会话的ID,初始为null,表示暂未选择任何会话
let currentSessionId = null;
// 会话列表,初始为空数组,用于存储所有的会话对象
let sessions = [];
// 当前知识库的ID,初始为null,表示暂未选择任何知识库
let currentKbId = null;
// 为字符串进行HTML转义,防止XSS攻击
function escapeHtml(text) {
// 创建一个div元素作为容器
const div = document.createElement('div');
// 将待转义文本设置为div的textContent(自动完成转义)
div.textContent = text;
// 返回转义后的HTML内容
return div.innerHTML;
}
// 滚动消息框到底部
function scrollToBottom() {
// 获取聊天消息区域元素
const chatMessages = document.getElementById('chatMessages');
// 设置滚动条位置到最底部
chatMessages.scrollTop = chatMessages.scrollHeight;
}
// 定义一个用于渲染Markdown文本为HTML的函数
function renderMarkdown(text) {
// 判断marked库是否已加载且能使用解析方法
if (typeof marked !== 'undefined' && marked.parse) {
try {
// 尝试用marked库将Markdown文本解析成HTML
return marked.parse(text);
} catch (e) {
// 如果解析出错,则进行HTML转义并换行
return escapeHtml(text).replace(/\n/g, '<br>');
}
}
// 如果marked库不可用,直接做HTML转义并处理换行
return escapeHtml(text).replace(/\n/g, '<br>');
}
// 将Markdown内容渲染到指定元素(支持降级为普通文本)
function renderMarkdownToElement(element, text) {
// 若未提供内容,显示思考中图标
if (!text) {
element.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
return;
}
// 优先判断marked库是否可用(渲染markdown)
if (typeof marked !== 'undefined' && marked.parse) {
try {
// 使用marked进行markdown转html
element.innerHTML = marked.parse(text);
} catch (e) {
// 渲染失败则退化为转义+换行
element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
}
} else {
// 没有marked库则直接转义+换行显示
element.innerHTML = escapeHtml(text).replace(/\n/g, '<br>');
}
}
// 主函数:处理用户提交问题事件
async function askQuestion(event) {
// 阻止表单默认提交行为(防止页面刷新)
event.preventDefault();
// 获取输入框的用户问题并去除首尾空白
const question = document.getElementById('questionInput').value.trim();
// 获取消息显示区域元素
const chatMessages = document.getElementById('chatMessages');
// 若问题为空则直接返回
if (!question) return;
// 如果没有会话,创建新会话
if (!currentSessionId) {
await createNewSession();
}
// 检查并移除初始空白提示(如有)
if (chatMessages.querySelector('.empty-state')) {
chatMessages.innerHTML = '';
}
// 创建用于展示问题的div元素
const questionDiv = document.createElement('div');
// 加上样式: 用户问题
questionDiv.className = 'chat-message chat-question';
// 构建用户气泡内容含图标、文本和时间
questionDiv.innerHTML = `
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-person-circle"></i> 问题:</strong>
<div class="mt-1">${escapeHtml(question)}</div>
</div>
<small class="text-muted">${new Date().toLocaleTimeString()}</small>
</div>
`;
// 显示到对话窗口
chatMessages.appendChild(questionDiv);
// 创建用于显示答案的div元素
const answerDiv = document.createElement('div');
// 答案样式
answerDiv.className = 'chat-message chat-answer';
// 动态生成唯一的答案内容div id(用于唯一标记)
const answerContentId = 'answerContent_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
// 创建存放答案内容的div
const answerContent = document.createElement('div');
// 为markdown渲染容器设置样式和id
answerContent.className = 'mt-2 markdown-content';
answerContent.id = answerContentId;
// 答案div显示机器人图标和时间
answerDiv.innerHTML = `
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-robot"></i> 答案:</strong>
</div>
<small class="text-muted">${new Date().toLocaleTimeString()}</small>
</div>
`;
// 将答案内容div插入到机器人气泡div的内容区
const flexGrowDiv = answerDiv.querySelector('.flex-grow-1');
flexGrowDiv.appendChild(answerContent);
// 插入答案div到消息区域
chatMessages.appendChild(answerDiv);
// 变量记录完整的回答内容
let fullAnswer = '';
// 标记渲染任务是否挂起(防止重复)
let pendingUpdate = false;
// 记录定时器id(去抖动用)
let updateTimer = null;
// 清空输入框内容
document.getElementById('questionInput').value = '';
// 滚动到底
scrollToBottom();
// 定义scheduleRender:将markdown渲染插入到队列合适时机执行
function scheduleRender() {
// 若当前无待渲染任务才安排渲染
if (!pendingUpdate) {
pendingUpdate = true;
// 下一帧渲染
requestAnimationFrame(() => {
// 将答案作为markdown渲染进dom
renderMarkdownToElement(answerContent, fullAnswer);
// 清理pending标识
pendingUpdate = false;
// 渲染后滚动到底部
scrollToBottom();
});
}
}
try {
// 组装API接口地址
const url = currentKbId
? `/api/v1/knowledgebases/${currentKbId}/chat`
: `/api/v1/knowledgebases/chat`;
// 请求后端,发起流式POST请求
const response = await fetch(url, {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({
question: question,
session_id: currentSessionId,
stream: true
})
});
// 非200响应时,抛出异常
if (!response.ok) throw new Error('请求失败');
// 使用ReadableStream reader获取response流
const reader = response.body.getReader();
// 用TextDecoder解码二进制数据
const decoder = new TextDecoder();
// 数据缓冲区(字符串)
let buffer = '';
// 初始展示“思考中...”提示
answerContent.innerHTML = '<i class="bi bi-hourglass-split"></i> 思考中...';
// 重新加载会话列表以更新标题
await loadSessions();
// 不断循环读取服务端推送流内容
while (true) {
// 逐步读取一段数据
const { done, value } = await reader.read();
// 若读完则结束循环
if (done) break;
// 本块新数据解码成字符串,并追加到buffer
buffer += decoder.decode(value, { stream: true });
// 按行切分(多行分别处理)
const lines = buffer.split('\n');
// 最后一行一般为数据残留,不处理,下次拼
buffer = lines.pop() || '';
// 处理每一行数据
for (const line of lines) {
// 筛选有效SSE协议"data: "数据包
if (line.startsWith('data: ')) {
// 去掉前缀,留下json内容
const data = line.slice(6);
// [DONE]信号作为流结束,仅跳过
if (data === '[DONE]') continue;
try {
// 解析JSON数据
const chunk = JSON.parse(data);
// 类型为流起始,清空内容
if (chunk.type === 'start') {
fullAnswer = '';
answerContent.innerHTML = '';
// 类型为内容,追加文本并安排去抖渲染
} else if (chunk.type === 'content') {
fullAnswer += chunk.content;
scheduleRender();
// 类型为done(流传输结束),直接最终渲染所有答案内容
} else if (chunk.type === 'done') {
renderMarkdownToElement(answerContent, fullAnswer);
+ renderSources(chunk.sources, chunk.metadata, answerDiv, currentKbId);
// 错误类型,alert显示报错内容
} else if (chunk.type === 'error') {
answerContent.innerHTML = `<div class="alert alert-danger">${chunk.content}</div>`;
}
} catch (e) {
// JSON解析失败时报console
console.error('解析流数据失败:', e);
}
}
}
}
} catch (error) {
// 通信异常、Fetch错误等: 显示错误气泡
answerContent.innerHTML = `<div class="alert alert-danger"><strong>错误:</strong> ${error.message}</div>`;
}
// 结束后确保界面滚动到底
scrollToBottom();
}
+// 渲染引用来源到 DOM 元素
+function renderSources(sources, metadata, answerDiv, currentKbId) {
+ const html = generateSourcesHtml(sources, metadata, currentKbId);
+ if (html) {
+ const tempDiv = document.createElement('div');
+ tempDiv.innerHTML = html;
+ while (tempDiv.firstChild) {
+ answerDiv.appendChild(tempDiv.firstChild);
+ }
+ }
+}
+// 生成引用来源的 HTML 字符串
+function generateSourcesHtml(sources, metadata, currentKbId) {
+ // 如果 sources 存在、且是数组且长度大于0
+ if (sources && Array.isArray(sources) && sources.length > 0) {
+ // 对每个 source 逐个生成 html
+ const sourcesHtml = sources.map((source, idx) => {
+ // 初始化分数数组
+ const scores = [];
+ // 定义辅助函数: 安全转换分数字段
+ const getScore = (score) => {
+ // 如果为空则返回 null
+ if (score === null || score === undefined) return null;
+ // 字符串转换为浮点数,否则直接返回
+ const num = typeof score === 'string' ? parseFloat(score) : score;
+ // 非数值返回 null,否则返回数字
+ return isNaN(num) ? null : num;
+ };
+ // 获取检索类型,默认为 unknown
+ const retrievalType = source.retrieval_type || 'unknown';
+ // 是否为 vector 检索
+ const isVector = retrievalType.includes('vector');
+ // 是否为 keyword 检索
+ const isKeyword = retrievalType.includes('keyword');
+ // 是否为 hybrid 混合检索
+ const isHybrid = retrievalType.includes('hybrid');
+ // 如果是仅向量检索并且不是混合
+ if (isVector && !isHybrid) {
+ // 获取向量分数
+ const vectorScore = getScore(source.vector_score);
+ // 有效且大于等于0则添加 badge
+ if (vectorScore !== null && vectorScore >= 0) {
+ scores.push(`<span class="badge bg-info me-1">向量相似度: ${vectorScore.toFixed(4)}</span>`);
+ }
+ }
+ // 如果是仅关键词检索并且不是混合
+ if (isKeyword && !isHybrid) {
+ // 获取关键词分数
+ const keywordScore = getScore(source.keyword_score);
+ // 有效且大于等于0则添加 badge
+ if (keywordScore !== null && keywordScore >= 0) {
+ scores.push(`<span class="badge bg-success me-1">关键词相似度: ${keywordScore.toFixed(4)}</span>`);
+ }
+ }
+ // 如果是混合检索
+ if (isHybrid) {
+ // 获取向量分数
+ const vectorScore = getScore(source.vector_score);
+ // 有效且大于等于0则添加 badge
+ if (vectorScore !== null && vectorScore >= 0) {
+ scores.push(`<span class="badge bg-info me-1">向量相似度: ${vectorScore.toFixed(4)}</span>`);
+ }
+ // 获取关键词分数
+ const keywordScore = getScore(source.keyword_score);
+ // 有效且大于等于0则添加 badge
+ if (keywordScore !== null && keywordScore >= 0) {
+ scores.push(`<span class="badge bg-success me-1">关键词相似度: ${keywordScore.toFixed(4)}</span>`);
+ }
+ // 获取混合 RRF 分数
+ const rrfScore = getScore(source.rrf_score);
+ // 有效且大于等于0则添加 badge
+ if (rrfScore !== null && rrfScore >= 0) {
+ scores.push(`<span class="badge bg-primary me-1">RRF 分数: ${rrfScore.toFixed(4)}</span>`);
+ }
+ }
+ // 获取重排序分数,并显示
+ const rerankScore = getScore(source.rerank_score);
+ // 有效且大于等于0则添加 badge
+ if (rerankScore !== null && rerankScore >= 0) {
+ scores.push(`<span class="badge bg-danger me-1">重排序分数: ${rerankScore.toFixed(4)}</span>`);
+ }
+ // 返回单个引用块的 html 字符串
+ return `
+ <div class="source-item mb-3 p-2 border rounded">
+ <div class="fw-bold mb-1">
+ <i class="bi bi-file-earmark"></i> ${idx + 1}. ${source.doc_name || '未知文档'}
+ </div>
+ <div class="mb-2">
+ ${scores.length > 0 ? scores.join('') : '<span class="text-muted small">暂无分数信息</span>'}
+ </div>
+ <div class="small text-muted">
+ <strong>分块内容:</strong>
+ <div class="mt-1 p-2 bg-white rounded" style="max-height: 150px; overflow-y: auto; font-size: 0.9em;">
+ ${escapeHtml(source.content || '')}
+ </div>
+ </div>
+ </div>
+ `;
+ // 连接所有的引用来源 HTML
+ }).join('');
+ // 返回所有引用来源的 html 块
+ return `
+ <div class="mt-3 p-3 bg-light rounded">
+ <strong><i class="bi bi-link-45deg"></i> 引用来源 (${sources.length} 个):</strong>
+ <div class="mt-2">
+ ${sourcesHtml}
+ </div>
+ </div>
+ `;
+ // 如果传入了知识库ID但是没有 sources
+ } else if (currentKbId) {
+ // 从元数据获取检索块数,没有则为0
+ const retrievedChunks = (metadata && metadata.retrieved_chunks) || 0;
+ // 定义提示信息
+ let message = '';
+ // 若未检索到任何文档,提示排查信息
+ if (retrievedChunks === 0) {
+ message = '未检索到任何文档。请检查:1) 知识库中是否有文档;2) 文档是否已处理完成;3) 检索阈值是否设置过高。';
+ // 如果检索到了但没引用,提示相关性低
+ } else {
+ message = '未找到相关引用来源。可能是检索到的文档与问题相关性较低。';
+ }
+ // 返回警告气泡 html
+ return `
+ <div class="mt-3 p-2 bg-warning bg-opacity-10 rounded">
+ <small class="text-warning">
+ <i class="bi bi-exclamation-triangle"></i> ${message}
+ </small>
+ </div>
+ `;
+ }
+ // 默认返回空字符串
+ return '';
+}
// 清空所有会话
// 异步函数,用于清空所有会话
async function clearAllSessions() {
// 弹窗确认操作,若取消则直接返回
if (!confirm('确定要清空所有会话吗?此操作不可恢复!')) return;
try {
// 发送DELETE请求到服务器,删除所有会话
const response = await fetch('/api/v1/knowledgebases/sessions', {
method: 'DELETE'
});
// 获取服务器返回的JSON结果
const result = await response.json();
// 如果接口返回成功
if (result.code === 200) {
// 当前会话ID置空
currentSessionId = null;
// 清空聊天消息内容
clearChatMessages();
// 重新加载会话列表
await loadSessions();
}
} catch (error) {
// 捕获异常并弹窗提示失败原因
alert('清空会话失败: ' + error.message);
}
}
// 格式化时间
// 用于格式化时间字符串为友好显示
function formatTime(timeStr) {
// 如果无有效时间,直接返回空字符串
if (!timeStr) return '';
// 将字符串转换为Date对象
const date = new Date(timeStr);
// 获取当前时间
const now = new Date();
// 计算时间差(毫秒)
const diff = now - date;
// 换算为分钟数
const minutes = Math.floor(diff / 60000);
// 换算为小时数
const hours = Math.floor(diff / 3600000);
// 换算为天数
const days = Math.floor(diff / 86400000);
// 1分钟内显示“刚刚”
if (minutes < 1) return '刚刚';
// 1小时内显示“x分钟前”
if (minutes < 60) return `${minutes}分钟前`;
// 24小时内显示“x小时前”
if (hours < 24) return `${hours}小时前`;
// 7天内显示“x天前”
if (days < 7) return `${days}天前`;
// 超过7天显示具体日期
return date.toLocaleDateString();
}
// 渲染消息
function renderMessages(messages) {
const chatMessages = document.getElementById('chatMessages');
if (messages.length === 0) {
chatMessages.innerHTML = `
<div class="empty-state">
<i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
<p class="mt-2">开始提问吧!</p>
</div>
`;
return;
}
chatMessages.innerHTML = messages.map(msg => {
if (msg.role === 'user') {
return `
<div class="chat-message chat-question">
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-person-circle"></i> 问题:</strong>
<div class="mt-1">${escapeHtml(msg.content)}</div>
</div>
<small class="text-muted">${formatTime(msg.created_at)}</small>
</div>
</div>
`;
} else {
+ const sourcesHtml = generateSourcesHtml(msg.sources, msg.metadata, null);
return `
<div class="chat-message chat-answer">
<div class="d-flex justify-content-between align-items-start">
<div class="flex-grow-1">
<strong><i class="bi bi-robot"></i> 答案:</strong>
<div class="mt-2 markdown-content">${renderMarkdown(msg.content)}</div>
+ ${sourcesHtml}
</div>
<small class="text-muted">${formatTime(msg.created_at)}</small>
</div>
</div>
`;
}
}).join('');
scrollToBottom();
}
// 加载会话
// 定义一个异步函数 loadSession,用于加载指定 sessionId 的会话
async function loadSession(sessionId) {
// 异常捕获,防止请求或处理过程抛错
try {
// 发送 GET 请求,获取指定会话的详情数据
const response = await fetch(`/api/v1/knowledgebases/sessions/${sessionId}`);
// 将返回的响应数据解析为 JSON 格式
const result = await response.json();
// 如果接口请求成功,即 code 等于 200
if (result.code === 200) {
// 设置当前会话 ID 为传入的 sessionId
currentSessionId = sessionId;
// 获取会话详情数据
const session = result.data.session;
// 获取该会话包含的消息列表,如果没有则为 []
const messages = result.data.messages || [];
// 调用方法将消息渲染到页面
renderMessages(messages);
// 重新加载全部会话列表,刷新左侧会话栏
await loadSessions();
}
} catch (error) {
// 捕获异常并弹窗提示加载失败的原因
alert('加载会话失败: ' + error.message);
}
}
// 渲染会话列表
// 将sessions数组渲染到界面左侧会话列表
function renderSessions() {
// 获取会话列表的DOM节点
const sessionList = document.getElementById('sessionList');
// 如果没有任何会话,显示为空状态
if (sessions.length === 0) {
sessionList.innerHTML = `
<div class="text-center text-muted py-5">
<i class="bi bi-chat-left-text" style="font-size: 2rem;"></i>
<p class="mt-2 small">暂无会话</p>
</div>
`;
return;
}
// 遍历sessions并渲染为会话项
sessionList.innerHTML = sessions.map(session => `
<div class="session-item ${session.id === currentSessionId ? 'active' : ''}"
onclick="loadSession('${session.id}')">
<button class="btn btn-sm btn-link text-danger p-0 session-delete"
onclick="event.stopPropagation(); deleteSession('${session.id}')">
<i class="bi bi-x-lg"></i>
</button>
<div class="session-title">${session.title || '新对话'}</div>
<div class="session-time">${formatTime(session.updated_at)}</div>
</div>
`).join('');
}
// 清空聊天消息
// 将聊天内容区重置为初始状态
function clearChatMessages() {
document.getElementById('chatMessages').innerHTML = `
<div class="empty-state">
<i class="bi bi-chat-left-text" style="font-size: 3rem;"></i>
<p class="mt-2">开始提问吧!</p>
</div>
`;
}
// 加载会话列表
// 异步加载会话列表并渲染到界面
async function loadSessions() {
try {
// 发送GET请求获取会话列表
const response = await fetch('/api/v1/knowledgebases/sessions');
// 解析JSON结果
const result = await response.json();
// 如果接口返回成功
if (result.code === 200) {
// 保存会话数组
sessions = result.data.items || [];
// 渲染会话列表
renderSessions();
}
} catch (error) {
// 捕捉异常并在控制台打印
console.error('加载会话列表失败:', error);
}
}
// 创建新会话
// 异步请求服务端创建新会话
async function createNewSession() {
try {
// 发送POST请求创建新会话
const response = await fetch('/api/v1/knowledgebases/sessions', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({})
});
// 解析返回结果
const result = await response.json();
// 如果创建成功
if (result.code === 200) {
// 切换到新会话
currentSessionId = result.data.id;
// 重新加载会话列表
await loadSessions();
// 清空聊天内容
clearChatMessages();
}
} catch (error) {
// 捕获异常并弹窗提示
alert('创建会话失败: ' + error.message);
}
}
// 监听 DOMContentLoaded 事件,确保页面元素加载完成后执行
document.addEventListener('DOMContentLoaded', function() {
// 加载会话列表
loadSessions();
// 判断 marked.js 是否已经加载
if (typeof marked !== 'undefined') {
// 配置 marked.js 的渲染选项
marked.setOptions({
// 启用软换行
breaks: true,
// 启用 Github Flavored Markdown
gfm: true,
// 禁用标题自动生成 id
headerIds: false,
// 禁用混淆处理
mangle: false
});
}
});
// 知识库选择变化处理函数
async function onKbChange() {
// 更新全局当前知识库ID变量
currentKbId = document.getElementById('kbSelect').value;;
// 如果已存在会话,则切换知识库时重新新建一个会话
if (currentSessionId) {
await createNewSession();
}
// 启用发送按钮(普通聊天和知识库聊天均可提问)
document.getElementById('submitBtn').disabled = false;
}
</script>
{% endblock %}