1.本章目标 #
本章将介绍文档管理模块的设计与实现,包括以下目标:
- 掌握文档管理模块的整体架构与功能定位。
- 了解文档相关模型(如
document.py)的定义与组织方式。 - 熟悉文档管理相关视图及接口的实现思路。
- 学会如何上传、删除、查询和维护文档文件。
- 理解文档与知识库及其它模块的关联关系。
- 掌握文档管理流程和存储机制,便于后续扩展和运维。
2.目录结构 #
# 项目根目录
rag-lite/
# 应用主目录
├── app/
# Flask 蓝图目录,用于组织视图函数
│ ├── blueprints/
# 蓝图初始化文件,定义蓝图包
│ │ ├── __init__.py
# 用户认证相关视图
│ │ ├── auth.py
# 文档管理相关视图
│ │ ├── document.py
# 知识库相关视图
│ │ ├── knowledgebase.py
# 系统设置相关视图
│ │ ├── settings.py
# 通用工具函数
│ │ └── utils.py
# ORM 数据模型目录
│ ├── models/
# 模型包初始化
│ │ ├── __init__.py
# 基础模型定义
│ │ ├── base.py
# 聊天消息模型
│ │ ├── chat_message.py
# 聊天会话模型
│ │ ├── chat_session.py
# 文档模型
│ │ ├── document.py
# 知识库模型
│ │ ├── knowledgebase.py
# 系统设置模型
│ │ ├── settings.py
# 用户模型
│ │ └── user.py
# 服务层目录
│ ├── services/
# 存储相关服务
│ │ ├── storage/
# 存储包初始化
│ │ │ ├── __init__.py
# 存储基类
│ │ │ ├── base.py
# 存储工厂类
│ │ │ ├── factory.py
# 本地存储实现
│ │ │ ├── local_storage.py
# minio 存储实现
│ │ │ └── minio_storage.py
# 向量数据库相关服务
│ │ ├── vectordb/
# 向量数据库包初始化
│ │ │ ├── __init__.py
# 向量数据库基类
│ │ │ ├── base.py
# Chroma 数据库实现
│ │ │ ├── chroma.py
# 向量数据库工厂类
│ │ │ ├── factory.py
# Milvus 数据库实现
│ │ │ └── milvus.py
# 服务层基类
│ │ ├── base_service.py
# 文档服务
│ │ ├── document_service.py
# 知识库相关服务
│ │ ├── knowledgebase_service.py
# 文档解析服务
│ │ ├── parser_service.py
# 设置相关服务
│ │ ├── settings_service.py
# 存储服务
│ │ ├── storage_service.py
# 用户服务
│ │ ├── user_service.py
# 向量相关服务
│ │ └── vector_service.py
# 静态文件目录(前端静态资源)
│ ├── static/
# 前端模板目录
│ ├── templates/
# 基础模板
│ │ ├── base.html
# 首页模板
│ │ ├── home.html
# 知识库详情页模板
│ │ ├── kb_detail.html
# 知识库列表页模板
│ │ ├── kb_list.html
# 登录页模板
│ │ ├── login.html
# 注册页模板
│ │ ├── register.html
# 设置页模板
│ │ └── settings.html
# 工具函数目录
│ ├── utils/
# 认证相关工具
│ │ ├── auth.py
# 数据库工具
│ │ ├── db.py
# 文档加载工具
│ │ ├── document_loader.py
# 向量模型工厂
│ │ ├── embedding_factory.py
# 日志工具
│ │ ├── logger.py
# 模型配置
│ │ ├── models_config.py
# 文本分割工具
│ │ └── text_splitter.py
# app 包初始化文件
│ ├── __init__.py
# 应用配置文件
│ └── config.py
# Chroma 向量数据库数据存储目录
├── chroma_db/
# 日志目录
├── logs/
# 主日志文件
│ └── rag_lite.log
# 持久化存储目录
├── storages/
# 卷挂载目录
├── volumes/
# Milvus 数据目录
│ ├── milvus/
# Minio 数据目录
│ └── minio/
# Docker Compose 配置文件
├── docker-compose.yml
# 文档管理脚本
├── documents.py
# 应用主入口
├── main.py
# Python 项目配置文件
├── pyproject.toml3.知识库详情 #
3.1. kb_detail.html #
app/templates/kb_detail.html
{% extends "base.html" %}
{% block title %}{{ kb.name }} - RAG Lite{% endblock %}
{% block content %}
<div class="row">
<div class="col-12">
<nav aria-label="breadcrumb">
<ol class="breadcrumb">
<li class="breadcrumb-item"><a href="/">首页</a></li>
<li class="breadcrumb-item"><a href="/kb">知识库</a></li>
<li class="breadcrumb-item active">{{ kb.name }}</li>
</ol>
</nav>
</div>
</div>
<div class="row">
<div class="col-12">
<div class="card">
<div class="card-header d-flex justify-content-between align-items-center">
<h5 class="mb-0"><i class="bi bi-file-earmark"></i> 文档管理</h5>
<button class="btn btn-sm btn-primary" data-bs-toggle="modal" data-bs-target="#uploadModal">
<i class="bi bi-upload"></i> 上传文档
</button>
</div>
<div class="card-body">
</div>
</div>
</div>
</div>
{% endblock %}
{% block extra_js %}
<script>
</script>
{% endblock %}3.2. knowledgebase.py #
app/blueprints/knowledgebase.py
# 知识库相关路由(视图 + API)
"""
知识库相关路由(视图 + API)
"""
# 导入Flask中的Blueprint和request
+from flask import Blueprint,request,render_template,send_file,abort,redirect, url_for
# 使用BytesIO将图片数据包装为文件流
from io import BytesIO
# 导入logging模块
import logging
# 导入mimetypes、os模块用于类型判断
import mimetypes
# 导入os模块用于路径操作
import os
# 导入自定义工具函数:异常处理装饰器、错误响应、成功响应
from app.blueprints.utils import (handle_api_error,error_response,success_response,get_current_user_or_error)
# 导入知识库服务
from app.services.knowledgebase_service import kb_service
# 导入认证工具函数:登录认证装饰器、获取当前用户、API登录认证装饰器
from app.utils.auth import login_required, get_current_user,api_login_required
# 导入分页工具函数
from app.blueprints.utils import (get_pagination_params,check_ownership)
# 导入存储服务
from app.services.storage_service import storage_service
# 配置logger
logger = logging.getLogger(__name__)
# 创建Blueprint实例,注册在Flask应用下
bp = Blueprint('knowledgebase', __name__)
# 定义路由:POST请求到/api/v1/kb
@bp.route('/api/v1/kb', methods=['POST'])
# 应用API登录认证装饰器
@api_login_required
# 应用自定义异常处理装饰器
@handle_api_error
# 定义创建知识库的视图函数
# 定义用于创建知识库的API接口
def api_create():
# 接口用途说明文档字符串
"""创建知识库"""
# 获取当前用户,如未登录则返回错误响应
current_user, err = get_current_user_or_error()
if err:
return err
# 检查请求是否为multipart/form-data(用于文件上传的表单方式)
if request.content_type and 'multipart/form-data' in request.content_type:
# 从表单数据中获取知识库名称
name = request.form.get('name')
# 如果未传入name参数,返回错误
if not name:
return error_response("name is required", 400)
# 获取描述字段,没有则为None
description = request.form.get('description') or None
# 获取分块大小,默认为512
chunk_size = int(request.form.get('chunk_size', 512))
# 获取分块重叠,默认为50
chunk_overlap = int(request.form.get('chunk_overlap', 50))
# 设置封面图片数据变量初值为None
cover_image_data = None
# 设置封面图片文件名变量初值为None
cover_image_filename = None
# 判断请求中是否包含'cover_image'文件
if 'cover_image' in request.files:
# 获取上传的封面图片文件对象
cover_file = request.files['cover_image']
# 如果上传的文件存在且有文件名
if cover_file and cover_file.filename:
# 读取文件内容为二进制数据
cover_image_data = cover_file.read()
# 获取上传文件的文件名
cover_image_filename = cover_file.filename
# 记录封面图片上传的信息到日志,包括文件名、字节大小和内容类型
logger.info(f"收到新知识库的封面图片上传: 文件名={cover_image_filename}, 大小={len(cover_image_data)} 字节, 内容类型={cover_file.content_type}")
else:
# 如果是json请求数据(向后兼容旧用法)
data = request.get_json()
# 判断是否存在name字段,不存在则报错
if not data or 'name' not in data:
return error_response("name is required", 400)
# 获取知识库名称
name = data['name']
# 获取描述
description = data.get('description')
# 获取分块大小,默认为512
chunk_size = data.get('chunk_size', 512)
# 获取分块重叠,默认为50
chunk_overlap = data.get('chunk_overlap', 50)
# 设置封面图片数据变量初值为None
cover_image_data = None
# 设置封面图片文件名变量初值为None
cover_image_filename = None
# 调用知识库服务,创建知识库,返回知识库信息字典
kb_dict = kb_service.create(
name=name, # 知识库名称
user_id=current_user['id'], # 用户ID
description=description, # 知识库描述
chunk_size=chunk_size, # 分块大小
chunk_overlap=chunk_overlap, # 分块重叠
cover_image_data=cover_image_data, # 封面图片数据
cover_image_filename=cover_image_filename # 封面图片文件名
)
# 返回成功响应,包含知识库信息
return success_response(kb_dict)
# 注册'/kb'路由,处理GET请求,显示知识库列表页面
@bp.route('/kb')
# 要求登录用户才能访问该视图,用于Web页面
@login_required
# 定义kb_list函数,渲染知识库列表页面
def kb_list():
# 设置本函数用途说明(文档字符串)
"""知识库列表页面"""
# 获取当前登录用户信息
current_user = get_current_user()
# 获取分页参数(页码和每页大小),最大每页100
page, page_size = get_pagination_params(max_page_size=100)
# 获取搜索和排序参数
search = request.args.get('search', '').strip() or None
sort_by = request.args.get('sort_by', 'created_at')
sort_order = request.args.get('sort_order', 'desc')
# 验证排序参数
if sort_by not in ['created_at', 'name', 'updated_at']:
sort_by = 'created_at'
if sort_order not in ['asc', 'desc']:
sort_order = 'desc'
# 调用知识库服务,获取分页后的知识库列表结果
result = kb_service.list(
user_id=current_user['id'], # 用户ID
page=page, # 页码
page_size=page_size, # 每页大小
search=search, # 搜索关键词
sort_by=sort_by, # 排序字段
sort_order=sort_order # 排序方向
)
# 渲染知识库列表页面模板,传递数据,包括知识库列表、分页信息
return render_template('kb_list.html',
kbs=result['items'],
pagination=result,
search=search or '',
sort_by=sort_by,
sort_order=sort_order)
# 注册DELETE方法的API路由,用于删除知识库
@bp.route('/api/v1/kb/<kb_id>', methods=['DELETE'])
# 要求API登录
@api_login_required
# 处理API错误的装饰器
@handle_api_error
def api_delete(kb_id):
"""删除知识库"""
# 获取当前用户信息,如果未登录则返回错误
current_user, err = get_current_user_or_error()
if err:
return err
# 根据知识库ID获取知识库信息
kb_dict = kb_service.get_by_id(kb_id)
# 如果知识库不存在,返回404错误
if not kb_dict:
return error_response("未找到知识库", 404)
# 验证当前用户是否拥有该知识库的操作权限
has_permission, err = check_ownership(kb_dict['user_id'], current_user['id'], "knowledgebase")
if not has_permission:
return err
# 调用服务删除知识库
success = kb_service.delete(kb_id)
# 如果删除失败,返回404错误
if not success:
return error_response("未找到知识库", 404)
# 返回删除成功的响应
return success_response("知识库删除成功")
# 注册PUT方法的API路由(用于更新知识库)
@bp.route('/api/v1/kb/<kb_id>', methods=['PUT'])
# 要求API登录
@api_login_required
# 捕获API内部错误的装饰器
@handle_api_error
def api_update(kb_id):
# 定义API用于更新知识库(含封面图片)
"""更新知识库(支持封面图片更新)"""
# 获取当前登录用户信息,如果未登录则返回错误响应
current_user, err = get_current_user_or_error()
if err:
return err
# 获取指定ID的知识库记录,验证其是否存在
kb_dict = kb_service.get_by_id(kb_id)
if not kb_dict:
return error_response("未找到知识库", 404)
# 校验当前用户是否有操作该知识库的权限
has_permission, err = check_ownership(kb_dict['user_id'], current_user['id'], "knowledgebase")
if not has_permission:
return err
# 判断请求内容类型是否为multipart/form-data(一般用于带文件上传的表单提交)
if request.content_type and 'multipart/form-data' in request.content_type:
# 从表单中获取普通字段
name = request.form.get('name')
description = request.form.get('description') or None
chunk_size = request.form.get('chunk_size')
chunk_overlap = request.form.get('chunk_overlap')
# 初始化封面图片相关变量
cover_image_data = None
cover_image_filename = None
# 获得delete_cover字段(类型字符串,需判断是否为'true')
delete_cover = request.form.get('delete_cover') == 'true'
# 如果有上传封面图片,则读取文件内容
if 'cover_image' in request.files:
cover_file = request.files['cover_image']
if cover_file and cover_file.filename:
cover_image_data = cover_file.read()
cover_image_filename = cover_file.filename
# 记录上传日志
logger.info(f"收到知识库 {kb_id} 的封面图片上传: 文件名={cover_image_filename}, 大小={len(cover_image_data)} 字节, 内容类型={cover_file.content_type}")
# 构建待更新的数据
update_data = {}
if name:
update_data['name'] = name
if description is not None:
update_data['description'] = description
if chunk_size:
update_data['chunk_size'] = int(chunk_size)
if chunk_overlap:
update_data['chunk_overlap'] = int(chunk_overlap)
else:
# 非表单上传,则按JSON结构解析请求内容
data = request.get_json()
# 如果请求体是空的,直接返回错误
if not data:
return error_response("请求体不能为空", 400)
# 构建可更新的数据字典
update_data = {}
if 'name' in data:
update_data['name'] = data['name']
if 'description' in data:
update_data['description'] = data.get('description')
if 'chunk_size' in data:
update_data['chunk_size'] = data['chunk_size']
if 'chunk_overlap' in data:
update_data['chunk_overlap'] = data['chunk_overlap']
# JSON请求时,cover_image相关变量置空
cover_image_data = None
cover_image_filename = None
delete_cover = data.get('delete_cover', False)
# 调用服务更新知识库,传入各字段及封面参数
updated_kb = kb_service.update(
kb_id=kb_id, # 知识库ID
cover_image_data=cover_image_data, # 封面图片的二进制内容
cover_image_filename=cover_image_filename, # 封面图片文件名
delete_cover=delete_cover, # 是否删除封面图片
**update_data # 其它可变字段
)
# 更新后如果找不到,返回404
if not updated_kb:
return error_response("未找到知识库", 404)
# 更新成功后,将最新的知识库数据返回给前端
return success_response(updated_kb, "知识库更新成功")
# 定义路由,获取指定知识库ID的封面图片,仅限登录用户访问
@bp.route('/kb/<kb_id>/cover')
@login_required
def kb_cover(kb_id):
"""获取知识库封面图片"""
# 获取当前已登录用户的信息
current_user = get_current_user()
# 根据知识库ID从知识库服务获取对应的知识库信息
kb = kb_service.get_by_id(kb_id)
# 检查知识库是否存在
if not kb:
# 如果知识库不存在,记录警告日志
logger.warning(f"知识库不存在: {kb_id}")
abort(404)
# 检查是否有权限访问(只能查看自己的知识库封面)
if kb.get('user_id') != current_user['id']:
# 如果不是当前用户的知识库,记录警告日志
logger.warning(f"用户 {current_user['id']} 尝试访问知识库 {kb_id} 的封面,但该知识库属于用户 {kb.get('user_id')}")
abort(403)
# 获取知识库的封面图片路径
cover_path = kb.get('cover_image')
# 检查是否有封面图片
if not cover_path:
# 如果没有封面,记录调试日志
logger.debug(f"知识库 {kb_id} 没有封面图片")
abort(404)
try:
# 通过存储服务下载封面图片数据
image_data = storage_service.download_file(cover_path)
# 如果未能获取到图片数据,记录错误日志并返回404
if not image_data:
logger.error(f"从路径下载封面图片失败: {cover_path}")
abort(404)
# 根据文件扩展名判断图片MIME类型
file_ext = os.path.splitext(cover_path)[1].lower()
# 自定义映射,优先根据文件扩展名判断图片MIME类型
mime_type_map = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp'
}
# 优先根据自定义映射获取MIME类型
mime_type = mime_type_map.get(file_ext)
if not mime_type:
# 如果没有命中自定义映射,则使用mimetypes猜测类型
mime_type, _ = mimetypes.guess_type(cover_path)
if not mime_type:
# 如果还未识别出类型,则默认用JPEG
mime_type = 'image/jpeg'
# 通过send_file响应图片数据和MIME类型,不以附件形式发送
return send_file(
BytesIO(image_data),#图片数据
mimetype=mime_type,#MIME类型
as_attachment=False#不以附件形式发送
)
except FileNotFoundError as e:
# 捕获文件未找到异常,记录错误日志
logger.error(f"封面图片文件未找到: {cover_path}, 错误: {e}")
abort(404)
except Exception as e:
# 捕获其他未预期异常,记录错误日志(包含堆栈信息)
logger.error(f"提供知识库 {kb_id} 的封面图片时出错, 路径: {cover_path}, 错误: {e}", exc_info=True)
abort(404)
# 设置路由,访问 /kb/<kb_id> 时触发此视图函数
+@bp.route('/kb/<kb_id>')
# 要求用户登录后才能访问该视图
+@login_required
# 定义知识库详情页视图函数,接收 kb_id 作为参数
+def kb_detail(kb_id):
# 知识库详情页面(文档管理)
+ """知识库详情页面(文档管理)"""
# 根据 kb_id 查询知识库信息
+ kb = kb_service.get_by_id(kb_id)
# 如果未找到知识库,则重定向回知识库列表页面
+ if not kb:
+ return redirect(url_for('knowledgebase.kb_list'))
# 渲染知识库详情模板,将知识库对象、文档列表和分页信息传递给模板
+ return render_template('kb_detail.html', kb=kb)3.3. kb_list.html #
app/templates/kb_list.html
{% extends "base.html" %}
{% block title %}知识库管理 - RAG Lite{% endblock %}
{% block content %}
<style>
@media (min-width: 992px) {
#kbList > div {
flex: 0 0 20%;
max-width: 20%;
}
}
</style>
<div class="row">
<div class="col-12">
<nav aria-label="breadcrumb" class="mb-3">
<ol class="breadcrumb">
<li class="breadcrumb-item"><a href="/">首页</a></li>
<li class="breadcrumb-item active">知识库管理</li>
</ol>
</nav>
<div class="d-flex justify-content-between align-items-center mb-4">
<h2><i class="bi bi-collection"></i> 知识库管理</h2>
<button class="btn btn-primary" data-bs-toggle="modal" data-bs-target="#createKbModal">
<i class="bi bi-plus-circle"></i> 创建知识库
</button>
</div>
<!-- 搜索和排序工具栏 -->
<div class="card mb-4">
<div class="card-body">
<form method="GET" action="/kb" id="searchForm" class="row g-3 align-items-end">
<div class="col-md-6">
<label for="searchInput" class="form-label">搜索</label>
<div class="input-group">
<span class="input-group-text"><i class="bi bi-search"></i></span>
<input type="text" class="form-control" id="searchInput" name="search"
placeholder="搜索知识库名称或描述..." value="{{ search }}">
{% if search %}
<button type="button" class="btn btn-outline-secondary" onclick="clearSearch()">
<i class="bi bi-x"></i>
</button>
{% endif %}
</div>
</div>
<div class="col-md-3">
<label for="sortBySelect" class="form-label">排序字段</label>
<select class="form-select" id="sortBySelect" name="sort_by" onchange="updateSearch()">
<option value="created_at" {% if sort_by == 'created_at' %}selected{% endif %}>创建时间</option>
<option value="name" {% if sort_by == 'name' %}selected{% endif %}>名称</option>
<option value="updated_at" {% if sort_by == 'updated_at' %}selected{% endif %}>更新时间</option>
</select>
</div>
<div class="col-md-2">
<label for="sortOrderSelect" class="form-label">排序方向</label>
<select class="form-select" id="sortOrderSelect" name="sort_order" onchange="updateSearch()">
<option value="desc" {% if sort_order == 'desc' %}selected{% endif %}>降序</option>
<option value="asc" {% if sort_order == 'asc' %}selected{% endif %}>升序</option>
</select>
</div>
<div class="col-md-1">
<button type="submit" class="btn btn-primary w-100">
<i class="bi bi-search"></i> 搜索
</button>
</div>
<input type="hidden" name="page" value="1">
<input type="hidden" name="page_size" value="{{ pagination.page_size if pagination else 10 }}">
</form>
</div>
</div>
<!-- 知识库列表 -->
<div class="row" id="kbList">
{% if kbs %}
{% for kb in kbs %}
<div class="col-12 col-sm-6 col-md-4 col-lg mb-4">
<div class="card h-100">
{% if kb.cover_image %}
<img src="/kb/{{ kb.id }}/cover" class="card-img-top" alt="{{ kb.name|e }}" style="height: 150px; object-fit: scale-down;">
{% else %}
<div class="card-img-top bg-light d-flex align-items-center justify-content-center" style="height: 150px;">
<i class="bi bi-folder" style="font-size: 3rem; color: #6c757d;"></i>
</div>
{% endif %}
<div class="card-body">
<h5 class="card-title">
<i class="bi bi-folder"></i> {{ kb.name }}
</h5>
<p class="card-text text-muted small">{{ kb.description or '无描述' }}</p>
</div>
<div class="card-footer bg-transparent">
+ <a href="/kb/{{ kb.id }}" class="btn btn-sm btn-primary">
+ <i class="bi bi-arrow-right"></i> 进入
+ </a>
<button class="btn btn-sm btn-warning"
data-kb-id="{{ kb.id }}"
data-kb-name="{{ kb.name }}"
data-kb-description="{{ kb.description or '' }}"
data-kb-chunk-size="{{ kb.chunk_size }}"
data-kb-chunk-overlap="{{ kb.chunk_overlap }}"
data-kb-cover-image="{{ kb.cover_image or '' }}"
onclick="editKbFromButton(this)">
<i class="bi bi-pencil"></i> 编辑
</button>
<button class="btn btn-sm btn-danger" onclick="deleteKb('{{ kb.id }}', '{{ kb.name }}')">
<i class="bi bi-trash"></i> 删除
</button>
</div>
</div>
</div>
{% endfor %}
{% else %}
<div class="col-12">
<div class="alert alert-info">
<i class="bi bi-info-circle"></i> 还没有知识库,点击上方按钮创建一个吧!
</div>
</div>
{% endif %}
</div>
<!-- 分页控件 -->
{% if pagination and pagination.total > pagination.page_size %}
<nav aria-label="知识库列表分页" class="mt-4">
<ul class="pagination justify-content-center">
{% set current_page = pagination.page %}
{% set total_pages = (pagination.total + pagination.page_size - 1) // pagination.page_size %}
<!-- 上一页 -->
<li class="page-item {% if current_page <= 1 %}disabled{% endif %}">
<a class="page-link" href="?page={{ current_page - 1 }}&page_size={{ pagination.page_size }}{% if search %}&search={{ search|urlencode }}{% endif %}&sort_by={{ sort_by }}&sort_order={{ sort_order }}"
{% if current_page <= 1 %}tabindex="-1" aria-disabled="true"{% endif %}>
<i class="bi bi-chevron-left"></i> 上一页
</a>
</li>
<!-- 页码 -->
{% set start_page = [1, current_page - 2] | max %}
{% set end_page = [total_pages, current_page + 2] | min %}
{% if start_page > 1 %}
<li class="page-item">
<a class="page-link" href="?page=1&page_size={{ pagination.page_size }}{% if search %}&search={{ search|urlencode }}{% endif %}&sort_by={{ sort_by }}&sort_order={{ sort_order }}">1</a>
</li>
{% if start_page > 2 %}
<li class="page-item disabled">
<span class="page-link">...</span>
</li>
{% endif %}
{% endif %}
{% for page_num in range(start_page, end_page + 1) %}
<li class="page-item {% if page_num == current_page %}active{% endif %}">
<a class="page-link" href="?page={{ page_num }}&page_size={{ pagination.page_size }}{% if search %}&search={{ search|urlencode }}{% endif %}&sort_by={{ sort_by }}&sort_order={{ sort_order }}">
{{ page_num }}
</a>
</li>
{% endfor %}
{% if end_page < total_pages %}
{% if end_page < total_pages - 1 %}
<li class="page-item disabled">
<span class="page-link">...</span>
</li>
{% endif %}
<li class="page-item">
<a class="page-link" href="?page={{ total_pages }}&page_size={{ pagination.page_size }}{% if search %}&search={{ search|urlencode }}{% endif %}&sort_by={{ sort_by }}&sort_order={{ sort_order }}">{{ total_pages }}</a>
</li>
{% endif %}
<!-- 下一页 -->
<li class="page-item {% if current_page >= total_pages %}disabled{% endif %}">
<a class="page-link" href="?page={{ current_page + 1 }}&page_size={{ pagination.page_size }}{% if search %}&search={{ search|urlencode }}{% endif %}&sort_by={{ sort_by }}&sort_order={{ sort_order }}"
{% if current_page >= total_pages %}tabindex="-1" aria-disabled="true"{% endif %}>
下一页 <i class="bi bi-chevron-right"></i>
</a>
</li>
</ul>
<div class="text-center text-muted small mt-2">
共 {{ pagination.total }} 个知识库{% if search %}(搜索: "{{ search }}"){% endif %},第 {{ current_page }} / {{ total_pages }} 页
</div>
</nav>
{% endif %}
</div>
</div>
<!-- 创建知识库模态框 -->
<div class="modal fade" id="createKbModal" tabindex="-1">
<div class="modal-dialog">
<div class="modal-content">
<div class="modal-header">
<h5 class="modal-title">创建知识库</h5>
<button type="button" class="btn-close" data-bs-dismiss="modal"></button>
</div>
<form id="createKbForm" onsubmit="createKb(event)" enctype="multipart/form-data">
<div class="modal-body">
<div class="mb-3">
<label class="form-label">名称 <span class="text-danger">*</span></label>
<input type="text" class="form-control" name="name" required>
</div>
<div class="mb-3">
<label class="form-label">描述</label>
<textarea class="form-control" name="description" rows="3"></textarea>
</div>
<div class="mb-3">
<label class="form-label">封面图片(可选)</label>
<input type="file" class="form-control" name="cover_image" accept="image/jpeg,image/png,image/gif,image/webp" id="coverImageInput">
<div class="form-text">支持 JPG、PNG、GIF、WEBP 格式,最大 5MB</div>
<div id="coverImagePreview" class="mt-2" style="display: none;">
<img id="coverPreviewImg" src="" alt="封面预览" class="img-thumbnail" style="max-width: 200px; max-height: 200px;">
</div>
</div>
<div class="row">
<div class="col-md-6 mb-3">
<label class="form-label">分块大小</label>
<input type="number" class="form-control" name="chunk_size" value="512" min="100" max="2000">
<div class="form-text">每个文本块的最大字符数,建议 512-1024</div>
</div>
<div class="col-md-6 mb-3">
<label class="form-label">分块重叠</label>
<input type="number" class="form-control" name="chunk_overlap" value="50" min="0" max="200">
<div class="form-text">相邻块之间的重叠字符数,建议 50-100</div>
</div>
</div>
</div>
<div class="modal-footer">
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">取消</button>
<button type="submit" class="btn btn-primary">创建</button>
</div>
</form>
</div>
</div>
</div>
<!-- 编辑知识库模态框 -->
<div class="modal fade" id="editKbModal" tabindex="-1">
<div class="modal-dialog">
<div class="modal-content">
<div class="modal-header">
<h5 class="modal-title">编辑知识库</h5>
<button type="button" class="btn-close" data-bs-dismiss="modal"></button>
</div>
<form id="editKbForm" onsubmit="updateKb(event)" enctype="multipart/form-data">
<input type="hidden" name="kb_id" id="editKbId">
<div class="modal-body">
<div class="mb-3">
<label class="form-label">名称 <span class="text-danger">*</span></label>
<input type="text" class="form-control" name="name" id="editKbName" required>
</div>
<div class="mb-3">
<label class="form-label">描述</label>
<textarea class="form-control" name="description" id="editKbDescription" rows="3"></textarea>
</div>
<div class="mb-3">
<label class="form-label">封面图片</label>
<div id="editCoverPreview" class="mb-2">
<img id="editCoverPreviewImg" src="" alt="当前封面" class="img-thumbnail" style="max-width: 200px; max-height: 200px; display: none;">
<div id="editCoverNoImage" class="text-muted small" style="display: none;">暂无封面</div>
</div>
<input type="file" class="form-control" name="cover_image" accept="image/jpeg,image/png,image/gif,image/webp" id="editCoverImageInput">
<div class="form-text">支持 JPG、PNG、GIF、WEBP 格式,最大 5MB。留空则不修改封面。</div>
<div class="form-check mt-2">
<input class="form-check-input" type="checkbox" name="delete_cover" id="editDeleteCover" value="true">
<label class="form-check-label" for="editDeleteCover">
删除封面图片
</label>
</div>
<div id="editCoverNewPreview" class="mt-2" style="display: none;">
<img id="editCoverNewPreviewImg" src="" alt="新封面预览" class="img-thumbnail" style="max-width: 200px; max-height: 200px;">
</div>
</div>
<div class="row">
<div class="col-md-6 mb-3">
<label class="form-label">分块大小</label>
<input type="number" class="form-control" name="chunk_size" id="editKbChunkSize" value="512" min="100" max="2000">
<div class="form-text">每个文本块的最大字符数,建议 512-1024</div>
</div>
<div class="col-md-6 mb-3">
<label class="form-label">分块重叠</label>
<input type="number" class="form-control" name="chunk_overlap" id="editKbChunkOverlap" value="50" min="0" max="200">
<div class="form-text">相邻块之间的重叠字符数,建议 50-100</div>
</div>
</div>
</div>
<div class="modal-footer">
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">取消</button>
<button type="submit" class="btn btn-primary">保存</button>
</div>
</form>
</div>
</div>
</div>
{% endblock %}
{% block extra_js %}
<script>
// 异步函数,用于创建知识库
async function createKb(event) {
// 阻止表单默认提交
event.preventDefault();
// 获取表单对象
const form = event.target;
// 构造 FormData,收集表单数据
const formData = new FormData(form);
try {
// 发送 POST 请求到后端 API,提交表单数据
const response = await fetch('/api/v1/kb', {
method: 'POST',
// body为FormData,浏览器会自动设置Content-Type
body: formData
});
// 如果响应成功,刷新页面
if (response.ok) {
location.reload();
} else {
// 否则获取错误信息并弹窗提示
const error = await response.json();
alert('创建失败: ' + error.message);
}
} catch (error) {
// 捕获异常并弹窗提示用户
alert('创建失败: ' + error.message);
}
}
// 异步函数,用于删除知识库
async function deleteKb(kbId, kbName) {
// 弹窗确认是否删除知识库
if (!confirm(`确定要删除知识库 "${kbName}" 吗?此操作不可恢复!`)) {
return;
}
try {
// 发送 DELETE 请求到后端 API
const response = await fetch(`/api/v1/kb/${kbId}`, {
method: 'DELETE'
});
// 如果响应成功,刷新页面
if (response.ok) {
location.reload();
} else {
// 否则弹窗提示错误信息
const error = await response.json();
alert('删除失败: ' + error.message);
}
} catch (error) {
// 捕获异常并弹窗提示
alert('删除失败: ' + error.message);
}
}
// 从按钮的 data 属性读取知识库数据,然后打开编辑界面
function editKbFromButton(button) {
// 获取知识库ID
const kbId = button.getAttribute('data-kb-id');
// 获取知识库名称
const name = button.getAttribute('data-kb-name');
// 获取描述,默认为空字符串
const description = button.getAttribute('data-kb-description') || '';
// 获取分块大小,默认512
const chunkSize = parseInt(button.getAttribute('data-kb-chunk-size')) || 512;
// 获取分块重叠,默认50
const chunkOverlap = parseInt(button.getAttribute('data-kb-chunk-overlap')) || 50;
// 获取封面图片路径,默认为空
const coverImage = button.getAttribute('data-kb-cover-image') || '';
// 调用编辑函数,填充数据到表单
editKb(kbId, name, description, chunkSize, chunkOverlap,coverImage);
}
// 编辑知识库时弹出模态框并初始化数据
function editKb(kbId, name, description, chunkSize, chunkOverlap,coverImage) {
// 设置表单的知识库ID
document.getElementById('editKbId').value = kbId;
// 设置知识库名称
document.getElementById('editKbName').value = name;
// 设置描述
document.getElementById('editKbDescription').value = description || '';
// 设置分块大小
document.getElementById('editKbChunkSize').value = chunkSize;
// 设置分块重叠
document.getElementById('editKbChunkOverlap').value = chunkOverlap;
// 初始化不勾选删除封面
document.getElementById('editDeleteCover').checked = false;
// 清空已选择的新封面文件
document.getElementById('editCoverImageInput').value = '';
// 获取当前封面预览图片元素
const previewImg = document.getElementById('editCoverPreviewImg');
// 获取“暂无封面”提示元素
const noImageDiv = document.getElementById('editCoverNoImage');
// 获取新封面预览div
const newPreview = document.getElementById('editCoverNewPreview');
// 获取新封面预览图片元素
const newPreviewImg = document.getElementById('editCoverNewPreviewImg');
// 隐藏新图片预览
if (newPreview) {
newPreview.style.display = 'none';
}
// 清空新图片预览src
if (newPreviewImg) {
newPreviewImg.src = '';
}
// 如果有旧封面,则显示
if (coverImage) {
previewImg.src = `/kb/${kbId}/cover`;
previewImg.style.display = 'block';
noImageDiv.style.display = 'none';
} else {
// 否则显示“暂无封面”
previewImg.style.display = 'none';
noImageDiv.style.display = 'block';
}
// 展示编辑知识库的模态框
const modal = new bootstrap.Modal(document.getElementById('editKbModal'));
modal.show();
}
// 获取编辑时选择封面图片的input
const editCoverImageInput = document.getElementById('editCoverImageInput');
// 如果找到了input,则监听change事件
if (editCoverImageInput) {
editCoverImageInput.addEventListener('change', function(e) {
// 获取用户选择的第一个文件
const file = e.target.files[0];
// 获取新封面预览容器
const newPreview = document.getElementById('editCoverNewPreview');
// 获取新封面预览图片元素
const newPreviewImg = document.getElementById('editCoverNewPreviewImg');
// 获取删除封面复选框
const deleteCheckbox = document.getElementById('editDeleteCover');
// 如果有选文件
if (file) {
// 定义允许的图片类型
const validTypes = ['image/jpeg', 'image/jpg', 'image/png', 'image/gif', 'image/webp'];
// 如果不符合要求的格式,弹窗提示并重置
if (!validTypes.includes(file.type)) {
alert('不支持的图片格式,请选择 JPG、PNG、GIF 或 WEBP 格式的图片');
e.target.value = '';
if (newPreview) {
newPreview.style.display = 'none';
}
return;
}
// 如果图片超过5MB,弹窗提示并重置
if (file.size > 5 * 1024 * 1024) {
alert('图片文件大小超过 5MB 限制');
e.target.value = '';
if (newPreview) {
newPreview.style.display = 'none';
}
return;
}
// 文件读取器,用于预览图片
const reader = new FileReader();
// 文件读取完成后显示预览
reader.onload = function(event) {
if (newPreviewImg) {
newPreviewImg.src = event.target.result;
}
if (newPreview) {
newPreview.style.display = 'block';
}
// 选择新图片时自动取消删除封面的选项
if (deleteCheckbox) {
deleteCheckbox.checked = false;
}
};
// 读取图片失败时弹窗提示
reader.onerror = function() {
alert('读取图片文件失败,请重试');
e.target.value = '';
if (newPreview) {
newPreview.style.display = 'none';
}
};
// 以DataURL形式读取以便预览
reader.readAsDataURL(file);
} else {
// 未选择文件则隐藏新封面预览
if (newPreview) {
newPreview.style.display = 'none';
}
}
});
}
// 监听删除封面复选框的变化
document.getElementById('editDeleteCover')?.addEventListener('change', function(e) {
// 获取封面图片上传input
const fileInput = document.getElementById('editCoverImageInput');
// 获取新封面预览div
const newPreview = document.getElementById('editCoverNewPreview');
// 如果选中“删除封面”,则清空图片和隐藏新封面预览
if (e.target.checked) {
fileInput.value = ''; // 清空文件选择
newPreview.style.display = 'none';
}
});
// 异步函数,用于更新知识库
async function updateKb(event) {
// 阻止表单默认提交
event.preventDefault();
// 获取表单对象
const form = event.target;
// 获取表单数据
const formData = new FormData(form);
// 从formData中获取知识库ID
const kbId = formData.get('kb_id');
try {
// 发送PUT请求到后端API更新知识库
const response = await fetch(`/api/v1/kb/${kbId}`, {
method: 'PUT',
body: formData
});
// 如果更新成功,刷新页面
if (response.ok) {
location.reload();
} else {
// 否则弹窗显示错误
const error = await response.json();
alert('更新失败: ' + error.message);
}
} catch (error) {
// 捕获异常弹窗显示
alert('更新失败: ' + error.message);
}
}
// 获取封面图片文件输入框并监听change事件(当用户选择文件时触发)
document.getElementById('coverImageInput')?.addEventListener('change', function(e) {
// 获取用户选中的第一个文件
const file = e.target.files[0];
// 获取用于显示预览的容器
const preview = document.getElementById('coverImagePreview');
// 获取显示图片的img标签
const previewImg = document.getElementById('coverPreviewImg');
// 如果用户选择了文件
if (file) {
// 创建文件读取器
const reader = new FileReader();
// 文件读取完成后回调
reader.onload = function(e) {
// 将img标签的src设置为读取得到的图片数据
previewImg.src = e.target.result;
// 显示预览容器
preview.style.display = 'block';
};
// 以DataURL的形式读取图片文件
reader.readAsDataURL(file);
} else {
// 如果没有选择文件,则隐藏预览
preview.style.display = 'none';
}
});
// 搜索和排序功能
function updateSearch() {
document.getElementById('searchForm').submit();
}
function clearSearch() {
document.getElementById('searchInput').value = '';
// 移除搜索参数,保留排序参数
const url = new URL(window.location.href);
url.searchParams.delete('search');
url.searchParams.set('page', '1');
window.location.href = url.toString();
}
// 搜索框回车提交
document.getElementById('searchInput')?.addEventListener('keypress', function(e) {
if (e.key === 'Enter') {
e.preventDefault();
updateSearch();
}
});
</script>
{% endblock %}4.上传文档 #
4.1. document.py #
app/blueprints/document.py
"""
文档相关路由(视图 + API)
"""
# 从 Flask 导入 Blueprint 和 request,用于定义蓝图和处理请求
from flask import Blueprint,request
# 导入 os 模块,用于文件名后缀等操作
import os
# 导入 logging 模块,用于日志记录
import logging
# 从工具模块导入通用的响应和错误处理函数
from app.blueprints.utils import success_response, error_response, handle_api_error
# 导入文档服务,用于处理文档业务逻辑
from app.services.document_service import document_service
# 导入配置文件,获取相关配置参数
from app.config import Config
# 设置日志对象
logger = logging.getLogger(__name__)
# 创建一个名为 'document' 的蓝图
bp = Blueprint('document', __name__)
# 定义一个检查文件扩展名是否合法的辅助函数
def allowed_file(filename):
# 检查文件名中有小数点并且扩展名在允许的扩展名列表中
return '.' in filename and \
os.path.splitext(filename)[1][1:].lower() in Config.ALLOWED_EXTENSIONS
# 定义上传文档的 API 路由,POST 方法
@bp.route('/api/v1/knowledgebases/<kb_id>/documents', methods=['POST'])
# 使用装饰器统一捕获和处理 API 错误
@handle_api_error
def api_upload(kb_id):
# 上传文档接口
"""上传文档"""
# 如果请求中没有 'file' 字段,返回参数错误
if 'file' not in request.files:
return error_response("No file part", 400)
# 从请求中获取上传的文件对象
file = request.files['file']
# 如果用户未选择文件(文件名为空),返回错误
if file.filename == '':
return error_response("No file selected", 400)
# 如果文件类型不被允许(扩展名校验),返回错误
if not allowed_file(file.filename):
return error_response(f"File type not allowed. Allowed: {', '.join(Config.ALLOWED_EXTENSIONS)}", 400)
# 读取上传的文件内容
file_data = file.read()
# 检查文件大小是否超过上限
if len(file_data) > Config.MAX_FILE_SIZE:
return error_response(f"File size exceeds maximum {Config.MAX_FILE_SIZE} bytes", 400)
# 获取前端自定义的文件名(可选)
custom_name = request.form.get('name')
# 如果指定了自定义文件名
if custom_name:
# 获取原始文件扩展名
original_ext = os.path.splitext(file.filename)[1]
# 如果自定义文件名没有扩展名,自动补上原扩展名
if not os.path.splitext(custom_name)[1] and original_ext:
filename = custom_name + original_ext
else:
filename = custom_name
# 否则使用原始文件名
else:
filename = file.filename
# 校验最终得到的文件名:为空或只包含空白则报错
if not filename or not filename.strip():
return error_response("Filename is required", 400)
# 校验最终文件名是否包含扩展名
if '.' not in filename:
return error_response("Filename must have an extension", 400)
# 调用文档服务上传,返回文档信息字典
doc_dict = document_service.upload(kb_id, file_data, filename)
# 返回成功响应及新文档信息
return success_response(doc_dict)
4.2. document_service.py #
app/services/document_service.py
# 导入os模块,用于处理文件和路径
import os
# 导入uuid模块,用于生成唯一ID
import uuid
# 导入BaseService基类
from app.services.base_service import BaseService
# 导入Document模型,重命名为DocumentModel
from app.models.document import Document as DocumentModel
# 导入Knowledgebase知识库模型
from app.models.knowledgebase import Knowledgebase
# 导入存储服务
from app.services.storage_service import storage_service
# 导入配置项
from app.config import Config
# 定义DocumentService服务类,继承自BaseService
class DocumentService(BaseService[DocumentModel]):
"""文档服务"""
# 上传文档方法
def upload(self, kb_id: str, file_data: bytes, filename: str) -> dict:
"""
上传文档
参数:
kb_id: 知识库ID
file_data: 文件数据
filename: 文件名
返回:
创建的文档字典
"""
# 初始化变量,标识文件是否已经上传
file_uploaded = False
# 初始化文件路径
file_path = None
try:
# 使用数据库会话,判断知识库是否存在
with self.session() as session:
kb = session.query(Knowledgebase).filter(Knowledgebase.id == kb_id).first()
# 如果知识库不存在,抛出异常
if not kb:
raise ValueError(f"知识库 {kb_id} 不存在")
# 检查文件名是否为空或者没有扩展名
if not filename or '.' not in filename:
raise ValueError(f"文件名必须包含扩展名: {filename}")
# 获取文件扩展名
file_ext = os.path.splitext(filename)[1]
# 如果有后缀,去掉点号并转为小写
if file_ext:
file_ext = file_ext[1:].lower()
else:
# 如果没有文件后缀,抛出异常
raise ValueError(f"文件名必须包含扩展名: {filename}")
# 检查文件类型是否合法
if not file_ext or file_ext not in Config.ALLOWED_EXTENSIONS:
raise ValueError(f"不支持的文件类型: '{file_ext}'。允许类型: {', '.join(Config.ALLOWED_EXTENSIONS)}")
# 生成文档ID,用于标识唯一文档
doc_id = uuid.uuid4().hex[:32]
# 构建文件存储路径,便于后续文件操作
file_path = f"documents/{kb_id}/{doc_id}/{filename}"
# 优先将文件上传到本地/云存储,保证文件存在再创建记录
try:
storage_service.upload_file(file_path, file_data)
file_uploaded = True
except Exception as storage_error:
# 上传存储失败时写入日志并抛出异常
self.logger.error(f"上传文件到存储时发生错误: {storage_error}")
raise ValueError(f"文件上传失败: {str(storage_error)}")
# 在数据库中创建文档记录
with self.transaction() as session:
doc = DocumentModel(
id=doc_id,
kb_id=kb_id,
name=filename,
file_path=file_path,
file_type=file_ext,
file_size=len(file_data),
status='pending'
)
# 添加文档记录到会话
session.add(doc)
# Flush 保证数据同步到数据库
session.flush()
# 刷新文档对象,获得新写入的数据
session.refresh(doc)
# 对象转dict,避免对象分离后后续属性访问失败
doc_dict = doc.to_dict()
# 日志记录上传成功
self.logger.info(f"文档上传成功: {doc.id}")
# 返回新创建的文档字典信息
return doc_dict
except Exception as e:
# 异常处理,如果文件已上传但事务失败,则尝试删除已上传的文件
if file_uploaded and file_path:
try:
storage_service.delete_file(file_path)
except Exception as delete_error:
self.logger.warning(f"删除已上传文件时出错: {delete_error}")
# 重新抛出异常
raise
# 实例化DocumentService
document_service = DocumentService()4.3. init.py #
app/init.py
# RAG Lite 应用模块说明
"""
RAG Lite Application
"""
# 导入操作系统相关模块
import os
# 从 Flask 包导入 Flask 应用对象
from flask import Flask
# 导入 Flask 跨域资源共享支持
from flask_cors import CORS
# 导入应用配置类
from app.config import Config
# 导入日志工具,用于获取日志记录器
from app.utils.logger import get_logger
# 导入数据库初始化函数
from app.utils.db import init_db
# 导入蓝图模块
+from app.blueprints import auth,knowledgebase,settings,document
# 导入获取当前用户信息函数
from app.utils.auth import get_current_user
# 定义创建 Flask 应用的工厂函数
def create_app(config_class=Config):
# 获取日志记录器,名称为当前模块名
logger = get_logger(__name__)
# 尝试初始化数据库
try:
# 输出日志,表示即将初始化数据库
logger.info("初始化数据库...")
# 执行数据库初始化函数
init_db()
# 输出日志,表示数据库初始化成功
logger.info("数据库初始化成功")
# 捕获任意异常
except Exception as e:
# 输出警告日志,提示数据库初始化失败,并输出异常信息
logger.warning(f"数据库初始化失败: {e}")
# 输出警告日志,提示检查数据库是否已存在,并建议手动创建数据表
logger.warning("请确认数据库已存在,或手动创建数据表")
# 创建 Flask 应用对象,并指定模板和静态文件目录
base_dir = os.path.abspath(os.path.dirname(__file__))
# 创建 Flask 应用对象,并指定模板和静态文件目录
app = Flask(
__name__,
# 指定模板文件目录
template_folder=os.path.join(base_dir, 'templates'),
# 指定静态文件目录
static_folder=os.path.join(base_dir, 'static')
)
# 从给定配置类加载配置信息到应用
app.config.from_object(config_class)
# 启用跨域请求支持
CORS(app)
# 记录应用创建日志信息
logger.info("Flask 应用已创建")
# 注册上下文处理器,使 current_user 在所有模板中可用
@app.context_processor
def inject_user():
# 返回当前用户信息字典
# 使用 get_current_user 获取当前用户信息,并将其添加到上下文字典中
# 这样在模板中可以直接使用 current_user 变量
return dict(current_user=get_current_user())
# 注册蓝图
app.register_blueprint(auth.bp)
# 注册知识库蓝图
app.register_blueprint(knowledgebase.bp)
# 注册设置蓝图
app.register_blueprint(settings.bp)
# 注册文档蓝图
+ app.register_blueprint(document.bp)
# 定义首页路由
@app.route('/')
def index():
return "Hello, World!"
# 返回已配置的 Flask 应用对象
return app
4.4. kb_detail.html #
app/templates/kb_detail.html
{% extends "base.html" %}
{% block title %}{{ kb.name }} - RAG Lite{% endblock %}
{% block content %}
<div class="row">
<div class="col-12">
<nav aria-label="breadcrumb">
<ol class="breadcrumb">
<li class="breadcrumb-item"><a href="/">首页</a></li>
<li class="breadcrumb-item"><a href="/kb">知识库</a></li>
<li class="breadcrumb-item active">{{ kb.name }}</li>
</ol>
</nav>
</div>
</div>
<div class="row">
<div class="col-12">
<div class="card">
<div class="card-header d-flex justify-content-between align-items-center">
<h5 class="mb-0"><i class="bi bi-file-earmark"></i> 文档管理</h5>
<button class="btn btn-sm btn-primary" data-bs-toggle="modal" data-bs-target="#uploadModal">
<i class="bi bi-upload"></i> 上传文档
</button>
</div>
<div class="card-body">
</div>
</div>
</div>
</div>
+<!-- 上传文档模态框 -->
+<div class="modal fade" id="uploadModal" tabindex="-1">
+ <div class="modal-dialog">
+ <div class="modal-content">
+ <div class="modal-header">
+ <h5 class="modal-title">上传文档</h5>
+ <button type="button" class="btn-close" data-bs-dismiss="modal"></button>
+ </div>
+ <form id="uploadForm" onsubmit="uploadDocument(event)" enctype="multipart/form-data">
+ <div class="modal-body">
+ <div class="mb-3">
+ <label class="form-label">选择文件</label>
+ <input type="file" class="form-control" name="file" accept=".pdf,.docx,.txt,.md" required>
+ <div class="form-text">支持 PDF、DOCX、TXT、MD 格式</div>
+ </div>
+ <div class="mb-3">
+ <label class="form-label">文档名称(可选)</label>
+ <input type="text" class="form-control" name="name" placeholder="留空使用文件名">
+ </div>
+ <div id="uploadProgress" class="progress d-none">
+ <div class="progress-bar progress-bar-striped progress-bar-animated"
+ role="progressbar" style="width: 100%"></div>
+ </div>
+ </div>
+ <div class="modal-footer">
+ <button type="button" class="btn btn-secondary" data-bs-dismiss="modal">取消</button>
+ <button type="submit" class="btn btn-primary">上传</button>
+ </div>
+ </form>
+ </div>
+ </div>
+</div>
{% endblock %}
{% block extra_js %}
<script>
+// 定义当前知识库的ID
+const kbId = '{{ kb.id }}';
+// 上传文档的异步函数
+async function uploadDocument(event) {
+ // 阻止表单的默认提交行为
+ event.preventDefault();
+ // 获取表单元素
+ const form = event.target;
+ // 创建包含表单数据的 FormData 对象
+ const formData = new FormData(form);
+ // 获取上传进度的进度条元素
+ const progressDiv = document.getElementById('uploadProgress');
+ // 获取提交按钮
+ const submitBtn = form.querySelector('button[type="submit"]');
+ // 显示进度条
+ progressDiv.classList.remove('d-none');
+ // 禁用提交按钮,防止重复提交
+ submitBtn.disabled = true;
+ try {
+ // 向后端接口发送 POST 请求上传文档
+ const response = await fetch(`/api/v1/knowledgebases/${kbId}/documents`, {
+ method: 'POST',
+ body: formData
+ });
+ // 如果响应成功
+ if (response.ok) {
+ // 解析返回的 JSON 数据
+ const result = await response.json();
+ // 获取新文档的 ID
+ const docId = result.data.id;
+ // 关闭上传模态框
+ bootstrap.Modal.getInstance(document.getElementById('uploadModal')).hide();
+ // 重置表单
+ form.reset();
+ // 刷新页面以便显示新上传的文档
+ location.reload();
+ } else {
+ // 如果上传失败,解析错误信息并显示提示
+ const error = await response.json();
+ alert('上传失败: ' + error.message);
+ }
+ } catch (error) {
+ // 捕获代码块内发生的异常并提示失败信息
+ alert('上传失败: ' + error.message);
+ } finally {
+ // 上传完成或失败后隐藏进度条,恢复提交按钮
+ progressDiv.classList.add('d-none');
+ submitBtn.disabled = false;
+ }
+}
</script>
{% endblock %}5.文档列表 #
5.1. knowledgebase.py #
app/blueprints/knowledgebase.py
# 知识库相关路由(视图 + API)
"""
知识库相关路由(视图 + API)
"""
# 导入Flask中的Blueprint和request
from flask import Blueprint,request,render_template,send_file,abort,redirect, url_for
# 使用BytesIO将图片数据包装为文件流
from io import BytesIO
# 导入logging模块
import logging
# 导入mimetypes、os模块用于类型判断
import mimetypes
# 导入os模块用于路径操作
import os
# 导入自定义工具函数:异常处理装饰器、错误响应、成功响应
from app.blueprints.utils import (handle_api_error,error_response,success_response,get_current_user_or_error)
# 导入知识库服务
from app.services.knowledgebase_service import kb_service
# 导入认证工具函数:登录认证装饰器、获取当前用户、API登录认证装饰器
from app.utils.auth import login_required, get_current_user,api_login_required
# 导入分页工具函数
from app.blueprints.utils import (get_pagination_params,check_ownership)
# 导入存储服务
from app.services.storage_service import storage_service
# 导入文档服务
+from app.services.document_service import document_service
# 配置logger
logger = logging.getLogger(__name__)
# 创建Blueprint实例,注册在Flask应用下
bp = Blueprint('knowledgebase', __name__)
# 定义路由:POST请求到/api/v1/kb
@bp.route('/api/v1/kb', methods=['POST'])
# 应用API登录认证装饰器
@api_login_required
# 应用自定义异常处理装饰器
@handle_api_error
# 定义创建知识库的视图函数
# 定义用于创建知识库的API接口
def api_create():
# 接口用途说明文档字符串
"""创建知识库"""
# 获取当前用户,如未登录则返回错误响应
current_user, err = get_current_user_or_error()
if err:
return err
# 检查请求是否为multipart/form-data(用于文件上传的表单方式)
if request.content_type and 'multipart/form-data' in request.content_type:
# 从表单数据中获取知识库名称
name = request.form.get('name')
# 如果未传入name参数,返回错误
if not name:
return error_response("name is required", 400)
# 获取描述字段,没有则为None
description = request.form.get('description') or None
# 获取分块大小,默认为512
chunk_size = int(request.form.get('chunk_size', 512))
# 获取分块重叠,默认为50
chunk_overlap = int(request.form.get('chunk_overlap', 50))
# 设置封面图片数据变量初值为None
cover_image_data = None
# 设置封面图片文件名变量初值为None
cover_image_filename = None
# 判断请求中是否包含'cover_image'文件
if 'cover_image' in request.files:
# 获取上传的封面图片文件对象
cover_file = request.files['cover_image']
# 如果上传的文件存在且有文件名
if cover_file and cover_file.filename:
# 读取文件内容为二进制数据
cover_image_data = cover_file.read()
# 获取上传文件的文件名
cover_image_filename = cover_file.filename
# 记录封面图片上传的信息到日志,包括文件名、字节大小和内容类型
logger.info(f"收到新知识库的封面图片上传: 文件名={cover_image_filename}, 大小={len(cover_image_data)} 字节, 内容类型={cover_file.content_type}")
else:
# 如果是json请求数据(向后兼容旧用法)
data = request.get_json()
# 判断是否存在name字段,不存在则报错
if not data or 'name' not in data:
return error_response("name is required", 400)
# 获取知识库名称
name = data['name']
# 获取描述
description = data.get('description')
# 获取分块大小,默认为512
chunk_size = data.get('chunk_size', 512)
# 获取分块重叠,默认为50
chunk_overlap = data.get('chunk_overlap', 50)
# 设置封面图片数据变量初值为None
cover_image_data = None
# 设置封面图片文件名变量初值为None
cover_image_filename = None
# 调用知识库服务,创建知识库,返回知识库信息字典
kb_dict = kb_service.create(
name=name, # 知识库名称
user_id=current_user['id'], # 用户ID
description=description, # 知识库描述
chunk_size=chunk_size, # 分块大小
chunk_overlap=chunk_overlap, # 分块重叠
cover_image_data=cover_image_data, # 封面图片数据
cover_image_filename=cover_image_filename # 封面图片文件名
)
# 返回成功响应,包含知识库信息
return success_response(kb_dict)
# 注册'/kb'路由,处理GET请求,显示知识库列表页面
@bp.route('/kb')
# 要求登录用户才能访问该视图,用于Web页面
@login_required
# 定义kb_list函数,渲染知识库列表页面
def kb_list():
# 设置本函数用途说明(文档字符串)
"""知识库列表页面"""
# 获取当前登录用户信息
current_user = get_current_user()
# 获取分页参数(页码和每页大小),最大每页100
page, page_size = get_pagination_params(max_page_size=100)
# 获取搜索和排序参数
search = request.args.get('search', '').strip() or None
sort_by = request.args.get('sort_by', 'created_at')
sort_order = request.args.get('sort_order', 'desc')
# 验证排序参数
if sort_by not in ['created_at', 'name', 'updated_at']:
sort_by = 'created_at'
if sort_order not in ['asc', 'desc']:
sort_order = 'desc'
# 调用知识库服务,获取分页后的知识库列表结果
result = kb_service.list(
user_id=current_user['id'], # 用户ID
page=page, # 页码
page_size=page_size, # 每页大小
search=search, # 搜索关键词
sort_by=sort_by, # 排序字段
sort_order=sort_order # 排序方向
)
# 渲染知识库列表页面模板,传递数据,包括知识库列表、分页信息
return render_template('kb_list.html',
kbs=result['items'],
pagination=result,
search=search or '',
sort_by=sort_by,
sort_order=sort_order)
# 注册DELETE方法的API路由,用于删除知识库
@bp.route('/api/v1/kb/<kb_id>', methods=['DELETE'])
# 要求API登录
@api_login_required
# 处理API错误的装饰器
@handle_api_error
def api_delete(kb_id):
"""删除知识库"""
# 获取当前用户信息,如果未登录则返回错误
current_user, err = get_current_user_or_error()
if err:
return err
# 根据知识库ID获取知识库信息
kb_dict = kb_service.get_by_id(kb_id)
# 如果知识库不存在,返回404错误
if not kb_dict:
return error_response("未找到知识库", 404)
# 验证当前用户是否拥有该知识库的操作权限
has_permission, err = check_ownership(kb_dict['user_id'], current_user['id'], "knowledgebase")
if not has_permission:
return err
# 调用服务删除知识库
success = kb_service.delete(kb_id)
# 如果删除失败,返回404错误
if not success:
return error_response("未找到知识库", 404)
# 返回删除成功的响应
return success_response("知识库删除成功")
# 注册PUT方法的API路由(用于更新知识库)
@bp.route('/api/v1/kb/<kb_id>', methods=['PUT'])
# 要求API登录
@api_login_required
# 捕获API内部错误的装饰器
@handle_api_error
def api_update(kb_id):
# 定义API用于更新知识库(含封面图片)
"""更新知识库(支持封面图片更新)"""
# 获取当前登录用户信息,如果未登录则返回错误响应
current_user, err = get_current_user_or_error()
if err:
return err
# 获取指定ID的知识库记录,验证其是否存在
kb_dict = kb_service.get_by_id(kb_id)
if not kb_dict:
return error_response("未找到知识库", 404)
# 校验当前用户是否有操作该知识库的权限
has_permission, err = check_ownership(kb_dict['user_id'], current_user['id'], "knowledgebase")
if not has_permission:
return err
# 判断请求内容类型是否为multipart/form-data(一般用于带文件上传的表单提交)
if request.content_type and 'multipart/form-data' in request.content_type:
# 从表单中获取普通字段
name = request.form.get('name')
description = request.form.get('description') or None
chunk_size = request.form.get('chunk_size')
chunk_overlap = request.form.get('chunk_overlap')
# 初始化封面图片相关变量
cover_image_data = None
cover_image_filename = None
# 获得delete_cover字段(类型字符串,需判断是否为'true')
delete_cover = request.form.get('delete_cover') == 'true'
# 如果有上传封面图片,则读取文件内容
if 'cover_image' in request.files:
cover_file = request.files['cover_image']
if cover_file and cover_file.filename:
cover_image_data = cover_file.read()
cover_image_filename = cover_file.filename
# 记录上传日志
logger.info(f"收到知识库 {kb_id} 的封面图片上传: 文件名={cover_image_filename}, 大小={len(cover_image_data)} 字节, 内容类型={cover_file.content_type}")
# 构建待更新的数据
update_data = {}
if name:
update_data['name'] = name
if description is not None:
update_data['description'] = description
if chunk_size:
update_data['chunk_size'] = int(chunk_size)
if chunk_overlap:
update_data['chunk_overlap'] = int(chunk_overlap)
else:
# 非表单上传,则按JSON结构解析请求内容
data = request.get_json()
# 如果请求体是空的,直接返回错误
if not data:
return error_response("请求体不能为空", 400)
# 构建可更新的数据字典
update_data = {}
if 'name' in data:
update_data['name'] = data['name']
if 'description' in data:
update_data['description'] = data.get('description')
if 'chunk_size' in data:
update_data['chunk_size'] = data['chunk_size']
if 'chunk_overlap' in data:
update_data['chunk_overlap'] = data['chunk_overlap']
# JSON请求时,cover_image相关变量置空
cover_image_data = None
cover_image_filename = None
delete_cover = data.get('delete_cover', False)
# 调用服务更新知识库,传入各字段及封面参数
updated_kb = kb_service.update(
kb_id=kb_id, # 知识库ID
cover_image_data=cover_image_data, # 封面图片的二进制内容
cover_image_filename=cover_image_filename, # 封面图片文件名
delete_cover=delete_cover, # 是否删除封面图片
**update_data # 其它可变字段
)
# 更新后如果找不到,返回404
if not updated_kb:
return error_response("未找到知识库", 404)
# 更新成功后,将最新的知识库数据返回给前端
return success_response(updated_kb, "知识库更新成功")
# 定义路由,获取指定知识库ID的封面图片,仅限登录用户访问
@bp.route('/kb/<kb_id>/cover')
@login_required
def kb_cover(kb_id):
"""获取知识库封面图片"""
# 获取当前已登录用户的信息
current_user = get_current_user()
# 根据知识库ID从知识库服务获取对应的知识库信息
kb = kb_service.get_by_id(kb_id)
# 检查知识库是否存在
if not kb:
# 如果知识库不存在,记录警告日志
logger.warning(f"知识库不存在: {kb_id}")
abort(404)
# 检查是否有权限访问(只能查看自己的知识库封面)
if kb.get('user_id') != current_user['id']:
# 如果不是当前用户的知识库,记录警告日志
logger.warning(f"用户 {current_user['id']} 尝试访问知识库 {kb_id} 的封面,但该知识库属于用户 {kb.get('user_id')}")
abort(403)
# 获取知识库的封面图片路径
cover_path = kb.get('cover_image')
# 检查是否有封面图片
if not cover_path:
# 如果没有封面,记录调试日志
logger.debug(f"知识库 {kb_id} 没有封面图片")
abort(404)
try:
# 通过存储服务下载封面图片数据
image_data = storage_service.download_file(cover_path)
# 如果未能获取到图片数据,记录错误日志并返回404
if not image_data:
logger.error(f"从路径下载封面图片失败: {cover_path}")
abort(404)
# 根据文件扩展名判断图片MIME类型
file_ext = os.path.splitext(cover_path)[1].lower()
# 自定义映射,优先根据文件扩展名判断图片MIME类型
mime_type_map = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp'
}
# 优先根据自定义映射获取MIME类型
mime_type = mime_type_map.get(file_ext)
if not mime_type:
# 如果没有命中自定义映射,则使用mimetypes猜测类型
mime_type, _ = mimetypes.guess_type(cover_path)
if not mime_type:
# 如果还未识别出类型,则默认用JPEG
mime_type = 'image/jpeg'
# 通过send_file响应图片数据和MIME类型,不以附件形式发送
return send_file(
BytesIO(image_data),#图片数据
mimetype=mime_type,#MIME类型
as_attachment=False#不以附件形式发送
)
except FileNotFoundError as e:
# 捕获文件未找到异常,记录错误日志
logger.error(f"封面图片文件未找到: {cover_path}, 错误: {e}")
abort(404)
except Exception as e:
# 捕获其他未预期异常,记录错误日志(包含堆栈信息)
logger.error(f"提供知识库 {kb_id} 的封面图片时出错, 路径: {cover_path}, 错误: {e}", exc_info=True)
abort(404)
# 设置路由,访问 /kb/<kb_id> 时触发此视图函数
@bp.route('/kb/<kb_id>')
# 要求用户登录后才能访问该视图
@login_required
# 定义知识库详情页视图函数,接收 kb_id 作为参数
def kb_detail(kb_id):
# 知识库详情页面(文档管理)
"""知识库详情页面(文档管理)"""
# 根据 kb_id 查询知识库信息
kb = kb_service.get_by_id(kb_id)
# 如果未找到知识库,则重定向回知识库列表页面
if not kb:
return redirect(url_for('knowledgebase.kb_list'))
# 获取分页参数,最大每页100条
+ page, page_size = get_pagination_params(max_page_size=100)
# 获取指定知识库下的文档列表,带分页
+ result = document_service.list_by_kb(kb_id, page=page, page_size=page_size)
# 渲染知识库详情模板,将知识库对象、文档列表和分页信息传递给模板
return render_template('kb_detail.html', kb=kb,
+ documents=result['items'],
+ pagination=result)5.2. base_service.py #
app/services/base_service.py
# 基础服务类
"""
基础服务类
"""
# 导入日志库
import logging
# 导入可选类型、泛型、类型变量和类型别名
from typing import Optional, TypeVar, Generic, Dict, Any
# 导入数据库会话和事务管理工具
from app.utils.db import db_session, db_transaction
# 创建日志记录器
logger = logging.getLogger(__name__)
# 定义泛型的类型变量T
T = TypeVar('T')
# 定义基础服务类,支持泛型
class BaseService(Generic[T]):
# 基础服务类,提供通用的数据库操作方法
# 初始化方法
def __init__(self):
# 初始化服务的日志记录器
self.logger = logging.getLogger(self.__class__.__name__)
# 数据库会话上下文管理器(只读)
def session(self):
"""
数据库会话上下文管理器(只读操作,不自动提交)
使用示例:
with self.session() as db:
result = db.query(Model).all()
# 不需要手动关闭 session
"""
# 返回数据库会话
return db_session()
# 数据库事务上下文管理器(自动提交)
def transaction(self):
"""
数据库事务上下文管理器(自动提交,出错时回滚)
使用示例:
with self.transaction() as db:
obj = Model(...)
db.add(obj)
# 自动提交,出错时自动回滚
"""
# 返回数据库事务
return db_transaction()
def get_by_id(self, model_class: type, entity_id: str) -> Optional[T]:
"""
根据ID获取实体(通用方法)
Args:
model_class: 模型类
entity_id: 实体ID
Returns:
实体对象,如果不存在则返回 None
"""
with self.session() as session:
try:
return session.query(model_class).filter(model_class.id == entity_id).first()
except Exception as e:
self.logger.error(f"Error getting {model_class.__name__} by id {entity_id}: {e}")
return None
# 定义通用的分页查询方法
+ def paginate_query(self, query, page: int = 1, page_size: int = 10,
+ order_by=None) -> Dict[str, Any]:
+ """
+ 通用分页查询方法
+ Args:
+ query: SQLAlchemy 查询对象(必须在 session 上下文中调用)
+ page: 页码
+ page_size: 每页数量
+ order_by: 排序字段(可选,SQLAlchemy 表达式)
+ Returns:
+ 包含 items, total, page, page_size 的字典
+ """
# 判断是否传入排序字段(注意不能直接 if order_by,否则部分 SQLAlchemy 表达式会抛异常)
+ if order_by is not None:
# 如传入排序条件则按该条件排序
+ query = query.order_by(order_by)
# 获取查询结果的总条数
+ total = query.count()
# 计算偏移量,例如第2页 offset=10 (从第11条开始)
+ offset = (page - 1) * page_size
# 查询当前页的数据
+ items = query.offset(offset).limit(page_size).all()
# 返回结果,items 为对象列表(支持自动 to_dict 转换),同时返回 total, page, page_size
+ return {
+ 'items': [item.to_dict() if hasattr(item, 'to_dict') else item for item in items],
+ 'total': total,
+ 'page': page,
+ 'page_size': page_size
+ }
5.3. document_service.py #
app/services/document_service.py
# 导入os模块,用于处理文件和路径
import os
# 导入uuid模块,用于生成唯一ID
import uuid
# 导入类型提示
+from typing import List, Optional, Dict
# 导入BaseService基类
from app.services.base_service import BaseService
# 导入Document模型,重命名为DocumentModel
from app.models.document import Document as DocumentModel
# 导入Knowledgebase知识库模型
from app.models.knowledgebase import Knowledgebase
# 导入存储服务
from app.services.storage_service import storage_service
# 导入配置项
from app.config import Config
# 定义DocumentService服务类,继承自BaseService
class DocumentService(BaseService[DocumentModel]):
"""文档服务"""
# 上传文档方法
def upload(self, kb_id: str, file_data: bytes, filename: str) -> dict:
"""
上传文档
参数:
kb_id: 知识库ID
file_data: 文件数据
filename: 文件名
返回:
创建的文档字典
"""
# 初始化变量,标识文件是否已经上传
file_uploaded = False
# 初始化文件路径
file_path = None
try:
# 使用数据库会话,判断知识库是否存在
with self.session() as session:
kb = session.query(Knowledgebase).filter(Knowledgebase.id == kb_id).first()
# 如果知识库不存在,抛出异常
if not kb:
raise ValueError(f"知识库 {kb_id} 不存在")
# 检查文件名是否为空或者没有扩展名
if not filename or '.' not in filename:
raise ValueError(f"文件名必须包含扩展名: {filename}")
# 获取文件扩展名
file_ext = os.path.splitext(filename)[1]
# 如果有后缀,去掉点号并转为小写
if file_ext:
file_ext = file_ext[1:].lower()
else:
# 如果没有文件后缀,抛出异常
raise ValueError(f"文件名必须包含扩展名: {filename}")
# 检查文件类型是否合法
if not file_ext or file_ext not in Config.ALLOWED_EXTENSIONS:
raise ValueError(f"不支持的文件类型: '{file_ext}'。允许类型: {', '.join(Config.ALLOWED_EXTENSIONS)}")
# 生成文档ID,用于标识唯一文档
doc_id = uuid.uuid4().hex[:32]
# 构建文件存储路径,便于后续文件操作
file_path = f"documents/{kb_id}/{doc_id}/{filename}"
# 优先将文件上传到本地/云存储,保证文件存在再创建记录
try:
storage_service.upload_file(file_path, file_data)
file_uploaded = True
except Exception as storage_error:
# 上传存储失败时写入日志并抛出异常
self.logger.error(f"上传文件到存储时发生错误: {storage_error}")
raise ValueError(f"文件上传失败: {str(storage_error)}")
# 在数据库中创建文档记录
with self.transaction() as session:
doc = DocumentModel(
id=doc_id,
kb_id=kb_id,
name=filename,
file_path=file_path,
file_type=file_ext,
file_size=len(file_data),
status='pending'
)
# 添加文档记录到会话
session.add(doc)
# Flush 保证数据同步到数据库
session.flush()
# 刷新文档对象,获得新写入的数据
session.refresh(doc)
# 对象转dict,避免对象分离后后续属性访问失败
doc_dict = doc.to_dict()
# 日志记录上传成功
self.logger.info(f"文档上传成功: {doc.id}")
# 返回新创建的文档字典信息
return doc_dict
except Exception as e:
# 异常处理,如果文件已上传但事务失败,则尝试删除已上传的文件
if file_uploaded and file_path:
try:
storage_service.delete_file(file_path)
except Exception as delete_error:
self.logger.warning(f"删除已上传文件时出错: {delete_error}")
# 重新抛出异常
raise
# 定义根据知识库ID获取文档列表的方法
+ def list_by_kb(self, kb_id: str, page: int = 1, page_size: int = 10,
+ status: Optional[str] = None) -> Dict:
# 方法文档,说明该方法用于获取知识库的文档列表
+ """获取知识库的文档列表"""
# 创建数据库会话
+ with self.session() as session:
# 查询指定知识库ID下的所有文档
+ query = session.query(DocumentModel).filter(DocumentModel.kb_id == kb_id)
# 如果设置了文档状态,则对状态进行过滤
+ if status:
+ query = query.filter(DocumentModel.status == status)
# 使用分页方法返回结果,按创建时间倒序排列
+ return self.paginate_query(query, page=page, page_size=page_size,
+ order_by=DocumentModel.created_at.desc())
# 实例化DocumentService
document_service = DocumentService()5.4. kb_detail.html #
app/templates/kb_detail.html
{% extends "base.html" %}
{% block title %}{{ kb.name }} - RAG Lite{% endblock %}
{% block content %}
<div class="row">
<div class="col-12">
<nav aria-label="breadcrumb">
<ol class="breadcrumb">
<li class="breadcrumb-item"><a href="/">首页</a></li>
<li class="breadcrumb-item"><a href="/kb">知识库</a></li>
<li class="breadcrumb-item active">{{ kb.name }}</li>
</ol>
</nav>
</div>
</div>
<div class="row">
<div class="col-12">
<div class="card">
<div class="card-header d-flex justify-content-between align-items-center">
<h5 class="mb-0"><i class="bi bi-file-earmark"></i> 文档管理</h5>
<button class="btn btn-sm btn-primary" data-bs-toggle="modal" data-bs-target="#uploadModal">
<i class="bi bi-upload"></i> 上传文档
</button>
</div>
<div class="card-body">
+ <div id="docList">
+ {% if documents %}
+ <div class="table-responsive">
+ <table class="table table-hover">
+ <thead>
+ <tr>
+ <th>文档名称</th>
+ <th>状态</th>
+ <th>块数</th>
+ <th>文件大小</th>
+ <th>操作</th>
+ </tr>
+ </thead>
+ <tbody>
+ {% for doc in documents %}
+ <tr>
+ <td>
+ <i class="bi bi-file-earmark-{{ 'pdf' if doc.file_type == 'pdf' else 'word' if doc.file_type == 'docx' else 'text' }}"></i>
+ {{ doc.name }}
+ </td>
+ <td>
+ {% set status_map = {
+ 'completed': '已完成',
+ 'processing': '处理中',
+ 'failed': '失败',
+ 'pending': '待处理'
+ } %}
+ <span class="badge bg-{{ 'success' if doc.status == 'completed' else 'warning' if doc.status == 'processing' else 'danger' if doc.status == 'failed' else 'secondary' }}">
+ {{ status_map.get(doc.status, doc.status) }}
+ </span>
+ </td>
+ <td>{{ doc.chunk_count or 0 }}</td>
+ <td>{{ "%.2f"|format(doc.file_size / 1024) }} KB</td>
+ <td>
+ {% if doc.status == 'completed' %}
+ <a href="/documents/{{ doc.id }}/chunks" class="btn btn-sm btn-info me-1">
+ <i class="bi bi-list-ul"></i> 查看分块
+ </a>
+ {% endif %}
+ {% if doc.status == 'pending' %}
+ <button class="btn btn-sm btn-primary me-1" onclick="processDoc('{{ doc.id }}', '{{ doc.name }}')">
+ <i class="bi bi-play-circle"></i> 处理
+ </button>
+ {% elif doc.status in ['completed', 'failed'] %}
+ <button class="btn btn-sm btn-warning me-1" onclick="processDoc('{{ doc.id }}', '{{ doc.name }}')">
+ <i class="bi bi-arrow-clockwise"></i> 重新处理
+ </button>
+ {% elif doc.status == 'processing' %}
+ <button class="btn btn-sm btn-secondary me-1" disabled>
+ <i class="bi bi-hourglass-split"></i> 处理中
+ </button>
+ {% endif %}
+ <button class="btn btn-sm btn-danger" onclick="deleteDoc('{{ doc.id }}', '{{ doc.name }}')">
+ <i class="bi bi-trash"></i> 删除
+ </button>
+ </td>
+ </tr>
+ {% endfor %}
+ </tbody>
+ </table>
+ </div>
+ <!-- 分页控件 -->
+ {% if pagination and pagination.total > pagination.page_size %}
+ <nav aria-label="文档列表分页" class="mt-3">
+ <ul class="pagination justify-content-center">
+ {% set current_page = pagination.page %}
+ {% set total_pages = (pagination.total + pagination.page_size - 1) // pagination.page_size %}
+ <!-- 上一页 -->
+ <li class="page-item {% if current_page <= 1 %}disabled{% endif %}">
+ <a class="page-link" href="?page={{ current_page - 1 }}&page_size={{ pagination.page_size }}"
+ {% if current_page <= 1 %}tabindex="-1" aria-disabled="true"{% endif %}>
+ <i class="bi bi-chevron-left"></i> 上一页
+ </a>
+ </li>
+ <!-- 页码 -->
+ {% set start_page = [1, current_page - 2] | max %}
+ {% set end_page = [total_pages, current_page + 2] | min %}
+ {% if start_page > 1 %}
+ <li class="page-item">
+ <a class="page-link" href="?page=1&page_size={{ pagination.page_size }}">1</a>
+ </li>
+ {% if start_page > 2 %}
+ <li class="page-item disabled">
+ <span class="page-link">...</span>
+ </li>
+ {% endif %}
+ {% endif %}
+ {% for page_num in range(start_page, end_page + 1) %}
+ <li class="page-item {% if page_num == current_page %}active{% endif %}">
+ <a class="page-link" href="?page={{ page_num }}&page_size={{ pagination.page_size }}">
+ {{ page_num }}
+ </a>
+ </li>
+ {% endfor %}
+ {% if end_page < total_pages %}
+ {% if end_page < total_pages - 1 %}
+ <li class="page-item disabled">
+ <span class="page-link">...</span>
+ </li>
+ {% endif %}
+ <li class="page-item">
+ <a class="page-link" href="?page={{ total_pages }}&page_size={{ pagination.page_size }}">{{ total_pages }}</a>
+ </li>
+ {% endif %}
+ <!-- 下一页 -->
+ <li class="page-item {% if current_page >= total_pages %}disabled{% endif %}">
+ <a class="page-link" href="?page={{ current_page + 1 }}&page_size={{ pagination.page_size }}"
+ {% if current_page >= total_pages %}tabindex="-1" aria-disabled="true"{% endif %}>
+ 下一页 <i class="bi bi-chevron-right"></i>
+ </a>
+ </li>
+ </ul>
+ <div class="text-center text-muted small mt-2">
+ 共 {{ pagination.total }} 个文档,第 {{ current_page }} / {{ total_pages }} 页
+ </div>
+ </nav>
+ {% endif %}
+ {% else %}
+ <div class="text-center text-muted py-5">
+ <i class="bi bi-inbox" style="font-size: 3rem;"></i>
+ <p class="mt-2">还没有文档,点击上方按钮上传文档</p>
+ </div>
+ {% endif %}
+ </div>
</div>
</div>
</div>
</div>
<!-- 上传文档模态框 -->
<div class="modal fade" id="uploadModal" tabindex="-1">
<div class="modal-dialog">
<div class="modal-content">
<div class="modal-header">
<h5 class="modal-title">上传文档</h5>
<button type="button" class="btn-close" data-bs-dismiss="modal"></button>
</div>
<form id="uploadForm" onsubmit="uploadDocument(event)" enctype="multipart/form-data">
<div class="modal-body">
<div class="mb-3">
<label class="form-label">选择文件</label>
<input type="file" class="form-control" name="file" accept=".pdf,.docx,.txt,.md" required>
<div class="form-text">支持 PDF、DOCX、TXT、MD 格式</div>
</div>
<div class="mb-3">
<label class="form-label">文档名称(可选)</label>
<input type="text" class="form-control" name="name" placeholder="留空使用文件名">
</div>
<div id="uploadProgress" class="progress d-none">
<div class="progress-bar progress-bar-striped progress-bar-animated"
role="progressbar" style="width: 100%"></div>
</div>
</div>
<div class="modal-footer">
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">取消</button>
<button type="submit" class="btn btn-primary">上传</button>
</div>
</form>
</div>
</div>
</div>
{% endblock %}
{% block extra_js %}
<script>
// 定义当前知识库的ID
const kbId = '{{ kb.id }}';
// 上传文档的异步函数
async function uploadDocument(event) {
// 阻止表单的默认提交行为
event.preventDefault();
// 获取表单元素
const form = event.target;
// 创建包含表单数据的 FormData 对象
const formData = new FormData(form);
// 获取上传进度的进度条元素
const progressDiv = document.getElementById('uploadProgress');
// 获取提交按钮
const submitBtn = form.querySelector('button[type="submit"]');
// 显示进度条
progressDiv.classList.remove('d-none');
// 禁用提交按钮,防止重复提交
submitBtn.disabled = true;
try {
// 向后端接口发送 POST 请求上传文档
const response = await fetch(`/api/v1/knowledgebases/${kbId}/documents`, {
method: 'POST',
body: formData
});
// 如果响应成功
if (response.ok) {
// 解析返回的 JSON 数据
const result = await response.json();
// 获取新文档的 ID
const docId = result.data.id;
// 关闭上传模态框
bootstrap.Modal.getInstance(document.getElementById('uploadModal')).hide();
// 重置表单
form.reset();
// 刷新页面以便显示新上传的文档
location.reload();
} else {
// 如果上传失败,解析错误信息并显示提示
const error = await response.json();
alert('上传失败: ' + error.message);
}
} catch (error) {
// 捕获代码块内发生的异常并提示失败信息
alert('上传失败: ' + error.message);
} finally {
// 上传完成或失败后隐藏进度条,恢复提交按钮
progressDiv.classList.add('d-none');
submitBtn.disabled = false;
}
}
</script>
{% endblock %}6.文档分块 #
6.1. parser_service.py #
app/services/parser_service.py
"""
文档解析服务
使用 LangChain 的文档加载器
"""
# 导入日志模块
import logging
# 导入类型注解 List
from typing import List
# 导入 LangChain 的 Document 类型
from langchain_core.documents import Document
# 导入自定义的文档加载器
from app.utils.document_loader import DocumentLoader
# 获取日志记录器
logger = logging.getLogger(__name__)
# 定义文档解析服务类
class ParserService:
"""文档解析服务(使用 LangChain)"""
# 统一解析接口,返回 LangChain Document 列表
def parse(self, file_data: bytes, file_type: str) -> List[Document]:
"""
统一解析接口(返回 LangChain Document 列表)
Args:
file_data: 文件数据(bytes)
file_type: 文件类型(pdf/docx/txt/md)
Returns:
LangChain Document 列表
"""
# 将文件类型转换成小写,便于统一处理
file_type = file_type.lower()
# 调用文档加载器,加载文档并返回 Document 列表
return DocumentLoader.load(file_data, file_type)
# 创建 ParserService 的单例实例
parser_service = ParserService()
6.2. document_loader.py #
app/utils/document_loader.py
"""
文档加载器
"""
# 导入日志模块
import logging
# 导入类型注解
from typing import List
# 导入LangChain社区的文档加载器
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
# 导入LangChain核心Document类型
from langchain_core.documents import Document
# 导入临时文件工具
from tempfile import NamedTemporaryFile
# 导入os模块
import os
# 获取日志记录器
logger = logging.getLogger(__name__)
# 定义文档加载器类
class DocumentLoader:
"""文档加载器封装"""
@staticmethod
def load_pdf(file_data: bytes) -> List[Document]:
"""
加载 PDF 文档
Args:
file_data: PDF 文件数据(bytes)
Returns:
Document 列表
"""
# 异常捕获,避免加载出错崩溃
try:
# 创建一个临时pdf文件,将内容写入
with NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
tmp_file.write(file_data)
tmp_path = tmp_file.name
# 内部try/finally保证文件最终被删除
try:
# 加载PDF文件为Document对象
loader = PyPDFLoader(tmp_path)
documents = loader.load()
return documents
finally:
# 删除临时文件
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except Exception as e:
# 打印并抛出加载错误
logger.error(f"加载 PDF 时出错: {e}")
raise ValueError(f"Failed to load PDF: {str(e)}")
@staticmethod
def load_docx(file_data: bytes) -> List[Document]:
"""
加载 DOCX 文档
Args:
file_data: DOCX 文件数据(bytes)
Returns:
Document 列表
"""
# 异常捕获
try:
# 再次导入本地模块,保险起见
from tempfile import NamedTemporaryFile
import os
# 创建临时docx文件,写入内容
with NamedTemporaryFile(delete=False, suffix='.docx') as tmp_file:
tmp_file.write(file_data)
tmp_path = tmp_file.name
# 加载和清理临时文件
try:
# 加载DOCX为Document对象列表
loader = Docx2txtLoader(tmp_path)
documents = loader.load()
return documents
finally:
# 删除临时文件
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except Exception as e:
# 日志并抛出异常
logger.error(f"加载 DOCX 时出错: {e}")
raise ValueError(f"Failed to load DOCX: {str(e)}")
@staticmethod
def load_text(file_data: bytes, encoding: str = 'utf-8') -> List[Document]:
"""
加载文本文件
Args:
file_data: 文本文件数据(bytes)
encoding: 文件编码
Returns:
Document 列表
"""
# 总体异常捕获
try:
# 再次导入临时文件/OS模块
from tempfile import NamedTemporaryFile
import os
# 写入临时txt文件(二进制写模式)
with NamedTemporaryFile(delete=False, suffix='.txt', mode='wb') as tmp_file:
tmp_file.write(file_data)
tmp_path = tmp_file.name
# 加载过程及编码兜底
try:
# 优先尝试指定的编码
loader = TextLoader(tmp_path, encoding=encoding)
documents = loader.load()
return documents
except UnicodeDecodeError:
# 编码失败自动用gbk重试
try:
loader = TextLoader(tmp_path, encoding='gbk')
documents = loader.load()
return documents
except Exception as e:
# 日志并抛错
logger.error(f"加载文本文件时出错: {e}")
raise ValueError(f"Failed to load text file: {str(e)}")
finally:
# 删除临时文件
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except Exception as e:
# 加载文本总异常
logger.error(f"加载文本时出错: {e}")
raise ValueError(f"Failed to load text: {str(e)}")
@staticmethod
def load(file_data: bytes, file_type: str) -> List[Document]:
"""
统一加载接口
Args:
file_data: 文件数据(bytes)
file_type: 文件类型(pdf/docx/txt/md)
Returns:
Document 列表
"""
# 文件类型小写化,统一处理
file_type = file_type.lower()
# PDF文件
if file_type == 'pdf':
return DocumentLoader.load_pdf(file_data)
# DOCX文件
elif file_type == 'docx':
return DocumentLoader.load_docx(file_data)
# 文本文件/markdown
elif file_type in ['txt', 'md']:
return DocumentLoader.load_text(file_data)
else:
# 不支持文件类型抛异常
raise ValueError(f"Unsupported file type: {file_type}")
6.3. text_splitter.py #
app/utils/text_splitter.py
# 文本分割器
"""
文本分割器
"""
# 导入日志模块
import logging
# 导入类型提示
from typing import List
# 导入递归字符切分器
from langchain_text_splitters import RecursiveCharacterTextSplitter
# 导入文档对象
from langchain_core.documents import Document
# 获取日志记录器
logger = logging.getLogger(__name__)
# 定义文本分割器类
class TextSplitter:
# 文本分割器封装
# 初始化方法,设置块大小和重叠字符数
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
"""
初始化文本分割器
Args:
chunk_size: 每个块的最大字符数
chunk_overlap: 块之间的重叠字符数
"""
# 设置块的大小
self.chunk_size = chunk_size
# 设置块之间的重叠字符数
self.chunk_overlap = chunk_overlap
# 创建递归字符分割器实例,并设置分割参数
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,#块大小
chunk_overlap=chunk_overlap,#块之间的重叠字符数
length_function=len,#长度函数
separators=["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " ", ""]#分隔符
)
# 分割文档列表为若干块
def split_documents(self, documents: List[Document], doc_id: str = None) -> List[dict]:
"""
分割文档列表
Args:
documents: Document 列表
doc_id: 文档ID(用于生成块ID)
Returns:
块列表,每个块包含 id, text, chunk_index, metadata
"""
# 如果文档为空,直接返回空列表
if not documents:
return []
# 使用分割器分割文档
chunks = self.splitter.split_documents(documents)
# 结果列表
result = []
# 遍历每个分割块
for i, chunk in enumerate(chunks):
# 生成块ID,包含文档ID和序号
chunk_id = f"{doc_id}_{i}" if doc_id else str(i)
# 把块信息添加到结果中
result.append({
"id": chunk_id,# 块ID
"text": chunk.page_content,# 块文本内容
"chunk_index": i,# 块索引
"metadata": chunk.metadata# 块元数据
})
# 返回所有分割块
return result
6.4. document.py #
app/blueprints/document.py
"""
文档相关路由(视图 + API)
"""
# 从 Flask 导入 Blueprint 和 request,用于定义蓝图和处理请求
from flask import Blueprint,request
# 导入 os 模块,用于文件名后缀等操作
import os
# 导入 logging 模块,用于日志记录
import logging
# 从工具模块导入通用的响应和错误处理函数
from app.blueprints.utils import success_response, error_response, handle_api_error
# 导入文档服务,用于处理文档业务逻辑
from app.services.document_service import document_service
# 导入配置文件,获取相关配置参数
from app.config import Config
# 设置日志对象
logger = logging.getLogger(__name__)
# 创建一个名为 'document' 的蓝图
bp = Blueprint('document', __name__)
# 定义一个检查文件扩展名是否合法的辅助函数
def allowed_file(filename):
# 检查文件名中有小数点并且扩展名在允许的扩展名列表中
return '.' in filename and \
os.path.splitext(filename)[1][1:].lower() in Config.ALLOWED_EXTENSIONS
# 定义上传文档的 API 路由,POST 方法
@bp.route('/api/v1/knowledgebases/<kb_id>/documents', methods=['POST'])
# 使用装饰器统一捕获和处理 API 错误
@handle_api_error
def api_upload(kb_id):
# 上传文档接口
"""上传文档"""
# 如果请求中没有 'file' 字段,返回参数错误
if 'file' not in request.files:
return error_response("No file part", 400)
# 从请求中获取上传的文件对象
file = request.files['file']
# 如果用户未选择文件(文件名为空),返回错误
if file.filename == '':
return error_response("No file selected", 400)
# 如果文件类型不被允许(扩展名校验),返回错误
if not allowed_file(file.filename):
return error_response(f"File type not allowed. Allowed: {', '.join(Config.ALLOWED_EXTENSIONS)}", 400)
# 读取上传的文件内容
file_data = file.read()
# 检查文件大小是否超过上限
if len(file_data) > Config.MAX_FILE_SIZE:
return error_response(f"File size exceeds maximum {Config.MAX_FILE_SIZE} bytes", 400)
# 获取前端自定义的文件名(可选)
custom_name = request.form.get('name')
# 如果指定了自定义文件名
if custom_name:
# 获取原始文件扩展名
original_ext = os.path.splitext(file.filename)[1]
# 如果自定义文件名没有扩展名,自动补上原扩展名
if not os.path.splitext(custom_name)[1] and original_ext:
filename = custom_name + original_ext
else:
filename = custom_name
# 否则使用原始文件名
else:
filename = file.filename
# 校验最终得到的文件名:为空或只包含空白则报错
if not filename or not filename.strip():
return error_response("Filename is required", 400)
# 校验最终文件名是否包含扩展名
if '.' not in filename:
return error_response("Filename must have an extension", 400)
# 调用文档服务上传,返回文档信息字典
doc_dict = document_service.upload(kb_id, file_data, filename)
# 返回成功响应及新文档信息
return success_response(doc_dict)
# 定义路由,处理 POST 请求,API 路径包含待处理文档的ID
+@bp.route('/api/v1/documents/<doc_id>/process', methods=['POST'])
# 使用自定义的错误处理装饰器,以统一接口异常响应格式
+@handle_api_error
+def api_process(doc_id):
# 处理文档(API接口),doc_id是传入的文档ID
+ """处理文档"""
+ try:
# 调用文档服务方法,提交文档处理任务(异步)
+ document_service.process(doc_id)
# 返回成功响应信息,提示任务已提交
+ return success_response({"message": "文档处理任务已提交"})
+ except ValueError as e:
# 捕获业务异常(如找不到文档),返回404和错误信息
+ return error_response(str(e), 404)
+ except Exception as e:
# 捕获其他未知异常,记录错误日志(带堆栈信息)
+ logger.error(f"处理文档时出错: {e}", exc_info=True)
# 返回500通用错误响应,附带错误原因
+ return error_response(f"处理文档失败: {str(e)}", 500)6.5. document_service.py #
app/services/document_service.py
# 导入os模块,用于处理文件和路径
import os
# 导入uuid模块,用于生成唯一ID
import uuid
# 导入类型提示
from typing import List, Optional, Dict
# 导入线程池,用于异步处理文档
+from concurrent.futures import ThreadPoolExecutor
# 导入BaseService基类
from app.services.base_service import BaseService
# 导入Document模型,重命名为DocumentModel
from app.models.document import Document as DocumentModel
# 导入Knowledgebase知识库模型
from app.models.knowledgebase import Knowledgebase
# 导入存储服务
from app.services.storage_service import storage_service
# 导入解析服务
+from app.services.parser_service import parser_service
# 导入配置项
from app.config import Config
# 导入文本分块器
+from app.utils.text_splitter import TextSplitter
# 导入LangChain文档对象
+from langchain_core.documents import Document
# 定义DocumentService服务类,继承自BaseService
class DocumentService(BaseService[DocumentModel]):
"""文档服务"""
+ def __init__(self):
+ """初始化服务"""
+ super().__init__()
+ self.executor = ThreadPoolExecutor(max_workers=4)
# 上传文档方法
def upload(self, kb_id: str, file_data: bytes, filename: str) -> dict:
"""
上传文档
参数:
kb_id: 知识库ID
file_data: 文件数据
filename: 文件名
返回:
创建的文档字典
"""
# 初始化变量,标识文件是否已经上传
file_uploaded = False
# 初始化文件路径
file_path = None
try:
# 使用数据库会话,判断知识库是否存在
with self.session() as session:
kb = session.query(Knowledgebase).filter(Knowledgebase.id == kb_id).first()
# 如果知识库不存在,抛出异常
if not kb:
raise ValueError(f"知识库 {kb_id} 不存在")
# 检查文件名是否为空或者没有扩展名
if not filename or '.' not in filename:
raise ValueError(f"文件名必须包含扩展名: {filename}")
# 获取文件扩展名
file_ext = os.path.splitext(filename)[1]
# 如果有后缀,去掉点号并转为小写
if file_ext:
file_ext = file_ext[1:].lower()
else:
# 如果没有文件后缀,抛出异常
raise ValueError(f"文件名必须包含扩展名: {filename}")
# 检查文件类型是否合法
if not file_ext or file_ext not in Config.ALLOWED_EXTENSIONS:
raise ValueError(f"不支持的文件类型: '{file_ext}'。允许类型: {', '.join(Config.ALLOWED_EXTENSIONS)}")
# 生成文档ID,用于标识唯一文档
doc_id = uuid.uuid4().hex[:32]
# 构建文件存储路径,便于后续文件操作
file_path = f"documents/{kb_id}/{doc_id}/{filename}"
# 优先将文件上传到本地/云存储,保证文件存在再创建记录
try:
storage_service.upload_file(file_path, file_data)
file_uploaded = True
except Exception as storage_error:
# 上传存储失败时写入日志并抛出异常
self.logger.error(f"上传文件到存储时发生错误: {storage_error}")
raise ValueError(f"文件上传失败: {str(storage_error)}")
# 在数据库中创建文档记录
with self.transaction() as session:
doc = DocumentModel(
id=doc_id,#文档ID
kb_id=kb_id,#知识库ID
name=filename,#文件名
file_path=file_path,#文件路径
file_type=file_ext,#文件类型
file_size=len(file_data),#文件大小
status='pending'#文档状态
)
# 添加文档记录到会话
session.add(doc)
# Flush 保证数据同步到数据库
session.flush()
# 刷新文档对象,获得新写入的数据
session.refresh(doc)
# 对象转dict,避免对象分离后后续属性访问失败
doc_dict = doc.to_dict()
# 日志记录上传成功
self.logger.info(f"文档上传成功: {doc.id}")
# 返回新创建的文档字典信息
return doc_dict
except Exception as e:
# 异常处理,如果文件已上传但事务失败,则尝试删除已上传的文件
if file_uploaded and file_path:
try:
storage_service.delete_file(file_path)
except Exception as delete_error:
self.logger.warning(f"删除已上传文件时出错: {delete_error}")
# 重新抛出异常
raise
# 定义根据知识库ID获取文档列表的方法
def list_by_kb(self, kb_id: str, page: int = 1, page_size: int = 10,
status: Optional[str] = None) -> Dict:
# 方法文档,说明该方法用于获取知识库的文档列表
"""获取知识库的文档列表"""
# 创建数据库会话
with self.session() as session:
# 查询指定知识库ID下的所有文档
query = session.query(DocumentModel).filter(DocumentModel.kb_id == kb_id)
# 如果设置了文档状态,则对状态进行过滤
if status:
query = query.filter(DocumentModel.status == status)
# 使用分页方法返回结果,按创建时间倒序排列
return self.paginate_query(query, page=page, page_size=page_size,
order_by=DocumentModel.created_at.desc())
# 定义处理单个文档的方法(手动触发处理)
+ def process(self, doc_id: str):
+ """
+ 处理文档(手动触发)
+ Args:
+ doc_id: 文档ID
+ """
# 创建数据库会话,检查文档是否存在
+ with self.session() as session:
# 根据doc_id查询文档对象
+ doc = session.query(DocumentModel).filter(DocumentModel.id == doc_id).first()
# 如果文档不存在则抛出异常
+ if not doc:
+ raise ValueError(f"Document {doc_id} not found")
# 记录已经提交文档处理任务的日志
+ self.logger.info(f"提交文档处理任务: {doc_id}")
# 在线程池中异步提交处理任务
+ future = self.executor.submit(self._process_document, doc_id)
# 定义异常回调函数,用于捕获子线程中的异常
+ def exception_callback(future):
+ try:
# 获取异步线程的执行结果,发生异常会在此抛出
+ future.result()
+ except Exception as e:
# 记录错误日志及异常堆栈信息
+ self.logger.error(f"文档处理任务异常: {doc_id}, 错误: {e}", exc_info=True)
# 给future对象添加回调函数,任务结束时自动处理异常
+ future.add_done_callback(exception_callback)
# 文档实际处理方法(在子线程中执行,异步)
+ def _process_document(self, doc_id: str):
+ """
+ 处理文档(异步)
+ Args:
+ doc_id: 文档ID
+ """
+ try:
# 日志:文档处理任务开始
+ self.logger.info(f"开始处理文档: {doc_id}")
# 首先开启事务,获取文档信息和知识库配置并更新文档初始状态
+ with self.transaction() as session:
# 查询文档对象
+ doc = session.query(DocumentModel).filter(DocumentModel.id == doc_id).first()
# 未找到文档则输出日志并返回
+ if not doc:
+ self.logger.error(f"未找到文档 {doc_id}")
+ return
# 若文档已被处理过(完成或失败),则需重置状态
+ need_cleanup = doc.status in ['completed', 'failed']
+ if need_cleanup:
# 重置状态为待处理,分块数归零、错误信息清除
+ doc.status = 'pending'
+ doc.chunk_count = 0
+ doc.error_message = None
# 更新为正在处理状态
+ doc.status = 'processing'
# 刷新到数据库(写入未提交)
+ session.flush()
# 提前取出相关数据,避免事务作用域外对象失效
+ kb_id = doc.kb_id
+ file_path = doc.file_path
+ file_type = doc.file_type
+ doc_name = doc.name
# 查询知识库配置
+ kb = session.query(Knowledgebase).filter(Knowledgebase.id == kb_id).first()
# 未找到知识库抛出异常
+ if not kb:
+ raise ValueError(f"知识库 {kb_id} 未找到")
# 获取知识库分块参数
+ kb_chunk_size = kb.chunk_size
+ kb_chunk_overlap = kb.chunk_overlap
# 日志:文档已标记为处理中
+ self.logger.info(f"文档 {doc_id} 状态已更新为 processing(处理中)")
# 从存储中下载文件内容
+ file_data = storage_service.download_file(file_path)
# 解析文件,根据类型抽取原始文本内容
+ langchain_docs = parser_service.parse(file_data, file_type)
# 若抽取出的文本为空,则抛出异常
+ if not langchain_docs:
+ raise ValueError("未能抽取到任何文本内容")
# 创建文本分块器,指定知识库参数
+ splitter = TextSplitter(
+ chunk_size=kb_chunk_size,#块大小
+ chunk_overlap=kb_chunk_overlap,#块之间的重叠字符数
+ )
# 将文档内容分块
+ chunks = splitter.split_documents(langchain_docs, doc_id=doc_id)
# 如果分块失败,抛出异常
+ if not chunks:
+ raise ValueError("文档未能成功分块")
# 再次开启事务,更新文档状态为完成,记录分块数
+ with self.transaction() as session:
+ doc = session.query(DocumentModel).filter(DocumentModel.id == doc_id).first()
+ if doc:
+ doc.status = 'completed'#完成状态
+ doc.chunk_count = len(chunks)#分块数
# 日志:处理完成,输出分块数
+ self.logger.info(f"文档处理完成: {doc_id}, 分块数量: {len(chunks)}")
+ except Exception as e:
# 捕获异常后,更新文档状态为失败,并记录错误信息(限长)
+ with self.transaction() as session:
# 查询文档对象
+ doc = session.query(DocumentModel).filter(DocumentModel.id == doc_id).first()
# 如果文档存在,则更新状态为失败,并记录错误信息(限长)
+ if doc:
+ doc.status = 'failed'#失败状态
+ doc.error_message = str(e)[:500]#错误信息
# 记录处理失败的日志
+ self.logger.error(f"处理文档 {doc_id} 时发生错误: {e}")
# 实例化DocumentService
document_service = DocumentService()6.6. kb_detail.html #
app/templates/kb_detail.html
{% extends "base.html" %}
{% block title %}{{ kb.name }} - RAG Lite{% endblock %}
{% block content %}
<div class="row">
<div class="col-12">
<nav aria-label="breadcrumb">
<ol class="breadcrumb">
<li class="breadcrumb-item"><a href="/">首页</a></li>
<li class="breadcrumb-item"><a href="/kb">知识库</a></li>
<li class="breadcrumb-item active">{{ kb.name }}</li>
</ol>
</nav>
</div>
</div>
<div class="row">
<div class="col-12">
<div class="card">
<div class="card-header d-flex justify-content-between align-items-center">
<h5 class="mb-0"><i class="bi bi-file-earmark"></i> 文档管理</h5>
<button class="btn btn-sm btn-primary" data-bs-toggle="modal" data-bs-target="#uploadModal">
<i class="bi bi-upload"></i> 上传文档
</button>
</div>
<div class="card-body">
<div id="docList">
{% if documents %}
<div class="table-responsive">
<table class="table table-hover">
<thead>
<tr>
<th>文档名称</th>
<th>状态</th>
<th>块数</th>
<th>文件大小</th>
<th>操作</th>
</tr>
</thead>
<tbody>
{% for doc in documents %}
<tr>
<td>
<i class="bi bi-file-earmark-{{ 'pdf' if doc.file_type == 'pdf' else 'word' if doc.file_type == 'docx' else 'text' }}"></i>
{{ doc.name }}
</td>
<td>
{% set status_map = {
'completed': '已完成',
'processing': '处理中',
'failed': '失败',
'pending': '待处理'
} %}
<span class="badge bg-{{ 'success' if doc.status == 'completed' else 'warning' if doc.status == 'processing' else 'danger' if doc.status == 'failed' else 'secondary' }}">
{{ status_map.get(doc.status, doc.status) }}
</span>
</td>
<td>{{ doc.chunk_count or 0 }}</td>
<td>{{ "%.2f"|format(doc.file_size / 1024) }} KB</td>
<td>
{% if doc.status == 'completed' %}
<a href="/documents/{{ doc.id }}/chunks" class="btn btn-sm btn-info me-1">
<i class="bi bi-list-ul"></i> 查看分块
</a>
{% endif %}
{% if doc.status == 'pending' %}
<button class="btn btn-sm btn-primary me-1" onclick="processDoc('{{ doc.id }}', '{{ doc.name }}')">
<i class="bi bi-play-circle"></i> 处理
</button>
{% elif doc.status in ['completed', 'failed'] %}
<button class="btn btn-sm btn-warning me-1" onclick="processDoc('{{ doc.id }}', '{{ doc.name }}')">
<i class="bi bi-arrow-clockwise"></i> 重新处理
</button>
{% elif doc.status == 'processing' %}
<button class="btn btn-sm btn-secondary me-1" disabled>
<i class="bi bi-hourglass-split"></i> 处理中
</button>
{% endif %}
<button class="btn btn-sm btn-danger" onclick="deleteDoc('{{ doc.id }}', '{{ doc.name }}')">
<i class="bi bi-trash"></i> 删除
</button>
</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
<!-- 分页控件 -->
{% if pagination and pagination.total > pagination.page_size %}
<nav aria-label="文档列表分页" class="mt-3">
<ul class="pagination justify-content-center">
{% set current_page = pagination.page %}
{% set total_pages = (pagination.total + pagination.page_size - 1) // pagination.page_size %}
<!-- 上一页 -->
<li class="page-item {% if current_page <= 1 %}disabled{% endif %}">
<a class="page-link" href="?page={{ current_page - 1 }}&page_size={{ pagination.page_size }}"
{% if current_page <= 1 %}tabindex="-1" aria-disabled="true"{% endif %}>
<i class="bi bi-chevron-left"></i> 上一页
</a>
</li>
<!-- 页码 -->
{% set start_page = [1, current_page - 2] | max %}
{% set end_page = [total_pages, current_page + 2] | min %}
{% if start_page > 1 %}
<li class="page-item">
<a class="page-link" href="?page=1&page_size={{ pagination.page_size }}">1</a>
</li>
{% if start_page > 2 %}
<li class="page-item disabled">
<span class="page-link">...</span>
</li>
{% endif %}
{% endif %}
{% for page_num in range(start_page, end_page + 1) %}
<li class="page-item {% if page_num == current_page %}active{% endif %}">
<a class="page-link" href="?page={{ page_num }}&page_size={{ pagination.page_size }}">
{{ page_num }}
</a>
</li>
{% endfor %}
{% if end_page < total_pages %}
{% if end_page < total_pages - 1 %}
<li class="page-item disabled">
<span class="page-link">...</span>
</li>
{% endif %}
<li class="page-item">
<a class="page-link" href="?page={{ total_pages }}&page_size={{ pagination.page_size }}">{{ total_pages }}</a>
</li>
{% endif %}
<!-- 下一页 -->
<li class="page-item {% if current_page >= total_pages %}disabled{% endif %}">
<a class="page-link" href="?page={{ current_page + 1 }}&page_size={{ pagination.page_size }}"
{% if current_page >= total_pages %}tabindex="-1" aria-disabled="true"{% endif %}>
下一页 <i class="bi bi-chevron-right"></i>
</a>
</li>
</ul>
<div class="text-center text-muted small mt-2">
共 {{ pagination.total }} 个文档,第 {{ current_page }} / {{ total_pages }} 页
</div>
</nav>
{% endif %}
{% else %}
<div class="text-center text-muted py-5">
<i class="bi bi-inbox" style="font-size: 3rem;"></i>
<p class="mt-2">还没有文档,点击上方按钮上传文档</p>
</div>
{% endif %}
</div>
</div>
</div>
</div>
</div>
<!-- 上传文档模态框 -->
<div class="modal fade" id="uploadModal" tabindex="-1">
<div class="modal-dialog">
<div class="modal-content">
<div class="modal-header">
<h5 class="modal-title">上传文档</h5>
<button type="button" class="btn-close" data-bs-dismiss="modal"></button>
</div>
<form id="uploadForm" onsubmit="uploadDocument(event)" enctype="multipart/form-data">
<div class="modal-body">
<div class="mb-3">
<label class="form-label">选择文件</label>
<input type="file" class="form-control" name="file" accept=".pdf,.docx,.txt,.md" required>
<div class="form-text">支持 PDF、DOCX、TXT、MD 格式</div>
</div>
<div class="mb-3">
<label class="form-label">文档名称(可选)</label>
<input type="text" class="form-control" name="name" placeholder="留空使用文件名">
</div>
<div id="uploadProgress" class="progress d-none">
<div class="progress-bar progress-bar-striped progress-bar-animated"
role="progressbar" style="width: 100%"></div>
</div>
</div>
<div class="modal-footer">
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">取消</button>
<button type="submit" class="btn btn-primary">上传</button>
</div>
</form>
</div>
</div>
</div>
{% endblock %}
{% block extra_js %}
<script>
// 定义当前知识库的ID
const kbId = '{{ kb.id }}';
// 上传文档的异步函数
async function uploadDocument(event) {
// 阻止表单的默认提交行为
event.preventDefault();
// 获取表单元素
const form = event.target;
// 创建包含表单数据的 FormData 对象
const formData = new FormData(form);
// 获取上传进度的进度条元素
const progressDiv = document.getElementById('uploadProgress');
// 获取提交按钮
const submitBtn = form.querySelector('button[type="submit"]');
// 显示进度条
progressDiv.classList.remove('d-none');
// 禁用提交按钮,防止重复提交
submitBtn.disabled = true;
try {
// 向后端接口发送 POST 请求上传文档
const response = await fetch(`/api/v1/knowledgebases/${kbId}/documents`, {
method: 'POST',
body: formData
});
// 如果响应成功
if (response.ok) {
// 解析返回的 JSON 数据
const result = await response.json();
// 获取新文档的 ID
const docId = result.data.id;
// 关闭上传模态框
bootstrap.Modal.getInstance(document.getElementById('uploadModal')).hide();
// 重置表单
form.reset();
// 刷新页面以便显示新上传的文档
location.reload();
} else {
// 如果上传失败,解析错误信息并显示提示
const error = await response.json();
alert('上传失败: ' + error.message);
}
} catch (error) {
// 捕获代码块内发生的异常并提示失败信息
alert('上传失败: ' + error.message);
} finally {
// 上传完成或失败后隐藏进度条,恢复提交按钮
progressDiv.classList.add('d-none');
submitBtn.disabled = false;
}
}
+// 定义一个异步函数,处理文档的流程,接收文档ID和文档名作为参数
+async function processDoc(docId, docName) {
+ // 使用 try-catch 捕获异常
+ try {
+ // 向后端发送处理文档的 POST 请求
+ const response = await fetch(`/api/v1/documents/${docId}/process`, {
+ method: 'POST'
+ });
+ // 判断响应是否成功
+ if (response.ok) {
+ // 弹出提示,告知用户任务已提交
+ alert('文档处理任务已提交,请稍候刷新页面查看状态');
+ // 设置延迟,1秒后刷新页面,以便服务器有时间处理
+ setTimeout(() => {
+ location.reload();
+ }, 1000);
+ } else {
+ // 若处理失败,解析响应中的错误信息
+ const error = await response.json();
+ // 弹窗提示用户错误信息
+ alert('处理失败: ' + error.message);
+ }
+ } catch (error) {
+ // 捕获异常,弹窗提示错误
+ alert('处理失败: ' + error.message);
+ }
+}
</script>
{% endblock %}7.Chroma向量 #
7.1. vector_service.py #
app/services/vector_service.py
"""
向量数据库服务
提供统一的向量数据库访问接口
"""
from app.services.vectordb.factory import get_vector_db_service
# 创建默认实例(使用全局设置)
vector_service = get_vector_db_service()7.2. init.py #
app/services/vectordb/init.py
"""
向量数据库模块
提供统一的向量数据库接口,支持多种向量数据库后端
"""
7.3. base.py #
app/services/vectordb/base.py
"""
向量数据库抽象接口
"""
# 导入 abc 库,定义抽象基类和抽象方法
from abc import ABC, abstractmethod
# 导入类型提示:列表、字典、可选和任意类型
from typing import List, Dict, Optional, Any
# 导入 LangChain 的 Document 类,用于文档对象
from langchain_core.documents import Document
# 定义一个向量数据库的抽象接口,继承自 ABC 抽象基类
class VectorDBInterface(ABC):
"""向量数据库抽象接口"""
# 定义抽象方法:获取或创建集合
@abstractmethod
def get_or_create_collection(self, collection_name: str) -> Any:
"""
获取或创建集合
Args:
collection_name: 集合名称
Returns:
向量存储对象
"""
# 子类需要实现具体逻辑
pass
# 定义抽象方法:向向量存储中添加文档
@abstractmethod
def add_documents(self, collection_name: str, documents: List[Document],
ids: Optional[List[str]] = None) -> List[str]:
"""
添加文档到向量存储
Args:
collection_name: 集合名称
documents: Document 列表
ids: 文档ID列表(可选)
Returns:
添加的文档ID列表
"""
# 子类需要实现具体逻辑
pass
# 定义抽象方法:删除指定的文档
@abstractmethod
def delete_documents(self, collection_name: str, ids: Optional[List[str]] = None,
filter: Optional[Dict] = None) -> None:
"""
删除文档
Args:
collection_name: 集合名称
ids: 要删除的文档ID列表(可选)
filter: 过滤条件(可选)
"""
# 子类需要实现具体逻辑
pass
7.4. chroma.py #
app/services/vectordb/chroma.py
# ChromaDB 向量数据库实现
"""
ChromaDB 向量数据库实现
"""
# 导入日志模块
import logging
# 导入需要的类型提示
from typing import List, Dict, Optional, Any
# 导入 LangChain 的 Chroma 类
from langchain_chroma import Chroma
# 导入 Document 类
from langchain_core.documents import Document
# 导入向量数据库接口基类
from app.services.vectordb.base import VectorDBInterface
# 导入全局配置
from app.config import Config
# 导入嵌入模型工厂
from app.utils.embedding_factory import EmbeddingFactory
# 获取日志记录器
logger = logging.getLogger(__name__)
# 定义 Chroma 向量数据库实现类
class ChromaVectorDB(VectorDBInterface):
# ChromaDB 向量数据库实现
# 初始化方法
def __init__(self, persist_directory: Optional[str] = None,):
"""
初始化 ChromaDB 服务
Args:
persist_directory: 持久化目录,如果为None则使用配置中的值
settings: 设置字典,用于创建Embedding模型
"""
# 如果持久化目录为 None,则从配置读取
if persist_directory is None:
persist_directory = Config.CHROMA_PERSIST_DIRECTORY
# 设置持久化目录属性
self.persist_directory = persist_directory
# 动态创建 Embedding 模型
self.embeddings = EmbeddingFactory.create_embeddings()
# 记录 ChromaDB 初始化信息
logger.info(f"ChromaDB 已初始化, 持久化目录: {persist_directory}")
# 获取或创建集合(向量存储)
def get_or_create_collection(self, collection_name: str) -> Chroma:
# 获取或创建向量存储对象
vectorstore = Chroma(
collection_name=collection_name,
embedding_function=self.embeddings,
persist_directory=self.persist_directory
)
# 返回向量存储对象
return vectorstore
# 向向量存储添加文档
def add_documents(self, collection_name: str, documents: List[Document],
ids: Optional[List[str]] = None) -> List[str]:
# 获取集合
vectorstore = self.get_or_create_collection(collection_name)
# 如果指定了 ids
if ids:
# 添加文档,指定 ids
result_ids = vectorstore.add_documents(documents=documents, ids=ids)
else:
# 添加文档,不指定 ids
result_ids = vectorstore.add_documents(documents=documents)
# 记录日志,添加文档
logger.info(f"已向 ChromaDB 集合 {collection_name} 添加 {len(documents)} 个文档")
# 返回已添加文档的 id 列表
return result_ids
# 删除文档
def delete_documents(self, collection_name: str, ids: Optional[List[str]] = None,
vectorstore = self.get_or_create_collection(collection_name)
if ids:
vectorstore.delete(ids=ids)
elif filter:
# ChromaDB 不支持直接使用 filter 参数删除
# 需要先查询出符合条件的文档 IDs,然后删除
try:
collection = vectorstore._collection
# 使用 where 条件查询匹配的文档
# filter 格式: {"doc_id": "xxx"}
where = filter
results = collection.get(where=where)
if results and 'ids' in results and results['ids']:
matched_ids = results['ids']
vectorstore.delete(ids=matched_ids)
logger.info(f"已通过filter条件删除{len(matched_ids)}个文档")
else:
logger.info(f"未找到匹配filter条件的文档,无需删除")
except Exception as e:
logger.error(f"使用filter删除文档时出错: {e}", exc_info=True)
raise
else:
raise ValueError(f"你既没有传ids,也没有传filter")
logger.info(f"已经从ChromDB集合{collection_name}删除文档")
7.5. factory.py #
app/services/vectordb/factory.py
"""
向量数据库工厂
"""
# 导入日志模块
import logging
# 导入可选类型提示
from typing import Optional
# 导入向量数据库接口基类
from app.services.vectordb.base import VectorDBInterface
# 导入 Chroma 的向量数据库实现
from app.services.vectordb.chroma import ChromaVectorDB
# 导入全局配置
from app.config import Config
# 获取日志记录器
logger = logging.getLogger(__name__)
# 定义向量数据库工厂类
class VectorDBFactory:
"""向量数据库工厂"""
# 类属性:用于保存单例实例
_instance: Optional[VectorDBInterface] = None
# 类方法:创建向量数据库实例
@classmethod
def create_vector_db(cls, vectordb_type: Optional[str] = None, **kwargs) -> VectorDBInterface:
"""
创建向量数据库实例
Args:
vectordb_type: 向量数据库类型 ('chroma' 或 'milvus'),如果为None则从配置读取
settings: 设置字典,用于创建Embedding模型
**kwargs: 向量数据库的初始化参数
Returns:
向量数据库实例
"""
# 如果未传入 vectordb_type,则从配置读取,默认 'chroma'
if vectordb_type is None:
vectordb_type = getattr(Config, 'VECTORDB_TYPE', 'chroma')
# 转为小写,统一判断
vectordb_type = vectordb_type.lower()
# 如果选择的是 Chroma
if vectordb_type == 'chroma':
# 获取持久化目录(可选参数)
persist_directory = kwargs.get('persist_directory')
# 创建 ChromaVectorDB 实例并返回
return ChromaVectorDB(
persist_directory=persist_directory
)
else:
# 其他类型暂不支持,抛出异常
raise ValueError(f"Unsupported vector database type: {vectordb_type}")
# 类方法:获取单例向量数据库实例(懒加载)
@classmethod
def get_instance(cls) -> VectorDBInterface:
"""
获取单例向量数据库实例(懒加载)
Args:
settings: 设置字典,用于创建Embedding模型
Returns:
向量数据库实例
"""
# 如果单例实例为空,则创建
if cls._instance is None:
cls._instance = cls.create_vector_db()
# 返回单例实例
return cls._instance
# 工厂方法:便捷获取向量数据库服务实例
def get_vector_db_service() -> VectorDBInterface:
"""
便捷函数:获取向量数据库服务实例
Args:
settings: 设置字典,用于创建Embedding模型
Returns:
向量数据库服务实例
"""
# 返回单例向量数据库实例
return VectorDBFactory.get_instance()
7.6. embedding_factory.py #
app/utils/embedding_factory.py
"""
Embedding 模型工厂
支持多种 Embedding 模型提供商
"""
# 导入日志模块
import logging
# 导入 HuggingFace Embeddings 类
from langchain_huggingface import HuggingFaceEmbeddings
# 导入 OpenAI Embeddings 类
from langchain_openai import OpenAIEmbeddings
# 导入 Ollama Embeddings 类
from langchain_community.embeddings import OllamaEmbeddings
# 导入全局设置服务
from app.services.settings_service import settings_service
# 获取 logger 对象
logger = logging.getLogger(__name__)
# 定义 Embedding 工厂类
class EmbeddingFactory:
"""Embedding 模型工厂"""
# 默认模型名称
DEFAULT_MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'
# 定义静态方法,用于创建 embedding 对象
@staticmethod
def create_embeddings():
"""
创建 Embedding 模型
Args:
settings: 设置字典,包含 embedding_provider, embedding_model_name, embedding_api_key, embedding_base_url
Returns:
Embeddings 对象
"""
# 从 settings_service 获取嵌入设置
settings = settings_service.get()
# 获取 embedding 提供商,默认为 huggingface
provider = settings.get('embedding_provider', 'huggingface')
# 获取 embedding 模型名称
model_name = settings.get('embedding_model_name')
# 获取 embedding api key
api_key = settings.get('embedding_api_key')
# 获取 embedding base url
base_url = settings.get('embedding_base_url')
try:
# 如果提供商为 huggingface
if provider == 'huggingface':
# 创建 HuggingFace Embeddings 对象
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
# 记录日志
logger.info(f"创建 HuggingFace Embeddings: {model_name}")
# 如果提供商为 openai
elif provider == 'openai':
# 如果没有 api_key 抛出异常
if not api_key:
raise ValueError("OpenAI Embeddings 需要 API Key")
# 创建 OpenAI Embeddings 对象
embeddings = OpenAIEmbeddings(
model=model_name,
openai_api_key=api_key
)
# 记录日志
logger.info(f"创建 OpenAI Embeddings: {model_name}")
# 如果提供商为 ollama
elif provider == 'ollama':
# 如果没有 base_url 抛出异常
if not base_url:
raise ValueError("Ollama Embeddings 需要 Base URL")
# 创建 Ollama Embeddings 对象
embeddings = OllamaEmbeddings(
model=model_name,
base_url=base_url
)
# 记录日志
logger.info(f"创建 Ollama Embeddings: {model_name}, base_url: {base_url}")
else:
# 未知的提供商,警告日志,使用默认 huggingface
logger.warning(f"未知的 Embedding 提供商: {provider},使用默认的 HuggingFace")
embeddings = HuggingFaceEmbeddings(
model_name=EmbeddingFactory.DEFAULT_MODEL_NAME,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
# 返回构造的 embeddings 对象
return embeddings
except Exception as e:
# 出现异常时记录错误日志
logger.error(f"创建 Embedding 模型失败: {e}", exc_info=True)
# 失败时回退到默认模型并记录警告
logger.warning(f"回退到默认 HuggingFace 模型: {EmbeddingFactory.DEFAULT_MODEL_NAME}")
return HuggingFaceEmbeddings(
model_name=EmbeddingFactory.DEFAULT_MODEL_NAME,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
7.7. documents.py #
documents.py
# 导入chromadb库
import chromadb
# 连接到ChromaDB的持久化存储,指定数据库路径
client = chromadb.PersistentClient(path="./chroma_db") # 你的数据库路径
# 获取数据库中所有的集合(Collection)
collections = client.list_collections()
# 打印分隔线
print("=" * 40)
# 打印所有集合的名称
print(f"所有集合: {[c.name for c in collections]}")
# 再次打印分隔线
print("=" * 40)
# 遍历每一个集合,输出集合的详细内容
for collection in collections:
# 打印当前集合的名称
print(f"集合名: {collection.name}")
# 打印集合内容分隔线
print("-" * 30)
# 获取集合中的内容,并明确要求包含documents, metadatas, embeddings
results = collection.get(include=['documents', 'metadatas', 'embeddings'])
# 获取所有条目的id
ids = results['ids']
# 获取所有条目的文本内容
docs = results['documents']
# 获取所有条目的元数据
metadatas = results['metadatas']
# 获取所有条目的向量(embedding),用get防止KeyError
embeddings = results.get('embeddings') # 使用 get 方法,避免 KeyError
# 遍历每一条数据,同时获得编号
for i, (id_val, doc, meta, embedding) in enumerate(zip(ids, docs, metadatas, embeddings), 1):
# 打印当前文档的名称和向量维度信息
print(f"{meta.get('doc_name', '')} 第{i}条 向量维度: {len(embedding)}")7.8. config.py #
app/config.py
"""
配置管理模块
"""
# 导入操作系统相关模块
import os
# 导入 Path,处理路径
from pathlib import Path
# 导入 dotenv,用于加载 .env 文件中的环境变量
from dotenv import load_dotenv
# 加载 .env 文件中的环境变量到系统环境变量
load_dotenv()
# 定义应用配置类
class Config:
"""应用配置类"""
# 基础配置
# 项目根目录路径(取上级目录)
BASE_DIR = Path(__file__).parent.parent
# 加载环境变量 SECRET_KEY,若未设置则使用默认开发密钥
SECRET_KEY = os.environ.get('SECRET_KEY') or 'dev-secret-key-change-in-production'
# 应用配置
# 读取应用监听的主机地址,默认为本地所有地址
APP_HOST = os.environ.get('APP_HOST', '0.0.0.0')
# 读取应用监听的端口,默认为 5000,类型为 int
APP_PORT = int(os.environ.get('APP_PORT', 5000))
# 读取 debug 模式配置,字符串转小写等于 'true' 则为 True(开启调试)
APP_DEBUG = os.environ.get('APP_DEBUG', 'false').lower() == 'true'
# 读取允许上传的最大文件大小,默认为 100MB,类型为 int
MAX_FILE_SIZE = int(os.environ.get('MAX_FILE_SIZE', 104857600)) # 100MB
# 允许上传的文件扩展名集合
ALLOWED_EXTENSIONS = {'pdf', 'docx', 'txt', 'md'}
# 允许上传的图片扩展名集合
ALLOWED_IMAGE_EXTENSIONS = {'jpg', 'jpeg', 'png', 'gif', 'webp'}
# 允许上传的图片最大大小,默认为 5MB,类型为 int
MAX_IMAGE_SIZE = int(os.environ.get('MAX_IMAGE_SIZE', 5242880)) # 5MB
# 日志配置
# 日志目录,默认 './logs'
LOG_DIR = os.environ.get('LOG_DIR', './logs')
# 日志文件名,默认 'rag_lite.log'
LOG_FILE = os.environ.get('LOG_FILE', 'rag_lite.log')
# 日志等级,默认 'INFO'
LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO')
# 是否启用控制台日志,默认 True
LOG_ENABLE_CONSOLE = os.environ.get('LOG_ENABLE_CONSOLE', 'true').lower() == 'true'
# 是否启用文件日志,默认 True
LOG_ENABLE_FILE = os.environ.get('LOG_ENABLE_FILE', 'true').lower() == 'true'
# 数据库配置
# 数据库主机地址,默认为 'localhost'
DB_HOST = os.environ.get('DB_HOST', 'localhost')
# 数据库端口号,默认为 3306
DB_PORT = int(os.environ.get('DB_PORT', 3306))
# 数据库用户名,默认为 'root'
DB_USER = os.environ.get('DB_USER', 'root')
# 数据库密码,默认为 'root'
DB_PASSWORD = os.environ.get('DB_PASSWORD', 'root')
# 数据库名称,默认为 'rag-lite'
DB_NAME = os.environ.get('DB_NAME', 'rag-lite')
# 数据库字符集,默认为 'utf8mb4'
DB_CHARSET = os.environ.get('DB_CHARSET', 'utf8mb4')
# 存储配置
STORAGE_TYPE = os.environ.get('STORAGE_TYPE', 'local') # 'local' 或 'minio'
STORAGE_DIR = os.environ.get('STORAGE_DIR', './storages')
# MinIO 配置(当 STORAGE_TYPE='minio' 时使用)
MINIO_ENDPOINT = os.environ.get('MINIO_ENDPOINT', '')
MINIO_ACCESS_KEY = os.environ.get('MINIO_ACCESS_KEY', '')
MINIO_SECRET_KEY = os.environ.get('MINIO_SECRET_KEY', '')
MINIO_BUCKET_NAME = os.environ.get('MINIO_BUCKET_NAME', 'rag-lite')
MINIO_SECURE = os.environ.get('MINIO_SECURE', 'false').lower() == 'true'
MINIO_REGION = os.environ.get('MINIO_REGION', None)
# 向量数据库配置
+ VECTORDB_TYPE = os.environ.get('VECTORDB_TYPE', 'chroma') # 'chroma' 或 'milvus'
+ CHROMA_PERSIST_DIRECTORY = os.environ.get('CHROMA_PERSIST_DIRECTORY', './chroma_db')
# 深度求索配置
+ DEEPSEEK_CHAT_MODEL = os.environ.get('DEEPSEEK_CHAT_MODEL', 'deepseek-chat')
+ DEEPSEEK_API_KEY = os.environ.get('DEEPSEEK_API_KEY', 'sk-c4e682d07ed643e0bce7bb66f24c5720')
+ DEEPSEEK_BASE_URL = os.environ.get('DEEPSEEK_BASE_URL', 'https://api.deepseek.com')7.9 .env #
+DEEPSEEK_CHAT_MODEL=deepseek-chat
+DEEPSEEK_API_KEY=sk-c4e682d07ed643e0bce7bb66f24c5720
+DEEPSEEK_BASE_URL=https://api.deepseek.com7.10. document_service.py #
app/services/document_service.py
# 导入os模块,用于处理文件和路径
import os
# 导入uuid模块,用于生成唯一ID
import uuid
# 导入类型提示
from typing import List, Optional, Dict
# 导入线程池,用于异步处理文档
from concurrent.futures import ThreadPoolExecutor
# 导入BaseService基类
from app.services.base_service import BaseService
# 导入Document模型,重命名为DocumentModel
from app.models.document import Document as DocumentModel
# 导入Knowledgebase知识库模型
from app.models.knowledgebase import Knowledgebase
# 导入存储服务
from app.services.storage_service import storage_service
# 导入解析服务
from app.services.parser_service import parser_service
# 导入配置项
from app.config import Config
# 导入文本分块器
from app.utils.text_splitter import TextSplitter
# 导入LangChain文档对象
from langchain_core.documents import Document
+from app.services.vector_service import vector_service
# 定义DocumentService服务类,继承自BaseService
class DocumentService(BaseService[DocumentModel]):
"""文档服务"""
def __init__(self):
"""初始化服务"""
super().__init__()
self.executor = ThreadPoolExecutor(max_workers=4)
# 上传文档方法
def upload(self, kb_id: str, file_data: bytes, filename: str) -> dict:
"""
上传文档
参数:
kb_id: 知识库ID
file_data: 文件数据
filename: 文件名
返回:
创建的文档字典
"""
# 初始化变量,标识文件是否已经上传
file_uploaded = False
# 初始化文件路径
file_path = None
try:
# 使用数据库会话,判断知识库是否存在
with self.session() as session:
kb = session.query(Knowledgebase).filter(Knowledgebase.id == kb_id).first()
# 如果知识库不存在,抛出异常
if not kb:
raise ValueError(f"知识库 {kb_id} 不存在")
# 检查文件名是否为空或者没有扩展名
if not filename or '.' not in filename:
raise ValueError(f"文件名必须包含扩展名: {filename}")
# 获取文件扩展名
file_ext = os.path.splitext(filename)[1]
# 如果有后缀,去掉点号并转为小写
if file_ext:
file_ext = file_ext[1:].lower()
else:
# 如果没有文件后缀,抛出异常
raise ValueError(f"文件名必须包含扩展名: {filename}")
# 检查文件类型是否合法
if not file_ext or file_ext not in Config.ALLOWED_EXTENSIONS:
raise ValueError(f"不支持的文件类型: '{file_ext}'。允许类型: {', '.join(Config.ALLOWED_EXTENSIONS)}")
# 生成文档ID,用于标识唯一文档
doc_id = uuid.uuid4().hex[:32]
# 构建文件存储路径,便于后续文件操作
file_path = f"documents/{kb_id}/{doc_id}/{filename}"
# 优先将文件上传到本地/云存储,保证文件存在再创建记录
try:
storage_service.upload_file(file_path, file_data)
file_uploaded = True
except Exception as storage_error:
# 上传存储失败时写入日志并抛出异常
self.logger.error(f"上传文件到存储时发生错误: {storage_error}")
raise ValueError(f"文件上传失败: {str(storage_error)}")
# 在数据库中创建文档记录
with self.transaction() as session:
doc = DocumentModel(
id=doc_id,#文档ID
kb_id=kb_id,#知识库ID
name=filename,#文件名
file_path=file_path,#文件路径
file_type=file_ext,#文件类型
file_size=len(file_data),#文件大小
status='pending'#文档状态
)
# 添加文档记录到会话
session.add(doc)
# Flush 保证数据同步到数据库
session.flush()
# 刷新文档对象,获得新写入的数据
session.refresh(doc)
# 对象转dict,避免对象分离后后续属性访问失败
doc_dict = doc.to_dict()
# 日志记录上传成功
self.logger.info(f"文档上传成功: {doc.id}")
# 返回新创建的文档字典信息
return doc_dict
except Exception as e:
# 异常处理,如果文件已上传但事务失败,则尝试删除已上传的文件
if file_uploaded and file_path:
try:
storage_service.delete_file(file_path)
except Exception as delete_error:
self.logger.warning(f"删除已上传文件时出错: {delete_error}")
# 重新抛出异常
raise
# 定义根据知识库ID获取文档列表的方法
def list_by_kb(self, kb_id: str, page: int = 1, page_size: int = 10,
status: Optional[str] = None) -> Dict:
# 方法文档,说明该方法用于获取知识库的文档列表
"""获取知识库的文档列表"""
# 创建数据库会话
with self.session() as session:
# 查询指定知识库ID下的所有文档
query = session.query(DocumentModel).filter(DocumentModel.kb_id == kb_id)
# 如果设置了文档状态,则对状态进行过滤
if status:
query = query.filter(DocumentModel.status == status)
# 使用分页方法返回结果,按创建时间倒序排列
return self.paginate_query(query, page=page, page_size=page_size,
order_by=DocumentModel.created_at.desc())
# 定义处理单个文档的方法(手动触发处理)
def process(self, doc_id: str):
"""
处理文档(手动触发)
Args:
doc_id: 文档ID
"""
# 创建数据库会话,检查文档是否存在
with self.session() as session:
# 根据doc_id查询文档对象
doc = session.query(DocumentModel).filter(DocumentModel.id == doc_id).first()
# 如果文档不存在则抛出异常
if not doc:
raise ValueError(f"Document {doc_id} not found")
# 记录已经提交文档处理任务的日志
self.logger.info(f"提交文档处理任务: {doc_id}")
# 在线程池中异步提交处理任务
future = self.executor.submit(self._process_document, doc_id)
# 定义异常回调函数,用于捕获子线程中的异常
def exception_callback(future):
try:
# 获取异步线程的执行结果,发生异常会在此抛出
future.result()
except Exception as e:
# 记录错误日志及异常堆栈信息
self.logger.error(f"文档处理任务异常: {doc_id}, 错误: {e}", exc_info=True)
# 给future对象添加回调函数,任务结束时自动处理异常
future.add_done_callback(exception_callback)
# 文档实际处理方法(在子线程中执行,异步)
def _process_document(self, doc_id: str):
"""
处理文档(异步)
Args:
doc_id: 文档ID
"""
try:
# 日志:文档处理任务开始
self.logger.info(f"开始处理文档: {doc_id}")
# 首先开启事务,获取文档信息和知识库配置并更新文档初始状态
with self.transaction() as session:
# 查询文档对象
doc = session.query(DocumentModel).filter(DocumentModel.id == doc_id).first()
# 未找到文档则输出日志并返回
if not doc:
self.logger.error(f"未找到文档 {doc_id}")
return
# 若文档已被处理过(完成或失败),则需重置状态
need_cleanup = doc.status in ['completed', 'failed']
if need_cleanup:
# 重置状态为待处理,分块数归零、错误信息清除
doc.status = 'pending'
doc.chunk_count = 0
doc.error_message = None
# 更新为正在处理状态
doc.status = 'processing'
# 刷新到数据库(写入未提交)
session.flush()
# 提前取出相关数据,避免事务作用域外对象失效
kb_id = doc.kb_id
file_path = doc.file_path
file_type = doc.file_type
doc_name = doc.name
+ collection_name = f"kb_{doc.kb_id}" if need_cleanup else None
# 查询知识库配置
kb = session.query(Knowledgebase).filter(Knowledgebase.id == kb_id).first()
# 未找到知识库抛出异常
if not kb:
raise ValueError(f"知识库 {kb_id} 未找到")
# 获取知识库分块参数
kb_chunk_size = kb.chunk_size
kb_chunk_overlap = kb.chunk_overlap
# 如果需要清理旧分块和向量,则在事务作用域外先进行删除,避免占用数据库连接
+ if need_cleanup:
+ try:
# 调用向量服务,删除指定集合、指定文档ID下的所有向量数据
+ vector_service.delete_documents(
+ collection_name=collection_name,
+ filter={"doc_id": doc_id}
+ )
# 输出信息日志,标明文档的旧向量已被删除
+ self.logger.info(f"已删除文档 {doc_id} 的旧向量")
+ except Exception as e:
# 如果删除操作出错,则记录警告日志
+ self.logger.warning(f"删除向量时出错: {e}")
# 日志:文档已标记为处理中
self.logger.info(f"文档 {doc_id} 状态已更新为 processing(处理中)")
# 从存储中下载文件内容
file_data = storage_service.download_file(file_path)
# 解析文件,根据类型抽取原始文本内容
langchain_docs = parser_service.parse(file_data, file_type)
# 若抽取出的文本为空,则抛出异常
if not langchain_docs:
raise ValueError("未能抽取到任何文本内容")
# 创建文本分块器,指定知识库参数
splitter = TextSplitter(
chunk_size=kb_chunk_size,#块大小
chunk_overlap=kb_chunk_overlap,#块之间的重叠字符数
)
# 将文档内容分块
chunks = splitter.split_documents(langchain_docs, doc_id=doc_id)
# 如果分块失败,抛出异常
if not chunks:
raise ValueError("文档未能成功分块")
# 初始化一个列表用于存放转换后的 LangChain Document 对象
+ documents = []
# 遍历所有分块,将每个分块转换为 LangChain Document 对象
+ for chunk in chunks:
# 创建一个 LangChain Document,包含文本内容和相关元数据
+ doc_obj = Document(
+ page_content=chunk['text'],
+ metadata={
+ 'doc_id': doc_id, # 文档ID
+ 'doc_name': doc_name, # 文档名称
+ 'chunk_index': chunk['chunk_index'], # 分块索引
+ 'id': chunk['id'], # 分块ID
+ 'chunk_id': chunk['id'] # 分块ID
+ }
+ )
# 将生成的 Document 对象加入到 documents 列表
+ documents.append(doc_obj)
# 构造向量库集合名称,格式为 kb_知识库ID
+ collection_name = f"kb_{kb_id}"
# 提取所有分块的ID,用于向量存储
+ ids = [chunk['id'] for chunk in chunks]
# 调用向量服务,将分块后的文档写入向量库
+ vector_service.add_documents(
+ collection_name=collection_name, #集合名称
+ documents=documents, #文档列表
+ ids=ids #分块ID列表
+ )
# 再次开启事务,更新文档状态为完成,记录分块数
with self.transaction() as session:
doc = session.query(DocumentModel).filter(DocumentModel.id == doc_id).first()
if doc:
doc.status = 'completed'#完成状态
doc.chunk_count = len(chunks)#分块数
# 日志:处理完成,输出分块数
self.logger.info(f"文档处理完成: {doc_id}, 分块数量: {len(chunks)}")
except Exception as e:
# 捕获异常后,更新文档状态为失败,并记录错误信息(限长)
with self.transaction() as session:
# 查询文档对象
doc = session.query(DocumentModel).filter(DocumentModel.id == doc_id).first()
# 如果文档存在,则更新状态为失败,并记录错误信息(限长)
if doc:
doc.status = 'failed'#失败状态
doc.error_message = str(e)[:500]#错误信息
# 记录处理失败的日志
self.logger.error(f"处理文档 {doc_id} 时发生错误: {e}")
# 实例化DocumentService
document_service = DocumentService()8.Milvus向量 #
8.1. milvus.py #
app/services/vectordb/milvus.py
# Milvus 向量数据库实现说明
"""
Milvus 向量数据库实现
"""
# 导入日志模块
import logging
# 导入Milvus类
from langchain_milvus import Milvus
# 导入类型提示相关模块
from typing import List, Dict, Optional, Any
# 导入LangChain的文档类型
from langchain_core.documents import Document
# 导入向量数据库接口基类
from app.services.vectordb.base import VectorDBInterface
# 导入Embedding工厂方法
from app.utils.embedding_factory import EmbeddingFactory
# 获取日志记录器
logger = logging.getLogger(__name__)
# 定义 Milvus 向量数据库实现类,继承接口基类
class MilvusVectorDB(VectorDBInterface):
"""Milvus 向量数据库实现"""
# 初始化方法,设置连接参数和embedding模型
def __init__(self, connection_args: Optional[Dict] = None):
"""
初始化 Milvus 服务
Args:
connection_args: Milvus连接参数,例如:
{'host': 'localhost', 'port': '19530'}
"""
# 如果没有传递连接参数,则使用默认主机和端口
if connection_args is None:
connection_args = {'host': 'localhost', 'port': '19530'}
# 若端口是整数,转为字符串,以防止类型冲突(Milvus要求端口是字符串)
if 'port' in connection_args and isinstance(connection_args['port'], int):
connection_args = connection_args.copy()
connection_args['port'] = str(connection_args['port'])
# 保存连接参数到实例
self.connection_args = connection_args
# 动态创建 Embedding 模型
self.embeddings = EmbeddingFactory.create_embeddings()
# 打印初始化日志
logger.info(f"Milvus 已初始化, 连接参数: {connection_args}")
# 获取或创建 Milvus 向量集合的方法
def get_or_create_collection(self, collection_name: str) -> Any:
"""获取或创建向量存储"""
# 拷贝一份连接参数,检查端口类型
connection_args = self.connection_args.copy()
if 'port' in connection_args and isinstance(connection_args['port'], int):
connection_args['port'] = str(connection_args['port'])
# 创建 Milvus 向量存储对象,如果集合不存在会自动创建
# LangChain Milvus 会自动处理集合、索引创建和加载
vectorstore = Milvus(
collection_name=collection_name,#集合名称
embedding_function=self.embeddings,#embedding模型
connection_args=connection_args,#连接参数
)
# 如果集合对象有 _collection 属性,尝试加载已有集合(如果已经存在)
if hasattr(vectorstore, '_collection'):
try:
# 尝试加载集合
vectorstore._collection.load()
logger.debug(f"已加载 Milvus 集合 {collection_name}")
except Exception as e:
# 集合不存在或为空时输出Debug日志
logger.debug(f"集合 {collection_name} 可能不存在或为空: {e}")
# 返回vectorstore对象
return vectorstore
# 添加文档到 Milvus 的方法
def add_documents(self, collection_name: str, documents: List[Document],
ids: Optional[List[str]] = None) -> List[str]:
"""添加文档到向量存储"""
# 获取(或创建)对应的集合
vectorstore = self.get_or_create_collection(collection_name)
try:
# 指定了ID则传入,否则直接添加
if ids:
result_ids = vectorstore.add_documents(documents=documents, ids=ids)
else:
result_ids = vectorstore.add_documents(documents=documents)
# 确保数据写入并刷新到磁盘
# 通过内部 _collection 对象手动刷新
if hasattr(vectorstore, '_collection'):
vectorstore._collection.flush()#刷新集合
logger.debug(f"已刷新 Milvus 集合 {collection_name}")
# 记录添加文档的日志
logger.info(f"已向 Milvus 集合 {collection_name} 添加 {len(documents)} 个文档")
return result_ids#返回结果
except Exception as e:
# 添加失败时打印错误日志并抛出异常
logger.error(f"向 Milvus 集合 {collection_name} 添加文档时出错: {e}", exc_info=True)
raise
# 删除文档的方法
def delete_documents(self, collection_name: str, ids: Optional[List[str]] = None,
filter: Optional[Dict] = None) -> None:
"""删除文档"""
# 获取(或创建)指定名称的向量集合
vectorstore = self.get_or_create_collection(collection_name)
# 如果传入了ids,按id删除文档
if ids:
vectorstore.delete(ids=ids)
# 如果传入了filter,根据过滤条件删除文档
elif filter:
# 构造Milvus的表达式字符串,这里只处理doc_id的情况
expr = f'doc_id=="{filter["doc_id"]}"'
vectorstore.delete(expr=expr)
# ids和filter都未传,抛出异常
else:
raise ValueError(f"你既没有传ids,也没有传filter")
# 如果集合对象有_collection属性,则刷新集合,保证删除操作落盘
if hasattr(vectorstore, "_collection"):
vectorstore._collection.flush()
# 记录删除操作的日志
logger.info(f"已经从ChromDB集合{collection_name}删除文档")
8.2. .env #
.env
# 应用配置
APP_HOST=0.0.0.0
APP_PORT=5000
APP_DEBUG=True
MAX_FILE_SIZE=104857600
SECRET_KEY=dev-secret-key-change-in-production
# 日志配置
LOG_DIR=./logs
LOG_FILE=rag_lite.log
LOG_LEVEL=INFO
LOG_ENABLE_FILE=True
LOG_ENABLE_CONSOLE=True
# 数据库配置
DB_HOST=localhost
DB_PORT=3306
DB_USER=root
DB_PASSWORD=root
DB_NAME=rag-lite
DB_CHARSET=utf8mb4
# 存储配置
STORAGE_TYPE=local
STORAGE_DIR=./storages
MINIO_ENDPOINT=127.0.0.1:9000
MINIO_ACCESS_KEY=minioadmin
MINIO_SECRET_KEY=minioadmin
+VECTORDB_TYPE=milvus8.3. factory.py #
app/services/vectordb/factory.py
"""
向量数据库工厂
"""
# 导入日志模块
import logging
# 导入可选类型提示
from typing import Optional
# 导入向量数据库接口基类
from app.services.vectordb.base import VectorDBInterface
# 导入 Chroma 的向量数据库实现
from app.services.vectordb.chroma import ChromaVectorDB
# 导入 Milvus 的向量数据库实现
+from app.services.vectordb.milvus import MilvusVectorDB
# 导入全局配置
from app.config import Config
# 获取日志记录器
logger = logging.getLogger(__name__)
# 定义向量数据库工厂类
class VectorDBFactory:
"""向量数据库工厂"""
# 类属性:用于保存单例实例
_instance: Optional[VectorDBInterface] = None
# 类方法:创建向量数据库实例
@classmethod
def create_vector_db(cls, vectordb_type: Optional[str] = None, **kwargs) -> VectorDBInterface:
"""
创建向量数据库实例
Args:
vectordb_type: 向量数据库类型 ('chroma' 或 'milvus'),如果为None则从配置读取
settings: 设置字典,用于创建Embedding模型
**kwargs: 向量数据库的初始化参数
Returns:
向量数据库实例
"""
# 如果未传入 vectordb_type,则从配置读取,默认 'chroma'
if vectordb_type is None:
vectordb_type = getattr(Config, 'VECTORDB_TYPE', 'chroma')
# 转为小写,统一判断
vectordb_type = vectordb_type.lower()
# 如果选择的是 Chroma
if vectordb_type == 'chroma':
# 获取持久化目录(可选参数)
persist_directory = kwargs.get('persist_directory')
# 创建 ChromaVectorDB 实例并返回
return ChromaVectorDB(
persist_directory=persist_directory
)
+ elif vectordb_type == 'milvus':
# 从 kwargs 获取连接参数,如果没有则使用配置中的默认值(主机和端口)
+ connection_args = kwargs.get('connection_args') or {
+ 'host': getattr(Config, 'MILVUS_HOST', 'localhost'),
+ 'port': getattr(Config, 'MILVUS_PORT', '19530')
+ }
# 创建 MilvusVectorDB 实例,并传入连接参数
+ return MilvusVectorDB(
+ connection_args=connection_args
+ )
else:
# 其他类型暂不支持,抛出异常
raise ValueError(f"Unsupported vector database type: {vectordb_type}")
# 类方法:获取单例向量数据库实例(懒加载)
@classmethod
def get_instance(cls) -> VectorDBInterface:
"""
获取单例向量数据库实例(懒加载)
Args:
settings: 设置字典,用于创建Embedding模型
Returns:
向量数据库实例
"""
# 如果单例实例为空,则创建
if cls._instance is None:
cls._instance = cls.create_vector_db()
# 返回单例实例
return cls._instance
# 工厂方法:便捷获取向量数据库服务实例
def get_vector_db_service() -> VectorDBInterface:
"""
便捷函数:获取向量数据库服务实例
Args:
settings: 设置字典,用于创建Embedding模型
Returns:
向量数据库服务实例
"""
# 返回单例向量数据库实例
return VectorDBFactory.get_instance()9.查看分块 #
9.1. document_chunks.html #
app/templates/document_chunks.html
{% extends "base.html" %}
{% block title %}{{ document.name }} - 分块列表 - RAG Lite{% endblock %}
{% block content %}
<div class="row">
<div class="col-12">
<nav aria-label="breadcrumb" class="mb-3">
<ol class="breadcrumb">
<li class="breadcrumb-item"><a href="/">首页</a></li>
<li class="breadcrumb-item"><a href="/kb">知识库</a></li>
<li class="breadcrumb-item"><a href="/kb/{{ kb.id }}">{{ kb.name }}</a></li>
<li class="breadcrumb-item active">{{ document.name }} - 分块列表</li>
</ol>
</nav>
</div>
</div>
<div class="row">
<div class="col-12">
<div class="card">
<div class="card-header d-flex justify-content-between align-items-center">
<h5 class="mb-0">
<i class="bi bi-file-earmark-text"></i> {{ document.name }} - 分块列表
</h5>
<a href="/kb/{{ kb.id }}" class="btn btn-sm btn-secondary">
<i class="bi bi-arrow-left"></i> 返回文档列表
</a>
</div>
<div class="card-body">
<div class="mb-3">
<div class="row">
<div class="col-md-6">
<p class="mb-1"><strong>文档名称:</strong>{{ document.name }}</p>
<p class="mb-1"><strong>文档状态:</strong>
{% set status_map = {
'completed': '已完成',
'processing': '处理中',
'failed': '失败',
'pending': '待处理'
} %}
<span class="badge bg-{{ 'success' if document.status == 'completed' else 'warning' if document.status == 'processing' else 'danger' if document.status == 'failed' else 'secondary' }}">
{{ status_map.get(document.status, document.status) }}
</span>
</p>
</div>
<div class="col-md-6">
<p class="mb-1"><strong>分块数量:</strong>{{ chunks|length }}</p>
<p class="mb-1"><strong>文件大小:</strong>{{ "%.2f"|format(document.file_size / 1024) }} KB</p>
</div>
</div>
</div>
{% if chunks %}
<div class="table-responsive">
<table class="table table-hover">
<thead>
<tr>
<th style="width: 80px;">序号</th>
<th style="width: 200px;">分块ID</th>
<th>内容</th>
</tr>
</thead>
<tbody>
{% for chunk in chunks %}
<tr>
<td>{{ chunk.chunk_index + 1 }}</td>
<td>
<code class="text-muted" style="font-size: 0.85em;">{{ chunk.id }}</code>
</td>
<td>
<div class="chunk-content" style="max-height: 150px; overflow-y: auto;">
{{ chunk.content }}
</div>
<button class="btn btn-sm btn-link p-0 mt-1" data-chunk-id="{{ chunk.id }}" data-chunk-content="{{ chunk.content|replace('"', '"')|replace("'", "'")|replace('\n', ' ') }}" onclick="copyChunkContent(event)">
<i class="bi bi-clipboard"></i> 复制
</button>
</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
{% else %}
<div class="text-center text-muted py-5">
<i class="bi bi-inbox" style="font-size: 3rem;"></i>
<p class="mt-2">该文档还没有分块数据</p>
<p class="text-muted small">请确保文档已处理完成(状态为 completed)</p>
</div>
{% endif %}
</div>
</div>
</div>
</div>
{% endblock %}
{% block extra_js %}
<script>
// 定义一个函数用于复制分块内容
function copyChunkContent(event) {
// 获取被点击的按钮元素
const btn = event.target.closest('button');
// 获取分块ID(可用于后续拓展)
const chunkId = btn.getAttribute('data-chunk-id');
// 获取分块内容(含HTML实体编码)
let content = btn.getAttribute('data-chunk-content');
// 将HTML实体还原为原始字符(换行、单引号、双引号)
content = content.replace(/ /g, '\n').replace(/'/g, "'").replace(/"/g, '"');
// 使用浏览器Clipboard API将内容写入剪贴板
navigator.clipboard.writeText(content).then(function() {
// 成功复制后:先保存按钮原始显示内容
const originalHTML = btn.innerHTML;
// 修改按钮显示为已复制样式
btn.innerHTML = '<i class="bi bi-check"></i> 已复制';
btn.classList.add('text-success');
// 2秒后恢复按钮为原始内容和样式
setTimeout(function() {
btn.innerHTML = originalHTML;
btn.classList.remove('text-success');
}, 2000);
}).catch(function(err) {
// 如果写入剪贴板失败,打印错误并弹窗提示
console.error('复制失败:', err);
alert('复制失败,请手动复制');
});
}
</script>
{% endblock %}
9.2. document.py #
app/blueprints/document.py
"""
文档相关路由(视图 + API)
"""
# 从 Flask 导入 Blueprint 和 request,用于定义蓝图和处理请求
+from flask import Blueprint,request,flash,redirect,url_for,render_template
# 导入 os 模块,用于文件名后缀等操作
import os
# 导入 logging 模块,用于日志记录
import logging
# 从工具模块导入通用的响应和错误处理函数
from app.blueprints.utils import success_response, error_response, handle_api_error
# 导入文档服务,用于处理文档业务逻辑
from app.services.document_service import document_service
# 导入配置文件,获取相关配置参数
from app.config import Config
+from app.utils.auth import login_required
# 导入知识库服务
+from app.services.knowledgebase_service import kb_service
+from app.models.document import Document as DocumentModel
# 设置日志对象
logger = logging.getLogger(__name__)
# 创建一个名为 'document' 的蓝图
bp = Blueprint('document', __name__)
# 定义一个检查文件扩展名是否合法的辅助函数
def allowed_file(filename):
# 检查文件名中有小数点并且扩展名在允许的扩展名列表中
return '.' in filename and \
os.path.splitext(filename)[1][1:].lower() in Config.ALLOWED_EXTENSIONS
# 定义上传文档的 API 路由,POST 方法
@bp.route('/api/v1/knowledgebases/<kb_id>/documents', methods=['POST'])
# 使用装饰器统一捕获和处理 API 错误
@handle_api_error
def api_upload(kb_id):
# 上传文档接口
"""上传文档"""
# 如果请求中没有 'file' 字段,返回参数错误
if 'file' not in request.files:
return error_response("No file part", 400)
# 从请求中获取上传的文件对象
file = request.files['file']
# 如果用户未选择文件(文件名为空),返回错误
if file.filename == '':
return error_response("No file selected", 400)
# 如果文件类型不被允许(扩展名校验),返回错误
if not allowed_file(file.filename):
return error_response(f"File type not allowed. Allowed: {', '.join(Config.ALLOWED_EXTENSIONS)}", 400)
# 读取上传的文件内容
file_data = file.read()
# 检查文件大小是否超过上限
if len(file_data) > Config.MAX_FILE_SIZE:
return error_response(f"File size exceeds maximum {Config.MAX_FILE_SIZE} bytes", 400)
# 获取前端自定义的文件名(可选)
custom_name = request.form.get('name')
# 如果指定了自定义文件名
if custom_name:
# 获取原始文件扩展名
original_ext = os.path.splitext(file.filename)[1]
# 如果自定义文件名没有扩展名,自动补上原扩展名
if not os.path.splitext(custom_name)[1] and original_ext:
filename = custom_name + original_ext
else:
filename = custom_name
# 否则使用原始文件名
else:
filename = file.filename
# 校验最终得到的文件名:为空或只包含空白则报错
if not filename or not filename.strip():
return error_response("Filename is required", 400)
# 校验最终文件名是否包含扩展名
if '.' not in filename:
return error_response("Filename must have an extension", 400)
# 调用文档服务上传,返回文档信息字典
doc_dict = document_service.upload(kb_id, file_data, filename)
# 返回成功响应及新文档信息
return success_response(doc_dict)
# 定义路由,处理 POST 请求,API 路径包含待处理文档的ID
@bp.route('/api/v1/documents/<doc_id>/process', methods=['POST'])
# 使用自定义的错误处理装饰器,以统一接口异常响应格式
@handle_api_error
def api_process(doc_id):
# 处理文档(API接口),doc_id是传入的文档ID
"""处理文档"""
try:
# 调用文档服务方法,提交文档处理任务(异步)
document_service.process(doc_id)
# 返回成功响应信息,提示任务已提交
return success_response({"message": "文档处理任务已提交"})
except ValueError as e:
# 捕获业务异常(如找不到文档),返回404和错误信息
return error_response(str(e), 404)
except Exception as e:
# 捕获其他未知异常,记录错误日志(带堆栈信息)
logger.error(f"处理文档时出错: {e}", exc_info=True)
# 返回500通用错误响应,附带错误原因
return error_response(f"处理文档失败: {str(e)}", 500)
# 路由:文档分块列表页面,URL中包含doc_id
+@bp.route('/documents/<doc_id>/chunks')
# 需要用户登录
+@login_required
+def document_chunks(doc_id):
# 文档分块列表页面
+ """文档分块列表页面"""
# 根据文档ID获取文档对象
+ doc = document_service.get_by_id(DocumentModel,doc_id)
# 如果文档不存在,提示错误并跳转到知识库列表页面
+ if not doc:
+ flash('文档不存在', 'error')
+ return redirect(url_for('knowledgebase.kb_list'))
# 根据文档的kb_id获取对应的知识库对象
+ kb = kb_service.get_by_id(doc.kb_id)
# 如果知识库不存在,提示错误并跳转到知识库列表页面
+ if not kb:
+ flash('知识库不存在', 'error')
+ return redirect(url_for('knowledgebase.kb_list'))
# 获取分块数据
+ try:
# 动态导入向量数据库服务工厂方法
+ from app.services.vectordb.factory import get_vector_db_service
# 组合向量数据库中的 collection 名称
+ collection_name = f"kb_{doc.kb_id}"
# 获取向量数据库服务实例
+ vectordb = get_vector_db_service()
# 构建过滤条件,只获取当前文档的分块
+ filter_dict = {"doc_id": doc_id}
# 通过similarity_search_with_score方法,检索所有属于该文档的分块
+ results = vectordb.similarity_search_with_score(
+ collection_name=collection_name,
+ query=" ", # 用空格作为查询内容,目的是获取所有分块
+ k=100000, # 设置很大的k值来避免遗漏分块
+ filter=filter_dict # 应用文档ID过滤
+ )
# 提取查询结果中的Document对象,准备排序
+ chunks = [doc for doc, _ in results]
# 按chunk_index进行排序,保证顺序和原始文档一致
+ chunks.sort(key=lambda d: d.metadata.get('chunk_index', 0))
# 初始化分块列表
+ chunks_data = []
# 遍历所有分块对象,整理为前端模板使用的字典格式
+ for chunk in chunks:
+ chunks_data.append({
# 分块ID,如果不存在则尝试兼容chunk_id字段
+ 'id': chunk.metadata.get('id') or chunk.metadata.get('chunk_id', ''),
# 分块文本内容
+ 'content': chunk.page_content,
# 分块在文档中的序号
+ 'chunk_index': chunk.metadata.get('chunk_index', 0),
# 分块的原始元数据
+ 'metadata': chunk.metadata
+ })
# 捕获异常,如果出错则记录日志,并展示空分块列表
+ except Exception as e:
+ logger.error(f"获取分块数据失败: {e}")
+ chunks_data = []
# 渲染模板,传递知识库、文档及分块列表数据给页面
+ return render_template(
+ 'document_chunks.html',
+ kb=kb,
+ document=doc.to_dict(),
+ chunks=chunks_data
+ ) 9.3. base.py #
app/services/vectordb/base.py
"""
向量数据库抽象接口
"""
# 导入 abc 库,定义抽象基类和抽象方法
from abc import ABC, abstractmethod
# 导入类型提示:列表、字典、可选和任意类型
from typing import List, Dict, Optional, Any
# 导入 LangChain 的 Document 类,用于文档对象
from langchain_core.documents import Document
# 定义一个向量数据库的抽象接口,继承自 ABC 抽象基类
class VectorDBInterface(ABC):
"""向量数据库抽象接口"""
# 定义抽象方法:获取或创建集合
@abstractmethod
def get_or_create_collection(self, collection_name: str) -> Any:
"""
获取或创建集合
Args:
collection_name: 集合名称
Returns:
向量存储对象
"""
# 子类需要实现具体逻辑
pass
# 定义抽象方法:向向量存储中添加文档
@abstractmethod
def add_documents(self, collection_name: str, documents: List[Document],
ids: Optional[List[str]] = None) -> List[str]:
"""
添加文档到向量存储
Args:
collection_name: 集合名称
documents: Document 列表
ids: 文档ID列表(可选)
Returns:
添加的文档ID列表
"""
# 子类需要实现具体逻辑
pass
# 定义抽象方法:删除指定的文档
@abstractmethod
def delete_documents(self, collection_name: str, ids: Optional[List[str]] = None,
filter: Optional[Dict] = None) -> None:
"""
删除文档
Args:
collection_name: 集合名称
ids: 要删除的文档ID列表(可选)
filter: 过滤条件(可选)
"""
# 子类需要实现具体逻辑
pass
+ @abstractmethod
+ def similarity_search(self, collection_name: str, query: str,
+ k: int = 5, filter: Optional[Dict] = None) -> List[Document]:
+ """
+ 相似度搜索
+ Args:
+ collection_name: 集合名称
+ query: 查询文本
+ k: 返回结果数量
+ filter: 元数据过滤条件
+ Returns:
+ 检索到的 Document 列表
+ """
+ pass
+ @abstractmethod
+ def similarity_search_with_score(self, collection_name: str, query: str,
+ k: int = 5, filter: Optional[Dict] = None) -> List[tuple]:
+ """
+ 相似度搜索(带分数)
+ Args:
+ collection_name: 集合名称
+ query: 查询文本
+ k: 返回结果数量
+ filter: 元数据过滤条件
+ Returns:
+ (Document, score) 元组列表
+ """
+ pass
9.4. chroma.py #
app/services/vectordb/chroma.py
# ChromaDB 向量数据库实现
"""
ChromaDB 向量数据库实现
"""
# 导入日志模块
import logging
# 导入需要的类型提示
from typing import List, Dict, Optional, Any
# 导入 LangChain 的 Chroma 类
from langchain_chroma import Chroma
# 导入 Document 类
from langchain_core.documents import Document
# 导入向量数据库接口基类
from app.services.vectordb.base import VectorDBInterface
# 导入全局配置
from app.config import Config
# 导入嵌入模型工厂
from app.utils.embedding_factory import EmbeddingFactory
# 获取日志记录器
logger = logging.getLogger(__name__)
# 定义 Chroma 向量数据库实现类
class ChromaVectorDB(VectorDBInterface):
# ChromaDB 向量数据库实现
# 初始化方法
def __init__(self, persist_directory: Optional[str] = None,):
"""
初始化 ChromaDB 服务
Args:
persist_directory: 持久化目录,如果为None则使用配置中的值
settings: 设置字典,用于创建Embedding模型
"""
# 如果持久化目录为 None,则从配置读取
if persist_directory is None:
persist_directory = Config.CHROMA_PERSIST_DIRECTORY
# 设置持久化目录属性
self.persist_directory = persist_directory
# 动态创建 Embedding 模型
self.embeddings = EmbeddingFactory.create_embeddings()
# 记录 ChromaDB 初始化信息
logger.info(f"ChromaDB 已初始化, 持久化目录: {persist_directory}")
# 获取或创建集合(向量存储)
def get_or_create_collection(self, collection_name: str) -> Chroma:
# 获取或创建向量存储对象
vectorstore = Chroma(
collection_name=collection_name,
embedding_function=self.embeddings,
persist_directory=self.persist_directory
)
# 返回向量存储对象
return vectorstore
# 向向量存储添加文档
def add_documents(self, collection_name: str, documents: List[Document],
ids: Optional[List[str]] = None) -> List[str]:
# 获取集合
vectorstore = self.get_or_create_collection(collection_name)
# 如果指定了 ids
if ids:
# 添加文档,指定 ids
result_ids = vectorstore.add_documents(documents=documents, ids=ids)
else:
# 添加文档,不指定 ids
result_ids = vectorstore.add_documents(documents=documents)
# 记录日志,添加文档
logger.info(f"已向 ChromaDB 集合 {collection_name} 添加 {len(documents)} 个文档")
# 返回已添加文档的 id 列表
return result_ids
# 删除文档
def delete_documents(self, collection_name: str, ids: Optional[List[str]] = None,
filter: Optional[Dict] = None) -> None:
vectorstore = self.get_or_create_collection(collection_name)
if ids:
vectorstore.delete(ids=ids)
elif filter:
# ChromaDB 不支持直接使用 filter 参数删除
# 需要先查询出符合条件的文档 IDs,然后删除
try:
collection = vectorstore._collection
# 使用 where 条件查询匹配的文档
# filter 格式: {"doc_id": "xxx"}
where = filter
results = collection.get(where=where)
if results and 'ids' in results and results['ids']:
matched_ids = results['ids']
vectorstore.delete(ids=matched_ids)
logger.info(f"已通过filter条件删除{len(matched_ids)}个文档")
else:
logger.info(f"未找到匹配filter条件的文档,无需删除")
except Exception as e:
logger.error(f"使用filter删除文档时出错: {e}", exc_info=True)
raise
else:
raise ValueError(f"你既没有传ids,也没有传filter")
logger.info(f"已经从ChromDB集合{collection_name}删除文档")
# 定义相似度搜索方法
+ def similarity_search(self, collection_name: str, query: str,
+ k: int = 5, filter: Optional[Dict] = None) -> List[Document]:
# 方法说明,用于执行相似度搜索
+ """相似度搜索"""
# 获取或创建集合对应的向量存储对象
+ vectorstore = self.get_or_create_collection(collection_name)
# 如果指定了过滤条件
+ if filter:
# 带过滤条件地执行相似度搜索
+ results = vectorstore.similarity_search(query=query, k=k, filter=filter)
+ else:
# 不带过滤条件地执行相似度搜索
+ results = vectorstore.similarity_search(query=query, k=k)
# 返回搜索结果
+ return results
# 定义带分数的相似度搜索方法
+ def similarity_search_with_score(self, collection_name: str, query: str,
+ k: int = 5, filter: Optional[Dict] = None) -> List[tuple]:
+ # 方法说明,用于执行带分数的相似度搜索
+ vectorstore = self.get_or_create_collection(collection_name)
+ if filter:
+ results = vectorstore.similarity_search_with_score(
+ query=query, k=k, filter=filter
+ )
+ else:
+ results = vectorstore.similarity_search_with_score(query=query, k=k)
+ return results
9.5. milvus.py #
app/services/vectordb/milvus.py
# Milvus 向量数据库实现说明
"""
Milvus 向量数据库实现
"""
# 导入日志模块
import logging
# 导入Milvus类
from langchain_milvus import Milvus
# 导入类型提示相关模块
from typing import List, Dict, Optional, Any
# 导入LangChain的文档类型
from langchain_core.documents import Document
# 导入向量数据库接口基类
from app.services.vectordb.base import VectorDBInterface
# 导入Embedding工厂方法
from app.utils.embedding_factory import EmbeddingFactory
# 获取日志记录器
logger = logging.getLogger(__name__)
# 定义 Milvus 向量数据库实现类,继承接口基类
class MilvusVectorDB(VectorDBInterface):
"""Milvus 向量数据库实现"""
# 初始化方法,设置连接参数和embedding模型
def __init__(self, connection_args: Optional[Dict] = None):
"""
初始化 Milvus 服务
Args:
connection_args: Milvus连接参数,例如:
{'host': 'localhost', 'port': '19530'}
"""
# 如果没有传递连接参数,则使用默认主机和端口
if connection_args is None:
connection_args = {'host': 'localhost', 'port': '19530'}
# 若端口是整数,转为字符串,以防止类型冲突(Milvus要求端口是字符串)
if 'port' in connection_args and isinstance(connection_args['port'], int):
connection_args = connection_args.copy()
connection_args['port'] = str(connection_args['port'])
# 保存连接参数到实例
self.connection_args = connection_args
# 动态创建 Embedding 模型
self.embeddings = EmbeddingFactory.create_embeddings()
# 打印初始化日志
logger.info(f"Milvus 已初始化, 连接参数: {connection_args}")
# 获取或创建 Milvus 向量集合的方法
def get_or_create_collection(self, collection_name: str) -> Any:
"""获取或创建向量存储"""
# 拷贝一份连接参数,检查端口类型
connection_args = self.connection_args.copy()
if 'port' in connection_args and isinstance(connection_args['port'], int):
connection_args['port'] = str(connection_args['port'])
# 创建 Milvus 向量存储对象,如果集合不存在会自动创建
# LangChain Milvus 会自动处理集合、索引创建和加载
vectorstore = Milvus(
collection_name=collection_name,#集合名称
embedding_function=self.embeddings,#embedding模型
connection_args=connection_args,#连接参数
)
# 如果集合对象有 _collection 属性,尝试加载已有集合(如果已经存在)
if hasattr(vectorstore, '_collection'):
try:
# 尝试加载集合
vectorstore._collection.load()
logger.debug(f"已加载 Milvus 集合 {collection_name}")
except Exception as e:
# 集合不存在或为空时输出Debug日志
logger.debug(f"集合 {collection_name} 可能不存在或为空: {e}")
# 返回vectorstore对象
return vectorstore
# 添加文档到 Milvus 的方法
def add_documents(self, collection_name: str, documents: List[Document],
ids: Optional[List[str]] = None) -> List[str]:
"""添加文档到向量存储"""
# 获取(或创建)对应的集合
vectorstore = self.get_or_create_collection(collection_name)
try:
# 指定了ID则传入,否则直接添加
if ids:
result_ids = vectorstore.add_documents(documents=documents, ids=ids)
else:
result_ids = vectorstore.add_documents(documents=documents)
# 确保数据写入并刷新到磁盘
# 通过内部 _collection 对象手动刷新
if hasattr(vectorstore, '_collection'):
vectorstore._collection.flush()#刷新集合
logger.debug(f"已刷新 Milvus 集合 {collection_name}")
# 记录添加文档的日志
logger.info(f"已向 Milvus 集合 {collection_name} 添加 {len(documents)} 个文档")
return result_ids#返回结果
except Exception as e:
# 添加失败时打印错误日志并抛出异常
logger.error(f"向 Milvus 集合 {collection_name} 添加文档时出错: {e}", exc_info=True)
raise
# 删除文档的方法
def delete_documents(self, collection_name: str, ids: Optional[List[str]] = None,
filter: Optional[Dict] = None) -> None:
"""删除文档"""
vectorstore = self.get_or_create_collection(collection_name)
if ids:
vectorstore.delete(ids=ids)
elif filter:
expr = f'doc_id=="{filter["doc_id"]}"'
vectorstore.delete(expr=expr)
else:
raise ValueError(f"你既没有传ids,也没有传filter")
if hasattr(vectorstore, "_collection"):
# 刷新集合
vectorstore._collection.flush()
logger.info(f"已经从ChromDB集合{collection_name}删除文档")
# 定义相似度搜索方法
+ def similarity_search(self, collection_name: str, query: str,
+ k: int = 5, filter: Optional[Dict] = None) -> List[Document]:
# 方法说明(可选)
+ """相似度搜索"""
# 获取或创建名为 collection_name 的集合
+ vectorstore = self.get_or_create_collection(collection_name)
# 检查集合是否已被加载(LangChain Milvus 一般自动处理,这里显式保证)
+ if hasattr(vectorstore, '_collection'):
+ try:
# 显式加载集合
+ vectorstore._collection.load()
+ except Exception as e:
# 如果加载失败或已加载,记录到 debug 日志
+ logger.debug(f"集合可能已加载或加载失败: {e}")
# 如果指定了过滤条件
+ if filter:
# 使用带过滤表达式的相似度搜索
+ results = vectorstore.similarity_search(query=query, k=k, expr=filter)
+ else:
# 不带过滤条件,直接搜索
+ results = vectorstore.similarity_search(query=query, k=k)
# 返回搜索结果
+ return results
# 定义带分数的相似度搜索方法
+ def similarity_search_with_score(self, collection_name: str, query: str,
+ k: int = 5, filter: Optional[Dict] = None) -> List[tuple]:
# 获取对应名称的向量存储集合
+ vectorstore = self.get_or_create_collection(collection_name)
+ # 判断集合对象是否有"_collection"属性(Milvus 后端可能有,需要显式加载)
+ if hasattr(vectorstore, "_collection"):
+ try:
+ # 显式加载集合(加速后续搜索操作)
+ vectorstore._collection.load()
+ logger.info(f"已经加载集合{collection_name}")
+ except Exception as e:
+ # 加载失败/集合不存在时记录日志,不影响流程
+ logger.info(f"集合可能不存在:{e}")
+ # 如果传递了过滤条件
+ if filter:
+ # 根据过滤条件构造Milvus的过滤表达式,只支持doc_id精准查询
+ expr = f'doc_id=="{filter["doc_id"]}"'
+ # 带过滤表达式执行相似度检索,并拿到分数
+ results = vectorstore.similarity_search_with_score(
+ query=query, k=k, expr=expr
+ )
+ print("filter_results", len(results))
+ else:
+ # 如果没有过滤条件,直接执行检索
+ results = vectorstore.similarity_search_with_score(query=query, k=k)
+ # 返回(检索文档, 分数)结果列表
+ return results10.删除文档 #
10.1. document.py #
app/blueprints/document.py
"""
文档相关的路由
"""
from flask import Blueprint, request, flash, render_template, redirect, url_for
import os
from app.blueprints.utils import success_response, error_response, handle_api_error
from app.utils.logger import get_logger
from app.utils.file import allowed_file
from app.services.document_service import document_service
from app.services.vector_service import vector_service
from app.services.knowledgebase_service import kb_service
from app.config import Config
from app.utils.auth import login_required
from app.models.document import Document as DocumentModel
logger = get_logger(__name__)
bp = Blueprint("document", __name__)
@bp.route("/api/v1/knowledgebases/<kb_id>/documents", methods=["POST"])
@handle_api_error
def api_upload(kb_id):
if "file" not in request.files:
return error_response("没有文件字段", 400)
file = request.files["file"]
if file.filename == "":
return error_response("没有选中任何文件", 400)
if not allowed_file(file.filename):
return error_response(
f"文件类型不允许上传,只允许:{', '.join(Config.ALLOWED_EXTENSIONS)}", 400
)
# 读取上传的文件内容
file_data = file.read()
if len(file_data) > Config.MAX_FILE_SIZE:
return error_response(
f"文件大小超过了最大的大小:{', '.join(Config.MAX_FILE_SIZE)} bytes", 400
)
# 这是用户在前端自定义的文件名
custom_name = request.form.get("name")
if custom_name:
# 获取 原始的文件扩展名 .pdf
original_ext = os.path.splitext(file.filename)[1]
if not os.path.splitext(custom_name)[1] and original_ext:
# filename = store.pdf
filename = custom_name + original_ext
else:
filename = custom_name
else:
filename = file.filename
if not filename or not filename.strip():
return error_response(f"文件名必须存在", 400)
# 调用文档上传服务,返回文档信息的字典
doc_dict = document_service.upload(kb_id, file_data, filename)
return success_response(doc_dict)
@bp.route("/api/v1/documents/<doc_id>/process", methods=["POST"])
@handle_api_error
def api_process(doc_id):
try:
document_service.process(doc_id)
return success_response({"message": "文档处理任务已经提交"})
except ValueError as e:
return error_response(str(e), 400)
except Exception as e:
return error_response("处理文档失:{str(e)}", 500)
+@bp.route("/api/v1/documents/<doc_id>", methods=["DELETE"])
+@handle_api_error
+def api_delete(doc_id):
+ """删除文档API"""
+ try:
+ document_service.delete(doc_id)
+ return success_response({"message": "文档删除成功"})
+ except ValueError as e:
+ return error_response(str(e), 400)
+ except Exception as e:
+ logger.error(f"删除文档失败:{e}", exc_info=True)
+ return error_response(f"删除文档失败:{str(e)}", 500)
@bp.route("/documents/<doc_id>/chunks")
@login_required
def document_chunks(doc_id):
doc = document_service.get_by_id(DocumentModel, doc_id)
if not doc:
flash("文档不存在", "error")
return redirect(url_for("knowledgebase.kb_list"))
kb = kb_service.get_by_id(doc.kb_id)
if not kb:
flash("知识库不存在", "error")
return redirect(url_for("knowledgebase.kb_list"))
# 获取分块列表
try:
# 组合向量数据库的集合名称
collection_name = f"kb_{doc.kb_id}"
# 构建过滤条件,按文档ID进行过滤
filter_dict = {"doc_id": doc_id}
results = vector_service.similarity_search_with_score(
collection_name=collection_name, query="", k=10000, filter=filter_dict
)
print(f"vector_service查询到的结果:{len(results)}")
# 这个doc指的chromdb里的Document对象
document_vectors = [doc for doc, _ in results]
# 按chunk_index 进行排序,保证顺序和原始文档顺序是一样的
document_vectors.sort(key=lambda d: d.metadata.get("chunk_index", 0))
chunks_data = []
for document_vector in document_vectors:
chunks_data.append(
{
"id": document_vector.metadata.get("id"), # 文本分块ID
"content": document_vector.page_content, # 分块的文本内容
"chunk_index": document_vector.metadata.get(
"chunk_index"
), # 分块在文档中的索引
"metadata": document_vector.metadata,
}
)
except Exception as e:
logger.error(f"获取分块数据失败:{e}")
chunks_data = []
return render_template(
"document_chunks.html", kb=kb, document=doc.to_dict(), chunks=chunks_data
)
10.2. document_service.py #
app/services/document_service.py
from app.models.document import Document as DocumentModel
from app.services.base_service import BaseService
import os
from app.services.storage_service import storage_service
from app.services.parser_service import parser_service
from app.services.vector_service import vector_service
from app.config import Config
from app.models.knowledgebase import Knowledgebase
import uuid
from app.utils.text_splitter import TextSplitter
from concurrent.futures import ThreadPoolExecutor
from langchain_core.documents import Document
class DocumentService(BaseService[DocumentModel]):
def __init__(self):
super().__init__()
self.executor = ThreadPoolExecutor(max_workers=4)
def upload(self, kb_id, file_data, filename):
file_uploaded = False
file_path = None
try:
with self.session() as session:
kb = (
session.query(Knowledgebase)
.filter(Knowledgebase.id == kb_id)
.first()
)
if not kb:
raise ValueError(f"知识库{kb_id}不存在")
# 获取文件的扩展名
file_ext = os.path.splitext(filename)[1]
if file_ext:
file_ext = file_ext[1:].lower()
else:
raise ValueError(f"文件名必须包含扩展名:{filename}")
# 生成文档ID,用于标识唯一的文档
doc_id = uuid.uuid4().hex[:32]
# 构建文档存储的路径,以便于后续操作
file_path = f"documents/{kb_id}/{doc_id}/{filename}"
try:
storage_service.upload_file(file_path, file_data)
file_uploaded = True
except Exception as storage_error:
self.logger.error(f"上传文件到存储时发生了错误:{storage_error}")
raise ValueError(f"文件上传失败:{str(storage_error)}")
with self.transaction() as session:
doc = DocumentModel(
id=doc_id,
kb_id=kb_id,
name=filename,
file_path=file_path,
file_type=file_ext,
file_size=len(file_data),
status="pending", # 文档状态默认值为正在处理中
)
session.add(doc)
session.flush()
session.refresh(doc)
doc_dict = doc.to_dict()
self.logger.info(f"文档上传成功:{doc_id}")
return doc_dict
except Exception as e:
if file_uploaded and file_path:
try:
storage_service.delete_file(file_path)
except Exception as delete_error:
self.logger.warning(f"删除已经上传的文件时出错:{delete_error}")
raise
def list_by_kb(self, kb_id, page, page_size, status=None):
with self.session() as session:
query = session.query(DocumentModel).filter(DocumentModel.kb_id == kb_id)
if status:
query.filter(DocumentModel.status == status)
return self.paginate_query(
query, # 查询条件
page=page, # 当前 页码
page_size=page_size, # 返回多少条
order_by=DocumentModel.created_at.desc(), # 排序字段
)
def process(self, doc_id):
with self.session() as session:
doc = (
session.query(DocumentModel).filter(DocumentModel.id == doc_id).first()
)
if not doc:
raise ValueError("文档{doc_id}不存在")
self.logger.info(f"提交文档处理任务:{doc_id}")
# 在线程池中异步提交处理任务
future = self.executor.submit(self._process_document, doc_id)
def exception_callback(future):
try:
# 获取异步线程的执行结果,发生异常的时候会在此抛出异常
future.result()
except Exception as e:
self.logger.error(f"文档处理任务异常:{doc_id},错误:{e}", exc_info=True)
future.add_done_callback(exception_callback)
def _process_document(self, doc_id):
try:
self.logger.info(f"开始处理文档:{doc_id}")
with self.transaction() as session:
doc = (
session.query(DocumentModel)
.filter(DocumentModel.id == doc_id)
.first()
)
if not doc:
self.logger.error(f"未找到文档:{doc_id}")
return
# 如果文档已经处理过了,需要重置为原始的状态
need_cleanup = doc.status in ["completed", "failed"]
# 如果要恢复为原始状态的话,需要重置 状态,分块大小,错误处理清空
if need_cleanup:
doc.chunk_count = 0
doc.error_message = ""
# 把状态更新为处理中
doc.status = "processing"
# 刷新到数据库,这个时候写入未提交的修改
session.flush()
# 提前取出相关的数据
kb_id = doc.kb_id
file_path = doc.file_path
file_type = doc.file_type
doc_name = doc.name
# 一个知识库对应一个chromadb的向量数据库的集合
collection_name = f"kb_{doc.kb_id}"
kb = (
session.query(Knowledgebase)
.filter(Knowledgebase.id == kb_id)
.first()
)
if not kb:
raise ValueError(f"知识库不存在")
kb_chunk_size = kb.chunk_size
kb_chunk_overlap = kb.chunk_overlap
# 如果要清理旧数据
if need_cleanup:
try:
# 调用向量服务,删除指定集合中指定向档ID下面的所有的向量数据
vector_service.delete_documents(
collection_name=collection_name, filter={"doc_id": doc_id}
)
except Exception as e:
self.logger.warning(f"删除向量数据库失败:{e}")
self.logger.info(f"文档{doc_id}状态已经更新为processing状态了")
# 从存储中下载文件内容
file_data = storage_service.download_file(file_path)
# 解析文件,根据文件类型按不同的方法得到文本内容
langchain_docs = parser_service.parse(file_data, file_type)
self.logger.info(f"加载到{len(langchain_docs)}个文档")
if not langchain_docs:
raise ValueError(f"未能抽取到任何文本内容")
# 创建文本的分块器,指定知识库参数
splitter = TextSplitter(
chunk_size=kb_chunk_size, chunk_overlap=kb_chunk_overlap
)
# 将文档进行分块
chunks = splitter.split_documents(langchain_docs, doc_id=doc_id)
if not chunks:
raise ValueError(f"文档{doc_id}未能成功分块")
self.logger.info(f"加载到{len(chunks)}个分块")
# 初始化一个列表用于存放默认换后的langchain document对象
documents = []
for chunk in chunks:
# 创建一个langchain document对象
doc_obj = Document(
page_content=chunk["text"],
metadata={
"doc_id": doc_id, # 文档ID
"doc_name": doc_name, # 文档名称
"chunk_index": chunk["chunk_index"], # 分块索引
"id": chunk["id"], # 分块ID
"chunk_id": chunk["id"], # 分块ID
},
)
documents.append(doc_obj)
# 提取所有分块的ID,用于向量存储 chunk["id"]=它对就应的文档ID_index索引
chunk_ids = [chunk["id"] for chunk in chunks]
# 调用向量服务,将分块后的文档对象写入向量数据库
vector_service.add_documents(
collection_name=collection_name, documents=documents, ids=chunk_ids
)
with self.transaction() as session:
doc = (
session.query(DocumentModel)
.filter(DocumentModel.id == doc_id)
.first()
)
if doc:
doc.status = "completed"
doc.chunk_count = len(chunks)
self.logger.info(f"文档{doc_id}处理完成,分块数量为{len(chunks)}")
except Exception as e:
# 如果文档处理了,则需要更新文档的状态为失败,并且记录错误信息
with self.transaction() as session:
doc = (
session.query(DocumentModel)
.filter(DocumentModel.id == doc_id)
.first()
)
if doc:
doc.status = "failed"
doc.error_message = str(e)[:500]
session.flush()
session.refresh(doc)
self.logger.error(f"处理文档{doc_id}时发生了错误:{e}")
+ def delete(self, doc_id):
+ """
+ 删除文档
+ 包括:删除向量数据库中的向量数据、删除存储中的文件、删除数据库记录
+ """
+ with self.session() as session:
+ doc = (
+ session.query(DocumentModel)
+ .filter(DocumentModel.id == doc_id)
+ .first()
+ )
+ if not doc:
+ raise ValueError(f"文档{doc_id}不存在")
# 保存需要删除的信息
+ kb_id = doc.kb_id
+ file_path = doc.file_path
+ collection_name = f"kb_{kb_id}"
# 1. 删除向量数据库中的相关向量数据
+ try:
+ vector_service.delete_documents(
+ collection_name=collection_name,
+ filter={"doc_id": doc_id}
+ )
+ self.logger.info(f"已删除文档{doc_id}的向量数据")
+ except Exception as e:
+ self.logger.warning(f"删除向量数据失败:{e}")
# 2. 删除存储中的文件
+ if file_path:
+ try:
+ storage_service.delete_file(file_path)
+ self.logger.info(f"已删除文档{doc_id}的存储文件:{file_path}")
+ except Exception as e:
+ self.logger.warning(f"删除存储文件失败:{e}")
# 3. 删除数据库记录
+ with self.transaction() as session:
+ doc = (
+ session.query(DocumentModel)
+ .filter(DocumentModel.id == doc_id)
+ .first()
+ )
+ if doc:
+ session.delete(doc)
+ self.logger.info(f"已删除文档{doc_id}的数据库记录")
document_service = DocumentService()
10.3. kb_detail.html #
app/templates/kb_detail.html
{% extends "base.html" %}
{% block title %}{{ kb.name }} - RAG Lite{% endblock %}
{% block content %}
<div class="row">
<div class="col-12">
<nav aria-label="breadcrumb">
<ol class="breadcrumb">
<li class="breadcrumb-item"><a href="/">首页</a></li>
<li class="breadcrumb-item"><a href="/kb">知识库</a></li>
<li class="breadcrumb-item active">{{ kb.name }}</li>
</ol>
</nav>
</div>
</div>
<div class="row">
<div class="col-12">
<div class="card">
<div class="card-header d-flex justify-content-between align-items-center">
<h5 class="mb-0"><i class="bi bi-file-earmark"></i> 文档管理</h5>
<button class="btn btn-sm btn-primary" data-bs-toggle="modal" data-bs-target="#uploadModal">
<i class="bi bi-upload"></i> 上传文档
</button>
</div>
<div class="card-body">
<div id="docList">
{% if documents %}
<div class="table-responsive">
<table class="table table-hover">
<thead>
<tr>
<th>文档名称</th>
<th>状态</th>
<th>块数</th>
<th>文件大小</th>
<th>操作</th>
</tr>
</thead>
<tbody>
{% for doc in documents %}
<tr>
<td>
<i
class="bi bi-file-earmark-{{ 'pdf' if doc.file_type == 'pdf' else 'word' if doc.file_type == 'docx' else 'text' }}"></i>
{{ doc.name }}
</td>
<td>
{% set status_map = {
'completed': '已完成',
'processing': '处理中',
'failed': '失败',
'pending': '待处理'
} %}
<span
class="badge bg-{{ 'success' if doc.status == 'completed' else 'warning' if doc.status == 'processing' else 'danger' if doc.status == 'failed' else 'secondary' }}">
{{ status_map.get(doc.status, doc.status) }}
</span>
</td>
<td>{{ doc.chunk_count or 0 }}</td>
<td>{{ "%.2f"|format(doc.file_size / 1024) }} KB</td>
<td>
{% if doc.status == 'completed' %}
<a href="/documents/{{ doc.id }}/chunks" class="btn btn-sm btn-info me-1">
<i class="bi bi-list-ul"></i> 查看分块
</a>
{% endif %}
{% if doc.status == 'pending' %}
<button class="btn btn-sm btn-primary me-1"
onclick="processDoc('{{ doc.id }}', '{{ doc.name }}')">
<i class="bi bi-play-circle"></i> 处理
</button>
{% elif doc.status in ['completed', 'failed'] %}
<button class="btn btn-sm btn-warning me-1"
onclick="processDoc('{{ doc.id }}', '{{ doc.name }}')">
<i class="bi bi-arrow-clockwise"></i> 重新处理
</button>
{% elif doc.status == 'processing' %}
<button class="btn btn-sm btn-secondary me-1" disabled>
<i class="bi bi-hourglass-split"></i> 处理中
</button>
{% endif %}
<button class="btn btn-sm btn-danger"
onclick="deleteDoc('{{ doc.id }}', '{{ doc.name }}')">
<i class="bi bi-trash"></i> 删除
</button>
</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
<!-- 分页控件 -->
{% if pagination and pagination.total > pagination.page_size %}
<nav aria-label="文档列表分页" class="mt-3">
<ul class="pagination justify-content-center">
{% set current_page = pagination.page %}
{% set total_pages = (pagination.total + pagination.page_size - 1) // pagination.page_size
%}
<!-- 上一页 -->
<li class="page-item {% if current_page <= 1 %}disabled{% endif %}">
<a class="page-link"
href="?page={{ current_page - 1 }}&page_size={{ pagination.page_size }}" {% if
current_page <=1 %}tabindex="-1" aria-disabled="true" {% endif %}>
<i class="bi bi-chevron-left"></i> 上一页
</a>
</li>
<!-- 页码 -->
{% set start_page = [1, current_page - 2] | max %}
{% set end_page = [total_pages, current_page + 2] | min %}
{% if start_page > 1 %}
<li class="page-item">
<a class="page-link" href="?page=1&page_size={{ pagination.page_size }}">1</a>
</li>
{% if start_page > 2 %}
<li class="page-item disabled">
<span class="page-link">...</span>
</li>
{% endif %}
{% endif %}
{% for page_num in range(start_page, end_page + 1) %}
<li class="page-item {% if page_num == current_page %}active{% endif %}">
<a class="page-link" href="?page={{ page_num }}&page_size={{ pagination.page_size }}">
{{ page_num }}
</a>
</li>
{% endfor %}
{% if end_page < total_pages %} {% if end_page < total_pages - 1 %} <li
class="page-item disabled">
<span class="page-link">...</span>
</li>
{% endif %}
<li class="page-item">
<a class="page-link"
href="?page={{ total_pages }}&page_size={{ pagination.page_size }}">{{
total_pages }}</a>
</li>
{% endif %}
<!-- 下一页 -->
<li class="page-item {% if current_page >= total_pages %}disabled{% endif %}">
<a class="page-link"
href="?page={{ current_page + 1 }}&page_size={{ pagination.page_size }}" {% if
current_page>= total_pages %}tabindex="-1" aria-disabled="true"{% endif %}>
下一页 <i class="bi bi-chevron-right"></i>
</a>
</li>
</ul>
<div class="text-center text-muted small mt-2">
共 {{ pagination.total }} 个文档,第 {{ current_page }} / {{ total_pages }} 页
</div>
</nav>
{% endif %}
{% else %}
<div class="text-center text-muted py-5">
<i class="bi bi-inbox" style="font-size: 3rem;"></i>
<p class="mt-2">还没有文档,点击上方按钮上传文档</p>
</div>
{% endif %}
</div>
</div>
</div>
</div>
</div>
<!-- 上传文档模态框 -->
<div class="modal fade" id="uploadModal" tabindex="-1">
<div class="modal-dialog">
<div class="modal-content">
<div class="modal-header">
<h5 class="modal-title">上传文档</h5>
<button type="button" class="btn-close" data-bs-dismiss="modal"></button>
</div>
<form id="uploadForm" onsubmit="uploadDocument(event)" enctype="multipart/form-data">
<div class="modal-body">
<div class="mb-3">
<label class="form-label">选择文件</label>
<input type="file" class="form-control" name="file" accept=".pdf,.docx,.txt,.md" required>
<div class="form-text">支持 PDF、DOCX、TXT、MD 格式</div>
</div>
<div class="mb-3">
<label class="form-label">文档名称(可选)</label>
<input type="text" class="form-control" name="name" placeholder="留空使用文件名">
</div>
<div id="uploadProgress" class="progress d-none">
<div class="progress-bar progress-bar-striped progress-bar-animated" role="progressbar"
style="width: 100%"></div>
</div>
</div>
<div class="modal-footer">
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">取消</button>
<button type="submit" class="btn btn-primary">上传</button>
</div>
</form>
</div>
</div>
</div>
{% endblock %}
{% block extra_js %}
<script>
const kbId = '{{kb.id}}'
async function uploadDocument(event) {
event.preventDefault()
const form = event.target;
const formData = new FormData(form)
const processDiv = document.getElementById("uploadProgress")
const submitBtn = form.querySelector('button[type="submit"]')
processDiv.classList.remove('d-none')
submitBtn.disabled = true
try {
const response = await fetch(`/api/v1/knowledgebases/${kbId}/documents`, {
method: 'POST',
body: formData
})
if (response.ok) {
const result = await response.json()
const docId = result.data.id
bootstrap.Modal.getInstance(document.getElementById("uploadModal")).hide()
form.reset()
location.reload()
} else {
const error = await response.json()
alert("上传失败:" + error.message)
}
} catch (error) {
alert("上传失败:" + error.message)
} finally {
processDiv.classList.add('d-none')
submitBtn.disabled = false
}
}
async function processDoc(docId,docName){
try{
const response = await fetch(`/api/v1/documents/${docId}/process`,{method:"POST"})
if (response.ok){
alert("文件处理处理已经提交,请稍后刷新页面查看处理状态")
setTimeout(()=>location.reload(),1000)
}else{
const error = await response.json()
alert("处理失败:"+error.message)
}
}catch(error){
alert("处理失败"+error.message)
}
}
+ async function deleteDoc(docId, docName){
+ if (!confirm(`确定要删除文档"${docName}"吗?此操作不可恢复!`)){
+ return
+ }
try{
+ const response = await fetch(`/api/v1/documents/${docId}`, {
+ method: 'DELETE'
+ })
if (response.ok){
+ alert("文档删除成功")
+ location.reload()
}else{
const error = await response.json()
+ alert("删除失败:" + error.message)
}
}catch(error){
+ alert("删除失败:" + error.message)
}
}
</script>
{% endblock %}11.删除知识库 #
11.1. knowledgebase_service.py #
app/services/knowledgebase_service.py
from app.models.knowledgebase import Knowledgebase
+from app.models.document import Document as DocumentModel
from app.services.base_service import BaseService
import os
from app.services.storage_service import storage_service
+from app.services.vector_service import vector_service
from app.config import Config
class KnowledgebaseService(BaseService[Knowledgebase]):
def create(
self,
name,
user_id,
description,
chunk_size,
chunk_overlap,
cover_image_data,
cover_image_filename,
):
if cover_image_data and cover_image_filename:
# 获取不带.的文件扩展名
file_ext_without_dot = (
os.path.splitext(cover_image_filename)[1][1:].lower()
if "." in cover_image_filename
else ""
)
if not file_ext_without_dot:
raise ValueError(f"文件缺少扩展名:{cover_image_filename}")
if file_ext_without_dot not in Config.ALLOWED_IMAGE_EXTENSIONS:
raise ValueError(
f"不支持的图片格式:{file_ext_without_dot},支持的格式为{', '.join(Config.ALLOWED_IMAGE_EXTENSIONS)}"
)
if len(cover_image_data) == 0:
raise ValueError(f"上传的图片为空")
if len(cover_image_data) > Config.MAX_IMAGE_SIZE:
raise ValueError(
f"图片大小已经超过了最大限制:{Config.MAX_IMAGE_SIZE/1024/1024}M"
)
with self.transaction() as session:
kb = Knowledgebase(
name=name,
user_id=user_id,
description=description,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
# 将知识库对象添加到session
session.add(kb)
# 刷新session,生成知识库的ID
session.flush()
if cover_image_data and cover_image_filename:
# 构建封面图片的路径 统一使用小写扩展名 .png .jpg .gif 带点的文件扩展名
file_ext_with_dot = os.path.splitext(cover_image_filename)[1].lower()
cover_image_path = f"covers/{kb.id}{file_ext_with_dot}"
self.logger.info(
f"正在为新的知识库{kb.id}上传封面图片,文件名:{cover_image_filename},路径:{cover_image_path}"
)
storage_service.upload_file(cover_image_path, cover_image_data)
self.logger.info(f"成功上传知识库的封面图片:{cover_image_path}")
kb.cover_image = cover_image_path
session.flush()
# 刷新kb对象的数据库状态
session.refresh(kb)
# 把模型实例转成字典
kb_dict = kb.to_dict()
self.logger.info("创建知识库成功:ID:{kb.id}")
return kb_dict
def list(self, user_id, page, page_size, search, sort_by, sort_order):
with self.session() as session:
query = session.query(Knowledgebase)
# 如果指定了user_id,则只查找该 用户的知识库
if user_id:
query = query.filter(Knowledgebase.user_id == user_id)
if search:
search_pattern = f"%{search}%"
query = query.filter(
(Knowledgebase.name.like(search_pattern))
| (Knowledgebase.description.like(search_pattern))
)
sort_field = None
if sort_by == "name":
sort_field = Knowledgebase.name
elif sort_by == "updated_at":
sort_field = Knowledgebase.updated_at
else:
sort_field = Knowledgebase.created_at
if sort_order == "asc":
query = query.order_by(sort_field.asc())
else:
query = query.order_by(sort_field.desc())
# 统计总记录数
total = query.count()
# 计算分页的偏移量
offset = (page - 1) * page_size
kbs = query.offset(offset).limit(page_size).all()
items = []
for kb in kbs:
items.append(kb.to_dict())
return {
"items": items,
"total": total,
"page": page,
"page_size": page_size,
}
def delete(self, kb_id):
+ """
+ 删除知识库
+ 包括:删除所有文档的向量数据、删除所有文档的存储文件、删除封面图片、删除数据库记录
+ """
# 1. 获取知识库信息
+ with self.session() as session:
kb = session.query(Knowledgebase).filter(Knowledgebase.id == kb_id).first()
if not kb:
+ raise ValueError(f"知识库{kb_id}不存在")
# 保存需要删除的信息
+ kb_name = kb.name
+ cover_image_path = kb.cover_image if kb.cover_image else None
# 获取知识库下的所有文档
+ documents = session.query(DocumentModel).filter(DocumentModel.kb_id == kb_id).all()
+ doc_ids = [doc.id for doc in documents]
+ doc_file_paths = [doc.file_path for doc in documents if doc.file_path]
+ collection_name = f"kb_{kb_id}"
# 2. 删除向量数据库集合中的所有向量数据
+ if doc_ids:
+ try:
# 逐个删除每个文档的向量数据
+ for doc_id in doc_ids:
+ try:
+ vector_service.delete_documents(
+ collection_name=collection_name,
+ filter={"doc_id": doc_id}
+ )
+ except Exception as e:
+ self.logger.warning(f"删除文档{doc_id}的向量数据失败:{e}")
+ self.logger.info(f"已删除知识库{kb_id}的向量数据")
+ except Exception as e:
+ self.logger.warning(f"删除向量数据失败:{e}")
# 3. 删除所有文档的存储文件
+ for file_path in doc_file_paths:
+ if file_path:
+ try:
+ storage_service.delete_file(file_path)
+ self.logger.info(f"已删除文档存储文件:{file_path}")
+ except Exception as e:
+ self.logger.warning(f"删除文档存储文件失败:{file_path}, 错误:{e}")
# 4. 删除知识库的封面图片
+ if cover_image_path:
+ try:
+ storage_service.delete_file(cover_image_path)
+ self.logger.info(f"已删除知识库封面图片:{cover_image_path}")
+ except Exception as e:
+ self.logger.warning(f"删除封面图片失败:{e}")
# 5. 删除知识库的数据库记录(由于 CASCADE,文档记录会自动删除)
+ with self.transaction() as session:
+ kb = session.query(Knowledgebase).filter(Knowledgebase.id == kb_id).first()
+ if kb:
+ session.delete(kb)
+ self.logger.info(f"已删除知识库:{kb_id} {kb_name}")
+ return True
# def get_by_id(self, kb_id):
# kb = super().get_by_id(Knowledgebase, kb_id)
# if kb:
# return kb.to_dict()
# return None
def get_by_id(self, kb_id: str):
with self.session() as db_session:
try:
return (
db_session.query(Knowledgebase)
.filter(Knowledgebase.id == kb_id)
.first()
.to_dict()
)
except Exception as e:
self.logger.error("获取ID对应的对象失败:{e}")
return None
def update(
self, kb_id, cover_image_data, cover_image_filename, delete_cover, **kwargs
):
with self.transaction() as session:
kb = session.query(Knowledgebase).filter(Knowledgebase.id == kb_id).first()
if not kb:
return None
# 老的图片路径
old_cover_path = kb.cover_image if kb.cover_image else None
if delete_cover:
if old_cover_path:
# 如果有旧的封面图片,并且需要删除的话
storage_service.delete_file(old_cover_path)
self.logger.info(f"已成功删除旧的封面图片:{old_cover_path}")
# 更新数据库中的cover_image为None
setattr(kb, "cover_image", None)
elif cover_image_data and cover_image_filename:
file_ext_with_dot = os.path.splitext(cover_image_filename)[1]
file_ext_with_dot = file_ext_with_dot.lower()
# 构建新的图片路径
new_cover_path = f"covers/{kb_id}{file_ext_with_dot}"
if old_cover_path:
storage_service.delete_file(old_cover_path)
storage_service.upload_file(new_cover_path, cover_image_data)
setattr(kb, "cover_image", new_cover_path)
for key, value in kwargs.items():
if hasattr(kb, key) and value is not None:
setattr(kb, key, value)
# flush指是使用我们提供的kb的值去更新数据库
session.flush()
# 刷新对象,避免未提交前读到旧的数据
session.refresh(kb)
kb_dict = kb.to_dict()
self.logger.info(f"更新知识库:{kb_id} {kb.name}")
return kb_dict
kb_service = KnowledgebaseService()
12.设置联动 #
12.1. settings.html #
app/templates/settings.html
{% extends "base.html" %}
{% block title %}设置 - RAG Lite{% endblock %}
{% block content %}
<div class="row">
<div class="col-12">
<nav aria-label="breadcrumb" class="mb-3">
<ol class="breadcrumb">
<li class="breadcrumb-item"><a href="/">首页</a></li>
<li class="breadcrumb-item active">设置</li>
</ol>
</nav>
<h2><i class="bi bi-gear"></i> 系统设置</h2>
<p class="text-muted">配置模型、提示词和检索参数</p>
<form id="settingsForm" onsubmit="saveSettings(event)">
<!-- 标签页导航 -->
<ul class="nav nav-tabs mb-4" id="settingsTabs" role="tablist">
<li class="nav-item" role="presentation">
<button class="nav-link active" id="embedding-tab" data-bs-toggle="tab" data-bs-target="#embedding"
type="button" role="tab">
<i class="bi bi-diagram-3"></i> 向量嵌入模型
</button>
</li>
<li class="nav-item" role="presentation">
<button class="nav-link" id="llm-tab" data-bs-toggle="tab" data-bs-target="#llm" type="button"
role="tab">
<i class="bi bi-robot"></i> 大语言模型
</button>
</li>
<li class="nav-item" role="presentation">
<button class="nav-link" id="prompt-tab" data-bs-toggle="tab" data-bs-target="#prompt" type="button"
role="tab">
<i class="bi bi-chat-quote"></i> 提示词设置
</button>
</li>
<li class="nav-item" role="presentation">
<button class="nav-link" id="retrieval-tab" data-bs-toggle="tab" data-bs-target="#retrieval"
type="button" role="tab">
<i class="bi bi-search"></i> 检索设置
</button>
</li>
</ul>
<!-- 标签页内容 -->
<div class="tab-content" id="settingsTabContent">
<!-- 标签1: 向量嵌入模型 -->
<div class="tab-pane fade show active" id="embedding" role="tabpanel">
<div class="card">
<div class="card-header">
<h5 class="mb-0"><i class="bi bi-diagram-3"></i> 向量嵌入模型(Embedding)</h5>
</div>
<div class="card-body">
<div class="mb-3">
<label class="form-label">提供商 <span class="text-danger">*</span></label>
<select class="form-select" id="embeddingProvider" name="embedding_provider"
onchange="updateEmbeddingForm()" required>
<option value="huggingface">HuggingFace</option>
<option value="openai">OpenAI</option>
<option value="ollama">Ollama</option>
</select>
<div class="form-text">选择向量嵌入模型提供商</div>
</div>
<div class="mb-3" id="embeddingModelNameGroup">
<label class="form-label">模型名称</label>
<select class="form-select" id="embeddingModelName" name="embedding_model_name">
<option value="">请选择模型</option>
<!-- 动态填充 -->
</select>
<div class="form-text">选择模型名称或路径</div>
</div>
<div class="mb-3" id="embeddingApiKeyGroup" style="display: none;">
<label class="form-label">API Key</label>
<input type="password" class="form-control" id="embeddingApiKey"
name="embedding_api_key" placeholder="输入 API Key">
<div class="form-text">某些提供商需要 API Key</div>
</div>
<div class="mb-3" id="embeddingBaseUrlGroup" style="display: none;">
<label class="form-label">Base URL</label>
<input type="text" class="form-control" id="embeddingBaseUrl" name="embedding_base_url"
placeholder="例如: http://localhost:11434">
<div class="form-text">API Base URL(Ollama 需要)</div>
</div>
</div>
</div>
</div>
<!-- 标签2: 大语言模型 -->
<div class="tab-pane fade" id="llm" role="tabpanel">
<div class="card">
<div class="card-header">
<h5 class="mb-0"><i class="bi bi-robot"></i> 大语言模型(LLM)</h5>
</div>
<div class="card-body">
<div class="mb-3">
<label class="form-label">提供商 <span class="text-danger">*</span></label>
<select class="form-select" id="llmProvider" name="llm_provider"
onchange="updateLLMForm()" required>
<option value="deepseek">DeepSeek</option>
<option value="openai">OpenAI</option>
<option value="ollama">Ollama</option>
</select>
<div class="form-text">选择大语言模型提供商</div>
</div>
<div class="mb-3" id="llmModelNameGroup">
<label class="form-label">模型名称</label>
<select class="form-select" id="llmModelName" name="llm_model_name">
<option value="1" selected>请选择模型</option>
<!-- 动态填充 -->
</select>
<div class="form-text">选择模型名称</div>
</div>
<div class="mb-3" id="llmApiKeyGroup">
<label class="form-label">API Key</label>
<input type="password" class="form-control" id="llmApiKey" name="llm_api_key"
placeholder="输入 API Key">
<div class="form-text">某些提供商需要 API Key</div>
</div>
<div class="mb-3" id="llmBaseUrlGroup">
<label class="form-label">Base URL</label>
<input type="text" class="form-control" id="llmBaseUrl" name="llm_base_url"
placeholder="例如: https://api.deepseek.com">
<div class="form-text">API Base URL</div>
</div>
<div class="mb-3">
<label class="form-label">温度 (Temperature)</label>
<input type="number" class="form-control" id="llmTemperature" name="llm_temperature"
value="0.7" step="0.1" min="0" max="2" placeholder="0.7">
<div class="form-text">控制输出的随机性,值越大越随机(0-2)</div>
</div>
</div>
</div>
</div>
<!-- 标签3: 提示词设置 -->
<div class="tab-pane fade" id="prompt" role="tabpanel">
<div class="card">
<div class="card-body">
<!-- 子标签页导航 -->
<ul class="nav nav-tabs mb-4" id="promptSubTabs" role="tablist">
<li class="nav-item" role="presentation">
<button class="nav-link active" id="chat-prompt-sub-tab" data-bs-toggle="tab"
data-bs-target="#chat-prompt-sub" type="button" role="tab">
<i class="bi bi-chat"></i> 普通聊天提示词
</button>
</li>
<li class="nav-item" role="presentation">
<button class="nav-link" id="rag-prompt-sub-tab" data-bs-toggle="tab"
data-bs-target="#rag-prompt-sub" type="button" role="tab">
<i class="bi bi-book"></i> 知识库聊天提示词
</button>
</li>
</ul>
<!-- 子标签页内容 -->
<div class="tab-content" id="promptSubTabContent">
<!-- 子标签1: 普通聊天提示词 -->
<div class="tab-pane fade show active" id="chat-prompt-sub" role="tabpanel">
<div class="mb-4">
<label class="form-label">普通聊天系统提示词</label>
<textarea class="form-control" id="chatSystemPrompt" name="chat_system_prompt"
rows="10" placeholder="输入普通聊天提示词..."></textarea>
<div class="form-text mt-2">
<p class="mb-0">普通聊天提示词用于指导AI助手在普通聊天(未选择知识库)时的回答风格和行为。</p>
<p class="mb-0 mt-2"><strong>注意:</strong>这是系统消息的内容,不能使用变量。</p>
</div>
</div>
</div>
<!-- 子标签2: 知识库聊天提示词 -->
<div class="tab-pane fade" id="rag-prompt-sub" role="tabpanel">
<div class="mb-4">
<label class="form-label">知识库聊天系统提示词</label>
<textarea class="form-control" id="ragSystemPrompt" name="rag_system_prompt"
rows="6" placeholder="输入知识库聊天系统提示词..."></textarea>
<div class="form-text mt-2">
<p class="mb-0">知识库聊天系统提示词用于在会话开始时设置AI助手的角色和行为。</p>
<p class="mb-0 mt-2"><strong>注意:</strong>这是系统消息的内容,不能使用变量(如 {context} 或
{question})。</p>
</div>
</div>
<hr class="my-4">
<div class="mb-3">
<label class="form-label">知识库聊天查询提示词</label>
<textarea class="form-control" id="ragQueryPrompt" name="rag_query_prompt"
rows="10"
placeholder="例如:文档内容: {context} 问题:{question} 请基于文档内容回答问题。如果文档中没有相关信息,请明确说明。"></textarea>
<div class="form-text mt-2">
<p class="mb-1">知识库聊天查询提示词用于每次提问时构建提示,指导AI助手如何基于文档内容回答问题。</p>
<p class="mb-0"><strong>必须使用以下变量:</strong></p>
<ul class="mb-0">
<li><code>{context}</code> - 检索到的文档内容(必需)</li>
<li><code>{question}</code> - 用户的问题(必需)</li>
</ul>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<!-- 标签3: 检索设置 -->
<div class="tab-pane fade" id="retrieval" role="tabpanel">
<div class="card">
<div class="card-header">
<h5 class="mb-0"><i class="bi bi-search"></i> 检索设置</h5>
</div>
<div class="card-body">
<div class="mb-3">
<label class="form-label">检索模式 <span class="text-danger">*</span></label>
+ <select class="form-select" id="retrievalMode" name="retrieval_mode"
+ onchange="updateRetrievalForm()" required>
<option value="vector">向量检索</option>
<option value="keyword">全文检索</option>
<option value="hybrid">混合检索</option>
</select>
<div class="form-text">选择文档检索方式</div>
</div>
<div class="mb-3" id="vectorThresholdGroup">
<label class="form-label">向量检索阈值</label>
<input type="number" class="form-control" id="vectorThreshold" name="vector_threshold"
value="0.2" step="0.1" min="0" max="1" placeholder="0.2">
<div class="form-text">向量相似度阈值,低于此值的文档将被过滤(0-1)</div>
</div>
<div class="mb-3" id="keywordThresholdGroup" style="display: none;">
<label class="form-label">全文检索阈值</label>
<input type="number" class="form-control" id="keywordThreshold" name="keyword_threshold"
value="0.5" step="0.1" min="0" max="1" placeholder="0.5">
<div class="form-text">关键词匹配阈值(0-1)</div>
</div>
<div class="mb-3" id="vectorWeightGroup" style="display: none;">
<label class="form-label">向量检索权重</label>
<input type="number" class="form-control" id="vectorWeight" name="vector_weight"
value="0.7" step="0.1" min="0" max="1" placeholder="0.7">
<div class="form-text">混合检索时向量检索的权重(0-1),关键词检索权重 = 1 - 向量权重</div>
</div>
<div class="mb-3">
<label class="form-label">TopK 结果数量</label>
<input type="number" class="form-control" id="topK" name="top_k" value="5" min="1"
max="50" placeholder="5">
<div class="form-text">返回的文档数量(1-50)</div>
</div>
<div class="alert alert-info">
<i class="bi bi-info-circle"></i> <strong>说明:</strong>
<ul class="mb-0 mt-2">
<li><strong>向量检索:</strong>基于语义相似度检索,适合理解问题意图</li>
<li><strong>全文检索:</strong>基于关键词匹配检索,适合精确匹配</li>
<li><strong>混合检索:</strong>结合向量和关键词检索,综合两者的优势</li>
</ul>
</div>
</div>
</div>
</div>
</div>
<div class="d-flex justify-content-end gap-2 mt-4">
<button type="button" class="btn btn-secondary" onclick="resetSettings()">重置</button>
<button type="submit" class="btn btn-primary">
<i class="bi bi-save"></i> 保存设置
</button>
</div>
</form>
</div>
</div>
{% endblock %}
{% block extra_js %}
<script>
let availableModels = {}
function updateEmbeddingForm() {
const provider = document.getElementById('embeddingProvider').value
const modelSelect = document.getElementById('embeddingModelName')
const apiKeyGroup = document.getElementById('embeddingApiKeyGroup')
const baseUrlGroup = document.getElementById('embeddingBaseUrlGroup')
//先记住当前选中的模型
const currentValue = modelSelect.value
modelSelect.innerHTML = '<option value="">请选择模型</options>'
if (availableModels && availableModels.embedding_models[provider]) {
const providerInfo = availableModels.embedding_models[provider]
providerInfo.models.forEach(model => {
const option = document.createElement('option')
const optionValue = model.path || model.name
option.value = optionValue
const displayText = model.name + (model.dimension ? `(维度:${model.dimension})` : "")
option.textContent = displayText
modelSelect.appendChild(option)
})
//恢复刷新前选中的值
if (currentValue) {
const optionExists = Array.from(modelSelect.options).some(opt => opt.value === currentValue)
if (optionExists) {
modelSelect.value = currentValue
}
}
if (providerInfo.requires_api_key) {
apiKeyGroup.style.display = 'block'
}else{
apiKeyGroup.style.display = 'none'
}
if (providerInfo.requires_base_url) {
baseUrlGroup.style.display = 'block'
}else{
baseUrlGroup.style.display = 'none'
}
}
}
function updateLLMForm() {
const provider = document.getElementById('llmProvider').value
const modelSelect = document.getElementById('llmModelName')
const apiKeyGroup = document.getElementById('llmApiKeyGroup')
const baseUrlGroup = document.getElementById('llmBaseUrlGroup')
//先记住当前选中的模型
const currentValue = modelSelect.value
modelSelect.innerHTML = '<option value="">请选择模型</options>'
if (availableModels && availableModels.llm_models[provider]) {
const providerInfo = availableModels.llm_models[provider]
providerInfo.models.forEach(model => {
const option = document.createElement('option')
const optionValue = model.name
option.value = optionValue
const displayText = model.name;
option.textContent = displayText
modelSelect.appendChild(option)
})
if (providerInfo.requires_api_key) {
apiKeyGroup.style.display = 'block'
}else{
apiKeyGroup.style.display = 'none'
}
if (providerInfo.requires_base_url) {
baseUrlGroup.style.display = 'block'
}else{
baseUrlGroup.style.display = 'none'
}
+ //恢复刷新前选中的值
+ if (currentValue) {
+ const optionExists = Array.from(modelSelect.options).some(opt => opt.value === currentValue)
+ if (optionExists) {
+ modelSelect.value = currentValue
+ }
+ }
}
}
function loadSettings(settings) {
+ // 向量嵌入模型:先设置提供商和模型名称,然后更新表单
document.getElementById('embeddingProvider').value = settings.embedding_provider
+ document.getElementById('embeddingModelName').value = settings.embedding_model_name || ''
+ updateEmbeddingForm() // 更新模型列表和显示/隐藏字段(会尝试恢复之前的值)
+ document.getElementById('embeddingApiKey').value = settings.embedding_api_key || ''
+ document.getElementById('embeddingBaseUrl').value = settings.embedding_base_url || ''
+ // 大语言模型:先设置提供商和模型名称,然后更新表单
document.getElementById('llmProvider').value = settings.llm_provider
+ document.getElementById('llmModelName').value = settings.llm_model_name || ''
+ updateLLMForm() // 更新模型列表和显示/隐藏字段(会尝试恢复之前的值)
+ document.getElementById('llmApiKey').value = settings.llm_api_key || ''
+ document.getElementById('llmBaseUrl').value = settings.llm_base_url || ''
+ document.getElementById('llmTemperature').value = settings.llm_temperature || '0.7'
+ document.getElementById('chatSystemPrompt').value = settings.chat_system_prompt || ''
+ document.getElementById('ragSystemPrompt').value = settings.rag_system_prompt || ''
+ document.getElementById('ragQueryPrompt').value = settings.rag_query_prompt || ''
+ document.getElementById('retrievalMode').value = settings.retrieval_mode || 'vector'
+ document.getElementById('vectorThreshold').value = settings.vector_threshold || '0.2'
+ document.getElementById('keywordThreshold').value = settings.keyword_threshold || '0.5'
+ document.getElementById('vectorWeight').value = settings.vector_weight || '0.7'
+ document.getElementById('topK').value = settings.top_k || '5'
+ updateRetrievalForm() // 更新检索模式的显示/隐藏
}
document.addEventListener('DOMContentLoaded', async function () {
try {
const modelsResponse = await fetch(`/api/v1/settings/models`)
const modelsResult = await modelsResponse.json()
if (modelsResult.code == 200) {
availableModels = modelsResult.data
updateEmbeddingForm()
updateLLMForm()
updateRetrievalForm()
const settingsResponse = await fetch('/api/v1/settings')
const settingsResult = await settingsResponse.json()
if (settingsResult.code == 200) {
loadSettings(settingsResult.data)
}
}
} catch (error) {
alert("加载设置失败:" + error.message)
}
})
function resetSettings() {
if (confirm('确定要重置为默认设置吗?')) {
loadSettings({
embedding_provider: 'huggingface',
embedding_model_name: 'sentence-transformers/all-MiniLM-L6-v2',
embedding_api_key: 'embedding_api_key',
embedding_base_url: 'embedding_base_url',
llm_provider: 'deepseek',
llm_model_name: 'deepseek-chat',
llm_api_key: 'deepseek_api_key',
llm_base_url: 'https://api.deepseek.com',
llm_temperature: 0.7,
chat_system_prompt: '你是一个专业的AI助手。请友好、准确地回答用户的问题。',
rag_system_prompt: '你是一个专业的AI助手。请基于文档内容回答问题。',
rag_query_prompt: '文档内容:\n{context}\n\n问题:{question}\n\n请基于文档内容回答问题。如果文档中没有相关信息,请明确说明。',
retrieval_mode: 'vector',
vector_threshold: 0.2,
keyword_threshold: 0.2,
vector_weight: 0.5,
top_k: 5
});
}
}
async function saveSettings(event){
event.preventDefault();
const form = event.target;
const formData = new FormData(form)
const data = {
embedding_provider:formData.get('embedding_provider'),
embedding_model_name:formData.get('embedding_model_name'),
embedding_api_key:formData.get('embedding_api_key'),
embedding_base_url:formData.get('embedding_base_url'),
llm_provider:formData.get('llm_provider'),
llm_model_name:formData.get('llm_model_name'),
llm_api_key:formData.get('llm_api_key'),
llm_base_url:formData.get('llm_base_url'),
llm_temperature:formData.get('llm_temperature'),
chat_system_prompt:formData.get('chat_system_prompt'),
rag_system_prompt:formData.get('rag_system_prompt'),
rag_query_prompt:formData.get('rag_query_prompt'),
retrieval_mode:formData.get('retrieval_mode'),
vector_threshold:formData.get('vector_threshold'),
keyword_threshold:formData.get('keyword_threshold'),
vector_weight:formData.get('vector_weight'),
top_k:formData.get('top_k')
}
try{
const response = await fetch('/api/v1/settings',{
method:'PUT',
headers:{"Content-Type":"application/json"},
body:JSON.stringify(data)
})
const result = await response.json()
if (response.ok){
alert("保存设置成功")
location.reload()
}else{
alert("保存失败:"+result.message)
}
}catch(error){
alert("保存设置失败:"+error.message)
}
}
function updateRetrievalForm(){
const retrievalMode = document.getElementById('retrievalMode').value
const vectorThresholdGroup = document.getElementById('vectorThresholdGroup')
const keywordThresholdGroup = document.getElementById('keywordThresholdGroup')
const vectorWeightGroup = document.getElementById('vectorWeightGroup')
+ if(retrievalMode === 'vector'){
+ // 向量检索:只显示向量检索阈值
+ vectorThresholdGroup.style.display = 'block'
+ keywordThresholdGroup.style.display = 'none'
+ vectorWeightGroup.style.display = 'none'
+ }else if (retrievalMode === 'keyword'){
+ // 全文检索:只显示全文检索阈值
+ vectorThresholdGroup.style.display = 'none'
+ keywordThresholdGroup.style.display = 'block'
+ vectorWeightGroup.style.display = 'none'
+ }else if (retrievalMode === 'hybrid'){
+ // 混合检索:显示向量检索阈值、全文检索阈值和向量检索权重
+ vectorThresholdGroup.style.display = 'block'
+ keywordThresholdGroup.style.display = 'block'
+ vectorWeightGroup.style.display = 'block'
}
}
</script>
{% endblock %}