1. 项目准备 #
1.1 项目介绍 #
基于Neo4j图数据库和向量检索的智能图书问答系统,支持RAG(检索增强生成)架构。
1.2 项目架构 #
graphrag/ # 项目根目录
├── database/ # 数据库相关模块
│ ├── __init__.py # 数据库包初始化
│ ├── connection.py # Neo4j连接相关代码
│ └── queries.py # 数据库查询实现
├── llm/ # 大模型相关模块
│ ├── __init__.py # LLM包初始化
│ ├── base.py # LLM基类定义
│ ├── deepseek.py # DeepSeek模型相关代码
│ └── volcengine.py # 火山引擎(豆包)模型相关代码
├── pages/ # Streamlit页面目录
│ ├── 0_数据导入.py # 数据导入页面
│ └── 1_问答系统.py # 问答系统页面
├── services/ # 服务层
│ ├── embedding_service.py # 嵌入服务
│ ├── import_service.py # 导入服务
│ └── retrieval_service.py # 检索服务
├── utils/ # 工具类目录
│ ├── __init__.py # 工具包初始化
│ └── formatters.py # 格式化相关工具函数
├── config.py # 配置文件
├── constants.py # 常量定义
├── main.py # 项目入口,运行主程序
├── pyproject.toml # Python项目依赖与配置
└── 首页.py # Streamlit首页入口1.3 核心功能 #
1.3.1. 向量检索 #
- 使用火山方舟嵌入API生成文本向量
- 基于Neo4j向量索引进行相似度检索
- 支持图书和作者两种查询类型
1.3.2. RAG问答 #
- 向量检索获取相关上下文
- 使用LLM(豆包/DeepSeek)生成答案
- 支持温度参数调整
1.3.3. 多LLM支持 #
- 豆包(Doubao)LLM
- DeepSeek LLM
- 统一的抽象接口,易于扩展
1.3.4. 数据导入管理 #
- CSV文件上传和解析
- 自动创建节点和关系
- 批量生成嵌入向量
- 实时进度显示和错误处理
1.4 技术栈 #
- 前端框架: Streamlit
- 图数据库: Neo4j (使用py2neo)
- 向量嵌入: 火山方舟 (doubao-embedding-text-240715)
- LLM框架: LangChain
- LLM服务: 豆包 / DeepSeek
1.5 启动Neo4j #
- 下载地址:https://neo4j.com/download/
- 或使用 Docker:
docker run -d -p 7474:7474 -p 7687:7687 -e NEO4J_AUTH=neo4j/12345678 neo4j:latest
1.6 创建项目目录 #
mkdir graphrag
cd graphrag
uv init
uv add python-dotenv langchain langchain-deepseek pandas py2neo streamlit requests openai 1.7 首页 #
1.7.1 main.py #
# 应用入口说明
"""应用入口"""
# 导入 subprocess 模块,用于运行子进程
import subprocess
# 导入 sys 模块,用于访问解释器相关变量
import sys
# 判断当前脚本是否为主程序执行
if __name__ == "__main__":
# 使用 subprocess 运行 streamlit 应用,指定脚本为“首页.py”
subprocess.run([sys.executable, "-m", "streamlit", "run", "首页.py"])1.7.2 首页.py #
"""Streamlit应用主入口"""
import streamlit as st
st.set_page_config(
page_title="图书知识图谱系统",
layout="wide",
initial_sidebar_state="expanded"
)
st.title("图书知识图谱系统")
st.markdown("---")
st.markdown("""
欢迎使用图书知识图谱系统!
### 功能模块
1. **数据导入** - 导入CSV数据并生成向量
2. **问答系统** - 基于向量检索的智能问答
""")1.8 建立索引 #
# 清空整个Neo4j数据库,删除所有节点和关系
MATCH (n) DETACH DELETE n
# 显示当前数据库中的所有索引信息
SHOW INDEXES
# 删除名为“book_embeddings”的向量索引
DROP INDEX book_embeddings
# 删除名为“author_embeddings”的向量索引
DROP INDEX author_embeddings
# 为Book节点的embedding属性创建名为“book_embeddings”的向量索引,若不存在则创建
CREATE VECTOR INDEX book_embeddings IF NOT EXISTS
FOR (m:Book) ON m.embedding
OPTIONS { indexConfig: {
`vector.dimensions`: 2560, # 设置向量维数为2560
`vector.similarity_function`: 'cosine' # 使用余弦相似度作为索引匹配方式
}}
# 为Author节点的embedding属性创建名为“author_embeddings”的向量索引,若不存在则创建
CREATE VECTOR INDEX author_embeddings IF NOT EXISTS
FOR (m:Author) ON m.embedding
OPTIONS { indexConfig: {
`vector.dimensions`: 2560, # 设置向量维数为2560
`vector.similarity_function`: 'cosine' # 使用余弦相似度作为索引匹配方式
}}2. 数据导入 #
2.1. 0_数据导入.py #
pages/0_数据导入.py
# 数据导入页面模块注释
"""数据导入页面"""
# 导入streamlit库并简写为st
import streamlit as st
# 定义主函数
def main():
# 数据导入页面主函数的文档字符串
"""数据导入页面主函数"""
# 配置Streamlit页面,设置标题和布局
st.set_page_config(
page_title="数据导入 - 图书知识图谱",
layout="wide"
)
# 显示页面主标题
st.title("数据导入管理")
# 插入分割线
st.markdown("---")
# 判断是否为主程序入口
if __name__ == "__main__":
# 调用主函数
main() 3. 读取csv文件 #
3.1. constants.py #
constants.py
# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
# 书名
"name",
# 作者
"author",
# 出版社
"publisher",
# 类别
"category",
# 出版年份
"publish_year",
# 简介
"summary",
# 关键词(用分号分隔)
"keywords"
]3.2. import_service.py #
services/import_service.py
"""数据导入服务"""
# 导入io模块用于处理内存中的流数据
import io
# 导入pandas库用于数据分析处理
import pandas as pd
# 导入类型提示
from typing import List, Dict, Any, Optional, Callable
# 从constants模块导入必需的CSV列名常量
from constants import (REQUIRED_CSV_COLUMNS)
# 定义数据导入服务类(单例模式)
class ImportService:
"""数据导入服务(单例模式)"""
# 类变量:存储单例实例
_instance = None
# 重写__new__方法实现单例模式
def __new__(cls):
"""创建单例实例"""
# 如果实例不存在,则创建新实例
if cls._instance is None:
cls._instance = super().__new__(cls)
# 返回单例实例
return cls._instance
# 解析CSV内容为DataFrame的方法
def parse_csv(self, csv_content: str) -> pd.DataFrame:
"""解析CSV内容"""
# 使用pandas读取内存中的CSV内容并返回DataFrame
return pd.read_csv(io.StringIO(csv_content))
# 验证DataFrame格式的方法
def validate_csv(self, df: pd.DataFrame) -> tuple[bool, Optional[str]]:
"""验证CSV格式"""
# 如果DataFrame为空,返回False和错误信息
if df.empty:
return False, "CSV文件为空"
# 检查是否有缺失的必需字段
missing = [col for col in REQUIRED_CSV_COLUMNS if col not in df.columns]
# 如果没有缺失则返回True,否则返回缺失列信息
return (True, None) if not missing else (False, f"缺少列: {', '.join(missing)}")
# 获取导入服务实例的便捷方法
def get_import_service() -> ImportService:
"""获取导入服务单例实例"""
# 返回单例实例(无论调用多少次都返回同一个实例)
return ImportService()3.3. 0_数据导入.py #
pages/0_数据导入.py
# 数据导入页面模块注释
"""数据导入页面"""
# 导入streamlit库并简写为st
import streamlit as st
+from services.import_service import get_import_service
# 定义主函数
def main():
# 数据导入页面主函数的文档字符串
"""数据导入页面主函数"""
# 配置Streamlit页面,设置标题和布局
st.set_page_config(
page_title="数据导入 - 图书知识图谱",
layout="wide"
)
# 显示页面主标题
st.title("数据导入管理")
# 插入分割线
st.markdown("---")
# 获取服务实例
+ import_service = get_import_service()
# 创建一个标签页,只包含“CSV数据导入”标签,并将其赋值给tab1变量
+ tab1, = st.tabs(["CSV数据导入"])
# 在tab1标签页下进行后续操作
+ with tab1:
# 显示“CSV数据导入”二级标题
+ st.header("CSV数据导入")
# 展示有关上传CSV文件的说明及字段要求
+ st.markdown("""
+ 请上传包含图书信息的CSV文件。CSV文件应包含以下列:
+ - `name`: 书名
+ - `author`: 作者
+ - `publisher`: 出版社
+ - `category`: 类别
+ - `publish_year`: 出版年份
+ - `summary`: 简介
+ - `keywords`: 关键词(用分号分隔)
+ """)
# 显示文件上传控件,限定仅可上传csv类型文件,并显示帮助提示
+ uploaded_file = st.file_uploader(
+ "选择CSV文件",
+ type=["csv"],
+ help="上传包含图书信息的CSV文件"
+ )
# 如果用户已经上传了文件
+ if uploaded_file is not None:
+ try:
# 读取上传的文件内容,并解码为utf-8格式的字符串
+ csv_content = uploaded_file.read().decode("utf-8")
# 调用服务将CSV内容解析为DataFrame
+ df = import_service.parse_csv(csv_content)
# 验证上传的DataFrame是否符合CSV格式要求
+ is_valid, error_msg = import_service.validate_csv(df)
# 如果格式校验未通过,显示错误
+ if not is_valid:
+ st.error(f"CSV格式错误: {error_msg}")
# 如果校验通过,显示成功消息,并告知记录数量
+ else:
+ st.success(f"CSV文件验证通过!共 {len(df)} 条记录")
# 捕获任何异常,显示读取文件失败的错误信息
+ except Exception as e:
+ st.error(f"读取文件失败: {str(e)}")
# 判断是否为主程序入口
if __name__ == "__main__":
# 调用主函数
main() 4. 数据导入 #
4.1. .env #
.env
NEO4J_URI="bolt://localhost:7687"
NEO4J_USER="neo4j"
NEO4J_PASSWORD="12345678"
VOLC_EMBEDDINGS_API_URL=https://ark.cn-beijing.volces.com/api/v3/embeddings
VOLC_EMBEDDING_MODEL=doubao-embedding-text-240715
VOLC_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6
DeepSeek_BASE_URL="https://api.deepseek.com/v1"
DeepSeek_API_KEY="sk-ae59c1b6731e4d5d8f1f3fd0ad340b39"
DeepSeek_MODEL="deepseek-chat"
VOLCENGINE_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
VOLCENGINE_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6
VOLCENGINE_MODEL=doubao-seed-1-6-2506154.2. config.py #
config.py
"""配置管理模块"""
# 导入os模块,用于获取环境变量
import os
# 导入dotenv模块,用于加载.env文件中的环境变量
import dotenv
# 从dataclasses模块导入dataclass,用于简化数据类定义
from dataclasses import dataclass
# 加载.env文件中的所有环境变量
dotenv.load_dotenv()
# 定义Neo4j数据库配置的数据类
@dataclass
class Neo4jConfig:
"""Neo4j数据库配置"""
# 数据库URI
uri: str
# 用户名
user: str
# 密码
password: str
# 定义应用整体配置的数据类
@dataclass
class AppConfig:
"""应用配置"""
# Neo4j数据库配置属性
neo4j: Neo4jConfig
config=AppConfig(
neo4j=Neo4jConfig(
uri=os.environ.get("NEO4J_URI", "bolt://localhost:7687"),
user=os.environ.get("NEO4J_USER", "neo4j"),
password=os.environ.get("NEO4J_PASSWORD", "12345678"),
)
)4.3. init.py #
database/init.py
"""数据库模块"""
4.4. connection.py #
database/connection.py
# 这是Neo4j数据库连接管理的模块说明
"""Neo4j数据库连接管理"""
# 从py2neo库导入Graph类,用于连接和操作Neo4j数据库
from py2neo import Graph
# 从config模块导入config对象,用于获取配置信息
from config import config
# 使用配置信息创建Neo4j数据库连接实例
graph = Graph(config.neo4j.uri, auth=(config.neo4j.user, config.neo4j.password))
4.5. constants.py #
constants.py
# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
# 书名
"name",
# 作者
"author",
# 出版社
"publisher",
# 类别
"category",
# 出版年份
"publish_year",
# 简介
"summary",
# 关键词(用分号分隔)
"keywords"
]
# 定义图书节点标签
+NODE_LABEL_BOOK = "Book"
# 定义作者节点标签
+NODE_LABEL_AUTHOR = "Author"
# 定义出版社节点标签
+NODE_LABEL_PUBLISHER = "Publisher"
# 定义类别节点标签
+NODE_LABEL_CATEGORY = "Category"
# 定义关键词节点标签
+NODE_LABEL_KEYWORD = "Keyword"
# 定义“书籍-作者”关系类型
+RELATIONSHIP_TYPE_WRITTEN_BY = "written_by"
# 定义“书籍-出版社”关系类型
+RELATIONSHIP_TYPE_PUBLISHED_BY = "published_by"
# 定义“书籍-类别”关系类型
+RELATIONSHIP_TYPE_HAS_CATEGORY = "has_category"
# 定义“书籍-关键词”关系类型
+RELATIONSHIP_TYPE_HAS_KEYWORD = "has_keyword"4.6. 0_数据导入.py #
pages/0_数据导入.py
# 数据导入页面模块注释
"""数据导入页面"""
# 导入streamlit库并简写为st
import streamlit as st
+from services.import_service import import_service
# 定义主函数
def main():
# 数据导入页面主函数的文档字符串
"""数据导入页面主函数"""
# 配置Streamlit页面,设置标题和布局
st.set_page_config(
page_title="数据导入 - 图书知识图谱",
layout="wide"
)
# 显示页面主标题
st.title("数据导入管理")
# 插入分割线
st.markdown("---")
# 创建一个标签页,只包含“CSV数据导入”标签,并将其赋值给tab1变量
tab1, = st.tabs(["CSV数据导入"])
# 在tab1标签页下进行后续操作
with tab1:
# 显示“CSV数据导入”二级标题
st.header("CSV数据导入")
# 展示有关上传CSV文件的说明及字段要求
st.markdown("""
请上传包含图书信息的CSV文件。CSV文件应包含以下列:
- `name`: 书名
- `author`: 作者
- `publisher`: 出版社
- `category`: 类别
- `publish_year`: 出版年份
- `summary`: 简介
- `keywords`: 关键词(用分号分隔)
""")
# 显示文件上传控件,限定仅可上传csv类型文件,并显示帮助提示
uploaded_file = st.file_uploader(
"选择CSV文件",
type=["csv"],
help="上传包含图书信息的CSV文件"
)
# 如果用户已经上传了文件
if uploaded_file is not None:
try:
# 读取上传的文件内容,并解码为utf-8格式的字符串
csv_content = uploaded_file.read().decode("utf-8")
# 调用服务将CSV内容解析为DataFrame
df = import_service.parse_csv(csv_content)
# 验证上传的DataFrame是否符合CSV格式要求
is_valid, error_msg = import_service.validate_csv(df)
# 如果格式校验未通过,显示错误
if not is_valid:
st.error(f"CSV格式错误: {error_msg}")
# 如果校验通过,显示成功消息,并告知记录数量
else:
st.success(f"CSV文件验证通过!共 {len(df)} 条记录")
# 显示预览
+ with st.expander("数据预览", expanded=True):
+ st.dataframe(df)
# 导入选项
+ clear_existing = st.checkbox(
+ "清空现有数据",
+ value=False,
+ help="导入前清空数据库中的所有数据"
+ )
# 当用户点击"开始导入"按钮时执行
+ if st.button("开始导入", type="primary", width='stretch'):
# 创建一个显示日志信息的空容器
+ log_container = st.empty()
# 用于存放日志消息的列表
+ log_messages = []
# 定义进度回调函数
+ def progress_callback(message, current, total):
# 添加当前消息到日志列表(包含进度信息)
+ log_messages.append(f"- [{current}/{total}] {message}")
# 刷新日志容器内容
+ log_container.text("\n".join(log_messages))
+ try:
# 调用导入服务执行数据导入
+ stats = import_service.import_books(
+ df,#数据源
+ clear_existing=clear_existing,# 是否清空现有数据
+ progress_callback=progress_callback# 进度回调函数
+ )
# 显示最终统计信息
+ if stats:
+ st.info(
+ f"完成: 图书 {stats.get('books', 0)} | "
+ f"作者 {stats.get('authors', 0)} | "
+ f"出版社 {stats.get('publishers', 0)} | "
+ f"类别 {stats.get('categories', 0)} | "
+ f"关键词 {stats.get('keywords', 0)} | "
+ f"关系 {stats.get('relationships', 0)}"
+ )
# 如果有错误信息,显示出来
+ if stats.get("errors") and len(stats["errors"]) > 0:
+ with st.expander("⚠️ 导入过程中的错误", expanded=False):
+ for error in stats["errors"]:
+ st.error(error)
+ except Exception as e:
# 导入过程中如有异常则显示错误信息
+ st.error(f"导入失败: {str(e)}")
# 导入traceback模块以显示详细错误
+ import traceback
# 展开详细错误信息
+ with st.expander("详细错误信息"):
# 格式化并显示完整的异常调用栈
+ st.code(traceback.format_exc())
# 捕获任何异常,显示读取文件失败的错误信息
except Exception as e:
st.error(f"读取文件失败: {str(e)}")
# 判断是否为主程序入口
if __name__ == "__main__":
# 调用主函数
main() 4.8. import_service.py #
services/import_service.py
# 数据导入服务说明
"""数据导入服务"""
# 导入io模块,用于内存数据流操作
import io
# 导入pandas进行数据处理
import pandas as pd
# 从typing模块导入类型提示
from typing import List, Dict, Any, Optional, Callable
# 从database.connection导入数据库实例graph
+from database.connection import graph
# 从constants模块导入各类常量
+from constants import (
+ REQUIRED_CSV_COLUMNS, # 必需的CSV列名
+ NODE_LABEL_BOOK, # 图书节点标签
+ NODE_LABEL_AUTHOR, # 作者节点标签
+ NODE_LABEL_PUBLISHER, # 出版社节点标签
+ NODE_LABEL_CATEGORY, # 类别节点标签
+ NODE_LABEL_KEYWORD, # 关键词节点标签
+ RELATIONSHIP_TYPE_WRITTEN_BY, # 书-作者关系类型
+ RELATIONSHIP_TYPE_PUBLISHED_BY, # 书-出版社关系类型
+ RELATIONSHIP_TYPE_HAS_CATEGORY, # 书-类别关系类型
+ RELATIONSHIP_TYPE_HAS_KEYWORD # 书-关键词关系类型
+)
# 导入py2neo中的Node、Relationship、NodeMatcher
+from py2neo import Node, Relationship, NodeMatcher
# 定义数据导入服务类,采用单例模式
class ImportService:
# 类说明:数据导入服务(单例)
"""数据导入服务(单例模式)"""
# 初始化方法
+ def __init__(self):
# 创建NodeMatcher对象用于节点查找
+ self.node_matcher = NodeMatcher(graph)
# 解析CSV内容的函数
def parse_csv(self, csv_content: str) -> pd.DataFrame:
# 方法说明:解析CSV内容为DataFrame
"""解析CSV内容"""
# 利用pandas读取内存中的CSV字符串,返回DataFrame
return pd.read_csv(io.StringIO(csv_content))
# 验证CSV格式合法性的方法
def validate_csv(self, df: pd.DataFrame) -> tuple[bool, Optional[str]]:
# 方法说明:验证CSV格式
"""验证CSV格式"""
# 判断DataFrame是否为空,为空则返回错误消息
if df.empty:
return False, "CSV文件为空"
# 检查必需CSV列是否全部存在
missing = [col for col in REQUIRED_CSV_COLUMNS if col not in df.columns]
# 若无漏列返回True,否则返回缺失列名
return (True, None) if not missing else (False, f"缺少列: {', '.join(missing)}")
# 清空全部数据库的方法
+ def clear_database(self) -> None:
# 方法说明:清空数据库,删除所有节点和关系
+ """清空数据库"""
# 执行Cypher语句,删除所有节点及其关系
+ graph.run("MATCH (n) DETACH DELETE n")
# 内部方法:创建节点
+ def _create_nodes(
+ self, label, node_list, stats, stat_key,
+ progress_callback, current, total
+ ) -> int:
# 方法说明:创建指定类型的节点并更新进度
+ """创建节点
+ Args:
+ label: 节点标签
+ node_list: 节点列表,每项为字典
+ stats: 统计计数与错误信息字典
+ stat_key: 统计键名
+ progress_callback: 进度回调函数
+ current: 当前进度值
+ total: 总进度步数
+ """
# 有进度回调时,通知正在创建节点
+ if progress_callback:
+ progress_callback(f"正在创建{label}节点...", current, total)
# 初始化进度步数
+ step = current
# 遍历所有节点数据逐一创建
+ for node in node_list:
+ try:
# 取节点名称用于反馈与出错描述,默认“未知”
+ node_name = node.get("name", "未知")
# 构建节点对象
+ node = Node(label, **node)
# 写入数据库
+ graph.create(node)
# 节点计数器+1
+ stats[stat_key] += 1
# 进度步数+1
+ step += 1
# 回调进度反馈
+ if progress_callback:
+ progress_callback(f"创建{label}: {node_name}", step, total)
+ except Exception as e:
# 异常情况将出错信息记录到stats["errors"]
+ node_name = node.get("name", "未知")
+ stats["errors"].append(f"创建{label} {node_name} 失败: {str(e)}")
+ step += 1
# 返回当前进度步数
+ return step
# 内部方法:创建节点间关系
+ def _create_relationships(
+ self, start_label, end_label, edges, rel_type,
+ progress_callback, current, total, stats=None
+ ) -> int:
+ """
+ 创建关系
+ Args:
+ start_label: 起始节点标签
+ end_label: 终止节点标签
+ edges: 关系对列表,每项为[起始节点名称, 终止节点名称]
+ rel_type: 关系类型
+ progress_callback: 进度回调函数
+ current: 当前进度值
+ total: 总进度步数
+ stats: 统计计数与错误信息字典
+ """
# 初始化关系计数
+ count = 0
# 初始化进度值
+ step = current
# 遍历所有关系对
+ for edge in edges:
+ try:
# 查找起始节点对象
+ start_node = self.node_matcher.match(start_label, name=str(edge[0])).first()
# 查找终止节点对象
+ end_node = self.node_matcher.match(end_label, name=str(edge[1])).first()
# 两端节点均存在后创建关系
+ if start_node and end_node:
+ rel = Relationship(start_node, rel_type, end_node)#创建关系对象
+ graph.create(rel)#写入数据库
+ count += 1#关系计数器+1
+ step += 1#进度步数+1
# 调用进度回调反馈
+ if progress_callback:# 有进度回调时,通知正在创建关系
+ progress_callback(f"创建关系: {edge[0]}-{rel_type}->{edge[1]}", step, total)
+ else:
# 如有节点不存在,将详细信息加到stats["errors"]
+ if stats:
+ missing = []#缺失节点列表
+ if not start_node:
+ missing.append(f"{start_label}:{edge[0]}")#添加起始节点
+ if not end_node:
+ missing.append(f"{end_label}:{edge[1]}")#添加终止节点
+ stats["errors"].append(f"关系创建失败: 找不到节点 {', '.join(missing)}")#添加错误信息
+ step += 1#进度步数+1
+ except Exception as e:
# 其他异常也写入stats["errors"]
+ if stats:# 有统计信息时,添加错误信息
+ stats["errors"].append(f"创建关系 {edge[0]}-{rel_type}->{edge[1]} 时出错: {str(e)}")
+ step += 1
# 返回成功创建的关系数
+ return count
# 图书数据导入的主方法
+ def import_books(
+ self, df: pd.DataFrame, clear_existing: bool = False,
+ progress_callback: Optional[Callable] = None
+ ) -> Dict[str, Any]:
# 方法说明:将DataFrame批量导入为图数据库节点与关系
+ """导入图书数据"""
# 初始化统计信息字典
+ stats = {
+ "books": 0, "authors": 0, "publishers": 0,
+ "categories": 0, "keywords": 0, "relationships": 0, "errors": []
+ }
+ try:
# 是否需要清空库
+ if clear_existing:
# 有进度回调时先通知
+ if progress_callback:
+ progress_callback("正在清空数据库...", 0, 0)
# 调用清库
+ self.clear_database()
# 生成图书节点数据列表
+ books = []
+ for _, row in df.iterrows():
+ if pd.notna(row["name"]):
# 创建基本字典(书名)
+ book = {"name": row["name"]}
# 补充出版年份
+ if pd.notna(row.get("publish_year")):
+ book["publish_year"] = int(row["publish_year"])
# 补充简介
+ if pd.notna(row.get("summary")):
+ book["summary"] = str(row["summary"])
# 处理关键字列表
+ if pd.notna(row.get("keywords")):
+ book["keywords"] = [kw.strip() for kw in str(row["keywords"]).split(";") if kw.strip()]
+ books.append(book)
# authors节点只创建唯一且非空的作者名
+ authors = [{"name": name} for name in set(df["author"].dropna())]
# publishers节点只创建唯一且非空的出版社名
+ publishers = [{"name": name} for name in set(df["publisher"].dropna())]
# categories节点只创建唯一且非空的类别名
+ categories = [{"name": name} for name in set(df["category"].dropna())]
# 提取所有keywords列的分号分割项,并去重空
+ keywords_set = set()
# 遍历所有关键词,并去重空格
+ for keyword in df["keywords"].dropna():
# 将关键词按分号分割,并去重空格
+ keywords_set.update([kw.strip() for kw in keyword.split(";") if kw.strip()])
# 将关键词列表转换为字典列表
+ keywords = [{"name": keyword.strip()} for keyword in keywords_set if keyword.strip()]
# 构建关系对列表
+ rels_written_by = [] # 书-作者
+ rels_published_by = [] # 书-出版社
+ rels_has_category = [] # 书-类别
+ rels_has_keyword = [] # 书-关键词
# 遍历每行收集上述所有关系
+ for _, row in df.iterrows():
# 书-作者
+ if pd.notna(row["name"]) and pd.notna(row["author"]):
+ rels_written_by.append([row["name"], row["author"]])
# 书-出版社
+ if pd.notna(row["name"]) and pd.notna(row["publisher"]):
+ rels_published_by.append([row["name"], row["publisher"]])
# 书-类别
+ if pd.notna(row["name"]) and pd.notna(row["category"]):
+ rels_has_category.append([row["name"], row["category"]])
# 书-关键词(支持多关键词)
+ if pd.notna(row["name"]) and pd.notna(row["keywords"]):
+ for kw in row["keywords"].split(";"):
+ rels_has_keyword.append([row["name"], kw.strip()])
# 统计全部节点和关系数,便于进度反馈
+ total_steps = (
+ len(books) + len(authors) + len(publishers) +
+ len(categories) + len(keywords) +
+ len(rels_written_by) + len(rels_published_by) +
+ len(rels_has_category) + len(rels_has_keyword)
+ )
# 初始化进度指针
+ current = 0
# 创建所有图书节点
+ current = self._create_nodes(
+ NODE_LABEL_BOOK, books, stats, "books",
+ progress_callback, current, total_steps
+ )
# 创建作者节点
+ current = self._create_nodes(
+ NODE_LABEL_AUTHOR, authors, stats, "authors",
+ progress_callback, current, total_steps
+ )
# 创建出版社节点
+ current = self._create_nodes(
+ NODE_LABEL_PUBLISHER, publishers, stats, "publishers",
+ progress_callback, current, total_steps
+ )
# 创建类别节点
+ current = self._create_nodes(
+ NODE_LABEL_CATEGORY, categories, stats, "categories",
+ progress_callback, current, total_steps
+ )
# 创建关键词节点
+ current = self._create_nodes(
+ NODE_LABEL_KEYWORD, keywords, stats, "keywords",
+ progress_callback, current, total_steps
+ )
# 所有节点创建后,反馈将创建关系
+ if progress_callback:
+ progress_callback("正在创建关系...", current, total_steps)
# 创建书-作者关系并累计关系数
+ stats["relationships"] += self._create_relationships(
+ NODE_LABEL_BOOK, NODE_LABEL_AUTHOR, rels_written_by,
+ RELATIONSHIP_TYPE_WRITTEN_BY, progress_callback, current, total_steps, stats
+ )
# 更新进度
+ current += len(rels_written_by)
# 创建书-出版社关系
+ stats["relationships"] += self._create_relationships(
+ NODE_LABEL_BOOK, NODE_LABEL_PUBLISHER, rels_published_by,
+ RELATIONSHIP_TYPE_PUBLISHED_BY, progress_callback, current, total_steps, stats
+ )
+ current += len(rels_published_by)
# 创建书-类别关系
+ stats["relationships"] += self._create_relationships(
+ NODE_LABEL_BOOK, NODE_LABEL_CATEGORY, rels_has_category,
+ RELATIONSHIP_TYPE_HAS_CATEGORY, progress_callback, current, total_steps, stats
+ )
+ current += len(rels_has_category)
# 创建书-关键词关系
+ stats["relationships"] += self._create_relationships(
+ NODE_LABEL_BOOK, NODE_LABEL_KEYWORD, rels_has_keyword,
+ RELATIONSHIP_TYPE_HAS_KEYWORD, progress_callback, current, total_steps, stats
+ )
# 如果整个导入过程中出现异常,记录到stats["errors"]
+ except Exception as e:
+ stats["errors"].append(f"导入过程出错: {str(e)}")
# 返回统计字典
+ return stats
# 创建ImportService的单例实例
+import_service = ImportService()5. 向量嵌入 #
5.1. embedding_service.py #
services/embedding_service.py
"""向量初始化服务"""
# 导入类型提示所需的类型List、Optional、Callable
from typing import List, Optional, Callable
# 导入HTTP请求库requests
import requests
# 导入py2neo包中的NodeMatcher用于节点查询
from py2neo import NodeMatcher
# 导入常量:节点标签(图书、作者)
from constants import NODE_LABEL_BOOK, NODE_LABEL_AUTHOR
# 导入Neo4j数据库连接
from database.connection import graph
# 导入配置对象
from config import config
# 定义向量初始化服务类
class EmbeddingInitService:
"""向量初始化服务"""
# 构造函数,初始化必要的成员变量
def __init__(self):
# 创建NodeMatcher对象,用于查找节点
self.node_matcher = NodeMatcher(graph)
# 嵌入API的URL,从配置获取
self.api_url = config.embedding.api_url
# API密钥,从配置获取
self.api_key = config.embedding.api_key
# 嵌入模型名,从配置获取
self.model = config.embedding.model
# 为节点批量生成和更新嵌入向量
def update_embeddings(
self,
progress_callback: Optional[Callable] = None
) -> dict:
"""为节点生成嵌入向量"""
# 初始化统计信息字典
stats = {"total_nodes": 0, "processed": 0, "failed": 0, "errors": []}
# 要处理的节点标签列表(图书和作者)
node_labels = [NODE_LABEL_BOOK, NODE_LABEL_AUTHOR]
# 统计所有目标节点的总数
total_nodes = 0
for label in node_labels:
# 查找指定标签的所有节点
nodes = list(self.node_matcher.match(label))
# 累加节点数量
total_nodes += len(nodes)
# 将总节点数写入统计信息
stats["total_nodes"] = total_nodes
# 当前已处理节点数
current = 0
# 遍历每个标签
for node_label in node_labels:
try:
# 获取该标签下所有节点
nodes = list(self.node_matcher.match(node_label))
# 遍历每个节点
for node in nodes:
# 获取节点name属性
name = node.get("name")
try:
# 获取文本的向量
embedding = self.get_embedding(name)
# 将嵌入向量写入节点属性
node["embedding"] = embedding
# 更新节点到数据库
graph.push(node)
# 成功处理节点数量加一
stats["processed"] += 1
# 当前计数加一
current += 1
# 调用进度回调函数(已成功处理该节点)
progress_callback(f" 已为 {name} 设置嵌入向量", current, total_nodes)
except Exception as e:
# 处理该节点失败时,记录错误信息
stats["errors"].append(f"处理 {name} 时出错: {str(e)}")
# 失败计数加一
stats["failed"] += 1
# 当前计数加一
current += 1
# 调用进度回调函数(处理失败)
progress_callback(f"处理 {name} 失败", current, total_nodes)
except Exception as e:
# 标签下所有节点处理发生异常时,记录错误
stats["errors"].append(f"处理{node_label}节点时出错: {str(e)}")
# 调用进度回调函数(标签整体失败)
progress_callback(f"处理{node_label}节点时出错: {str(e)}", current, total_nodes)
# 返回统计信息
return stats
# 获取单条文本的嵌入向量
def get_embedding(self, text: str) -> List[float]:
"""获取文本的嵌入向量"""
# 构造HTTP请求头
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
# 构造请求体
payload = {"model": self.model, "input": text}
# 输出请求体到控制台(调试用)
print(payload)
try:
# 发送post请求到嵌入API
response = requests.post(
self.api_url,
json=payload,
headers=headers
)
# 若请求成功,且返回状态码200
if response.status_code == 200:
# 获取响应体中的数据内容
data = response.json()
# 从响应体获取嵌入向量
embedding = data["data"][0].get("embedding")
# 若未获得向量内容,则抛出异常
if not embedding:
raise Exception("API返回数据格式错误")
# 返回嵌入向量
return embedding
else:
# 若HTTP响应非200,抛出异常包含错误状态码
error_msg = f"API错误: HTTP {response.status_code}"
raise Exception(error_msg)
# 网络异常处理
except requests.exceptions.RequestException as e:
# 抛出网络相关异常
raise Exception(f"网络错误: {e}") from e
# 创建EmbeddingInitService实例并赋值给embedding_service变量,供外部调用
embedding_service = EmbeddingInitService()5.2. .env #
.env
NEO4J_URI="bolt://localhost:7687"
NEO4J_USER="neo4j"
NEO4J_PASSWORD="12345678"
+VOLC_EMBEDDINGS_API_URL=https://ark.cn-beijing.volces.com/api/v3/embeddings
+VOLC_EMBEDDING_MODEL=doubao-embedding-text-240715
+VOLC_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f65.3. config.py #
config.py
"""配置管理模块"""
# 导入os模块,用于获取环境变量
import os
# 导入dotenv模块,用于加载.env文件中的环境变量
import dotenv
# 从dataclasses模块导入dataclass,用于简化数据类定义
from dataclasses import dataclass
# 加载.env文件中的所有环境变量
dotenv.load_dotenv()
# 定义Neo4j数据库配置的数据类
@dataclass
class Neo4jConfig:
"""Neo4j数据库配置"""
# 数据库URI
uri: str
# 用户名
user: str
# 密码
password: str
+@dataclass
+class EmbeddingConfig:
+ """火山方舟嵌入API配置"""
+ api_url: str
+ api_key: str
+ model: str
# 定义应用整体配置的数据类
@dataclass
class AppConfig:
"""应用配置"""
# Neo4j数据库配置属性
neo4j: Neo4jConfig
# 火山方舟嵌入API配置属性
+ embedding: EmbeddingConfig
# 创建一个AppConfig对象,保存应用的整体配置
+config = AppConfig(
# Neo4j数据库配置,使用Neo4jConfig类初始化
+ neo4j=Neo4jConfig(
# 从环境变量获取数据库URI,如果没有则使用默认值"bolt://localhost:7687"
+ uri=os.environ.get("NEO4J_URI", "bolt://localhost:7687"),
# 从环境变量获取数据库用户名,如果没有则使用默认值"neo4j"
+ user=os.environ.get("NEO4J_USER", "neo4j"),
# 从环境变量获取数据库密码,如果没有则使用默认值"12345678"
+ password=os.environ.get("NEO4J_PASSWORD", "12345678"),
+ ),
# 嵌入API配置,使用EmbeddingConfig类初始化
+ embedding=EmbeddingConfig(
# 从环境变量获取嵌入API的URL,没有则使用默认值
+ api_url=os.environ.get(
+ "VOLC_EMBEDDINGS_API_URL",
+ "https://ark.cn-beijing.volces.com/api/v3/embeddings"
+ ),
# 从环境变量获取API密钥,如果没有则为空字符串
+ api_key=os.environ.get("VOLC_API_KEY", ""),
# 从环境变量获取嵌入模型名,没有则使用默认值"doubao-embedding-text-240715"
+ model=os.environ.get("VOLC_EMBEDDING_MODEL", "doubao-embedding-text-240715"),
)
+)5.4. connection.py #
database/connection.py
# 这是Neo4j数据库连接管理的模块说明
"""Neo4j数据库连接管理"""
# 从py2neo库导入Graph类,用于连接和操作Neo4j数据库
from py2neo import Graph
# 从config模块导入get_config函数,用于获取配置信息
+from config import config
# 使用配置信息创建Neo4j数据库连接实例
graph = Graph(config.neo4j.uri, auth=(config.neo4j.user, config.neo4j.password))
5.5. 0_数据导入.py #
pages/0_数据导入.py
# 数据导入页面模块注释
"""数据导入页面"""
# 导入streamlit库并简写为st
import streamlit as st
from services.import_service import import_service
+from services.embedding_service import embedding_service
# 定义主函数
def main():
# 数据导入页面主函数的文档字符串
"""数据导入页面主函数"""
# 配置Streamlit页面,设置标题和布局
st.set_page_config(
page_title="数据导入 - 图书知识图谱",
layout="wide"
)
# 显示页面主标题
st.title("数据导入管理")
# 插入分割线
st.markdown("---")
# 创建一个标签页,只包含“CSV数据导入”标签,并将其赋值给tab1变量
+ tab1, tab2 = st.tabs(["CSV数据导入","向量初始化"])
# 在tab1标签页下进行后续操作
with tab1:
# 显示“CSV数据导入”二级标题
st.header("CSV数据导入")
# 展示有关上传CSV文件的说明及字段要求
st.markdown("""
请上传包含图书信息的CSV文件。CSV文件应包含以下列:
- `name`: 书名
- `author`: 作者
- `publisher`: 出版社
- `category`: 类别
- `publish_year`: 出版年份
- `summary`: 简介
- `keywords`: 关键词(用分号分隔)
""")
# 显示文件上传控件,限定仅可上传csv类型文件,并显示帮助提示
uploaded_file = st.file_uploader(
"选择CSV文件",
type=["csv"],
help="上传包含图书信息的CSV文件"
)
# 如果用户已经上传了文件
if uploaded_file is not None:
try:
# 读取上传的文件内容,并解码为utf-8格式的字符串
csv_content = uploaded_file.read().decode("utf-8")
# 调用服务将CSV内容解析为DataFrame
df = import_service.parse_csv(csv_content)
# 验证上传的DataFrame是否符合CSV格式要求
is_valid, error_msg = import_service.validate_csv(df)
# 如果格式校验未通过,显示错误
if not is_valid:
st.error(f"CSV格式错误: {error_msg}")
# 如果校验通过,显示成功消息,并告知记录数量
else:
st.success(f"CSV文件验证通过!共 {len(df)} 条记录")
# 显示预览
with st.expander("数据预览", expanded=True):
st.dataframe(df)
# 导入选项
clear_existing = st.checkbox(
"清空现有数据",
value=False,
help="导入前清空数据库中的所有数据"
)
# 当用户点击"开始导入"按钮时执行
if st.button("开始导入", type="primary", width='stretch'):
# 创建一个显示日志信息的空容器
log_container = st.empty()
# 用于存放日志消息的列表
log_messages = []
# 定义进度回调函数
def progress_callback(message, current, total):
# 添加当前消息到日志列表(包含进度信息)
log_messages.append(f"- [{current}/{total}] {message}")
# 刷新日志容器内容
log_container.text("\n".join(log_messages))
try:
# 调用导入服务执行数据导入
stats = import_service.import_books(
df,#数据源
clear_existing=clear_existing,# 是否清空现有数据
progress_callback=progress_callback# 进度回调函数
)
# 显示最终统计信息
if stats:
st.info(
f"完成: 图书 {stats.get('books', 0)} | "
f"作者 {stats.get('authors', 0)} | "
f"出版社 {stats.get('publishers', 0)} | "
f"类别 {stats.get('categories', 0)} | "
f"关键词 {stats.get('keywords', 0)} | "
f"关系 {stats.get('relationships', 0)}"
)
# 如果有错误信息,显示出来
if stats.get("errors") and len(stats["errors"]) > 0:
with st.expander("⚠️ 导入过程中的错误", expanded=False):
for error in stats["errors"]:
st.error(error)
except Exception as e:
# 导入过程中如有异常则显示错误信息
st.error(f"导入失败: {str(e)}")
# 导入traceback模块以显示详细错误
import traceback
# 展开详细错误信息
with st.expander("详细错误信息"):
# 格式化并显示完整的异常调用栈
st.code(traceback.format_exc())
# 捕获任何异常,显示读取文件失败的错误信息
except Exception as e:
st.error(f"读取文件失败: {str(e)}")
# 使用tab2选项卡
+ with tab2:
# 设置页面标题为“向量初始化”
+ st.header("向量初始化")
# 显示多行文字介绍向量初始化的功能
+ st.markdown("""
+ 为数据库中的节点生成嵌入向量。此操作会为所有Book和Author节点生成向量,
+ 用于后续的向量检索功能。
+ """)
# 如果用户点击“开始生成向量”按钮
+ if st.button("开始生成向量", type="primary", width='stretch'):
# 创建一个空的日志容器用于显示进度
+ log_container = st.empty()
# 用于存放日志消息的列表
+ log_messages = []
# 定义进度回调函数
+ def progress_callback(message,current, total):
# 向日志列表追加当前进度消息
+ log_messages.append(f"- [{current}/{total}] {message}")
# 在日志容器中刷新显示所有日志消息
+ log_container.text("\n".join(log_messages))
+ try:
# 调用embedding_service的update_embeddings方法,开始进行向量初始化
+ stats = embedding_service.update_embeddings(progress_callback=progress_callback)
# 在页面上显示向量生成成功的提示
+ st.success("向量生成完成!")
# 显示统计信息标题
+ st.markdown("### 生成统计")
# 将页面分为三列用于展示不同统计项
+ col1, col2, col3 = st.columns(3)
# 第一列显示总节点数
+ with col1:
+ st.metric("总节点数", stats["total_nodes"])
# 第二列显示已处理成功的数量
+ with col2:
+ st.metric("成功", stats["processed"])
# 第三列显示处理失败的数量
+ with col3:
+ st.metric("失败", stats["failed"])
# 如果存在错误信息,则在可展开区域中展示所有错误
+ if stats["errors"]:
+ with st.expander("错误信息", expanded=False):
+ for error in stats["errors"]:
+ st.error(error)
# 捕获异常并显示错误信息
+ except Exception as e:
# 在页面上显示向量生成失败的错误信息
+ st.error(f"向量生成失败: {str(e)}")
# 导入traceback模块用于获取详细出错信息
+ import traceback
# 可展开区域显示完整Traceback
+ with st.expander("详细错误信息"):
+ st.code(traceback.format_exc())
# 判断是否为主程序入口
if __name__ == "__main__":
# 调用主函数
main() 6. 提问 #
6.1. 1_问答系统.py #
pages/1_问答系统.py
"""Streamlit图书知识图谱问答系统主应用"""
import streamlit as st
from config import config
def main():
"""主函数"""
st.set_page_config(
page_title="问答系统 - 图书知识图谱",
layout="wide"
)
if "messages" not in st.session_state:
st.session_state.messages = []
with st.sidebar:
st.markdown("### 参数设置")
st.markdown(
"<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
unsafe_allow_html=True,
)
for message in st.session_state.messages:
st.chat_message(message["role"]).write(message["content"])
if query := st.chat_input("输入图书相关问题", key="query_input"):
st.session_state.messages.append({"role": "user", "content": query})
st.chat_message("user").write(query)
if __name__ == "__main__":
main()
7. 向量化查询 #
7.1. queries.py #
database/queries.py
"""图数据库查询"""
# 导入类型提示
from typing import List, Dict, Any, Optional
# 导入py2neo的Node类型
from py2neo import Node
# 导入Neo4j连接对象
from database.connection import graph
# 导入常量:向量索引名
from constants import VECTOR_INDEX_BOOK, VECTOR_INDEX_AUTHOR
# 定义图书查询类
class BookQuery:
"""图书查询"""
# 基于嵌入向量检索图书的方法
def query_by_embedding(
self, query_embedding: List[float], top_k: int = 3
) -> List[Dict[str, Any]]:
"""基于向量检索图书"""
# 定义Cypher向量检索查询语句
query = f"""
CALL db.index.vector.queryNodes('{VECTOR_INDEX_BOOK}', $top_k, $query_embedding)
YIELD node, score
MATCH (node:Book)
OPTIONAL MATCH (node)-[:written_by]->(author:Author)
OPTIONAL MATCH (node)-[:published_by]->(publisher:Publisher)
OPTIONAL MATCH (node)-[:has_category]->(category:Category)
OPTIONAL MATCH (node)-[:has_keyword]->(keyword:Keyword)
WITH node, score, author, publisher, category, COLLECT(DISTINCT keyword.name) AS keyword_names
RETURN node, score, author, publisher, category, keyword_names
ORDER BY score DESC
"""
try:
# 执行Cypher查询,传入top_k和query_embedding参数
results = graph.run(query, top_k=top_k, query_embedding=query_embedding)
# 初始化图书结果列表
books = []
# 遍历查询结果
for record in results:
# 获取图书节点
node: Node = record["node"]
# 获取相似度分数
score = float(record["score"])
# 获取作者节点
author_node: Optional[Node] = record.get("author")
# 获取出版社节点
publisher_node: Optional[Node] = record.get("publisher")
# 获取类别节点
category_node: Optional[Node] = record.get("category")
# 获取关键词名称列表
keyword_names = record.get("keyword_names", [])
# 如果关系中存在关键词则使用,否则优先节点属性keywords
keywords = keyword_names if keyword_names else (node.get("keywords", []) or [])
# 构建图书结果字典
books.append({
"name": node.get("name", ""),
"similarity": score,
"作者": author_node.get("name") if author_node else None,
"出版社": publisher_node.get("name") if publisher_node else None,
"类别": category_node.get("name") if category_node else None,
"出版年份": node.get("publish_year"),
"简介": node.get("summary"),
"关键词": keywords,
})
# 返回图书结果列表
return books
except Exception as e:
# 查询出错时抛出异常并提示
raise Exception(f"图书查询失败: {e}") from e
# 定义作者查询类
class AuthorQuery:
"""作者查询"""
# 基于嵌入向量检索作者的方法
def query_by_embedding(
self, query_embedding: List[float], top_k: int = 3
) -> List[Dict[str, Any]]:
"""基于向量检索作者"""
# 定义Cypher向量检索查询语句
query = f"""
CALL db.index.vector.queryNodes('{VECTOR_INDEX_AUTHOR}', $top_k, $query_embedding)
YIELD node, score
MATCH (node:Author)
OPTIONAL MATCH (node)<-[:written_by]-(book:Book)
RETURN node, score, COLLECT(DISTINCT book.name) AS book_names
ORDER BY score DESC
"""
try:
# 执行Cypher查询,传入top_k和query_embedding参数
results = graph.run(query, top_k=top_k, query_embedding=query_embedding)
# 初始化作者结果列表
authors = []
# 遍历查询结果
for record in results:
# 获取作者节点
node: Node = record["node"]
# 获取相似度分数
score = float(record["score"])
# 获取相关图书名称列表
book_names = record.get("book_names", [])
# 构建作者结果字典
authors.append({
"name": node.get("name", ""),
"similarity": score,
"相关图书": book_names,
})
# 返回作者结果列表
return authors
except Exception as e:
# 查询出错时抛出异常并提示
raise Exception(f"作者查询失败: {e}") from e
7.2. retrieval_service.py #
services/retrieval_service.py
"""检索服务"""
# 导入类型提示
from typing import List, Dict, Any
# 导入图书和作者的查询类
from database.queries import BookQuery, AuthorQuery
# 导入嵌入服务(虽然此文件中未直接用到)
from services.embedding_service import embedding_service
# 导入用于区分查询类型和默认top_k值的常量
from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR, DEFAULT_TOP_K
class RetrievalService:
"""检索服务"""
# 初始化方法,分别实例化图书和作者查询对象
def __init__(self):
self.book_query = BookQuery()
self.author_query = AuthorQuery()
# 基于嵌入向量查询图书或作者的主方法
def query_by_embedding(
self,
query_embedding: List[float], # 问题的嵌入向量
query_type: str, # 查询类型(图书或作者)
top_k: int = DEFAULT_TOP_K, # 返回结果数量,默认值为常量
) -> List[Dict[str, Any]]:
"""使用嵌入向量查询"""
# 若查询类型为“图书”,则使用book_query执行查询
if query_type == QUERY_TYPE_BOOK:
return self.book_query.query_by_embedding(query_embedding, top_k)
# 若查询类型为“作者”,则使用author_query执行查询
elif query_type == QUERY_TYPE_AUTHOR:
return self.author_query.query_by_embedding(query_embedding, top_k)
# 其他类型(不支持的类型)抛出异常
else:
raise ValueError(f"不支持的查询类型: {query_type}")
# 创建RetrievalService的单例,便于外部调用
retrieval_service = RetrievalService()7.3. constants.py #
constants.py
# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
# 书名
"name",
# 作者
"author",
# 出版社
"publisher",
# 类别
"category",
# 出版年份
"publish_year",
# 简介
"summary",
# 关键词(用分号分隔)
"keywords"
]
# 定义图书节点标签
NODE_LABEL_BOOK = "Book"
# 定义作者节点标签
NODE_LABEL_AUTHOR = "Author"
# 定义出版社节点标签
NODE_LABEL_PUBLISHER = "Publisher"
# 定义类别节点标签
NODE_LABEL_CATEGORY = "Category"
# 定义关键词节点标签
NODE_LABEL_KEYWORD = "Keyword"
# 定义“书籍-作者”关系类型
RELATIONSHIP_TYPE_WRITTEN_BY = "written_by"
# 定义“书籍-出版社”关系类型
RELATIONSHIP_TYPE_PUBLISHED_BY = "published_by"
# 定义“书籍-类别”关系类型
RELATIONSHIP_TYPE_HAS_CATEGORY = "has_category"
# 定义“书籍-关键词”关系类型
RELATIONSHIP_TYPE_HAS_KEYWORD = "has_keyword"
# 查询类型
# 图书查询类型
+QUERY_TYPE_BOOK = "图书"
# 作者查询类型
+QUERY_TYPE_AUTHOR = "作者"
# 默认Top K值
+DEFAULT_TOP_K = 3
# 向量索引名称
# 图书向量索引名称
+VECTOR_INDEX_BOOK = "book_embeddings"
# 作者向量索引名称
+VECTOR_INDEX_AUTHOR = "author_embeddings"7.4. 1_问答系统.py #
pages/1_问答系统.py
# Streamlit图书知识图谱问答系统主应用
"""Streamlit图书知识图谱问答系统主应用"""
# 导入Streamlit库
import streamlit as st
# 导入配置文件
from config import config
# 导入嵌入向量服务
+from services.embedding_service import embedding_service
# 导入检索服务
+from services.retrieval_service import retrieval_service
# 导入查询类型常量
+from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR
# 定义主函数
def main():
"""主函数"""
# 设置页面配置(标题、布局)
st.set_page_config(
page_title="问答系统 - 图书知识图谱",
layout="wide"
)
# 如果session_state中还没有'messages',则初始化为空列表
if "messages" not in st.session_state:
st.session_state.messages = []
# 侧边栏参数设置
with st.sidebar:
# 显示参数设置的标题
st.markdown("### 参数设置")
# 单选框选择查询类型(默认选中第一个,即图书)
+ query_type = st.radio("选择查询类型", [QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR], index=0)
# 滑块选择返回结果数量(Top K),范围1-10,默认3
+ top_k = st.slider("返回结果数量 (Top K)", 1, 10, 3, 1)
# 在主界面居中显示系统标题
st.markdown(
"<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
unsafe_allow_html=True,
)
# 按顺序显示历史消息
for message in st.session_state.messages:
st.chat_message(message["role"]).write(message["content"])
# 获取用户输入的问题(如果有输入)
if query := st.chat_input("输入图书相关问题", key="query_input"):
# 将用户输入添加到消息历史记录
st.session_state.messages.append({"role": "user", "content": query})
# 显示用户输入的消息
st.chat_message("user").write(query)
# 显示查询中loading动画
+ with st.spinner("正在查询中..."):
+ try:
# 获取问题的嵌入向量
+ query_embedding = embedding_service.get_embedding(query)
# 使用嵌入向量进行检索,获取结果
+ results = retrieval_service.query_by_embedding(query_embedding, query_type, top_k)
# 显示“查询结果”标题
+ st.markdown("#### 查询结果")
# 如果有返回结果,逐条显示
+ if results:
+ for idx, item in enumerate(results, 1):
# 用可展开板块显示每一条结果,默认展开第一个
+ with st.expander(f"结果 {idx}", expanded=True if idx == 1 else False):
# 按字典项逐个展示
+ for key, value in item.items():
+ st.write(f"**{key}**: {value}")
+ else:
# 若无结果,显示提示
+ st.info("未找到相关结果。")
+ except Exception as err:
# 捕获异常时的处理
+ error_msg = f"查询过程中出错: {str(err)}"
+ st.session_state.messages.append({
+ "role": "assistant",
+ "content": error_msg
+ })
+ st.chat_message("assistant").write(error_msg)
+ st.error(error_msg)
# 判断是否作为主程序运行
if __name__ == "__main__":
# 调用主函数启动应用
main()8. 提问 #
8.1. init.py #
llm/init.py
# 从当前包导入BaseLLM基类
from .base import BaseLLM
# 从当前包导入DeepSeekLLM类
from .deepseek import DeepSeekLLM
# 从当前包导入VolcengineLLM类
from .volcengine import VolcengineLLM
# 定义__all__变量,指定包导出的公有接口
__all__ = ["BaseLLM", "DeepSeekLLM", "VolcengineLLM"]8.2. base.py #
llm/base.py
# LLM抽象基类的文档字符串
"""LLM抽象基类"""
# 导入ABC和abstractmethod,用于创建抽象基类
from abc import ABC, abstractmethod
# 定义LLM的抽象基类,继承自ABC
class BaseLLM(ABC):
# 类的文档字符串,说明是LLM抽象基类
"""LLM抽象基类"""
# 抽象方法generate,子类必须实现
@abstractmethod
def generate(self, prompt: str, **kwargs) -> str:
# 生成回复的文档字符串
"""生成回复"""
# 抽象方法内使用pass,占位
pass
8.3. deepseek.py #
llm/deepseek.py
"""DeepSeek LLM实现"""
# 导入可选类型提示
from typing import Optional
# 导入 DeepSeek 的 ChatDeepSeek 类
from langchain_deepseek import ChatDeepSeek
# 导入自定义的基础LLM类
from .base import BaseLLM
# 导入全局配置对象
from config import config
# 导入默认温度和最大token常量
from constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS
# 定义 DeepSeekLLM 类,继承自 BaseLLM
class DeepSeekLLM(BaseLLM):
"""DeepSeek LLM"""
# 构造函数,支持自定义模型名、api_key、温度和最大输出tokens数量
def __init__(
self,
model_name: Optional[str] = None,
api_key: Optional[str] = None,
temperature: float = DEFAULT_TEMPERATURE,
max_tokens: int = DEFAULT_MAX_TOKENS,
):
# 设置模型名:优先用传入参数,否则用配置里的默认
self.model_name = model_name or config.deepseek.model
# 设置api_key:优先用传入参数,否则用配置里的默认
self.api_key = api_key or config.deepseek.api_key
# 设置温度:优先用传入参数,否则用默认
self.temperature = temperature or DEFAULT_TEMPERATURE
# 设置最大返回token数:优先用传入参数,否则用默认
self.max_tokens = max_tokens or DEFAULT_MAX_TOKENS
# 初始化 DeepSeek 的 LLM 实例
self.llm = ChatDeepSeek(
api_key=self.api_key,
model=self.model_name,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
# 生成方法,实现对LLM的调用
def generate(self, prompt: str) -> str:
"""生成回复"""
try:
# 调用 DeepSeek LLM 获取生成内容
response = self.llm.invoke(prompt)
# 检查是否有返回内容
if not response or not response.content:
# 返回内容为空时抛出异常
raise Exception("API返回空响应")
# 正常时返回生成内容
return response.content
except Exception as e:
# 捕获并重新抛出异常,加上友好的报错说明
raise Exception(f"DeepSeek API调用失败: {e}") from e
8.4. volcengine.py #
llm/volcengine.py
"""火山引擎LLM实现"""
# 导入Optional类型用于参数类型标注
from typing import Optional
# 导入OpenAI库用于与火山引擎API交互
from openai import OpenAI
# 导入自定义的BaseLLM基类
from .base import BaseLLM
# 导入全局配置对象
from config import config
# 导入默认温度和最大tokens常量
from constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS
# 定义VolcengineLLM类,继承自BaseLLM
class VolcengineLLM(BaseLLM):
"""火山引擎LLM"""
# 构造函数,初始化各项参数
def __init__(self, model_name: Optional[str] = None, api_key: Optional[str] = None,
temperature: float = DEFAULT_TEMPERATURE, max_tokens: int = DEFAULT_MAX_TOKENS):
# 获取API基础URL
self.base_url = config.volcengine.base_url
# 获取模型名称,优先使用传入值,否则取配置中的默认值
self.model_name = model_name or config.volcengine.model
# 获取API密钥,优先使用传入值,否则取配置中的默认值
self.api_key = api_key or config.volcengine.api_key
# 设置温度参数,控制生成内容的多样性
self.temperature = temperature or DEFAULT_TEMPERATURE
# 设置最大token数量
self.max_tokens = max_tokens or DEFAULT_MAX_TOKENS
# 初始化OpenAI客户端,用于与火山引擎API交互
self.client = OpenAI(
base_url=self.base_url,
api_key=self.api_key,
)
# 定义generate方法用于生成模型回复
def generate(self, prompt: str) -> str:
"""生成回复"""
try:
# 向火山引擎API发送请求,生成回复
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
# 判断响应内容是否有效
if not response.choices or not response.choices[0].message.content:
# 如果返回内容为空,抛出异常
raise Exception("API返回空响应")
# 返回生成的回复内容
return response.choices[0].message.content
except Exception as e:
# 捕获异常并抛出自定义异常信息
raise Exception(f"Volcengine API调用失败: {e}") from e
8.5. init.py #
utils/init.py
"""工具函数模块"""
8.6. formatters.py #
utils/formatters.py
"""格式化工具"""
# 导入类型提示List, Dict, Any
from typing import List, Dict, Any
# 定义用于格式化查询结果上下文的函数
def format_context(query_type: str, results: List[Dict[str, Any]]) -> str:
"""格式化查询结果上下文"""
# 定义内部函数:格式化一本图书的方法
def format_book(result: Dict[str, Any]) -> str:
# 准备字段及相应的值,元组形式依次为(字段名, 字段值)
fields = [
("作者", result.get("作者")),
("出版社", result.get("出版社")),
("类别", result.get("类别")),
("出版年份", result.get("出版年份")),
("简介", result.get("简介")),
# 如果关键词存在,拼接成字符串,否则为None
("关键词", ", ".join(result["关键词"]) if result.get("关键词") else None),
]
# 对所有有值的字段拼接为多行字符串,格式为“- 字段名: 字段值”
return "\n".join([f" - {k}: {v}" for k, v in fields if v])
# 定义内部函数:格式化作者(只显示相关图书)
def format_author(result: Dict[str, Any]) -> str:
# 如果结果中有“相关图书”字段,则拼接显示
if result.get("相关图书"):
return f" - 相关图书: {', '.join(result['相关图书'])}"
# 否则返回空串
return ""
# 用于存放所有格式化后的每条结果内容
lines = []
# 遍历全部结果数据
for idx, result in enumerate(results, 1):
# 构建当前项的标题,包括序号、名称及相似度
header = f"{idx}. {result['name']} (相似度: {result['similarity']:.4f})"
# 根据查询类型决定调用哪个格式化方法
details = format_book(result) if query_type == "图书" else format_author(result)
# 若有详情,拼接标题、详情及换行;否则仅标题加换行
info = f"{header}\n{details}\n" if details else f"{header}\n"
# 将本次内容添加到lines列表
lines.append(info)
# 用两个换行符分隔拼接所有结果,返回最终字符串
return "\n\n".join(lines)
8.7. .env #
.env
NEO4J_URI="bolt://localhost:7687"
NEO4J_USER="neo4j"
NEO4J_PASSWORD="12345678"
VOLC_EMBEDDINGS_API_URL=https://ark.cn-beijing.volces.com/api/v3/embeddings
VOLC_EMBEDDING_MODEL=doubao-embedding-text-240715
VOLC_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6
+VOLCENGINE_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
+VOLCENGINE_MODEL=doubao-seed-1-6-250615
+VOLCENGINE_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6
+DeepSeek_BASE_URL="https://api.deepseek.com/v1"
+DeepSeek_API_KEY="sk-278496d471bc4f4cb0ccb8c389a15018"
+DeepSeek_MODEL="deepseek-chat"8.8. config.py #
config.py
"""配置管理模块"""
# 导入os模块,用于获取环境变量
import os
# 导入dotenv模块,用于加载.env文件中的环境变量
import dotenv
# 从dataclasses模块导入dataclass,用于简化数据类定义
from dataclasses import dataclass
+from typing import Optional
# 加载.env文件中的所有环境变量
dotenv.load_dotenv()
# 定义Neo4j数据库配置的数据类
@dataclass
class Neo4jConfig:
"""Neo4j数据库配置"""
# 数据库URI
uri: str
# 用户名
user: str
# 密码
password: str
@dataclass
class EmbeddingConfig:
"""火山方舟嵌入API配置"""
api_url: str
api_key: str
model: str
+@dataclass
+class VolcengineConfig:
+ """火山引擎API配置"""
+ base_url: Optional[str]
+ api_key: Optional[str]
+ model: Optional[str]
+@dataclass
+class DeepSeekConfig:
+ """DeepSeek API配置"""
+ base_url: str
+ api_key: str
+ model: str
# 定义应用整体配置的数据类
@dataclass
class AppConfig:
"""应用配置"""
# Neo4j数据库配置属性
neo4j: Neo4jConfig
# 火山方舟嵌入API配置属性
embedding: EmbeddingConfig
# 火山引擎API配置属性
+ volcengine: VolcengineConfig
# DeepSeek API配置属性
+ deepseek: DeepSeekConfig
# 创建一个AppConfig对象,保存应用的整体配置
config = AppConfig(
# Neo4j数据库配置,使用Neo4jConfig类初始化
neo4j=Neo4jConfig(
# 从环境变量获取数据库URI,如果没有则使用默认值"bolt://localhost:7687"
uri=os.environ.get("NEO4J_URI", "bolt://localhost:7687"),
# 从环境变量获取数据库用户名,如果没有则使用默认值"neo4j"
user=os.environ.get("NEO4J_USER", "neo4j"),
# 从环境变量获取数据库密码,如果没有则使用默认值"12345678"
password=os.environ.get("NEO4J_PASSWORD", "12345678"),
),
# 嵌入API配置,使用EmbeddingConfig类初始化
embedding=EmbeddingConfig(
# 从环境变量获取嵌入API的URL,没有则使用默认值
api_url=os.environ.get(
"VOLC_EMBEDDINGS_API_URL",
"https://ark.cn-beijing.volces.com/api/v3/embeddings"
),
# 从环境变量获取API密钥,如果没有则为空字符串
api_key=os.environ.get("VOLC_API_KEY", ""),
# 从环境变量获取嵌入模型名,没有则使用默认值"doubao-embedding-text-240715"
model=os.environ.get("VOLC_EMBEDDING_MODEL", "doubao-embedding-text-240715"),
),
# 初始化火山引擎API配置,从环境变量获取相关参数
volcengine=VolcengineConfig(
# 从环境变量获取火山引擎API的基础URL
base_url=os.environ.get("VOLCENGINE_BASE_URL","https://ark.cn-beijing.volces.com/api/v3"),
# 从环境变量获取火山引擎API的密钥
api_key=os.environ.get("VOLCENGINE_API_KEY","d52e49a1-36ea-44bb-bc6e-65ce789a72f6"),
# 从环境变量获取火山引擎API的模型名称
model=os.environ.get("VOLCENGINE_MODEL","doubao-seed-1-6-250615"),
),
# 初始化DeepSeek API配置,支持从环境变量读取参数,并提供默认值
deepseek=DeepSeekConfig(
# 从环境变量获取DeepSeek基础URL,若未设置则使用默认值"https://api.deepseek.com"
base_url=os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com"),
# 从环境变量获取DeepSeek API密钥,默认值为空字符串
api_key=os.environ.get("DEEPSEEK_API_KEY", "ae59c1b6731e4d5d8f1f3fd0ad340b39"),
# 从环境变量获取DeepSeek模型名,若未设置则使用默认值"deepseek-chat"
model=os.environ.get("DEEPSEEK_MODEL", "deepseek-chat"),
),
)8.9. constants.py #
constants.py
# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
# 书名
"name",
# 作者
"author",
# 出版社
"publisher",
# 类别
"category",
# 出版年份
"publish_year",
# 简介
"summary",
# 关键词(用分号分隔)
"keywords"
]
# 定义图书节点标签
NODE_LABEL_BOOK = "Book"
# 定义作者节点标签
NODE_LABEL_AUTHOR = "Author"
# 定义出版社节点标签
NODE_LABEL_PUBLISHER = "Publisher"
# 定义类别节点标签
NODE_LABEL_CATEGORY = "Category"
# 定义关键词节点标签
NODE_LABEL_KEYWORD = "Keyword"
# 定义“书籍-作者”关系类型
RELATIONSHIP_TYPE_WRITTEN_BY = "written_by"
# 定义“书籍-出版社”关系类型
RELATIONSHIP_TYPE_PUBLISHED_BY = "published_by"
# 定义“书籍-类别”关系类型
RELATIONSHIP_TYPE_HAS_CATEGORY = "has_category"
# 定义“书籍-关键词”关系类型
RELATIONSHIP_TYPE_HAS_KEYWORD = "has_keyword"
# 查询类型
# 图书查询类型
QUERY_TYPE_BOOK = "图书"
# 作者查询类型
QUERY_TYPE_AUTHOR = "作者"
# 默认Top K值
DEFAULT_TOP_K = 3
# 默认最大Tokens值
+DEFAULT_MAX_TOKENS = 4096
# 默认温度值
+DEFAULT_TEMPERATURE = 0.7
# 向量索引名称
# 图书向量索引名称
VECTOR_INDEX_BOOK = "book_embeddings"
# 作者向量索引名称
VECTOR_INDEX_AUTHOR = "author_embeddings"8.10. 1_问答系统.py #
pages/1_问答系统.py
# Streamlit图书知识图谱问答系统主应用
"""Streamlit图书知识图谱问答系统主应用"""
# 导入Streamlit库
import streamlit as st
# 导入PromptTemplate
+from langchain_core.prompts import PromptTemplate
# 导入配置文件
from config import config
# 导入嵌入向量服务
from services.embedding_service import embedding_service
# 导入检索服务
from services.retrieval_service import retrieval_service
# 导入查询类型常量
from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR
# 导入大模型服务
+from llm import VolcengineLLM, DeepSeekLLM
+from utils.formatters import format_context
+prompt_template = PromptTemplate(
+ input_variables=["question", "context"],
+ template="""你是一名图书知识助手,需要根据提供的图书信息回答用户的提问。
+ 请直接回答问题,如果信息不足,请回答"根据现有信息无法确定"。
+ 问题:{question}
+ 图书信息:\n{context}
+ 回答:""",
+)
# 定义主函数
def main():
"""主函数"""
# 设置页面配置(标题、布局)
st.set_page_config(
page_title="问答系统 - 图书知识图谱",
layout="wide"
)
# 如果session_state中还没有'messages',则初始化为空列表
if "messages" not in st.session_state:
st.session_state.messages = []
# 侧边栏参数设置
with st.sidebar:
# 显示参数设置的标题
st.markdown("### 参数设置")
# 单选框选择查询类型(默认选中第一个,即图书)
query_type = st.radio("选择查询类型", [QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR], index=0)
# 滑块选择返回结果数量(Top K),范围1-10,默认3
top_k = st.slider("返回结果数量 (Top K)", 1, 10, 3, 1)
# 滑块选择温度值,范围0.0-1.0,默认0.3
+ temperature = st.slider("温度 (Temperature)", 0.0, 1.0, 0.3, 0.1)
# 选择大模型服务商,默认火山引擎
+ llm_provider = st.selectbox("选择大模型服务商", ["volcengine", "deepseek"], index=0)
# 如果选择火山引擎,则显示火山引擎API Key和模型名
+ if llm_provider == "volcengine":
+ api_key = st.text_input(
+ "火山引擎 API Key",
+ value=config.volcengine.api_key or "",
+ type="password",
+ help="如留空则使用服务器默认配置",
+ )
+ model_name = st.text_input(
+ "火山引擎模型名",
+ value=config.volcengine.model or "",
+ help="如留空则使用服务器默认配置"
+ )
+ else:
+ api_key = st.text_input(
+ "DeepSeek API Key",
+ value=config.deepseek.api_key or "",
+ type="password",
+ help="如留空则使用服务器默认配置",
+ )
+ model_name = st.text_input(
+ "DeepSeek模型名",
+ value=config.deepseek.model or "",
+ help="如留空则使用服务器默认配置"
+ )
# 在主界面居中显示系统标题
st.markdown(
"<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
unsafe_allow_html=True,
)
# 按顺序显示历史消息
for message in st.session_state.messages:
st.chat_message(message["role"]).write(message["content"])
+ if message["role"] == "assistant":
+ # 展开可查看详细检索结果
+ with st.expander("查看详细结果"):
+ # 以 json 格式显示检索的类型和结果
+ st.json({"type": message['query_type'], "results": message['results']})
# 获取用户输入的问题(如果有输入)
if query := st.chat_input("输入图书相关问题", key="query_input"):
# 将用户输入添加到消息历史记录
st.session_state.messages.append({"role": "user", "content": query})
# 显示用户输入的消息
st.chat_message("user").write(query)
# 显示查询中的 loading 动画
with st.spinner("正在查询中..."):
try:
# 获取用户输入 query 的嵌入向量
query_embedding = embedding_service.get_embedding(query)
# 用嵌入向量进行检索,查询对应的结果
results = retrieval_service.query_by_embedding(query_embedding, query_type, top_k)
# 如果没有检索到结果
+ if not results or len(results) == 0:
# 设置回复内容为没有找到相关信息
+ answer = "抱歉,没有找到相关的信息。"
else:
# 如果选择的服务商是火山引擎
+ if llm_provider == "volcengine":
# 实例化火山引擎大模型
+ llm = VolcengineLLM(
+ model_name=model_name if model_name else None,
+ api_key=api_key if api_key else None,
+ )
+ else:
# 否则实例化 DeepSeek 大模型
+ llm = DeepSeekLLM(
+ model_name=model_name if model_name else None,
+ api_key=api_key if api_key else None,
+ )
# 根据检索结果格式化上下文字符串
+ context_str = format_context(query_type, results)
# 打印上下文字符串,便于调试
+ print("context_str: ", context_str)
# 用模板格式化最终 prompt
+ final_prompt = prompt_template.format(
+ question=query,
+ context=context_str
+ )
# 打印最终 prompt,便于调试
+ print("final_prompt: ", final_prompt)
# 调用大模型生成答案
+ answer = llm.generate(final_prompt)
# 展开可查看详细检索结果
+ with st.expander("查看详细结果"):
# 以 json 格式显示检索的类型和结果
+ st.json({"type": query_type, "results": results})
# 把大模型回复添加到会话历史
+ st.session_state.messages.append({
+ "role": "assistant",
+ "content": answer,
+ "query_type": query_type,
+ "results": results # 检索结果
+ })
# 用 chat_message 组件显示大模型回复
+ st.chat_message("assistant").write(answer)
# 重新运行页面,用于强制刷新
+ st.rerun()
except Exception as err:
# 打印异常信息,便于调试
+ print(err)
# 捕获异常时组织错误提示内容
error_msg = f"查询过程中出错: {str(err)}"
# 将错误信息添加到会话历史
st.session_state.messages.append({
"role": "assistant",
"content": error_msg
+ "query_type": query_type, # 查询类型
+ "results": results # 检索结果
})
# 显示错误信息在对话界面
st.chat_message("assistant").write(error_msg)
# 在页面上弹出错误提示
st.error(error_msg)
# 判断是否作为主程序运行
if __name__ == "__main__":
# 调用主函数启动应用
main()9. 历史记录 #
9.1. 1_问答系统.py #
pages/1_问答系统.py
# Streamlit图书知识图谱问答系统主应用
"""Streamlit图书知识图谱问答系统主应用"""
# 导入Streamlit库
import streamlit as st
# 导入PromptTemplate
from langchain_core.prompts import PromptTemplate
# 导入配置文件
from config import config
# 导入嵌入向量服务
from services.embedding_service import embedding_service
# 导入检索服务
from services.retrieval_service import retrieval_service
# 导入查询类型常量
from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR
# 导入大模型服务
from llm import VolcengineLLM, DeepSeekLLM
from utils.formatters import format_context
prompt_template = PromptTemplate(
input_variables=["question", "context"],
template="""你是一名图书知识助手,需要根据提供的图书信息回答用户的提问。
请直接回答问题,如果信息不足,请回答"根据现有信息无法确定"。
问题:{question}
图书信息:\n{context}
回答:""",
)
# 定义主函数
def main():
"""主函数"""
# 设置页面配置(标题、布局)
st.set_page_config(
page_title="问答系统 - 图书知识图谱",
layout="wide"
)
# 如果session_state中还没有'messages',则初始化为空列表
if "messages" not in st.session_state:
st.session_state.messages = []
# 如果session_state中还没有'history',则初始化为空列表
+ if "history" not in st.session_state:
+ st.session_state.history = []
# 侧边栏参数设置
with st.sidebar:
# 显示参数设置的标题
st.markdown("### 参数设置")
# 单选框选择查询类型(默认选中第一个,即图书)
query_type = st.radio("选择查询类型", [QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR], index=0)
# 滑块选择返回结果数量(Top K),范围1-10,默认3
top_k = st.slider("返回结果数量 (Top K)", 1, 10, 3, 1)
# 滑块选择温度值,范围0.0-1.0,默认0.3
temperature = st.slider("温度 (Temperature)", 0.0, 1.0, 0.3, 0.1)
# 选择大模型服务商,默认火山引擎
llm_provider = st.selectbox("选择大模型服务商", ["volcengine", "deepseek"], index=0)
# 如果选择火山引擎,则显示火山引擎API Key和模型名
if llm_provider == "volcengine":
api_key = st.text_input(
"火山引擎 API Key",
value=config.volcengine.api_key or "",
type="password",
help="如留空则使用服务器默认配置",
)
model_name = st.text_input(
"火山引擎模型名",
value=config.volcengine.model or "",
help="如留空则使用服务器默认配置"
)
else:
api_key = st.text_input(
"DeepSeek API Key",
value=config.deepseek.api_key or "",
type="password",
help="如留空则使用服务器默认配置",
)
model_name = st.text_input(
"DeepSeek模型名",
value=config.deepseek.model or "",
help="如留空则使用服务器默认配置"
)
+ st.markdown("### 历史查询")
+ if st.session_state.history:
+ for i, item in enumerate(st.session_state.history):
+ with st.expander(f"查询 {i+1}: {item['question']}"):
+ st.json(item)
+ else:
+ st.info("暂无历史查询记录")
# 在主界面居中显示系统标题
st.markdown(
"<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
unsafe_allow_html=True,
)
# 按顺序显示历史消息
for message in st.session_state.messages:
st.chat_message(message["role"]).write(message["content"])
# 获取用户输入的问题(如果有输入)
if query := st.chat_input("输入图书相关问题", key="query_input"):
# 将用户输入添加到消息历史记录
st.session_state.messages.append({"role": "user", "content": query})
# 显示用户输入的消息
st.chat_message("user").write(query)
# 显示查询中的 loading 动画
with st.spinner("正在查询中..."):
try:
# 获取用户输入 query 的嵌入向量
query_embedding = embedding_service.get_embedding(query)
# 用嵌入向量进行检索,查询对应的结果
results = retrieval_service.query_by_embedding(query_embedding, query_type, top_k)
# 如果没有检索到结果
if not results or len(results) == 0:
# 设置回复内容为没有找到相关信息
answer = "抱歉,没有找到相关的信息。"
else:
# 如果选择的服务商是火山引擎
if llm_provider == "volcengine":
# 实例化火山引擎大模型
llm = VolcengineLLM(
model_name=model_name if model_name else None,
api_key=api_key if api_key else None,
)
else:
# 否则实例化 DeepSeek 大模型
llm = DeepSeekLLM(
model_name=model_name if model_name else None,
api_key=api_key if api_key else None,
)
# 根据检索结果格式化上下文字符串
context_str = format_context(query_type, results)
# 打印上下文字符串,便于调试
print("context_str: ", context_str)
# 用模板格式化最终 prompt
final_prompt = prompt_template.format(
question=query,
context=context_str
)
# 打印最终 prompt,便于调试
print("final_prompt: ", final_prompt)
# 调用大模型生成答案
answer = llm.generate(final_prompt)
# 将本次对话历史添加到 session_state.history,也包含了用户问题、检索类型、上下文、模型回复及温度参数
+ st.session_state.history.append({
+ "question": query, # 用户输入的问题
+ "query_type": query_type, # 查询类型(如“图书”/“作者”)
+ "context": context_str, # 格式化后的上下文字符串
+ "answer": answer, # 大模型生成的答案
+ "temperature": temperature, # 当前使用的温度参数
+ })
# 展开可查看详细检索结果
with st.expander("查看详细结果"):
# 以 json 格式显示检索的类型和结果
st.json({"type": query_type, "results": results})
# 把大模型回复添加到会话历史
st.session_state.messages.append({
"role": "assistant",
"content": answer,
"query_type": query_type,
"results": results # 检索结果
})
# 用 chat_message 组件显示大模型回复
st.chat_message("assistant").write(answer)
# 重新运行页面,用于强制刷新
st.rerun()
except Exception as err:
# 打印异常信息,便于调试
print(err)
# 捕获异常时组织错误提示内容
error_msg = f"查询过程中出错: {str(err)}"
# 将错误信息添加到会话历史
st.session_state.messages.append({
"role": "assistant",
"content": error_msg
})
# 显示错误信息在对话界面
st.chat_message("assistant").write(error_msg)
# 在页面上弹出错误提示
st.error(error_msg)
# 判断是否作为主程序运行
if __name__ == "__main__":
# 调用主函数启动应用
main()
10. 项目整体架构 #
本项目是一个基于 Streamlit 多页面应用 的图书知识图谱问答系统,采用分层架构:
用户浏览器
↓
┌─────────────────────────────────────┐
│ Streamlit UI 层 (pages/) │
│ ├─ 首页.py (欢迎页) │
│ ├─ 0_数据导入.py (导入+向量化) │
│ └─ 1_问答系统.py (RAG问答) │
├─────────────────────────────────────┤
│ 服务层 (services/) │
│ ├─ ImportService (CSV导入) │
│ ├─ EmbeddingInitService (向量初始化) │
│ ├─ EmbeddingService (嵌入API) │
│ └─ RetrievalService (检索编排) │
├─────────────────────────────────────┤
│ LLM层 (llm/) │
│ ├─ BaseLLM (抽象基类) │
│ ├─ VolcengineLLM (火山引擎) │
│ └─ DeepSeekLLM (DeepSeek) │
├─────────────────────────────────────┤
│ 数据库层 (database/) │
│ ├─ GraphManager (连接管理) │
│ ├─ BookQuery (图书查询) │
│ └─ AuthorQuery (作者查询) │
├─────────────────────────────────────┤
│ 基础设施 │
│ ├─ config.py (配置管理) │
│ ├─ constants.py (常量定义) │
│ └─ exceptions.py (异常定义) │
└─────────────────────────────────────┘
↓ ↓
Neo4j 火山引擎API
图数据库 (Embedding/LLM)11 类图 #
classDiagram
direction TB
%% ========== 配置层 ==========
class AppConfig {
+neo4j: Neo4jConfig
+volcengine: VolcengineConfig
+deepseek: DeepSeekConfig
+embedding: EmbeddingConfig
}
class Neo4jConfig {
+uri: str
+user: str
+password: str
}
class VolcengineConfig {
+base_url: Optional~str~
+api_key: Optional~str~
+model: Optional~str~
}
class DeepSeekConfig {
+base_url: str
+api_key: str
+model: str
}
class EmbeddingConfig {
+api_url: str
+api_key: str
+model: str
}
AppConfig *-- Neo4jConfig
AppConfig *-- VolcengineConfig
AppConfig *-- DeepSeekConfig
AppConfig *-- EmbeddingConfig
%% ========== 数据库层 ==========
class GraphManager {
-_instance: GraphManager$
-_graph: Graph$
+graph: Graph
+__new__() GraphManager
+__init__()
}
class BookQuery {
-graph: Graph
+query_by_embedding(embedding, top_k) List~Dict~
}
class AuthorQuery {
-graph: Graph
+query_by_embedding(embedding, top_k) List~Dict~
}
GraphManager ..> AppConfig : 读取配置
BookQuery --> GraphManager : get_graph()
AuthorQuery --> GraphManager : get_graph()
%% ========== LLM层 ==========
class BaseLLM {
<<abstract>>
+generate(prompt, **kwargs)* str
}
class VolcengineLLM {
-model_name: str
-base_url: str
-api_key: str
-client: OpenAI
+generate(prompt, temperature) str
}
class DeepSeekLLM {
-model_name: str
-llm: ChatDeepSeek
+generate(prompt, temperature) str
}
BaseLLM <|-- VolcengineLLM
BaseLLM <|-- DeepSeekLLM
VolcengineLLM ..> AppConfig : 读取配置
DeepSeekLLM ..> AppConfig : 读取配置
%% ========== 服务层 ==========
class EmbeddingService {
-api_url: str
-api_key: str
-model: str
-max_retries: int
+get_embedding(text) List~float~
}
class RetrievalService {
-book_query: BookQuery
-author_query: AuthorQuery
-embedding_service: EmbeddingService
+query_by_embedding(embedding, type, top_k) List~Dict~
}
class ImportService {
-graph: Graph
-node_matcher: NodeMatcher
+parse_csv(content) DataFrame
+validate_csv(df) tuple
+clear_database()
+import_books(df, clear, callback) Dict
-_create_nodes(...) int
-_create_relationships(...) int
}
class EmbeddingInitService {
-graph: Graph
-node_matcher: NodeMatcher
-embedding_service: EmbeddingService
+init_embeddings(labels, callback) dict
}
EmbeddingService ..> AppConfig : 读取配置
RetrievalService --> BookQuery
RetrievalService --> AuthorQuery
RetrievalService --> EmbeddingService
ImportService --> GraphManager : get_graph()
EmbeddingInitService --> GraphManager : get_graph()
EmbeddingInitService --> EmbeddingService
%% ========== 异常 ==========
class ConfigurationError
class DatabaseError
class EmbeddingAPIError
class LLMAPIError
Exception <|-- ConfigurationError
Exception <|-- DatabaseError
Exception <|-- EmbeddingAPIError
Exception <|-- LLMAPIError
12 页面执行过程 #
12.1. 首页 #
"""Streamlit应用主入口"""
import streamlit as st
st.set_page_config(
page_title="图书知识图谱系统",
layout="wide",
initial_sidebar_state="expanded"
)
st.title("图书知识图谱系统")
st.markdown("---")
st.markdown("""
欢迎使用图书知识图谱系统!
**功能模块**
1. **问答系统** - 基于向量检索的智能问答
2. **数据导入** - 导入CSV数据并生成向量
请使用左侧导航栏访问各个功能模块。
""")执行过程很简单:
- Streamlit 启动时加载
首页.py作为主入口 set_page_config()设置浏览器标题和布局- 渲染标题和功能说明的 Markdown 文本
- Streamlit 自动扫描
pages/目录,在左侧栏生成导航链接(0_数据导入、1_问答系统)
12.2 数据导入页 #
这是最复杂的页面,包含两个功能标签页:CSV导入 和 向量初始化。
12.2.1 时序图:CSV 数据导入流程 #
sequenceDiagram
actor User as 用户
participant UI as 0_数据导入.py
participant IS as ImportService
participant GM as GraphManager
participant Neo4j as Neo4j数据库
User->>UI: 上传 CSV 文件
UI->>IS: parse_csv(csv_content)
IS-->>UI: 返回 DataFrame
UI->>IS: validate_csv(df)
IS-->>UI: (True, None) 或 (False, 错误信息)
alt 验证失败
UI-->>User: 显示错误 ❌
end
UI-->>User: 显示数据预览表格
User->>UI: 点击「开始导入」
UI->>IS: import_books(df, clear_existing, callback)
alt 清空现有数据
IS->>GM: get_graph()
GM-->>IS: Graph 实例
IS->>Neo4j: MATCH (n) DETACH DELETE n
end
Note over IS: 提取唯一值:books, authors,<br/>publishers, categories, keywords
Note over IS: 遍历 DataFrame<br/>准备关系边列表
loop 创建 Book 节点
IS->>Neo4j: graph.create(Node("Book", ...))
IS-->>UI: callback("创建Book: 三体", 1, N)
UI-->>User: 更新进度条
end
loop 创建 Author/Publisher/Category/Keyword 节点
IS->>Neo4j: graph.create(Node(label, name=...))
IS-->>UI: callback(...)
end
loop 创建关系(written_by, published_by, ...)
IS->>Neo4j: node_matcher.match("Book", name=...)
Neo4j-->>IS: 起始节点
IS->>Neo4j: node_matcher.match("Author", name=...)
Neo4j-->>IS: 目标节点
IS->>Neo4j: graph.create(Relationship(start, "written_by", end))
IS-->>UI: callback(...)
end
IS-->>UI: 返回 stats 统计
UI-->>User: 显示导入统计(图书/作者/关系数等)
关键执行步骤详解:
| 步骤 | 代码位置 | 说明 |
|---|---|---|
| ① 文件上传 | 0_数据导入.py:37-41 |
st.file_uploader 接收 CSV 文件 |
| ② 解析CSV | ImportService.parse_csv() |
用 pd.read_csv 解析为 DataFrame |
| ③ 验证格式 | ImportService.validate_csv() |
检查是否包含7个必需列 |
| ④ 预览 | 0_数据导入.py:58-59 |
st.dataframe 显示前10行 |
| ⑤ 导入 | ImportService.import_books() |
核心逻辑:先提取唯一值,再批量创建节点和关系 |
| ⑥ 节点创建 | _create_nodes() |
用 py2neo.Node + graph.create() 创建 |
| ⑦ 关系创建 | _create_relationships() |
用 NodeMatcher.match() 查找节点,Relationship() 创建关系 |
12.2.2 向量初始化流程 #
sequenceDiagram
actor User as 用户
participant UI as 0_数据导入.py (Tab2)
participant EIS as EmbeddingInitService
participant ES as EmbeddingService
participant GM as GraphManager
participant Neo4j as Neo4j数据库
participant API as 火山方舟 Embedding API
User->>UI: 选择节点类型 [Book, Author]
User->>UI: 点击「开始生成向量」
UI->>EIS: init_embeddings(["Book","Author"], callback)
loop 对每个标签(Book, Author)
EIS->>Neo4j: node_matcher.match("Book")
Neo4j-->>EIS: 返回所有 Book 节点列表
loop 对每个节点
EIS->>ES: get_embedding(node.name)
ES->>API: POST /embeddings {model, input: "三体"}
API-->>ES: {data: [{embedding: [0.1, 0.2, ...]}]}
ES-->>EIS: 返回向量 List[float]
EIS->>Neo4j: node["embedding"] = 向量
EIS->>Neo4j: graph.push(node)
EIS-->>UI: callback("✓ 已为 三体 设置嵌入向量", n, total)
UI-->>User: 更新进度条
end
end
EIS-->>UI: 返回 stats
UI-->>User: 显示统计(总节点/成功/失败)
关键点:
EmbeddingService.get_embedding()调用火山方舟的 REST API 获取文本的嵌入向量(2048维浮点数组)- 内置重试机制:服务器错误(5xx)或网络异常时最多重试3次
- 向量通过
graph.push(node)写回 Neo4j 节点的embedding属性
12.3. 问答页 #
这是系统的核心功能页面,实现了 RAG(检索增强生成) 流程。
12.3.1 问答流程 #
sequenceDiagram
actor User as 用户
participant UI as 1_问答系统.py
participant RS as RetrievalService
participant ES as EmbeddingService
participant BQ as BookQuery
participant Neo4j as Neo4j数据库
participant FMT as format_context()
participant LLM as VolcengineLLM / DeepSeekLLM
participant API as LLM API
Note over UI: 侧边栏:设置查询类型、Top K、<br/>温度、模型服务商、API Key
User->>UI: 输入问题:"活着的作者是谁?"
UI->>UI: session_state.messages.append(user msg)
rect rgb(240, 248, 255)
Note over UI,API: RAG 核心流程
%% Step 1: 生成问题的嵌入向量
UI->>ES: get_embedding("活着的作者是谁?")
ES->>API: POST 火山方舟 /embeddings
API-->>ES: 返回 query_embedding [0.12, 0.34, ...]
ES-->>UI: query_embedding
%% Step 2: 创建 LLM 实例
UI->>LLM: new VolcengineLLM(api_key, model_name)
LLM->>LLM: 读取 AppConfig, 创建 OpenAI client
%% Step 3: 向量检索
UI->>RS: query_by_embedding(embedding, "图书", top_k=3)
RS->>BQ: query_by_embedding(embedding, 3)
BQ->>Neo4j: CALL db.index.vector.queryNodes(...)<br/>+ OPTIONAL MATCH 作者/出版社/类别/关键词
Neo4j-->>BQ: 返回匹配的节点 + 相似度分数
BQ-->>RS: [{name:"活着", similarity:0.93, 作者:"余华", ...}, ...]
RS-->>UI: results
%% Step 4: 格式化上下文
UI->>FMT: format_context("图书", results)
FMT-->>UI: "1. 活着 (相似度: 0.9342)\n - 作者: 余华\n..."
%% Step 5: 构造 Prompt
UI->>UI: prompt_template.format(question, context)
Note over UI: "你是一名图书知识助手...<br/>问题:活着的作者是谁?<br/>图书信息:1. 活着...<br/>回答:"
%% Step 6: 调用 LLM 生成回答
UI->>LLM: generate(final_prompt, temperature=0.3)
LLM->>API: POST chat/completions {model, messages, temperature}
API-->>LLM: "活着的作者是余华。"
LLM-->>UI: answer
end
UI->>UI: session_state.messages.append(assistant msg)
UI->>UI: session_state.history.append(记录)
UI-->>User: 显示回答 + 可展开的详细结果
UI->>UI: st.rerun() 刷新页面
RAG 核心 6 步详解:
| 步骤 | 作用 | 代码位置 |
|---|---|---|
| ① 问题向量化 | 将用户问题转为嵌入向量 | get_embedding(query) → EmbeddingService |
| ② 创建 LLM | 根据用户选择的服务商创建对应 LLM 实例 | VolcengineLLM(...) 或 DeepSeekLLM(...) |
| ③ 向量检索 | 用问题向量在 Neo4j 中搜索最相似的图书/作者 | RetrievalService → BookQuery/AuthorQuery |
| ④ 格式化上下文 | 将检索结果格式化为 LLM 可读的文本 | format_context() |
| ⑤ 构造 Prompt | 用 LangChain PromptTemplate 拼接最终提示词 | prompt_template.format(...) |
| ⑥ LLM 生成 | 调用 LLM API 生成最终回答 | llm.generate(final_prompt) |
13. 数据流总结图 #
graph LR
subgraph RAG问答流程
J[用户问题] -->|get_embedding| H2[火山方舟API]
H2 -->|query_embedding| K[向量检索]
K -->|Cypher| E2[(Neo4j)]
E2 -->|相似结果| L[format_context]
L -->|上下文| M[PromptTemplate]
J -->|问题| M
M -->|final_prompt| N[LLM API]
N -->|回答| O[显示给用户]
end
subgraph 数据导入流程
A[CSV文件] -->|parse_csv| B[DataFrame]
B -->|validate_csv| C{验证}
C -->|通过| D[创建节点]
D -->|py2neo| E[(Neo4j)]
D --> F[创建关系]
F -->|py2neo| E
E -->|init_embeddings| G[读取节点]
G -->|get_embedding| H[火山方舟API]
H -->|向量| I[写回节点]
I -->|graph.push| E
end
整个系统的核心思路就是:先导入数据建图 → 再向量化 → 最后用 RAG 检索+生成 回答用户问题。