1.本章目标 #
本章将介绍 rag-lite 项目中数据与模型层的设计与实现,包括以下目标:
- 理解项目中各类数据模型(如用户、知识库、文档、聊天会话与消息、系统配置等)的结构、作用及其关系。
- 掌握如何使用 SQLAlchemy 定义 ORM 基础模型类(BaseModel),实现通用数据表操作功能(如 to_dict、repr 等)。
- 理解包结构和模型组件的组织方式,掌握模型包的初始化 (init.py) 设计方法。
- 学会扩展和编写新的数据模型,为业务需求(如知识管理、聊天记录、文档存储等)提供数据支撑。
- 为后续数据持久化与交互操作(如增删改查、与数据库集成等)打下基础。
2.目录结构 #
# 项目根目录
rag-lite/
# 应用主目录
├── app/
# 数据模型目录
│ ├── models/
# 数据模型包的初始化文件
│ │ ├── __init__.py
# ORM模型基础类定义
│ │ ├── base.py
# 聊天消息模型
│ │ ├── chat_message.py
# 聊天会话模型
│ │ ├── chat_session.py
# 文档模型
│ │ ├── document.py
# 知识库模型
│ │ ├── knowledgebase.py
# 系统配置模型
│ │ ├── settings.py
# 用户模型
│ │ └── user.py
# 静态资源目录
├── static/
# 模板文件目录
├── templates/
# 工具类目录
├── utils/
# 数据库工具
│ ├── db.py
# 日志工具
│ └── logger.py
# app包初始化文件
├── __init__.py
# 应用配置文件
└── config.py
# 日志目录
├── logs/
# 日志文件
│ └── rag_lite.log
# 应用入口文件
├── main.py
# Python项目管理配置文件
└── pyproject.toml3.数据结构 #
classDiagram
class User {
<<用户模型>>
+String id (PK) "用户主键"
+String username (UQ) "用户名"
+String email (UQ) "邮箱"
+String password_hash "密码哈希"
+Boolean is_active "是否激活"
+DateTime created_at "创建时间"
+DateTime updated_at "更新时间"
+Knowledgebase[] knowledgebases "拥有的知识库"
+ChatSession[] chat_sessions "创建的聊天会话"
}
class Knowledgebase {
<<知识库模型>>
+String id (PK) "知识库主键"
+String user_id (FK) "用户ID(外键)"
+String name "知识库名称"
+Text description "描述"
+String cover_image "封面图片路径"
+Integer chunk_size "分块大小"
+Integer chunk_overlap "分块重叠大小"
+DateTime created_at "创建时间"
+DateTime updated_at "更新时间"
+User user "所属用户"
+Document[] documents "包含的文档"
+ChatSession[] chat_sessions "使用的聊天会话"
}
class Document {
<<文档模型>>
+String id (PK) "文档主键"
+String kb_id (FK) "知识库ID(外键)"
+String name "文档名称"
+String file_path "文件路径"
+String file_type "文件类型"
+BigInt file_size "文件大小"
+String status "文档状态"
+Integer chunk_count "文档分块数"
+Text error_message "错误信息"
+DateTime created_at "创建时间"
+DateTime updated_at "更新时间"
+Knowledgebase knowledgebase "所属知识库"
}
class ChatSession {
<<聊天会话模型>>
+String id (PK) "会话主键"
+String user_id (FK) "用户ID(外键)"
+String kb_id (FK) "知识库ID(外键)"
+String title "会话标题"
+DateTime created_at "创建时间"
+DateTime updated_at "更新时间"
+User user "所属用户"
+Knowledgebase knowledgebase "关联的知识库"
+ChatMessage[] chat_messages "包含的聊天消息"
}
class ChatMessage {
<<聊天消息模型>>
+String id (PK) "消息主键"
+String session_id (FK) "会话ID(外键)"
+String role "角色(user或assistant)"
+Text content "消息内容"
+JSON sources "引用来源(JSON格式)"
+DateTime created_at "创建时间"
+ChatSession chat_session "所属会话"
}
class Settings {
<<设置模型>>
+String id (PK) "设置主键(默认为'global')"
+String embedding_provider "Embedding提供商"
+String embedding_model_name "Embedding模型名称"
+String embedding_api_key "Embedding API Key"
+String embedding_base_url "Embedding Base URL"
+String llm_provider "LLM提供商"
+String llm_model_name "LLM模型名称"
+String llm_api_key "LLM API Key"
+String llm_base_url "LLM Base URL"
+String llm_temperature "LLM温度参数"
+Text chat_system_prompt "普通聊天系统提示词"
+Text rag_system_prompt "知识库聊天系统提示词"
+Text rag_query_prompt "知识库聊天查询提示词"
+String retrieval_mode "检索模式(vector/keyword/hybrid)"
+String vector_threshold "向量检索阈值"
+String keyword_threshold "全文检索阈值"
+String vector_weight "向量检索权重(混合检索时使用)"
+String top_n "TopN结果数量"
+DateTime created_at "创建时间"
+DateTime updated_at "更新时间"
}
User "1" -- "*" Knowledgebase : "拥有"
User "1" -- "*" ChatSession : "创建"
Knowledgebase "1" -- "*" Document : "包含"
Knowledgebase "1" -- "*" ChatSession : "用于"
ChatSession "1" -- "*" ChatMessage : "包含"
%% Foreign key relationships
Knowledgebase --> User : "user_id"
Document --> Knowledgebase : "kb_id"
ChatSession --> User : "user_id"
ChatSession --> Knowledgebase : "kb_id"
ChatMessage --> ChatSession : "session_id"
graph TB
%% 定义样式
classDef userClass fill:#e1f5fe,stroke:#01579b,stroke-width:2px
classDef kbClass fill:#f3e5f5,stroke:#4a148c,stroke-width:2px
classDef docClass fill:#e8f5e8,stroke:#1b5e20,stroke-width:2px
classDef sessionClass fill:#fff3e0,stroke:#e65100,stroke-width:2px
classDef messageClass fill:#fce4ec,stroke:#880e4f,stroke-width:2px
classDef settingsClass fill:#e0f2f1,stroke:#004d40,stroke-width:2px
%% 节点定义
subgraph "核心实体关系"
U[用户<br/>User<br/>id, username, email, ...]
class U userClass
KB[知识库<br/>Knowledgebase<br/>id, name, description, chunk_size, ...]
class KB kbClass
D[文档<br/>Document<br/>id, name, file_path, status, chunk_count, ...]
class D docClass
CS[聊天会话<br/>ChatSession<br/>id, title, kb_id, ...]
class CS sessionClass
CM[聊天消息<br/>ChatMessage<br/>id, role, content, sources, ...]
class CM messageClass
S[系统设置<br/>Settings<br/>id, embedding_provider, llm_provider, ...]
class S settingsClass
end
%% 关系连接
U -- "1:n<br/>拥有" --> KB
U -- "1:n<br/>创建" --> CS
KB -- "1:n<br/>包含" --> D
KB -- "1:n<br/>用于" --> CS
CS -- "1:n<br/>包含" --> CM
%% 外键关系
KB -. "user_id (FK)" .-> U
D -. "kb_id (FK)" .-> KB
CS -. "user_id (FK)" .-> U
CS -. "kb_id (FK)" .-> KB
CM -. "session_id (FK)" .-> CS
%% 独立实体
S
%% 图例说明
subgraph "图例说明"
L1[主实体]:::userClass
L2[关系线:业务关系]
L3[虚线:外键引用]
L4[箭头方向:1对多关系]
end
4.功能实现 #
4.1. init.py #
app/models/init.py
"""
数据模型模块
"""
# 导入数据库模型基类和自定义基类
from app.models.base import Base, BaseModel
# 导入知识库相关模型
from app.models.knowledgebase import Knowledgebase
# 导入文档模型
from app.models.document import Document
# 导入系统配置模型
from app.models.settings import Settings
# 导入用户模型
from app.models.user import User
# 导入聊天会话模型
from app.models.chat_session import ChatSession
# 导入聊天消息模型
from app.models.chat_message import ChatMessage
# 定义当前模块对外可用的成员列表
__all__ = [
'Base',
'BaseModel',
'Knowledgebase',
'Document',
'Settings',
'User',
'ChatSession',
'ChatMessage'
]
4.2. base.py #
app/models/base.py
# 数据库模型基类说明文档字符串
"""
数据库模型基类
"""
# 导入SQLAlchemy的declarative_base用于创建模型基类
from sqlalchemy.ext.declarative import declarative_base
# 导入inspect用于反射获取模型信息
from sqlalchemy.inspection import inspect
# 导入类型标注用的类型
from typing import Dict, Any, Optional
# 创建统一的Base类,所有ORM模型都应继承自该Base
Base = declarative_base()
# 定义所有模型的基类
class BaseModel(Base):
# 说明此类为抽象类,不会创建表
"""
模型基类,提供通用方法
所有模型应该继承此类而不是直接继承 Base
"""
__abstract__ = True # 标记为抽象类,不会创建表
# 将模型对象转为字典的方法
def to_dict(self, exclude: Optional[list] = None, **kwargs) -> Dict[str, Any]:
# 通用的to_dict方法说明文档
"""
转换为字典(通用实现)
Args:
exclude: 要排除的字段列表(如 ['password_hash'])
**kwargs: 额外的参数,用于特殊处理(如 include_password=True)
Returns:
字典格式的数据
"""
# 如果没有传入exclude,则用空列表
exclude = exclude or []
# 初始化结果字典
result = {}
# 获取当前模型类的所有列定义
mapper = inspect(self.__class__)
for column in mapper.columns:
# 获取列名
col_name = column.name
# 排除要忽略的字段
if col_name in exclude:
continue
# 获取字段值
value = getattr(self, col_name, None)
# 如果是日期时间类型,调用isoformat转换为字符串
if hasattr(value, 'isoformat'):
result[col_name] = value.isoformat() if value else None
else:
result[col_name] = value
# 返回字典化后的结果
return result
# 统一的repr方法,便于调试打印
def __repr__(self) -> str:
# repr方法说明文档
"""
通用 __repr__ 实现
子类可以定义 __repr_fields__ 来指定要显示的字段
如果没有定义,则显示 id 字段
"""
# 如果子类定义了__repr_fields__,优先显示这些字段
if hasattr(self, '__repr_fields__'):
fields = getattr(self, '__repr_fields__')
attrs = ', '.join(f"{field}={getattr(self, field, None)}" for field in fields)
else:
# 默认显示id字段
attrs = f"id={getattr(self, 'id', None)}"
# 返回格式化后的字符串
return f"<{self.__class__.__name__}({attrs})>"
4.3. chat_message.py #
app/models/chat_message.py
# 聊天消息模型
"""
聊天消息模型
"""
# 导入 SQLAlchemy 的列类型、外键约束、JSON 类型等
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, JSON
# 导入 SQLAlchemy 的时间函数
from sqlalchemy.sql import func
# 导入 uuid 用于生成唯一的主键
import uuid
# 导入自定义的基础模型 BaseModel
from app.models.base import BaseModel
# 定义 ChatMessage 类,继承自 BaseModel
class ChatMessage(BaseModel):
# 为类添加说明文档
"""聊天消息模型"""
# 指定在数据库中的表名为 chat_message
__tablename__ = 'chat_message'
# 指定 __repr__ 时显示的字段为 id, session_id, role
__repr_fields__ = ['id', 'session_id', 'role'] # 指定 __repr__ 显示的字段
# 定义主键 id,使用 32 位字符串,默认值为 uuid 的前 32 位
id = Column(String(32), primary_key=True, default=lambda: uuid.uuid4().hex[:32])
# 定义 session_id,外键关联 chat_session.id,级联删除,不可为空,加索引
session_id = Column(String(32), ForeignKey('chat_session.id', ondelete='CASCADE'), nullable=False, index=True)
# 定义角色 role,为字符串类型,不可为空,区分 'user' 或 'assistant'
role = Column(String(16), nullable=False) # 'user' 或 'assistant'
# 定义消息内容 content,为 Text 类型,不可为空
content = Column(Text, nullable=False) # 消息内容
# 定义引用来源 sources,为 JSON 类型,可以为空
sources = Column(JSON, nullable=True) # 引用来源(JSON格式)
# 定义创建时间 created_at,默认为当前时间,加索引
created_at = Column(DateTime, default=func.now(), index=True)
4.4. chat_session.py #
app/models/chat_session.py
# 聊天会话模型的说明文档字符串
"""
聊天会话模型
"""
# 导入SQLAlchemy的基础类Column、String、DateTime、ForeignKey
from sqlalchemy import Column, String, DateTime, ForeignKey
# 导入SQLAlchemy的时间函数
from sqlalchemy.sql import func
# 导入uuid库用于生成唯一id
import uuid
# 导入自定义基础模型
from app.models.base import BaseModel
# 定义ChatSession类,继承自BaseModel
class ChatSession(BaseModel):
# 设置模型说明(docstring)
"""聊天会话模型"""
# 指定数据库表名为 chat_session
__tablename__ = 'chat_session'
# 指定实例输出时,显示的字段
__repr_fields__ = ['id', 'user_id', 'title'] # 指定 __repr__ 显示的字段
# 定义主键id,类型为字符串(32),默认用uuid生成前32位字符串
id = Column(String(32), primary_key=True, default=lambda: uuid.uuid4().hex[:32])
# 用户id,外键关联到user表,删除用户后会级联删除,非空并建立索引
user_id = Column(String(32), ForeignKey('user.id', ondelete='CASCADE'), nullable=False, index=True)
# 关联知识库id,外键关联到knowledgebase表,删除知识库后设为NULL,可为空并建立索引
kb_id = Column(String(32), ForeignKey('knowledgebase.id', ondelete='SET NULL'), nullable=True, index=True)
# 会话标题,类型为字符串(255),可以为空(自动生成或用户设置)
title = Column(String(255), nullable=True) # 会话标题(自动生成或用户设置)
# 创建时间,默认为当前时间,并建立索引
created_at = Column(DateTime, default=func.now(), index=True)
# 更新时间,默认为当前时间,更新时自动修改
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
4.5. document.py #
app/models/document.py
# 文档模型
"""
文档模型
"""
# 导入 SQLAlchemy 的列类型
from sqlalchemy import Column, String, Integer, BigInteger, DateTime, Text, ForeignKey
# 导入 SQLAlchemy 的 func 用于生成时间戳
from sqlalchemy.sql import func
# 导入 uuid 用于生成唯一 id
import uuid
# 导入基础模型 BaseModel
from app.models.base import BaseModel
# 定义 Document 文档模型,继承自 BaseModel
class Document(BaseModel):
# 为模型添加说明文档
"""文档模型"""
# 指定表名为 'document'
__tablename__ = 'document'
# 指定 __repr__ 方法中展示的字段
__repr_fields__ = ['id', 'name', 'status'] # 指定 __repr__ 显示的字段
# 文档主键 id,类型为字符串,默认使用 uuid
id = Column(String(32), primary_key=True, default=lambda: uuid.uuid4().hex[:32])
# 知识库 id,外键,指向 knowledgebase.id,删除时级联,不能为空,建立索引
kb_id = Column(String(32), ForeignKey('knowledgebase.id', ondelete='CASCADE'), nullable=False, index=True)
# 文档名称,最大 255 字符,不能为空
name = Column(String(255), nullable=False)
# 文件路径,最大 512 字符,不能为空
file_path = Column(String(512), nullable=False)
# 文件类型,最大 16 字符,不能为空
file_type = Column(String(16), nullable=False)
# 文件大小,类型为 BigInteger,不能为空,默认 0
file_size = Column(BigInteger, nullable=False, default=0)
# 文档状态,最大 16 字符,不能为空,默认 'pending',建立索引
status = Column(String(16), nullable=False, default='pending', index=True)
# 文档分块数,类型为整数,不能为空,默认 0
chunk_count = Column(Integer, nullable=False, default=0)
# 错误信息,允许为空
error_message = Column(Text, nullable=True)
# 创建时间,默认当前时间,建立索引
created_at = Column(DateTime, default=func.now(), index=True)
# 更新时间,默认当前时间,每次更新自动修改,建立索引
updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), index=True)
4.6. knowledgebase.py #
app/models/knowledgebase.py
# 知识库模型文件头部文档字符串
"""
知识库模型
"""
# 导入SQLAlchemy的字段类型和相关功能
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey
# 导入SQLAlchemy的SQL函数
from sqlalchemy.sql import func
# 导入uuid模块用于生成唯一ID
import uuid
# 导入基础模型类
from app.models.base import BaseModel
# 定义知识库模型,继承自基础模型
class Knowledgebase(BaseModel):
# 知识库模型的说明文档
"""知识库模型"""
# 指定模型对应的数据库表名
__tablename__ = 'knowledgebase'
# 指定 __repr__ 方法显示的字段
__repr_fields__ = ['id', 'name'] # 指定 __repr__ 显示的字段
# 主键id字段,类型为32位字符串,默认值为uuid的前32位
id = Column(String(32), primary_key=True, default=lambda: uuid.uuid4().hex[:32])
# 用户id字段,外键关联到user表的id,删除用户时级联删除,不可为空,并且建有索引
user_id = Column(String(32), ForeignKey('user.id', ondelete='CASCADE'), nullable=False, index=True)
# 知识库名称字段,不可为空并建有索引,最大长度128
name = Column(String(128), nullable=False, index=True)
# 描述字段,类型为Text,可以为空
description = Column(Text, nullable=True)
# 封面图片路径字段,类型为字符串,最大长度512,可以为空
cover_image = Column(String(512), nullable=True, comment='封面图片路径')
# 分块大小字段,类型为整数,不能为空,默认值为512
chunk_size = Column(Integer, nullable=False, default=512)
# 分块重叠大小字段,类型为整数,不能为空,默认值为50
chunk_overlap = Column(Integer, nullable=False, default=50)
# 创建时间字段,类型为DateTime,默认为当前时间,并建立索引
created_at = Column(DateTime, default=func.now(), index=True)
# 更新时间字段,类型为DateTime,默认为当前时间,更新时自动变为当前时间
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
4.7. settings.py #
app/models/settings.py
# 设置模型的文档字符串
"""
设置模型
"""
# 引入 SQLAlchemy 的 Column、String、Text、DateTime 类型
from sqlalchemy import Column, String, Text, DateTime
# 引入 SQLAlchemy 的 func,用于时间戳
from sqlalchemy.sql import func
# 引入自定义的 BaseModel 作为父类
from app.models.base import BaseModel
# 定义设置模型类,继承自 BaseModel
class Settings(BaseModel):
# 设置模型说明(单例模式,只存储一条记录)
"""设置模型(单例模式,只存储一条记录)"""
# 指定数据表名为 settings
__tablename__ = 'settings'
# 指定 __repr__ 方法显示的字段为 id
__repr_fields__ = ['id'] # 指定 __repr__ 显示的字段
# 主键,类型为 String(32),默认值为 'global'
id = Column(String(32), primary_key=True, default='global')
# Embedding 配置
embedding_provider = Column(String(64), nullable=False, default='huggingface') # huggingface, openai, ollama
embedding_model_name = Column(String(255), nullable=True) # 模型名称或路径
embedding_api_key = Column(String(255), nullable=True) # API Key(OpenAI 需要)
embedding_base_url = Column(String(512), nullable=True) # Base URL(Ollama 需要)
# LLM 配置
# LLM 提供商,类型为 String(64),不能为空,默认值为 'deepseek'
llm_provider = Column(String(64), nullable=False, default='deepseek') # deepseek, openai, ollama
# LLM 模型名称,类型为 String(255),可为空
llm_model_name = Column(String(255), nullable=True) # 模型名称
# LLM API Key,类型为 String(255),可为空
llm_api_key = Column(String(255), nullable=True) # API Key
# LLM Base URL,类型为 String(512),可为空
llm_base_url = Column(String(512), nullable=True) # Base URL
# LLM 温度参数,类型为 String(16),可为空,默认值为 '0.7'
llm_temperature = Column(String(16), nullable=True, default='0.7') # LLM 温度参数
# 提示词配置
# 普通聊天系统提示词,类型为 Text,可为空
chat_system_prompt = Column(Text, nullable=True) # 普通聊天系统提示词
# 知识库聊天系统提示词,类型为 Text,可为空(会话开始时设置)
rag_system_prompt = Column(Text, nullable=True) # 知识库聊天系统提示词(会话开始时设置)
# 知识库聊天查询提示词,类型为 Text,可为空(每次提问时使用,可包含{context}和{question})
rag_query_prompt = Column(Text, nullable=True) # 知识库聊天查询提示词(每次提问时使用,可包含{context}和{question})
# 检索配置
# 检索模式,类型为 String(32),不能为空,默认值为 'vector'(可选值:vector, keyword, hybrid)
retrieval_mode = Column(String(32), nullable=False, default='vector') # vector, keyword, hybrid
# 向量检索阈值,类型为 String(16),可为空,默认值为 '0.2'
vector_threshold = Column(String(16), nullable=True, default='0.2') # 向量检索阈值
# 全文检索阈值,类型为 String(16),可为空,默认值为 '0.5'
keyword_threshold = Column(String(16), nullable=True, default='0.5') # 全文检索阈值
# 向量检索权重,类型为 String(16),可为空,默认值为 '0.7'(混合检索时使用)
vector_weight = Column(String(16), nullable=True, default='0.7') # 向量检索权重(混合检索时使用)
# TopN 结果数量,类型为 String(16),可为空,默认值为 '5'
top_n = Column(String(16), nullable=True, default='5') # TopN 结果数量
# 创建时间,类型为 DateTime,默认值为当前时间
created_at = Column(DateTime, default=func.now())
# 更新时间,类型为 DateTime,默认值为当前时间,更新时自动更改
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
4.8. user.py #
app/models/user.py
# 用户模型说明文档字符串
"""
用户模型
"""
# 从 SQLAlchemy 导入 Column、String、DateTime、Boolean 类型
from sqlalchemy import Column, String, DateTime, Boolean
# 从 SQLAlchemy 导入 func 用于处理时间戳
from sqlalchemy.sql import func
# 导入 uuid 模块用于生成唯一标识符
import uuid
# 从项目模型基类导入 BaseModel
from app.models.base import BaseModel
# 声明 User 用户模型,继承自 BaseModel
class User(BaseModel):
# 用户模型的类说明
"""用户模型"""
# 指定数据表名称为 'user'
__tablename__ = 'user'
# 指定 __repr__ 时显示的字段为 id 和 username
__repr_fields__ = ['id', 'username'] # 指定 __repr__ 显示的字段
# 用户主键 id,使用 uuid 生成,字符串32位
id = Column(String(32), primary_key=True, default=lambda: uuid.uuid4().hex[:32])
# 用户名字段,必填,唯一且建立索引,最大长度64
username = Column(String(64), nullable=False, unique=True, index=True)
# 邮箱字段,可以为空,唯一且建立索引,最大长度128
email = Column(String(128), nullable=True, unique=True, index=True)
# 密码哈希字段,必填,最大长度255
password_hash = Column(String(255), nullable=False) # 存储密码哈希
# 用户是否激活,默认为激活(True),不可为空
is_active = Column(Boolean, nullable=False, default=True)
# 创建时间字段,默认当前时间,建立索引
created_at = Column(DateTime, default=func.now(), index=True)
# 更新时间字段,默认当前时间,更新时自动刷新
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
# 将当前用户对象转换为字典
def to_dict(self, include_password=False, **kwargs):
# 转换为字典,如果未包含密码,排除 password_hash 字段
"""转换为字典"""
exclude = ['password_hash'] if not include_password else []
# 调用父类的 to_dict 方法,传入要排除的字段
return super().to_dict(exclude=exclude, **kwargs)
4.9. db.py #
app/utils/db.py
# 数据库连接管理(使用上下文管理器封装)的文档字符串
"""
数据库连接管理(使用上下文管理器封装)
"""
# 导入 logging,用于日志记录
import logging
# 导入 contextmanager,用于自定义上下文管理器
from contextlib import contextmanager
# 导入 create_engine,用于创建数据库引擎
from sqlalchemy import create_engine
# 导入 sessionmaker,用于创建会话工厂
from sqlalchemy.orm import sessionmaker
# 导入 SQLAlchemyError,用于捕获 SQLAlchemy 抛出的异常
from sqlalchemy.exc import SQLAlchemyError
# 导入 QueuePool,用于连接池管理
from sqlalchemy.pool import QueuePool
# 导入配置信息
from app.config import Config
# 导入 Base 用于初始化数据库表
from app.models import Base
# 获取当前模块的 logger 对象,用于日志输出
logger = logging.getLogger(__name__)
# 构建数据库连接 URL 的函数
def get_database_url() -> str:
# 返回一个格式化的数据库连接字符串
return (
f"mysql+pymysql://{Config.DB_USER}:{Config.DB_PASSWORD}@"
f"{Config.DB_HOST}:{Config.DB_PORT}/{Config.DB_NAME}?charset={Config.DB_CHARSET}"
)
# 创建数据库引擎,指定连接池、大小、回收等参数
engine = create_engine(
get_database_url(), # 获取数据库连接 URL
poolclass=QueuePool, # 使用 QueuePool 作为连接池
pool_size=10, # 连接池的最大连接数为 10
max_overflow=20, # 允许最大溢出连接数为 20
pool_pre_ping=True, # 连接每次获取前先检查可用性
pool_recycle=3600, # 3600 秒后回收连接
echo=False # 不输出 SQL 日志
)
# 创建 SQLAlchemy 会话工厂,用于生成 session 对象
SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False)
# 定义数据库只读会话的上下文管理器
@contextmanager
def db_session():
"""
数据库会话上下文管理器(适用于只读操作,不自动提交)
使用示例:
>>> from app.utils.db import db_session
>>> with db_session() as db:
... users = db.query(User).all()
... # 只读操作,不会自动提交
返回:
SQLAlchemy Session 对象
"""
# 创建数据库 session 实例
session = SessionLocal()
try:
# 将 session 交给调用方使用
yield session
except Exception as e:
# 出现异常时记录错误日志
logger.error(f"数据库会话错误: {e}")
# 重新抛出异常
raise
finally:
# 会话使用完毕后关闭连接
session.close()
# 定义数据库事务型会话的上下文管理器
@contextmanager
def db_transaction():
"""
数据库事务上下文管理器(显式事务,自动提交,出错自动回滚)
使用示例:
>>> from app.utils.db import db_transaction
>>> with db_transaction() as db:
... user = User(name="test")
... db.add(user)
... # 自动提交,异常时自动回滚
返回:
SQLAlchemy Session 对象
"""
# 创建数据库 session 实例
session = SessionLocal()
try:
# 将 session 交给调用方使用
yield session
# 事务正常结束时自动提交
session.commit()
except SQLAlchemyError as e:
# 捕获 SQLAlchemy 相关异常自动回滚
session.rollback()
# 记录数据库事务相关错误
logger.error(f"数据库事务错误: {e}")
# 重新抛出异常
raise
except Exception as e:
# 捕获其他异常同样回滚
session.rollback()
# 记录数据库操作异常
logger.error(f"数据库操作异常: {e}")
# 重新抛出异常
raise
finally:
# 会话使用完毕后关闭连接
session.close()
# 初始化数据库表结构的函数
def init_db():
# 尝试创建所有基于 Base 的表结构
try:
Base.metadata.create_all(engine)
# 记录数据库表结构初始化完成
logger.info("数据库表结构初始化完成")
except Exception as e:
# 初始化出错时记录错误信息
logger.error(f"数据库初始化失败: {e}")
# 重新抛出异常
raise
4.10. .env #
.env
# 应用配置
APP_HOST=0.0.0.0
APP_PORT=5000
APP_DEBUG=True
MAX_FILE_SIZE=104857600
SECRET_KEY=dev-secret-key-change-in-production
# 日志配置
LOG_DIR=./logs
LOG_FILE=rag_lite.log
LOG_LEVEL=INFO
LOG_ENABLE_FILE=True
LOG_ENABLE_CONSOLE=True
# 数据库配置
+DB_HOST=localhost
+DB_PORT=3306
+DB_USER=root
+DB_PASSWORD=root
+DB_NAME=rag-lite
+DB_CHARSET=utf8mb44.11. .gitignore #
.gitignore
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
docs
.DS_Store
storage
chroma_db
logs
+uv.lock
+pyproject.toml
+__pycache__4.12. init.py #
app/init.py
# RAG Lite 应用模块说明
"""
RAG Lite Application
"""
# 导入操作系统相关模块
import os
# 从 Flask 包导入 Flask 应用对象
from flask import Flask
# 导入 Flask 跨域资源共享支持
from flask_cors import CORS
# 导入应用配置类
from app.config import Config
# 导入日志工具,用于获取日志记录器
from app.utils.logger import get_logger
# 导入数据库初始化函数
+from app.utils.db import init_db
# 定义创建 Flask 应用的工厂函数
def create_app(config_class=Config):
# 获取日志记录器,名称为当前模块名
logger = get_logger(__name__)
# 尝试初始化数据库
+ try:
# 输出日志,表示即将初始化数据库
+ logger.info("初始化数据库...")
# 执行数据库初始化函数
+ init_db()
# 输出日志,表示数据库初始化成功
+ logger.info("数据库初始化成功")
# 捕获任意异常
+ except Exception as e:
# 输出警告日志,提示数据库初始化失败,并输出异常信息
+ logger.warning(f"数据库初始化失败: {e}")
# 输出警告日志,提示检查数据库是否已存在,并建议手动创建数据表
+ logger.warning("请确认数据库已存在,或手动创建数据表")
# 创建 Flask 应用对象,并指定模板和静态文件目录
base_dir = os.path.abspath(os.path.dirname(__file__))
# 创建 Flask 应用对象,并指定模板和静态文件目录
app = Flask(
__name__,
# 指定模板文件目录
template_folder=os.path.join(base_dir, 'templates'),
# 指定静态文件目录
static_folder=os.path.join(base_dir, 'static')
)
# 从给定配置类加载配置信息到应用
app.config.from_object(config_class)
# 启用跨域请求支持
CORS(app)
# 记录应用创建日志信息
logger.info("Flask 应用已创建")
# 定义首页路由
@app.route('/')
def index():
return "Hello, World!"
# 返回已配置的 Flask 应用对象
return app
4.13. config.py #
app/config.py
"""
配置管理模块
"""
# 导入操作系统相关模块
import os
# 导入 Path,处理路径
from pathlib import Path
# 导入 dotenv,用于加载 .env 文件中的环境变量
from dotenv import load_dotenv
# 加载 .env 文件中的环境变量到系统环境变量
load_dotenv()
# 定义应用配置类
class Config:
"""应用配置类"""
# 基础配置
# 项目根目录路径(取上级目录)
BASE_DIR = Path(__file__).parent.parent
# 加载环境变量 SECRET_KEY,若未设置则使用默认开发密钥
SECRET_KEY = os.environ.get('SECRET_KEY') or 'dev-secret-key-change-in-production'
# 应用配置
# 读取应用监听的主机地址,默认为本地所有地址
APP_HOST = os.environ.get('APP_HOST', '0.0.0.0')
# 读取应用监听的端口,默认为 5000,类型为 int
APP_PORT = int(os.environ.get('APP_PORT', 5000))
# 读取 debug 模式配置,字符串转小写等于 'true' 则为 True(开启调试)
APP_DEBUG = os.environ.get('APP_DEBUG', 'false').lower() == 'true'
# 读取允许上传的最大文件大小,默认为 100MB,类型为 int
MAX_FILE_SIZE = int(os.environ.get('MAX_FILE_SIZE', 104857600)) # 100MB
# 允许上传的文件扩展名集合
ALLOWED_EXTENSIONS = {'pdf', 'docx', 'txt', 'md'}
# 日志配置
# 日志目录,默认 './logs'
LOG_DIR = os.environ.get('LOG_DIR', './logs')
# 日志文件名,默认 'rag_lite.log'
LOG_FILE = os.environ.get('LOG_FILE', 'rag_lite.log')
# 日志等级,默认 'INFO'
LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO')
# 是否启用控制台日志,默认 True
LOG_ENABLE_CONSOLE = os.environ.get('LOG_ENABLE_CONSOLE', 'true').lower() == 'true'
# 是否启用文件日志,默认 True
LOG_ENABLE_FILE = os.environ.get('LOG_ENABLE_FILE', 'true').lower() == 'true'
# 数据库配置
# 数据库主机地址,默认为 'localhost'
+ DB_HOST = os.environ.get('DB_HOST', 'localhost')
# 数据库端口号,默认为 3306
+ DB_PORT = int(os.environ.get('DB_PORT', 3306))
# 数据库用户名,默认为 'root'
+ DB_USER = os.environ.get('DB_USER', 'root')
# 数据库密码,默认为 'root'
+ DB_PASSWORD = os.environ.get('DB_PASSWORD', 'root')
# 数据库名称,默认为 'rag-lite'
+ DB_NAME = os.environ.get('DB_NAME', 'rag-lite')
# 数据库字符集,默认为 'utf8mb4'
+ DB_CHARSET = os.environ.get('DB_CHARSET', 'utf8mb4')