1. 环境准备 #
1.1 启动Neo4j #
- 下载地址:https://neo4j.com/download/
- 或使用 Docker:
docker run -d -p 7474:7474 -p 7687:7687 -e NEO4J_AUTH=neo4j/12345678 neo4j:latest
1.2 创建项目目录 #
mkdir graphrag
cd graphrag
uv init
uv add dotenv langchain langchain-deepseek pandas py2neo streamlit requests openai 2. 数据导入 #
2.1. 0_数据导入.py #
pages/0_数据导入.py
# 数据导入页面模块注释
"""数据导入页面"""
# 导入streamlit库并简写为st
import streamlit as st
# 定义主函数
def main():
# 数据导入页面主函数的文档字符串
"""数据导入页面主函数"""
# 配置Streamlit页面,设置标题和布局
st.set_page_config(
page_title="数据导入 - 图书知识图谱",
layout="wide"
)
# 显示页面主标题
st.title("数据导入管理")
# 插入分割线
st.markdown("---")
# 判断是否为主程序入口
if __name__ == "__main__":
# 调用主函数
main() 3. 读取csv文件 #
3.1. constants.py #
constants.py
# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
# 书名
"name",
# 作者
"author",
# 出版社
"publisher",
# 类别
"category",
# 出版年份
"publish_year",
# 简介
"summary",
# 关键词(用分号分隔)
"keywords"
]3.2. import_service.py #
services/import_service.py
"""数据导入服务"""
# 导入io模块用于处理内存中的流数据
import io
# 导入pandas库用于数据分析处理
import pandas as pd
# 导入类型提示
from typing import List, Dict, Any, Optional, Callable
# 从constants模块导入必需的CSV列名常量
from constants import (REQUIRED_CSV_COLUMNS)
# 定义数据导入服务类(单例模式)
class ImportService:
"""数据导入服务(单例模式)"""
# 类变量:存储单例实例
_instance = None
# 重写__new__方法实现单例模式
def __new__(cls):
"""创建单例实例"""
# 如果实例不存在,则创建新实例
if cls._instance is None:
cls._instance = super().__new__(cls)
# 返回单例实例
return cls._instance
# 解析CSV内容为DataFrame的方法
def parse_csv(self, csv_content: str) -> pd.DataFrame:
"""解析CSV内容"""
# 使用pandas读取内存中的CSV内容并返回DataFrame
return pd.read_csv(io.StringIO(csv_content))
# 验证DataFrame格式的方法
def validate_csv(self, df: pd.DataFrame) -> tuple[bool, Optional[str]]:
"""验证CSV格式"""
# 如果DataFrame为空,返回False和错误信息
if df.empty:
return False, "CSV文件为空"
# 检查是否有缺失的必需字段
missing = [col for col in REQUIRED_CSV_COLUMNS if col not in df.columns]
# 如果没有缺失则返回True,否则返回缺失列信息
return (True, None) if not missing else (False, f"缺少列: {', '.join(missing)}")
# 获取导入服务实例的便捷方法
def get_import_service() -> ImportService:
"""获取导入服务单例实例"""
# 返回单例实例(无论调用多少次都返回同一个实例)
return ImportService()3.3. 0_数据导入.py #
pages/0_数据导入.py
# 数据导入页面模块注释
"""数据导入页面"""
# 导入streamlit库并简写为st
import streamlit as st
+from services.import_service import get_import_service
# 定义主函数
def main():
# 数据导入页面主函数的文档字符串
"""数据导入页面主函数"""
# 配置Streamlit页面,设置标题和布局
st.set_page_config(
page_title="数据导入 - 图书知识图谱",
layout="wide"
)
# 显示页面主标题
st.title("数据导入管理")
# 插入分割线
st.markdown("---")
# 获取服务实例
+ import_service = get_import_service()
# 创建一个标签页,只包含“CSV数据导入”标签,并将其赋值给tab1变量
+ tab1, = st.tabs(["CSV数据导入"])
# 在tab1标签页下进行后续操作
+ with tab1:
# 显示“CSV数据导入”二级标题
+ st.header("CSV数据导入")
# 展示有关上传CSV文件的说明及字段要求
+ st.markdown("""
+ 请上传包含图书信息的CSV文件。CSV文件应包含以下列:
+ - `name`: 书名
+ - `author`: 作者
+ - `publisher`: 出版社
+ - `category`: 类别
+ - `publish_year`: 出版年份
+ - `summary`: 简介
+ - `keywords`: 关键词(用分号分隔)
+ """)
# 显示文件上传控件,限定仅可上传csv类型文件,并显示帮助提示
+ uploaded_file = st.file_uploader(
+ "选择CSV文件",
+ type=["csv"],
+ help="上传包含图书信息的CSV文件"
+ )
# 如果用户已经上传了文件
+ if uploaded_file is not None:
+ try:
# 读取上传的文件内容,并解码为utf-8格式的字符串
+ csv_content = uploaded_file.read().decode("utf-8")
# 调用服务将CSV内容解析为DataFrame
+ df = import_service.parse_csv(csv_content)
# 验证上传的DataFrame是否符合CSV格式要求
+ is_valid, error_msg = import_service.validate_csv(df)
# 如果格式校验未通过,显示错误
+ if not is_valid:
+ st.error(f"CSV格式错误: {error_msg}")
# 如果校验通过,显示成功消息,并告知记录数量
+ else:
+ st.success(f"CSV文件验证通过!共 {len(df)} 条记录")
# 捕获任何异常,显示读取文件失败的错误信息
+ except Exception as e:
+ st.error(f"读取文件失败: {str(e)}")
# 判断是否为主程序入口
if __name__ == "__main__":
# 调用主函数
main() 4. 数据导入 #
4.1. .env #
.env
NEO4J_URI="bolt://localhost:7687"
NEO4J_USER="neo4j"
NEO4J_PASSWORD="12345678"
4.2. config.py #
config.py
"""配置管理模块"""
# 导入os模块,用于获取环境变量
import os
# 导入dotenv模块,用于加载.env文件中的环境变量
import dotenv
# 从dataclasses模块导入dataclass,用于简化数据类定义
from dataclasses import dataclass
# 加载.env文件中的所有环境变量
dotenv.load_dotenv()
# 定义Neo4j数据库配置的数据类
@dataclass
class Neo4jConfig:
"""Neo4j数据库配置"""
# 数据库URI
uri: str
# 用户名
user: str
# 密码
password: str
# 定义应用整体配置的数据类
@dataclass
class AppConfig:
"""应用配置"""
# Neo4j数据库配置属性
neo4j: Neo4jConfig
config=AppConfig(
neo4j=Neo4jConfig(
uri=os.environ.get("NEO4J_URI", "bolt://localhost:7687"),
user=os.environ.get("NEO4J_USER", "neo4j"),
password=os.environ.get("NEO4J_PASSWORD", "12345678"),
)
)4.3. init.py #
database/init.py
"""数据库模块"""
4.4. connection.py #
database/connection.py
# 这是Neo4j数据库连接管理的模块说明
"""Neo4j数据库连接管理"""
# 从py2neo库导入Graph类,用于连接和操作Neo4j数据库
from py2neo import Graph
# 从config模块导入config对象,用于获取配置信息
from config import config
# 使用配置信息创建Neo4j数据库连接实例
graph = Graph(config.neo4j.uri, auth=(config.neo4j.user, config.neo4j.password))
4.5. constants.py #
constants.py
# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
# 书名
"name",
# 作者
"author",
# 出版社
"publisher",
# 类别
"category",
# 出版年份
"publish_year",
# 简介
"summary",
# 关键词(用分号分隔)
"keywords"
]
# 定义图书节点标签
+NODE_LABEL_BOOK = "Book"
# 定义作者节点标签
+NODE_LABEL_AUTHOR = "Author"
# 定义出版社节点标签
+NODE_LABEL_PUBLISHER = "Publisher"
# 定义类别节点标签
+NODE_LABEL_CATEGORY = "Category"
# 定义关键词节点标签
+NODE_LABEL_KEYWORD = "Keyword"
# 定义“书籍-作者”关系类型
+RELATIONSHIP_TYPE_WRITTEN_BY = "written_by"
# 定义“书籍-出版社”关系类型
+RELATIONSHIP_TYPE_PUBLISHED_BY = "published_by"
# 定义“书籍-类别”关系类型
+RELATIONSHIP_TYPE_HAS_CATEGORY = "has_category"
# 定义“书籍-关键词”关系类型
+RELATIONSHIP_TYPE_HAS_KEYWORD = "has_keyword"4.6. 0_数据导入.py #
pages/0_数据导入.py
# 数据导入页面模块注释
"""数据导入页面"""
# 导入streamlit库并简写为st
import streamlit as st
+from services.import_service import import_service
# 定义主函数
def main():
# 数据导入页面主函数的文档字符串
"""数据导入页面主函数"""
# 配置Streamlit页面,设置标题和布局
st.set_page_config(
page_title="数据导入 - 图书知识图谱",
layout="wide"
)
# 显示页面主标题
st.title("数据导入管理")
# 插入分割线
st.markdown("---")
# 创建一个标签页,只包含“CSV数据导入”标签,并将其赋值给tab1变量
tab1, = st.tabs(["CSV数据导入"])
# 在tab1标签页下进行后续操作
with tab1:
# 显示“CSV数据导入”二级标题
st.header("CSV数据导入")
# 展示有关上传CSV文件的说明及字段要求
st.markdown("""
请上传包含图书信息的CSV文件。CSV文件应包含以下列:
- `name`: 书名
- `author`: 作者
- `publisher`: 出版社
- `category`: 类别
- `publish_year`: 出版年份
- `summary`: 简介
- `keywords`: 关键词(用分号分隔)
""")
# 显示文件上传控件,限定仅可上传csv类型文件,并显示帮助提示
uploaded_file = st.file_uploader(
"选择CSV文件",
type=["csv"],
help="上传包含图书信息的CSV文件"
)
# 如果用户已经上传了文件
if uploaded_file is not None:
try:
# 读取上传的文件内容,并解码为utf-8格式的字符串
csv_content = uploaded_file.read().decode("utf-8")
# 调用服务将CSV内容解析为DataFrame
df = import_service.parse_csv(csv_content)
# 验证上传的DataFrame是否符合CSV格式要求
is_valid, error_msg = import_service.validate_csv(df)
# 如果格式校验未通过,显示错误
if not is_valid:
st.error(f"CSV格式错误: {error_msg}")
# 如果校验通过,显示成功消息,并告知记录数量
else:
st.success(f"CSV文件验证通过!共 {len(df)} 条记录")
# 显示预览
+ with st.expander("数据预览", expanded=True):
+ st.dataframe(df)
# 导入选项
+ clear_existing = st.checkbox(
+ "清空现有数据",
+ value=False,
+ help="导入前清空数据库中的所有数据"
+ )
# 当用户点击"开始导入"按钮时执行
+ if st.button("开始导入", type="primary", width='stretch'):
# 创建一个显示日志信息的空容器
+ log_container = st.empty()
# 用于存放日志消息的列表
+ log_messages = []
# 定义进度回调函数
+ def progress_callback(message, current, total):
# 添加当前消息到日志列表(包含进度信息)
+ log_messages.append(f"- [{current}/{total}] {message}")
# 刷新日志容器内容
+ log_container.text("\n".join(log_messages))
+ try:
# 调用导入服务执行数据导入
+ stats = import_service.import_books(
+ df,#数据源
+ clear_existing=clear_existing,# 是否清空现有数据
+ progress_callback=progress_callback# 进度回调函数
+ )
# 显示最终统计信息
+ if stats:
+ st.info(
+ f"完成: 图书 {stats.get('books', 0)} | "
+ f"作者 {stats.get('authors', 0)} | "
+ f"出版社 {stats.get('publishers', 0)} | "
+ f"类别 {stats.get('categories', 0)} | "
+ f"关键词 {stats.get('keywords', 0)} | "
+ f"关系 {stats.get('relationships', 0)}"
+ )
# 如果有错误信息,显示出来
+ if stats.get("errors") and len(stats["errors"]) > 0:
+ with st.expander("⚠️ 导入过程中的错误", expanded=False):
+ for error in stats["errors"]:
+ st.error(error)
+ except Exception as e:
# 导入过程中如有异常则显示错误信息
+ st.error(f"导入失败: {str(e)}")
# 导入traceback模块以显示详细错误
+ import traceback
# 展开详细错误信息
+ with st.expander("详细错误信息"):
# 格式化并显示完整的异常调用栈
+ st.code(traceback.format_exc())
# 捕获任何异常,显示读取文件失败的错误信息
except Exception as e:
st.error(f"读取文件失败: {str(e)}")
# 判断是否为主程序入口
if __name__ == "__main__":
# 调用主函数
main() 4.8. import_service.py #
services/import_service.py
# 数据导入服务说明
"""数据导入服务"""
# 导入io模块,用于内存数据流操作
import io
# 导入pandas进行数据处理
import pandas as pd
# 从typing模块导入类型提示
from typing import List, Dict, Any, Optional, Callable
# 从database.connection导入数据库实例graph
+from database.connection import graph
# 从constants模块导入各类常量
+from constants import (
+ REQUIRED_CSV_COLUMNS, # 必需的CSV列名
+ NODE_LABEL_BOOK, # 图书节点标签
+ NODE_LABEL_AUTHOR, # 作者节点标签
+ NODE_LABEL_PUBLISHER, # 出版社节点标签
+ NODE_LABEL_CATEGORY, # 类别节点标签
+ NODE_LABEL_KEYWORD, # 关键词节点标签
+ RELATIONSHIP_TYPE_WRITTEN_BY, # 书-作者关系类型
+ RELATIONSHIP_TYPE_PUBLISHED_BY, # 书-出版社关系类型
+ RELATIONSHIP_TYPE_HAS_CATEGORY, # 书-类别关系类型
+ RELATIONSHIP_TYPE_HAS_KEYWORD # 书-关键词关系类型
+)
# 导入py2neo中的Node、Relationship、NodeMatcher
+from py2neo import Node, Relationship, NodeMatcher
# 定义数据导入服务类,采用单例模式
class ImportService:
# 类说明:数据导入服务(单例)
"""数据导入服务(单例模式)"""
# 初始化方法
+ def __init__(self):
# 创建NodeMatcher对象用于节点查找
+ self.node_matcher = NodeMatcher(graph)
# 解析CSV内容的函数
def parse_csv(self, csv_content: str) -> pd.DataFrame:
# 方法说明:解析CSV内容为DataFrame
"""解析CSV内容"""
# 利用pandas读取内存中的CSV字符串,返回DataFrame
return pd.read_csv(io.StringIO(csv_content))
# 验证CSV格式合法性的方法
def validate_csv(self, df: pd.DataFrame) -> tuple[bool, Optional[str]]:
# 方法说明:验证CSV格式
"""验证CSV格式"""
# 判断DataFrame是否为空,为空则返回错误消息
if df.empty:
return False, "CSV文件为空"
# 检查必需CSV列是否全部存在
missing = [col for col in REQUIRED_CSV_COLUMNS if col not in df.columns]
# 若无漏列返回True,否则返回缺失列名
return (True, None) if not missing else (False, f"缺少列: {', '.join(missing)}")
# 清空全部数据库的方法
+ def clear_database(self) -> None:
# 方法说明:清空数据库,删除所有节点和关系
+ """清空数据库"""
# 执行Cypher语句,删除所有节点及其关系
+ graph.run("MATCH (n) DETACH DELETE n")
# 内部方法:创建节点
+ def _create_nodes(
+ self, label, node_list, stats, stat_key,
+ progress_callback, current, total
+ ) -> int:
# 方法说明:创建指定类型的节点并更新进度
+ """创建节点
+ Args:
+ label: 节点标签
+ node_list: 节点列表,每项为字典
+ stats: 统计计数与错误信息字典
+ stat_key: 统计键名
+ progress_callback: 进度回调函数
+ current: 当前进度值
+ total: 总进度步数
+ """
# 有进度回调时,通知正在创建节点
+ if progress_callback:
+ progress_callback(f"正在创建{label}节点...", current, total)
# 初始化进度步数
+ step = current
# 遍历所有节点数据逐一创建
+ for node in node_list:
+ try:
# 取节点名称用于反馈与出错描述,默认“未知”
+ node_name = node.get("name", "未知")
# 构建节点对象
+ node = Node(label, **node)
# 写入数据库
+ graph.create(node)
# 节点计数器+1
+ stats[stat_key] += 1
# 进度步数+1
+ step += 1
# 回调进度反馈
+ if progress_callback:
+ progress_callback(f"创建{label}: {node_name}", step, total)
+ except Exception as e:
# 异常情况将出错信息记录到stats["errors"]
+ node_name = node.get("name", "未知")
+ stats["errors"].append(f"创建{label} {node_name} 失败: {str(e)}")
+ step += 1
# 返回当前进度步数
+ return step
# 内部方法:创建节点间关系
+ def _create_relationships(
+ self, start_label, end_label, edges, rel_type,
+ progress_callback, current, total, stats=None
+ ) -> int:
+ """
+ 创建关系
+ Args:
+ start_label: 起始节点标签
+ end_label: 终止节点标签
+ edges: 关系对列表,每项为[起始节点名称, 终止节点名称]
+ rel_type: 关系类型
+ progress_callback: 进度回调函数
+ current: 当前进度值
+ total: 总进度步数
+ stats: 统计计数与错误信息字典
+ """
# 初始化关系计数
+ count = 0
# 初始化进度值
+ step = current
# 遍历所有关系对
+ for edge in edges:
+ try:
# 查找起始节点对象
+ start_node = self.node_matcher.match(start_label, name=str(edge[0])).first()
# 查找终止节点对象
+ end_node = self.node_matcher.match(end_label, name=str(edge[1])).first()
# 两端节点均存在后创建关系
+ if start_node and end_node:
+ rel = Relationship(start_node, rel_type, end_node)#创建关系对象
+ graph.create(rel)#写入数据库
+ count += 1#关系计数器+1
+ step += 1#进度步数+1
# 调用进度回调反馈
+ if progress_callback:# 有进度回调时,通知正在创建关系
+ progress_callback(f"创建关系: {edge[0]}-{rel_type}->{edge[1]}", step, total)
+ else:
# 如有节点不存在,将详细信息加到stats["errors"]
+ if stats:
+ missing = []#缺失节点列表
+ if not start_node:
+ missing.append(f"{start_label}:{edge[0]}")#添加起始节点
+ if not end_node:
+ missing.append(f"{end_label}:{edge[1]}")#添加终止节点
+ stats["errors"].append(f"关系创建失败: 找不到节点 {', '.join(missing)}")#添加错误信息
+ step += 1#进度步数+1
+ except Exception as e:
# 其他异常也写入stats["errors"]
+ if stats:# 有统计信息时,添加错误信息
+ stats["errors"].append(f"创建关系 {edge[0]}-{rel_type}->{edge[1]} 时出错: {str(e)}")
+ step += 1
# 返回成功创建的关系数
+ return count
# 图书数据导入的主方法
+ def import_books(
+ self, df: pd.DataFrame, clear_existing: bool = False,
+ progress_callback: Optional[Callable] = None
+ ) -> Dict[str, Any]:
# 方法说明:将DataFrame批量导入为图数据库节点与关系
+ """导入图书数据"""
# 初始化统计信息字典
+ stats = {
+ "books": 0, "authors": 0, "publishers": 0,
+ "categories": 0, "keywords": 0, "relationships": 0, "errors": []
+ }
+ try:
# 是否需要清空库
+ if clear_existing:
# 有进度回调时先通知
+ if progress_callback:
+ progress_callback("正在清空数据库...", 0, 0)
# 调用清库
+ self.clear_database()
# 生成图书节点数据列表
+ books = []
+ for _, row in df.iterrows():
+ if pd.notna(row["name"]):
# 创建基本字典(书名)
+ book = {"name": row["name"]}
# 补充出版年份
+ if pd.notna(row.get("publish_year")):
+ book["publish_year"] = int(row["publish_year"])
# 补充简介
+ if pd.notna(row.get("summary")):
+ book["summary"] = str(row["summary"])
# 处理关键字列表
+ if pd.notna(row.get("keywords")):
+ book["keywords"] = [kw.strip() for kw in str(row["keywords"]).split(";") if kw.strip()]
+ books.append(book)
# authors节点只创建唯一且非空的作者名
+ authors = [{"name": name} for name in set(df["author"].dropna())]
# publishers节点只创建唯一且非空的出版社名
+ publishers = [{"name": name} for name in set(df["publisher"].dropna())]
# categories节点只创建唯一且非空的类别名
+ categories = [{"name": name} for name in set(df["category"].dropna())]
# 提取所有keywords列的分号分割项,并去重空
+ keywords_set = set()
# 遍历所有关键词,并去重空格
+ for keyword in df["keywords"].dropna():
# 将关键词按分号分割,并去重空格
+ keywords_set.update([kw.strip() for kw in keyword.split(";") if kw.strip()])
# 将关键词列表转换为字典列表
+ keywords = [{"name": keyword.strip()} for keyword in keywords_set if keyword.strip()]
# 构建关系对列表
+ rels_written_by = [] # 书-作者
+ rels_published_by = [] # 书-出版社
+ rels_has_category = [] # 书-类别
+ rels_has_keyword = [] # 书-关键词
# 遍历每行收集上述所有关系
+ for _, row in df.iterrows():
# 书-作者
+ if pd.notna(row["name"]) and pd.notna(row["author"]):
+ rels_written_by.append([row["name"], row["author"]])
# 书-出版社
+ if pd.notna(row["name"]) and pd.notna(row["publisher"]):
+ rels_published_by.append([row["name"], row["publisher"]])
# 书-类别
+ if pd.notna(row["name"]) and pd.notna(row["category"]):
+ rels_has_category.append([row["name"], row["category"]])
# 书-关键词(支持多关键词)
+ if pd.notna(row["name"]) and pd.notna(row["keywords"]):
+ for kw in row["keywords"].split(";"):
+ rels_has_keyword.append([row["name"], kw.strip()])
# 统计全部节点和关系数,便于进度反馈
+ total_steps = (
+ len(books) + len(authors) + len(publishers) +
+ len(categories) + len(keywords) +
+ len(rels_written_by) + len(rels_published_by) +
+ len(rels_has_category) + len(rels_has_keyword)
+ )
# 初始化进度指针
+ current = 0
# 创建所有图书节点
+ current = self._create_nodes(
+ NODE_LABEL_BOOK, books, stats, "books",
+ progress_callback, current, total_steps
+ )
# 创建作者节点
+ current = self._create_nodes(
+ NODE_LABEL_AUTHOR, authors, stats, "authors",
+ progress_callback, current, total_steps
+ )
# 创建出版社节点
+ current = self._create_nodes(
+ NODE_LABEL_PUBLISHER, publishers, stats, "publishers",
+ progress_callback, current, total_steps
+ )
# 创建类别节点
+ current = self._create_nodes(
+ NODE_LABEL_CATEGORY, categories, stats, "categories",
+ progress_callback, current, total_steps
+ )
# 创建关键词节点
+ current = self._create_nodes(
+ NODE_LABEL_KEYWORD, keywords, stats, "keywords",
+ progress_callback, current, total_steps
+ )
# 所有节点创建后,反馈将创建关系
+ if progress_callback:
+ progress_callback("正在创建关系...", current, total_steps)
# 创建书-作者关系并累计关系数
+ stats["relationships"] += self._create_relationships(
+ NODE_LABEL_BOOK, NODE_LABEL_AUTHOR, rels_written_by,
+ RELATIONSHIP_TYPE_WRITTEN_BY, progress_callback, current, total_steps, stats
+ )
# 更新进度
+ current += len(rels_written_by)
# 创建书-出版社关系
+ stats["relationships"] += self._create_relationships(
+ NODE_LABEL_BOOK, NODE_LABEL_PUBLISHER, rels_published_by,
+ RELATIONSHIP_TYPE_PUBLISHED_BY, progress_callback, current, total_steps, stats
+ )
+ current += len(rels_published_by)
# 创建书-类别关系
+ stats["relationships"] += self._create_relationships(
+ NODE_LABEL_BOOK, NODE_LABEL_CATEGORY, rels_has_category,
+ RELATIONSHIP_TYPE_HAS_CATEGORY, progress_callback, current, total_steps, stats
+ )
+ current += len(rels_has_category)
# 创建书-关键词关系
+ stats["relationships"] += self._create_relationships(
+ NODE_LABEL_BOOK, NODE_LABEL_KEYWORD, rels_has_keyword,
+ RELATIONSHIP_TYPE_HAS_KEYWORD, progress_callback, current, total_steps, stats
+ )
# 如果整个导入过程中出现异常,记录到stats["errors"]
+ except Exception as e:
+ stats["errors"].append(f"导入过程出错: {str(e)}")
# 返回统计字典
+ return stats
# 创建ImportService的单例实例
+import_service = ImportService()5. 向量嵌入 #
5.1. embedding_service.py #
services/embedding_service.py
"""向量初始化服务"""
# 导入类型提示所需的类型List、Optional、Callable
from typing import List, Optional, Callable
# 导入HTTP请求库requests
import requests
# 导入py2neo包中的NodeMatcher用于节点查询
from py2neo import NodeMatcher
# 导入常量:节点标签(图书、作者)
from constants import NODE_LABEL_BOOK, NODE_LABEL_AUTHOR
# 导入Neo4j数据库连接
from database.connection import graph
# 导入配置对象
from config import config
# 定义向量初始化服务类
class EmbeddingInitService:
"""向量初始化服务"""
# 构造函数,初始化必要的成员变量
def __init__(self):
# 创建NodeMatcher对象,用于查找节点
self.node_matcher = NodeMatcher(graph)
# 嵌入API的URL,从配置获取
self.api_url = config.embedding.api_url
# API密钥,从配置获取
self.api_key = config.embedding.api_key
# 嵌入模型名,从配置获取
self.model = config.embedding.model
# 为节点批量生成和更新嵌入向量
def update_embeddings(
self,
progress_callback: Optional[Callable] = None
) -> dict:
"""为节点生成嵌入向量"""
# 初始化统计信息字典
stats = {"total_nodes": 0, "processed": 0, "failed": 0, "errors": []}
# 要处理的节点标签列表(图书和作者)
node_labels = [NODE_LABEL_BOOK, NODE_LABEL_AUTHOR]
# 统计所有目标节点的总数
total_nodes = 0
for label in node_labels:
# 查找指定标签的所有节点
nodes = list(self.node_matcher.match(label))
# 累加节点数量
total_nodes += len(nodes)
# 将总节点数写入统计信息
stats["total_nodes"] = total_nodes
# 当前已处理节点数
current = 0
# 遍历每个标签
for node_label in node_labels:
try:
# 获取该标签下所有节点
nodes = list(self.node_matcher.match(node_label))
# 遍历每个节点
for node in nodes:
# 获取节点name属性
name = node.get("name")
try:
# 获取文本的向量
embedding = self.get_embedding(name)
# 将嵌入向量写入节点属性
node["embedding"] = embedding
# 更新节点到数据库
graph.push(node)
# 成功处理节点数量加一
stats["processed"] += 1
# 当前计数加一
current += 1
# 调用进度回调函数(已成功处理该节点)
progress_callback(f" 已为 {name} 设置嵌入向量", current, total_nodes)
except Exception as e:
# 处理该节点失败时,记录错误信息
stats["errors"].append(f"处理 {name} 时出错: {str(e)}")
# 失败计数加一
stats["failed"] += 1
# 当前计数加一
current += 1
# 调用进度回调函数(处理失败)
progress_callback(f"处理 {name} 失败", current, total_nodes)
except Exception as e:
# 标签下所有节点处理发生异常时,记录错误
stats["errors"].append(f"处理{node_label}节点时出错: {str(e)}")
# 调用进度回调函数(标签整体失败)
progress_callback(f"处理{node_label}节点时出错: {str(e)}", current, total_nodes)
# 返回统计信息
return stats
# 获取单条文本的嵌入向量
def get_embedding(self, text: str) -> List[float]:
"""获取文本的嵌入向量"""
# 构造HTTP请求头
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
# 构造请求体
payload = {"model": self.model, "input": text}
# 输出请求体到控制台(调试用)
print(payload)
try:
# 发送post请求到嵌入API
response = requests.post(
self.api_url,
json=payload,
headers=headers
)
# 若请求成功,且返回状态码200
if response.status_code == 200:
# 获取响应体中的数据内容
data = response.json()
# 从响应体获取嵌入向量
embedding = data["data"][0].get("embedding")
# 若未获得向量内容,则抛出异常
if not embedding:
raise Exception("API返回数据格式错误")
# 返回嵌入向量
return embedding
else:
# 若HTTP响应非200,抛出异常包含错误状态码
error_msg = f"API错误: HTTP {response.status_code}"
raise Exception(error_msg)
# 网络异常处理
except requests.exceptions.RequestException as e:
# 抛出网络相关异常
raise Exception(f"网络错误: {e}") from e
# 创建EmbeddingInitService实例并赋值给embedding_service变量,供外部调用
embedding_service = EmbeddingInitService()5.2. .env #
.env
NEO4J_URI="bolt://localhost:7687"
NEO4J_USER="neo4j"
NEO4J_PASSWORD="12345678"
+VOLC_EMBEDDINGS_API_URL=https://ark.cn-beijing.volces.com/api/v3/embeddings
+VOLC_EMBEDDING_MODEL=doubao-embedding-text-240715
+VOLC_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f65.3. config.py #
config.py
"""配置管理模块"""
# 导入os模块,用于获取环境变量
import os
# 导入dotenv模块,用于加载.env文件中的环境变量
import dotenv
# 从dataclasses模块导入dataclass,用于简化数据类定义
from dataclasses import dataclass
# 加载.env文件中的所有环境变量
dotenv.load_dotenv()
# 定义Neo4j数据库配置的数据类
@dataclass
class Neo4jConfig:
"""Neo4j数据库配置"""
# 数据库URI
uri: str
# 用户名
user: str
# 密码
password: str
+@dataclass
+class EmbeddingConfig:
+ """火山方舟嵌入API配置"""
+ api_url: str
+ api_key: str
+ model: str
# 定义应用整体配置的数据类
@dataclass
class AppConfig:
"""应用配置"""
# Neo4j数据库配置属性
neo4j: Neo4jConfig
# 火山方舟嵌入API配置属性
+ embedding: EmbeddingConfig
# 创建一个AppConfig对象,保存应用的整体配置
+config = AppConfig(
# Neo4j数据库配置,使用Neo4jConfig类初始化
+ neo4j=Neo4jConfig(
# 从环境变量获取数据库URI,如果没有则使用默认值"bolt://localhost:7687"
+ uri=os.environ.get("NEO4J_URI", "bolt://localhost:7687"),
# 从环境变量获取数据库用户名,如果没有则使用默认值"neo4j"
+ user=os.environ.get("NEO4J_USER", "neo4j"),
# 从环境变量获取数据库密码,如果没有则使用默认值"12345678"
+ password=os.environ.get("NEO4J_PASSWORD", "12345678"),
+ ),
# 嵌入API配置,使用EmbeddingConfig类初始化
+ embedding=EmbeddingConfig(
# 从环境变量获取嵌入API的URL,没有则使用默认值
+ api_url=os.environ.get(
+ "VOLC_EMBEDDINGS_API_URL",
+ "https://ark.cn-beijing.volces.com/api/v3/embeddings"
+ ),
# 从环境变量获取API密钥,如果没有则为空字符串
+ api_key=os.environ.get("VOLC_API_KEY", ""),
# 从环境变量获取嵌入模型名,没有则使用默认值"doubao-embedding-text-240715"
+ model=os.environ.get("VOLC_EMBEDDING_MODEL", "doubao-embedding-text-240715"),
)
+)5.4. connection.py #
database/connection.py
# 这是Neo4j数据库连接管理的模块说明
"""Neo4j数据库连接管理"""
# 从py2neo库导入Graph类,用于连接和操作Neo4j数据库
from py2neo import Graph
# 从config模块导入get_config函数,用于获取配置信息
+from config import config
# 使用配置信息创建Neo4j数据库连接实例
graph = Graph(config.neo4j.uri, auth=(config.neo4j.user, config.neo4j.password))
5.5. 0_数据导入.py #
pages/0_数据导入.py
# 数据导入页面模块注释
"""数据导入页面"""
# 导入streamlit库并简写为st
import streamlit as st
from services.import_service import import_service
+from services.embedding_service import embedding_service
# 定义主函数
def main():
# 数据导入页面主函数的文档字符串
"""数据导入页面主函数"""
# 配置Streamlit页面,设置标题和布局
st.set_page_config(
page_title="数据导入 - 图书知识图谱",
layout="wide"
)
# 显示页面主标题
st.title("数据导入管理")
# 插入分割线
st.markdown("---")
# 创建一个标签页,只包含“CSV数据导入”标签,并将其赋值给tab1变量
+ tab1, tab2 = st.tabs(["CSV数据导入","向量初始化"])
# 在tab1标签页下进行后续操作
with tab1:
# 显示“CSV数据导入”二级标题
st.header("CSV数据导入")
# 展示有关上传CSV文件的说明及字段要求
st.markdown("""
请上传包含图书信息的CSV文件。CSV文件应包含以下列:
- `name`: 书名
- `author`: 作者
- `publisher`: 出版社
- `category`: 类别
- `publish_year`: 出版年份
- `summary`: 简介
- `keywords`: 关键词(用分号分隔)
""")
# 显示文件上传控件,限定仅可上传csv类型文件,并显示帮助提示
uploaded_file = st.file_uploader(
"选择CSV文件",
type=["csv"],
help="上传包含图书信息的CSV文件"
)
# 如果用户已经上传了文件
if uploaded_file is not None:
try:
# 读取上传的文件内容,并解码为utf-8格式的字符串
csv_content = uploaded_file.read().decode("utf-8")
# 调用服务将CSV内容解析为DataFrame
df = import_service.parse_csv(csv_content)
# 验证上传的DataFrame是否符合CSV格式要求
is_valid, error_msg = import_service.validate_csv(df)
# 如果格式校验未通过,显示错误
if not is_valid:
st.error(f"CSV格式错误: {error_msg}")
# 如果校验通过,显示成功消息,并告知记录数量
else:
st.success(f"CSV文件验证通过!共 {len(df)} 条记录")
# 显示预览
with st.expander("数据预览", expanded=True):
st.dataframe(df)
# 导入选项
clear_existing = st.checkbox(
"清空现有数据",
value=False,
help="导入前清空数据库中的所有数据"
)
# 当用户点击"开始导入"按钮时执行
if st.button("开始导入", type="primary", width='stretch'):
# 创建一个显示日志信息的空容器
log_container = st.empty()
# 用于存放日志消息的列表
log_messages = []
# 定义进度回调函数
def progress_callback(message, current, total):
# 添加当前消息到日志列表(包含进度信息)
log_messages.append(f"- [{current}/{total}] {message}")
# 刷新日志容器内容
log_container.text("\n".join(log_messages))
try:
# 调用导入服务执行数据导入
stats = import_service.import_books(
df,#数据源
clear_existing=clear_existing,# 是否清空现有数据
progress_callback=progress_callback# 进度回调函数
)
# 显示最终统计信息
if stats:
st.info(
f"完成: 图书 {stats.get('books', 0)} | "
f"作者 {stats.get('authors', 0)} | "
f"出版社 {stats.get('publishers', 0)} | "
f"类别 {stats.get('categories', 0)} | "
f"关键词 {stats.get('keywords', 0)} | "
f"关系 {stats.get('relationships', 0)}"
)
# 如果有错误信息,显示出来
if stats.get("errors") and len(stats["errors"]) > 0:
with st.expander("⚠️ 导入过程中的错误", expanded=False):
for error in stats["errors"]:
st.error(error)
except Exception as e:
# 导入过程中如有异常则显示错误信息
st.error(f"导入失败: {str(e)}")
# 导入traceback模块以显示详细错误
import traceback
# 展开详细错误信息
with st.expander("详细错误信息"):
# 格式化并显示完整的异常调用栈
st.code(traceback.format_exc())
# 捕获任何异常,显示读取文件失败的错误信息
except Exception as e:
st.error(f"读取文件失败: {str(e)}")
# 使用tab2选项卡
+ with tab2:
# 设置页面标题为“向量初始化”
+ st.header("向量初始化")
# 显示多行文字介绍向量初始化的功能
+ st.markdown("""
+ 为数据库中的节点生成嵌入向量。此操作会为所有Book和Author节点生成向量,
+ 用于后续的向量检索功能。
+ """)
# 如果用户点击“开始生成向量”按钮
+ if st.button("开始生成向量", type="primary", width='stretch'):
# 创建一个空的日志容器用于显示进度
+ log_container = st.empty()
# 用于存放日志消息的列表
+ log_messages = []
# 定义进度回调函数
+ def progress_callback(message,current, total):
# 向日志列表追加当前进度消息
+ log_messages.append(f"- [{current}/{total}] {message}")
# 在日志容器中刷新显示所有日志消息
+ log_container.text("\n".join(log_messages))
+ try:
# 调用embedding_service的update_embeddings方法,开始进行向量初始化
+ stats = embedding_service.update_embeddings(progress_callback=progress_callback)
# 在页面上显示向量生成成功的提示
+ st.success("向量生成完成!")
# 显示统计信息标题
+ st.markdown("### 生成统计")
# 将页面分为三列用于展示不同统计项
+ col1, col2, col3 = st.columns(3)
# 第一列显示总节点数
+ with col1:
+ st.metric("总节点数", stats["total_nodes"])
# 第二列显示已处理成功的数量
+ with col2:
+ st.metric("成功", stats["processed"])
# 第三列显示处理失败的数量
+ with col3:
+ st.metric("失败", stats["failed"])
# 如果存在错误信息,则在可展开区域中展示所有错误
+ if stats["errors"]:
+ with st.expander("错误信息", expanded=False):
+ for error in stats["errors"]:
+ st.error(error)
# 捕获异常并显示错误信息
+ except Exception as e:
# 在页面上显示向量生成失败的错误信息
+ st.error(f"向量生成失败: {str(e)}")
# 导入traceback模块用于获取详细出错信息
+ import traceback
# 可展开区域显示完整Traceback
+ with st.expander("详细错误信息"):
+ st.code(traceback.format_exc())
# 判断是否为主程序入口
if __name__ == "__main__":
# 调用主函数
main() 6. 提问 #
6.1. 1_问答系统.py #
pages/1_问答系统.py
"""Streamlit图书知识图谱问答系统主应用"""
import streamlit as st
from config import config
def main():
"""主函数"""
st.set_page_config(
page_title="问答系统 - 图书知识图谱",
layout="wide"
)
if "messages" not in st.session_state:
st.session_state.messages = []
with st.sidebar:
st.markdown("### 参数设置")
st.markdown(
"<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
unsafe_allow_html=True,
)
for message in st.session_state.messages:
st.chat_message(message["role"]).write(message["content"])
if query := st.chat_input("输入图书相关问题", key="query_input"):
st.session_state.messages.append({"role": "user", "content": query})
st.chat_message("user").write(query)
if __name__ == "__main__":
main()
7. 向量化查询 #
7.1. queries.py #
database/queries.py
"""图数据库查询"""
# 导入类型提示
from typing import List, Dict, Any, Optional
# 导入py2neo的Node类型
from py2neo import Node
# 导入Neo4j连接对象
from database.connection import graph
# 导入常量:向量索引名
from constants import VECTOR_INDEX_BOOK, VECTOR_INDEX_AUTHOR
# 定义图书查询类
class BookQuery:
"""图书查询"""
# 基于嵌入向量检索图书的方法
def query_by_embedding(
self, query_embedding: List[float], top_k: int = 3
) -> List[Dict[str, Any]]:
"""基于向量检索图书"""
# 定义Cypher向量检索查询语句
query = f"""
CALL db.index.vector.queryNodes('{VECTOR_INDEX_BOOK}', $top_k, $query_embedding)
YIELD node, score
MATCH (node:Book)
OPTIONAL MATCH (node)-[:written_by]->(author:Author)
OPTIONAL MATCH (node)-[:published_by]->(publisher:Publisher)
OPTIONAL MATCH (node)-[:has_category]->(category:Category)
OPTIONAL MATCH (node)-[:has_keyword]->(keyword:Keyword)
WITH node, score, author, publisher, category, COLLECT(DISTINCT keyword.name) AS keyword_names
RETURN node, score, author, publisher, category, keyword_names
ORDER BY score DESC
"""
try:
# 执行Cypher查询,传入top_k和query_embedding参数
results = graph.run(query, top_k=top_k, query_embedding=query_embedding)
# 初始化图书结果列表
books = []
# 遍历查询结果
for record in results:
# 获取图书节点
node: Node = record["node"]
# 获取相似度分数
score = float(record["score"])
# 获取作者节点
author_node: Optional[Node] = record.get("author")
# 获取出版社节点
publisher_node: Optional[Node] = record.get("publisher")
# 获取类别节点
category_node: Optional[Node] = record.get("category")
# 获取关键词名称列表
keyword_names = record.get("keyword_names", [])
# 如果关系中存在关键词则使用,否则优先节点属性keywords
keywords = keyword_names if keyword_names else (node.get("keywords", []) or [])
# 构建图书结果字典
books.append({
"name": node.get("name", ""),
"similarity": score,
"作者": author_node.get("name") if author_node else None,
"出版社": publisher_node.get("name") if publisher_node else None,
"类别": category_node.get("name") if category_node else None,
"出版年份": node.get("publish_year"),
"简介": node.get("summary"),
"关键词": keywords,
})
# 返回图书结果列表
return books
except Exception as e:
# 查询出错时抛出异常并提示
raise Exception(f"图书查询失败: {e}") from e
# 定义作者查询类
class AuthorQuery:
"""作者查询"""
# 基于嵌入向量检索作者的方法
def query_by_embedding(
self, query_embedding: List[float], top_k: int = 3
) -> List[Dict[str, Any]]:
"""基于向量检索作者"""
# 定义Cypher向量检索查询语句
query = f"""
CALL db.index.vector.queryNodes('{VECTOR_INDEX_AUTHOR}', $top_k, $query_embedding)
YIELD node, score
MATCH (node:Author)
OPTIONAL MATCH (node)<-[:written_by]-(book:Book)
RETURN node, score, COLLECT(DISTINCT book.name) AS book_names
ORDER BY score DESC
"""
try:
# 执行Cypher查询,传入top_k和query_embedding参数
results = graph.run(query, top_k=top_k, query_embedding=query_embedding)
# 初始化作者结果列表
authors = []
# 遍历查询结果
for record in results:
# 获取作者节点
node: Node = record["node"]
# 获取相似度分数
score = float(record["score"])
# 获取相关图书名称列表
book_names = record.get("book_names", [])
# 构建作者结果字典
authors.append({
"name": node.get("name", ""),
"similarity": score,
"相关图书": book_names,
})
# 返回作者结果列表
return authors
except Exception as e:
# 查询出错时抛出异常并提示
raise Exception(f"作者查询失败: {e}") from e
7.2. retrieval_service.py #
services/retrieval_service.py
"""检索服务"""
# 导入类型提示
from typing import List, Dict, Any
# 导入图书和作者的查询类
from database.queries import BookQuery, AuthorQuery
# 导入嵌入服务(虽然此文件中未直接用到)
from services.embedding_service import embedding_service
# 导入用于区分查询类型和默认top_k值的常量
from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR, DEFAULT_TOP_K
class RetrievalService:
"""检索服务"""
# 初始化方法,分别实例化图书和作者查询对象
def __init__(self):
self.book_query = BookQuery()
self.author_query = AuthorQuery()
# 基于嵌入向量查询图书或作者的主方法
def query_by_embedding(
self,
query_embedding: List[float], # 问题的嵌入向量
query_type: str, # 查询类型(图书或作者)
top_k: int = DEFAULT_TOP_K, # 返回结果数量,默认值为常量
) -> List[Dict[str, Any]]:
"""使用嵌入向量查询"""
# 若查询类型为“图书”,则使用book_query执行查询
if query_type == QUERY_TYPE_BOOK:
return self.book_query.query_by_embedding(query_embedding, top_k)
# 若查询类型为“作者”,则使用author_query执行查询
elif query_type == QUERY_TYPE_AUTHOR:
return self.author_query.query_by_embedding(query_embedding, top_k)
# 其他类型(不支持的类型)抛出异常
else:
raise ValueError(f"不支持的查询类型: {query_type}")
# 创建RetrievalService的单例,便于外部调用
retrieval_service = RetrievalService()7.3. constants.py #
constants.py
# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
# 书名
"name",
# 作者
"author",
# 出版社
"publisher",
# 类别
"category",
# 出版年份
"publish_year",
# 简介
"summary",
# 关键词(用分号分隔)
"keywords"
]
# 定义图书节点标签
NODE_LABEL_BOOK = "Book"
# 定义作者节点标签
NODE_LABEL_AUTHOR = "Author"
# 定义出版社节点标签
NODE_LABEL_PUBLISHER = "Publisher"
# 定义类别节点标签
NODE_LABEL_CATEGORY = "Category"
# 定义关键词节点标签
NODE_LABEL_KEYWORD = "Keyword"
# 定义“书籍-作者”关系类型
RELATIONSHIP_TYPE_WRITTEN_BY = "written_by"
# 定义“书籍-出版社”关系类型
RELATIONSHIP_TYPE_PUBLISHED_BY = "published_by"
# 定义“书籍-类别”关系类型
RELATIONSHIP_TYPE_HAS_CATEGORY = "has_category"
# 定义“书籍-关键词”关系类型
RELATIONSHIP_TYPE_HAS_KEYWORD = "has_keyword"
# 查询类型
# 图书查询类型
+QUERY_TYPE_BOOK = "图书"
# 作者查询类型
+QUERY_TYPE_AUTHOR = "作者"
# 默认Top K值
+DEFAULT_TOP_K = 3
# 向量索引名称
# 图书向量索引名称
+VECTOR_INDEX_BOOK = "book_embeddings"
# 作者向量索引名称
+VECTOR_INDEX_AUTHOR = "author_embeddings"7.4. 1_问答系统.py #
pages/1_问答系统.py
# Streamlit图书知识图谱问答系统主应用
"""Streamlit图书知识图谱问答系统主应用"""
# 导入Streamlit库
import streamlit as st
# 导入配置文件
from config import config
# 导入嵌入向量服务
+from services.embedding_service import embedding_service
# 导入检索服务
+from services.retrieval_service import retrieval_service
# 导入查询类型常量
+from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR
# 定义主函数
def main():
"""主函数"""
# 设置页面配置(标题、布局)
st.set_page_config(
page_title="问答系统 - 图书知识图谱",
layout="wide"
)
# 如果session_state中还没有'messages',则初始化为空列表
if "messages" not in st.session_state:
st.session_state.messages = []
# 侧边栏参数设置
with st.sidebar:
# 显示参数设置的标题
st.markdown("### 参数设置")
# 单选框选择查询类型(默认选中第一个,即图书)
+ query_type = st.radio("选择查询类型", [QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR], index=0)
# 滑块选择返回结果数量(Top K),范围1-10,默认3
+ top_k = st.slider("返回结果数量 (Top K)", 1, 10, 3, 1)
# 在主界面居中显示系统标题
st.markdown(
"<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
unsafe_allow_html=True,
)
# 按顺序显示历史消息
for message in st.session_state.messages:
st.chat_message(message["role"]).write(message["content"])
# 获取用户输入的问题(如果有输入)
if query := st.chat_input("输入图书相关问题", key="query_input"):
# 将用户输入添加到消息历史记录
st.session_state.messages.append({"role": "user", "content": query})
# 显示用户输入的消息
st.chat_message("user").write(query)
# 显示查询中loading动画
+ with st.spinner("正在查询中..."):
+ try:
# 获取问题的嵌入向量
+ query_embedding = embedding_service.get_embedding(query)
# 使用嵌入向量进行检索,获取结果
+ results = retrieval_service.query_by_embedding(query_embedding, query_type, top_k)
# 显示“查询结果”标题
+ st.markdown("#### 查询结果")
# 如果有返回结果,逐条显示
+ if results:
+ for idx, item in enumerate(results, 1):
# 用可展开板块显示每一条结果,默认展开第一个
+ with st.expander(f"结果 {idx}", expanded=True if idx == 1 else False):
# 按字典项逐个展示
+ for key, value in item.items():
+ st.write(f"**{key}**: {value}")
+ else:
# 若无结果,显示提示
+ st.info("未找到相关结果。")
+ except Exception as err:
# 捕获异常时的处理
+ error_msg = f"查询过程中出错: {str(err)}"
+ st.session_state.messages.append({
+ "role": "assistant",
+ "content": error_msg
+ })
+ st.chat_message("assistant").write(error_msg)
+ st.error(error_msg)
# 判断是否作为主程序运行
if __name__ == "__main__":
# 调用主函数启动应用
main()8. 提问 #
8.1. init.py #
llm/init.py
# 从当前包导入BaseLLM基类
from .base import BaseLLM
# 从当前包导入DeepSeekLLM类
from .deepseek import DeepSeekLLM
# 从当前包导入VolcengineLLM类
from .volcengine import VolcengineLLM
# 定义__all__变量,指定包导出的公有接口
__all__ = ["BaseLLM", "DeepSeekLLM", "VolcengineLLM"]8.2. base.py #
llm/base.py
# LLM抽象基类的文档字符串
"""LLM抽象基类"""
# 导入ABC和abstractmethod,用于创建抽象基类
from abc import ABC, abstractmethod
# 定义LLM的抽象基类,继承自ABC
class BaseLLM(ABC):
# 类的文档字符串,说明是LLM抽象基类
"""LLM抽象基类"""
# 抽象方法generate,子类必须实现
@abstractmethod
def generate(self, prompt: str, **kwargs) -> str:
# 生成回复的文档字符串
"""生成回复"""
# 抽象方法内使用pass,占位
pass
8.3. deepseek.py #
llm/deepseek.py
"""DeepSeek LLM实现"""
# 导入可选类型提示
from typing import Optional
# 导入 DeepSeek 的 ChatDeepSeek 类
from langchain_deepseek import ChatDeepSeek
# 导入自定义的基础LLM类
from .base import BaseLLM
# 导入全局配置对象
from config import config
# 导入默认温度和最大token常量
from constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS
# 定义 DeepSeekLLM 类,继承自 BaseLLM
class DeepSeekLLM(BaseLLM):
"""DeepSeek LLM"""
# 构造函数,支持自定义模型名、api_key、温度和最大输出tokens数量
def __init__(
self,
model_name: Optional[str] = None,
api_key: Optional[str] = None,
temperature: float = DEFAULT_TEMPERATURE,
max_tokens: int = DEFAULT_MAX_TOKENS,
):
# 设置模型名:优先用传入参数,否则用配置里的默认
self.model_name = model_name or config.deepseek.model
# 设置api_key:优先用传入参数,否则用配置里的默认
self.api_key = api_key or config.deepseek.api_key
# 设置温度:优先用传入参数,否则用默认
self.temperature = temperature or DEFAULT_TEMPERATURE
# 设置最大返回token数:优先用传入参数,否则用默认
self.max_tokens = max_tokens or DEFAULT_MAX_TOKENS
# 初始化 DeepSeek 的 LLM 实例
self.llm = ChatDeepSeek(
api_key=self.api_key,
model=self.model_name,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
# 生成方法,实现对LLM的调用
def generate(self, prompt: str) -> str:
"""生成回复"""
try:
# 调用 DeepSeek LLM 获取生成内容
response = self.llm.invoke(prompt)
# 检查是否有返回内容
if not response or not response.content:
# 返回内容为空时抛出异常
raise Exception("API返回空响应")
# 正常时返回生成内容
return response.content
except Exception as e:
# 捕获并重新抛出异常,加上友好的报错说明
raise Exception(f"DeepSeek API调用失败: {e}") from e
8.4. volcengine.py #
llm/volcengine.py
"""火山引擎LLM实现"""
# 导入Optional类型用于参数类型标注
from typing import Optional
# 导入OpenAI库用于与火山引擎API交互
from openai import OpenAI
# 导入自定义的BaseLLM基类
from .base import BaseLLM
# 导入全局配置对象
from config import config
# 导入默认温度和最大tokens常量
from constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS
# 定义VolcengineLLM类,继承自BaseLLM
class VolcengineLLM(BaseLLM):
"""火山引擎LLM"""
# 构造函数,初始化各项参数
def __init__(self, model_name: Optional[str] = None, api_key: Optional[str] = None,
temperature: float = DEFAULT_TEMPERATURE, max_tokens: int = DEFAULT_MAX_TOKENS):
# 获取API基础URL
self.base_url = config.volcengine.base_url
# 获取模型名称,优先使用传入值,否则取配置中的默认值
self.model_name = model_name or config.volcengine.model
# 获取API密钥,优先使用传入值,否则取配置中的默认值
self.api_key = api_key or config.volcengine.api_key
# 设置温度参数,控制生成内容的多样性
self.temperature = temperature or DEFAULT_TEMPERATURE
# 设置最大token数量
self.max_tokens = max_tokens or DEFAULT_MAX_TOKENS
# 初始化OpenAI客户端,用于与火山引擎API交互
self.client = OpenAI(
base_url=self.base_url,
api_key=self.api_key,
)
# 定义generate方法用于生成模型回复
def generate(self, prompt: str) -> str:
"""生成回复"""
try:
# 向火山引擎API发送请求,生成回复
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
# 判断响应内容是否有效
if not response.choices or not response.choices[0].message.content:
# 如果返回内容为空,抛出异常
raise Exception("API返回空响应")
# 返回生成的回复内容
return response.choices[0].message.content
except Exception as e:
# 捕获异常并抛出自定义异常信息
raise Exception(f"Volcengine API调用失败: {e}") from e
8.5. init.py #
utils/init.py
"""工具函数模块"""
8.6. formatters.py #
utils/formatters.py
"""格式化工具"""
# 导入类型提示List, Dict, Any
from typing import List, Dict, Any
# 定义用于格式化查询结果上下文的函数
def format_context(query_type: str, results: List[Dict[str, Any]]) -> str:
"""格式化查询结果上下文"""
# 定义内部函数:格式化一本图书的方法
def format_book(result: Dict[str, Any]) -> str:
# 准备字段及相应的值,元组形式依次为(字段名, 字段值)
fields = [
("作者", result.get("作者")),
("出版社", result.get("出版社")),
("类别", result.get("类别")),
("出版年份", result.get("出版年份")),
("简介", result.get("简介")),
# 如果关键词存在,拼接成字符串,否则为None
("关键词", ", ".join(result["关键词"]) if result.get("关键词") else None),
]
# 对所有有值的字段拼接为多行字符串,格式为“- 字段名: 字段值”
return "\n".join([f" - {k}: {v}" for k, v in fields if v])
# 定义内部函数:格式化作者(只显示相关图书)
def format_author(result: Dict[str, Any]) -> str:
# 如果结果中有“相关图书”字段,则拼接显示
if result.get("相关图书"):
return f" - 相关图书: {', '.join(result['相关图书'])}"
# 否则返回空串
return ""
# 用于存放所有格式化后的每条结果内容
lines = []
# 遍历全部结果数据
for idx, result in enumerate(results, 1):
# 构建当前项的标题,包括序号、名称及相似度
header = f"{idx}. {result['name']} (相似度: {result['similarity']:.4f})"
# 根据查询类型决定调用哪个格式化方法
details = format_book(result) if query_type == "图书" else format_author(result)
# 若有详情,拼接标题、详情及换行;否则仅标题加换行
info = f"{header}\n{details}\n" if details else f"{header}\n"
# 将本次内容添加到lines列表
lines.append(info)
# 用两个换行符分隔拼接所有结果,返回最终字符串
return "\n\n".join(lines)
8.7. .env #
.env
NEO4J_URI="bolt://localhost:7687"
NEO4J_USER="neo4j"
NEO4J_PASSWORD="12345678"
VOLC_EMBEDDINGS_API_URL=https://ark.cn-beijing.volces.com/api/v3/embeddings
VOLC_EMBEDDING_MODEL=doubao-embedding-text-240715
VOLC_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6
+VOLCENGINE_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
+VOLCENGINE_MODEL=doubao-seed-1-6-250615
+VOLCENGINE_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6
+DeepSeek_BASE_URL="https://api.deepseek.com/v1"
+DeepSeek_API_KEY="sk-278496d471bc4f4cb0ccb8c389a15018"
+DeepSeek_MODEL="deepseek-chat"8.8. config.py #
config.py
"""配置管理模块"""
# 导入os模块,用于获取环境变量
import os
# 导入dotenv模块,用于加载.env文件中的环境变量
import dotenv
# 从dataclasses模块导入dataclass,用于简化数据类定义
from dataclasses import dataclass
+from typing import Optional
# 加载.env文件中的所有环境变量
dotenv.load_dotenv()
# 定义Neo4j数据库配置的数据类
@dataclass
class Neo4jConfig:
"""Neo4j数据库配置"""
# 数据库URI
uri: str
# 用户名
user: str
# 密码
password: str
@dataclass
class EmbeddingConfig:
"""火山方舟嵌入API配置"""
api_url: str
api_key: str
model: str
+@dataclass
+class VolcengineConfig:
+ """火山引擎API配置"""
+ base_url: Optional[str]
+ api_key: Optional[str]
+ model: Optional[str]
+@dataclass
+class DeepSeekConfig:
+ """DeepSeek API配置"""
+ base_url: str
+ api_key: str
+ model: str
# 定义应用整体配置的数据类
@dataclass
class AppConfig:
"""应用配置"""
# Neo4j数据库配置属性
neo4j: Neo4jConfig
# 火山方舟嵌入API配置属性
embedding: EmbeddingConfig
# 火山引擎API配置属性
+ volcengine: VolcengineConfig
# DeepSeek API配置属性
+ deepseek: DeepSeekConfig
# 创建一个AppConfig对象,保存应用的整体配置
config = AppConfig(
# Neo4j数据库配置,使用Neo4jConfig类初始化
neo4j=Neo4jConfig(
# 从环境变量获取数据库URI,如果没有则使用默认值"bolt://localhost:7687"
uri=os.environ.get("NEO4J_URI", "bolt://localhost:7687"),
# 从环境变量获取数据库用户名,如果没有则使用默认值"neo4j"
user=os.environ.get("NEO4J_USER", "neo4j"),
# 从环境变量获取数据库密码,如果没有则使用默认值"12345678"
password=os.environ.get("NEO4J_PASSWORD", "12345678"),
),
# 嵌入API配置,使用EmbeddingConfig类初始化
embedding=EmbeddingConfig(
# 从环境变量获取嵌入API的URL,没有则使用默认值
api_url=os.environ.get(
"VOLC_EMBEDDINGS_API_URL",
"https://ark.cn-beijing.volces.com/api/v3/embeddings"
),
# 从环境变量获取API密钥,如果没有则为空字符串
api_key=os.environ.get("VOLC_API_KEY", ""),
# 从环境变量获取嵌入模型名,没有则使用默认值"doubao-embedding-text-240715"
model=os.environ.get("VOLC_EMBEDDING_MODEL", "doubao-embedding-text-240715"),
+ ),
# 初始化火山引擎API配置,从环境变量获取相关参数
+ volcengine=VolcengineConfig(
# 从环境变量获取火山引擎API的基础URL
+ base_url=os.environ.get("VOLCENGINE_BASE_URL"),
# 从环境变量获取火山引擎API的密钥
+ api_key=os.environ.get("VOLCENGINE_API_KEY"),
# 从环境变量获取火山引擎API的模型名称
+ model=os.environ.get("VOLCENGINE_MODEL"),
+ ),
# 初始化DeepSeek API配置,支持从环境变量读取参数,并提供默认值
+ deepseek=DeepSeekConfig(
# 从环境变量获取DeepSeek基础URL,若未设置则使用默认值"https://api.deepseek.com"
+ base_url=os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com"),
# 从环境变量获取DeepSeek API密钥,默认值为空字符串
+ api_key=os.environ.get("DEEPSEEK_API_KEY", ""),
# 从环境变量获取DeepSeek模型名,若未设置则使用默认值"deepseek-chat"
+ model=os.environ.get("DEEPSEEK_MODEL", "deepseek-chat"),
+ ),
)8.9. constants.py #
constants.py
# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
# 书名
"name",
# 作者
"author",
# 出版社
"publisher",
# 类别
"category",
# 出版年份
"publish_year",
# 简介
"summary",
# 关键词(用分号分隔)
"keywords"
]
# 定义图书节点标签
NODE_LABEL_BOOK = "Book"
# 定义作者节点标签
NODE_LABEL_AUTHOR = "Author"
# 定义出版社节点标签
NODE_LABEL_PUBLISHER = "Publisher"
# 定义类别节点标签
NODE_LABEL_CATEGORY = "Category"
# 定义关键词节点标签
NODE_LABEL_KEYWORD = "Keyword"
# 定义“书籍-作者”关系类型
RELATIONSHIP_TYPE_WRITTEN_BY = "written_by"
# 定义“书籍-出版社”关系类型
RELATIONSHIP_TYPE_PUBLISHED_BY = "published_by"
# 定义“书籍-类别”关系类型
RELATIONSHIP_TYPE_HAS_CATEGORY = "has_category"
# 定义“书籍-关键词”关系类型
RELATIONSHIP_TYPE_HAS_KEYWORD = "has_keyword"
# 查询类型
# 图书查询类型
QUERY_TYPE_BOOK = "图书"
# 作者查询类型
QUERY_TYPE_AUTHOR = "作者"
# 默认Top K值
DEFAULT_TOP_K = 3
# 默认最大Tokens值
+DEFAULT_MAX_TOKENS = 4096
# 默认温度值
+DEFAULT_TEMPERATURE = 0.7
# 向量索引名称
# 图书向量索引名称
VECTOR_INDEX_BOOK = "book_embeddings"
# 作者向量索引名称
VECTOR_INDEX_AUTHOR = "author_embeddings"8.10. 1_问答系统.py #
pages/1_问答系统.py
# Streamlit图书知识图谱问答系统主应用
"""Streamlit图书知识图谱问答系统主应用"""
# 导入Streamlit库
import streamlit as st
# 导入PromptTemplate
+from langchain_core.prompts import PromptTemplate
# 导入配置文件
from config import config
# 导入嵌入向量服务
from services.embedding_service import embedding_service
# 导入检索服务
from services.retrieval_service import retrieval_service
# 导入查询类型常量
from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR
# 导入大模型服务
+from llm import VolcengineLLM, DeepSeekLLM
+from utils.formatters import format_context
+prompt_template = PromptTemplate(
+ input_variables=["question", "context"],
+ template="""你是一名图书知识助手,需要根据提供的图书信息回答用户的提问。
+ 请直接回答问题,如果信息不足,请回答"根据现有信息无法确定"。
+ 问题:{question}
+ 图书信息:\n{context}
+ 回答:""",
+)
# 定义主函数
def main():
"""主函数"""
# 设置页面配置(标题、布局)
st.set_page_config(
page_title="问答系统 - 图书知识图谱",
layout="wide"
)
# 如果session_state中还没有'messages',则初始化为空列表
if "messages" not in st.session_state:
st.session_state.messages = []
# 侧边栏参数设置
with st.sidebar:
# 显示参数设置的标题
st.markdown("### 参数设置")
# 单选框选择查询类型(默认选中第一个,即图书)
query_type = st.radio("选择查询类型", [QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR], index=0)
# 滑块选择返回结果数量(Top K),范围1-10,默认3
top_k = st.slider("返回结果数量 (Top K)", 1, 10, 3, 1)
# 滑块选择温度值,范围0.0-1.0,默认0.3
+ temperature = st.slider("温度 (Temperature)", 0.0, 1.0, 0.3, 0.1)
# 选择大模型服务商,默认火山引擎
+ llm_provider = st.selectbox("选择大模型服务商", ["volcengine", "deepseek"], index=0)
# 如果选择火山引擎,则显示火山引擎API Key和模型名
+ if llm_provider == "volcengine":
+ api_key = st.text_input(
+ "火山引擎 API Key",
+ value=config.volcengine.api_key or "",
+ type="password",
+ help="如留空则使用服务器默认配置",
+ )
+ model_name = st.text_input(
+ "火山引擎模型名",
+ value=config.volcengine.model or "",
+ help="如留空则使用服务器默认配置"
+ )
+ else:
+ api_key = st.text_input(
+ "DeepSeek API Key",
+ value=config.deepseek.api_key or "",
+ type="password",
+ help="如留空则使用服务器默认配置",
+ )
+ model_name = st.text_input(
+ "DeepSeek模型名",
+ value=config.deepseek.model or "",
+ help="如留空则使用服务器默认配置"
+ )
# 在主界面居中显示系统标题
st.markdown(
"<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
unsafe_allow_html=True,
)
# 按顺序显示历史消息
for message in st.session_state.messages:
st.chat_message(message["role"]).write(message["content"])
+ if message["role"] == "assistant":
+ # 展开可查看详细检索结果
+ with st.expander("查看详细结果"):
+ # 以 json 格式显示检索的类型和结果
+ st.json({"type": message['query_type'], "results": message['results']})
# 获取用户输入的问题(如果有输入)
if query := st.chat_input("输入图书相关问题", key="query_input"):
# 将用户输入添加到消息历史记录
st.session_state.messages.append({"role": "user", "content": query})
# 显示用户输入的消息
st.chat_message("user").write(query)
# 显示查询中的 loading 动画
with st.spinner("正在查询中..."):
try:
# 获取用户输入 query 的嵌入向量
query_embedding = embedding_service.get_embedding(query)
# 用嵌入向量进行检索,查询对应的结果
results = retrieval_service.query_by_embedding(query_embedding, query_type, top_k)
# 如果没有检索到结果
+ if not results or len(results) == 0:
# 设置回复内容为没有找到相关信息
+ answer = "抱歉,没有找到相关的信息。"
else:
# 如果选择的服务商是火山引擎
+ if llm_provider == "volcengine":
# 实例化火山引擎大模型
+ llm = VolcengineLLM(
+ model_name=model_name if model_name else None,
+ api_key=api_key if api_key else None,
+ )
+ else:
# 否则实例化 DeepSeek 大模型
+ llm = DeepSeekLLM(
+ model_name=model_name if model_name else None,
+ api_key=api_key if api_key else None,
+ )
# 根据检索结果格式化上下文字符串
+ context_str = format_context(query_type, results)
# 打印上下文字符串,便于调试
+ print("context_str: ", context_str)
# 用模板格式化最终 prompt
+ final_prompt = prompt_template.format(
+ question=query,
+ context=context_str
+ )
# 打印最终 prompt,便于调试
+ print("final_prompt: ", final_prompt)
# 调用大模型生成答案
+ answer = llm.generate(final_prompt)
# 展开可查看详细检索结果
+ with st.expander("查看详细结果"):
# 以 json 格式显示检索的类型和结果
+ st.json({"type": query_type, "results": results})
# 把大模型回复添加到会话历史
+ st.session_state.messages.append({
+ "role": "assistant",
+ "content": answer,
+ "query_type": query_type,
+ "results": results # 检索结果
+ })
# 用 chat_message 组件显示大模型回复
+ st.chat_message("assistant").write(answer)
# 重新运行页面,用于强制刷新
+ st.rerun()
except Exception as err:
# 打印异常信息,便于调试
+ print(err)
# 捕获异常时组织错误提示内容
error_msg = f"查询过程中出错: {str(err)}"
# 将错误信息添加到会话历史
st.session_state.messages.append({
"role": "assistant",
"content": error_msg
+ "query_type": query_type, # 查询类型
+ "results": results # 检索结果
})
# 显示错误信息在对话界面
st.chat_message("assistant").write(error_msg)
# 在页面上弹出错误提示
st.error(error_msg)
# 判断是否作为主程序运行
if __name__ == "__main__":
# 调用主函数启动应用
main()9. 历史记录 #
9.1. 1_问答系统.py #
pages/1_问答系统.py
# Streamlit图书知识图谱问答系统主应用
"""Streamlit图书知识图谱问答系统主应用"""
# 导入Streamlit库
import streamlit as st
# 导入PromptTemplate
from langchain_core.prompts import PromptTemplate
# 导入配置文件
from config import config
# 导入嵌入向量服务
from services.embedding_service import embedding_service
# 导入检索服务
from services.retrieval_service import retrieval_service
# 导入查询类型常量
from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR
# 导入大模型服务
from llm import VolcengineLLM, DeepSeekLLM
from utils.formatters import format_context
prompt_template = PromptTemplate(
input_variables=["question", "context"],
template="""你是一名图书知识助手,需要根据提供的图书信息回答用户的提问。
请直接回答问题,如果信息不足,请回答"根据现有信息无法确定"。
问题:{question}
图书信息:\n{context}
回答:""",
)
# 定义主函数
def main():
"""主函数"""
# 设置页面配置(标题、布局)
st.set_page_config(
page_title="问答系统 - 图书知识图谱",
layout="wide"
)
# 如果session_state中还没有'messages',则初始化为空列表
if "messages" not in st.session_state:
st.session_state.messages = []
# 如果session_state中还没有'history',则初始化为空列表
+ if "history" not in st.session_state:
+ st.session_state.history = []
# 侧边栏参数设置
with st.sidebar:
# 显示参数设置的标题
st.markdown("### 参数设置")
# 单选框选择查询类型(默认选中第一个,即图书)
query_type = st.radio("选择查询类型", [QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR], index=0)
# 滑块选择返回结果数量(Top K),范围1-10,默认3
top_k = st.slider("返回结果数量 (Top K)", 1, 10, 3, 1)
# 滑块选择温度值,范围0.0-1.0,默认0.3
temperature = st.slider("温度 (Temperature)", 0.0, 1.0, 0.3, 0.1)
# 选择大模型服务商,默认火山引擎
llm_provider = st.selectbox("选择大模型服务商", ["volcengine", "deepseek"], index=0)
# 如果选择火山引擎,则显示火山引擎API Key和模型名
if llm_provider == "volcengine":
api_key = st.text_input(
"火山引擎 API Key",
value=config.volcengine.api_key or "",
type="password",
help="如留空则使用服务器默认配置",
)
model_name = st.text_input(
"火山引擎模型名",
value=config.volcengine.model or "",
help="如留空则使用服务器默认配置"
)
else:
api_key = st.text_input(
"DeepSeek API Key",
value=config.deepseek.api_key or "",
type="password",
help="如留空则使用服务器默认配置",
)
model_name = st.text_input(
"DeepSeek模型名",
value=config.deepseek.model or "",
help="如留空则使用服务器默认配置"
)
+ st.markdown("### 历史查询")
+ if st.session_state.history:
+ for i, item in enumerate(st.session_state.history):
+ with st.expander(f"查询 {i+1}: {item['question']}"):
+ st.json(item)
+ else:
+ st.info("暂无历史查询记录")
# 在主界面居中显示系统标题
st.markdown(
"<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
unsafe_allow_html=True,
)
# 按顺序显示历史消息
for message in st.session_state.messages:
st.chat_message(message["role"]).write(message["content"])
# 获取用户输入的问题(如果有输入)
if query := st.chat_input("输入图书相关问题", key="query_input"):
# 将用户输入添加到消息历史记录
st.session_state.messages.append({"role": "user", "content": query})
# 显示用户输入的消息
st.chat_message("user").write(query)
# 显示查询中的 loading 动画
with st.spinner("正在查询中..."):
try:
# 获取用户输入 query 的嵌入向量
query_embedding = embedding_service.get_embedding(query)
# 用嵌入向量进行检索,查询对应的结果
results = retrieval_service.query_by_embedding(query_embedding, query_type, top_k)
# 如果没有检索到结果
if not results or len(results) == 0:
# 设置回复内容为没有找到相关信息
answer = "抱歉,没有找到相关的信息。"
else:
# 如果选择的服务商是火山引擎
if llm_provider == "volcengine":
# 实例化火山引擎大模型
llm = VolcengineLLM(
model_name=model_name if model_name else None,
api_key=api_key if api_key else None,
)
else:
# 否则实例化 DeepSeek 大模型
llm = DeepSeekLLM(
model_name=model_name if model_name else None,
api_key=api_key if api_key else None,
)
# 根据检索结果格式化上下文字符串
context_str = format_context(query_type, results)
# 打印上下文字符串,便于调试
print("context_str: ", context_str)
# 用模板格式化最终 prompt
final_prompt = prompt_template.format(
question=query,
context=context_str
)
# 打印最终 prompt,便于调试
print("final_prompt: ", final_prompt)
# 调用大模型生成答案
answer = llm.generate(final_prompt)
# 将本次对话历史添加到 session_state.history,也包含了用户问题、检索类型、上下文、模型回复及温度参数
+ st.session_state.history.append({
+ "question": query, # 用户输入的问题
+ "query_type": query_type, # 查询类型(如“图书”/“作者”)
+ "context": context_str, # 格式化后的上下文字符串
+ "answer": answer, # 大模型生成的答案
+ "temperature": temperature, # 当前使用的温度参数
+ })
# 展开可查看详细检索结果
with st.expander("查看详细结果"):
# 以 json 格式显示检索的类型和结果
st.json({"type": query_type, "results": results})
# 把大模型回复添加到会话历史
st.session_state.messages.append({
"role": "assistant",
"content": answer,
"query_type": query_type,
"results": results # 检索结果
})
# 用 chat_message 组件显示大模型回复
st.chat_message("assistant").write(answer)
# 重新运行页面,用于强制刷新
st.rerun()
except Exception as err:
# 打印异常信息,便于调试
print(err)
# 捕获异常时组织错误提示内容
error_msg = f"查询过程中出错: {str(err)}"
# 将错误信息添加到会话历史
st.session_state.messages.append({
"role": "assistant",
"content": error_msg
})
# 显示错误信息在对话界面
st.chat_message("assistant").write(error_msg)
# 在页面上弹出错误提示
st.error(error_msg)
# 判断是否作为主程序运行
if __name__ == "__main__":
# 调用主函数启动应用
main()