导航菜单

  • 1.langchain.intro
  • 2.langchain.chat_models
  • 3.langchain.prompts
  • 4.langchain.example_selectors
  • 5.output_parsers
  • 6.Runnable
  • 7.memory
  • 8.document_loaders
  • 9.text_splitters
  • 10.embeddings
  • 11.tool
  • 12.retrievers
  • 13.optimize
  • 14.项目介绍
  • 15.启动HTTP
  • 16.数据与模型
  • 17.权限管理
  • 18.知识库管理
  • 19.设置
  • 20.文档管理
  • 21.聊天
  • 22.API文档
  • 23.RAG优化
  • 24.索引时优化
  • 25.检索前优化
  • 26.检索后优化
  • 27.系统优化
  • 28.GraphRAG
  • 29.图
  • 30.为什么选择图数据库
  • 31.什么是 Neo4j
  • 32.安装和连接 Neo4j
  • 33.Neo4j核心概念
  • 34.Cypher基础
  • 35.模式匹配
  • 36.数据CRUD操作
  • 37.GraphRAG
  • 38.查询和过滤
  • 39.结果处理和聚合
  • 40.语句组合
  • 41.子查询
  • 42.模式和约束
  • 43.日期时间处理
  • 44.Cypher内置函数
  • 45.Python操作Neo4j
  • 46.neo4j
  • 47.py2neo
  • 48.Streamlit
  • 49.Pandas
  • 50.graphRAG
  • 51.deepdoc
  • 52.deepdoc
  • 53.deepdoc
  • 55.deepdoc
  • 54.deepdoc
  • Pillow
  • 1. 环境准备
    • 1.1 启动Neo4j
    • 1.2 创建项目目录
  • 2. 数据导入
    • 2.1. 0_数据导入.py
  • 3. 读取csv文件
    • 3.1. constants.py
    • 3.2. import_service.py
    • 3.3. 0_数据导入.py
  • 4. 数据导入
    • 4.1. .env
    • 4.2. config.py
    • 4.3. init.py
    • 4.4. connection.py
    • 4.5. constants.py
    • 4.6. 0_数据导入.py
    • 4.8. import_service.py
  • 5. 向量嵌入
    • 5.1. embedding_service.py
    • 5.2. .env
    • 5.3. config.py
    • 5.4. connection.py
    • 5.5. 0_数据导入.py
  • 6. 提问
    • 6.1. 1_问答系统.py
  • 7. 向量化查询
    • 7.1. queries.py
    • 7.2. retrieval_service.py
    • 7.3. constants.py
    • 7.4. 1_问答系统.py
  • 8. 提问
    • 8.1. init.py
    • 8.2. base.py
    • 8.3. deepseek.py
    • 8.4. volcengine.py
    • 8.5. init.py
    • 8.6. formatters.py
    • 8.7. .env
    • 8.8. config.py
    • 8.9. constants.py
    • 8.10. 1_问答系统.py
  • 9. 历史记录
    • 9.1. 1_问答系统.py

1. 环境准备 #

1.1 启动Neo4j #

  • 下载地址:https://neo4j.com/download/
  • 或使用 Docker:
    docker run -d -p 7474:7474 -p 7687:7687 -e NEO4J_AUTH=neo4j/12345678 neo4j:latest

1.2 创建项目目录 #

mkdir graphrag
cd graphrag
uv init
uv add  dotenv langchain langchain-deepseek  pandas py2neo streamlit requests  openai 

2. 数据导入 #

2.1. 0_数据导入.py #

pages/0_数据导入.py

# 数据导入页面模块注释
"""数据导入页面"""

# 导入streamlit库并简写为st
import streamlit as st

# 定义主函数
def main():
    # 数据导入页面主函数的文档字符串
    """数据导入页面主函数"""
    # 配置Streamlit页面,设置标题和布局
    st.set_page_config(
        page_title="数据导入 - 图书知识图谱",
        layout="wide"
    )

    # 显示页面主标题
    st.title("数据导入管理")
    # 插入分割线
    st.markdown("---")

# 判断是否为主程序入口
if __name__ == "__main__":
    # 调用主函数
    main()    

3. 读取csv文件 #

3.1. constants.py #

constants.py

# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
    # 书名
    "name", 
    # 作者
    "author", 
    # 出版社
    "publisher", 
    # 类别
    "category",
    # 出版年份
    "publish_year", 
    # 简介
    "summary", 
    # 关键词(用分号分隔)
    "keywords"
]

3.2. import_service.py #

services/import_service.py

"""数据导入服务"""

# 导入io模块用于处理内存中的流数据
import io
# 导入pandas库用于数据分析处理
import pandas as pd
# 导入类型提示
from typing import List, Dict, Any, Optional, Callable
# 从constants模块导入必需的CSV列名常量
from constants import (REQUIRED_CSV_COLUMNS)

# 定义数据导入服务类(单例模式)
class ImportService:
    """数据导入服务(单例模式)"""

    # 类变量:存储单例实例
    _instance = None

    # 重写__new__方法实现单例模式
    def __new__(cls):
        """创建单例实例"""
        # 如果实例不存在,则创建新实例
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        # 返回单例实例
        return cls._instance

    # 解析CSV内容为DataFrame的方法
    def parse_csv(self, csv_content: str) -> pd.DataFrame:
        """解析CSV内容"""
        # 使用pandas读取内存中的CSV内容并返回DataFrame
        return pd.read_csv(io.StringIO(csv_content))

    # 验证DataFrame格式的方法
    def validate_csv(self, df: pd.DataFrame) -> tuple[bool, Optional[str]]:
        """验证CSV格式"""
        # 如果DataFrame为空,返回False和错误信息
        if df.empty:
            return False, "CSV文件为空"
        # 检查是否有缺失的必需字段
        missing = [col for col in REQUIRED_CSV_COLUMNS if col not in df.columns]
        # 如果没有缺失则返回True,否则返回缺失列信息
        return (True, None) if not missing else (False, f"缺少列: {', '.join(missing)}")


# 获取导入服务实例的便捷方法
def get_import_service() -> ImportService:
    """获取导入服务单例实例"""
    # 返回单例实例(无论调用多少次都返回同一个实例)
    return ImportService()

3.3. 0_数据导入.py #

pages/0_数据导入.py

# 数据导入页面模块注释
"""数据导入页面"""

# 导入streamlit库并简写为st
import streamlit as st
+from services.import_service import get_import_service
# 定义主函数
def main():
    # 数据导入页面主函数的文档字符串
    """数据导入页面主函数"""
    # 配置Streamlit页面,设置标题和布局
    st.set_page_config(
        page_title="数据导入 - 图书知识图谱",
        layout="wide"
    )

    # 显示页面主标题
    st.title("数据导入管理")
    # 插入分割线
    st.markdown("---")
     # 获取服务实例
+   import_service = get_import_service()

    # 创建一个标签页,只包含“CSV数据导入”标签,并将其赋值给tab1变量
+   tab1, = st.tabs(["CSV数据导入"])

    # 在tab1标签页下进行后续操作
+   with tab1:
        # 显示“CSV数据导入”二级标题
+       st.header("CSV数据导入")
        # 展示有关上传CSV文件的说明及字段要求
+       st.markdown("""
+       请上传包含图书信息的CSV文件。CSV文件应包含以下列:
+       - `name`: 书名
+       - `author`: 作者
+       - `publisher`: 出版社
+       - `category`: 类别
+       - `publish_year`: 出版年份
+       - `summary`: 简介
+       - `keywords`: 关键词(用分号分隔)
+       """)

        # 显示文件上传控件,限定仅可上传csv类型文件,并显示帮助提示
+       uploaded_file = st.file_uploader(
+           "选择CSV文件",
+           type=["csv"],
+           help="上传包含图书信息的CSV文件"
+       )
        # 如果用户已经上传了文件
+       if uploaded_file is not None:
+           try:
                # 读取上传的文件内容,并解码为utf-8格式的字符串
+               csv_content = uploaded_file.read().decode("utf-8")
                # 调用服务将CSV内容解析为DataFrame
+               df = import_service.parse_csv(csv_content)

                # 验证上传的DataFrame是否符合CSV格式要求
+               is_valid, error_msg = import_service.validate_csv(df)

                # 如果格式校验未通过,显示错误
+               if not is_valid:
+                   st.error(f"CSV格式错误: {error_msg}")
                # 如果校验通过,显示成功消息,并告知记录数量
+               else:
+                   st.success(f"CSV文件验证通过!共 {len(df)} 条记录")
            # 捕获任何异常,显示读取文件失败的错误信息
+           except Exception as e:
+               st.error(f"读取文件失败: {str(e)}")



# 判断是否为主程序入口
if __name__ == "__main__":
    # 调用主函数
    main()    

4. 数据导入 #

4.1. .env #

.env

NEO4J_URI="bolt://localhost:7687"
NEO4J_USER="neo4j"
NEO4J_PASSWORD="12345678"

4.2. config.py #

config.py

"""配置管理模块"""
# 导入os模块,用于获取环境变量
import os
# 导入dotenv模块,用于加载.env文件中的环境变量
import dotenv
# 从dataclasses模块导入dataclass,用于简化数据类定义
from dataclasses import dataclass

# 加载.env文件中的所有环境变量
dotenv.load_dotenv()

# 定义Neo4j数据库配置的数据类
@dataclass
class Neo4jConfig:
    """Neo4j数据库配置"""
    # 数据库URI
    uri: str
    # 用户名
    user: str
    # 密码
    password: str

# 定义应用整体配置的数据类
@dataclass
class AppConfig:
    """应用配置"""
    # Neo4j数据库配置属性
    neo4j: Neo4jConfig

config=AppConfig(
        neo4j=Neo4jConfig(
            uri=os.environ.get("NEO4J_URI", "bolt://localhost:7687"),
            user=os.environ.get("NEO4J_USER", "neo4j"),
            password=os.environ.get("NEO4J_PASSWORD", "12345678"),
        )
)

4.3. init.py #

database/init.py

"""数据库模块"""

4.4. connection.py #

database/connection.py

# 这是Neo4j数据库连接管理的模块说明
"""Neo4j数据库连接管理"""

# 从py2neo库导入Graph类,用于连接和操作Neo4j数据库
from py2neo import Graph

# 从config模块导入config对象,用于获取配置信息
from config import config

# 使用配置信息创建Neo4j数据库连接实例
graph = Graph(config.neo4j.uri, auth=(config.neo4j.user, config.neo4j.password))

4.5. constants.py #

constants.py

# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
    # 书名
    "name", 
    # 作者
    "author", 
    # 出版社
    "publisher", 
    # 类别
    "category",
    # 出版年份
    "publish_year", 
    # 简介
    "summary", 
    # 关键词(用分号分隔)
    "keywords"
]

# 定义图书节点标签
+NODE_LABEL_BOOK = "Book"
# 定义作者节点标签
+NODE_LABEL_AUTHOR = "Author"
# 定义出版社节点标签
+NODE_LABEL_PUBLISHER = "Publisher"
# 定义类别节点标签
+NODE_LABEL_CATEGORY = "Category"
# 定义关键词节点标签
+NODE_LABEL_KEYWORD = "Keyword"

# 定义“书籍-作者”关系类型
+RELATIONSHIP_TYPE_WRITTEN_BY = "written_by"
# 定义“书籍-出版社”关系类型
+RELATIONSHIP_TYPE_PUBLISHED_BY = "published_by"
# 定义“书籍-类别”关系类型
+RELATIONSHIP_TYPE_HAS_CATEGORY = "has_category"
# 定义“书籍-关键词”关系类型
+RELATIONSHIP_TYPE_HAS_KEYWORD = "has_keyword"

4.6. 0_数据导入.py #

pages/0_数据导入.py

# 数据导入页面模块注释
"""数据导入页面"""

# 导入streamlit库并简写为st
import streamlit as st
+from services.import_service import import_service
# 定义主函数
def main():
    # 数据导入页面主函数的文档字符串
    """数据导入页面主函数"""
    # 配置Streamlit页面,设置标题和布局
    st.set_page_config(
        page_title="数据导入 - 图书知识图谱",
        layout="wide"
    )

    # 显示页面主标题
    st.title("数据导入管理")
    # 插入分割线
    st.markdown("---")
    # 创建一个标签页,只包含“CSV数据导入”标签,并将其赋值给tab1变量
    tab1, = st.tabs(["CSV数据导入"])

    # 在tab1标签页下进行后续操作
    with tab1:
        # 显示“CSV数据导入”二级标题
        st.header("CSV数据导入")
        # 展示有关上传CSV文件的说明及字段要求
        st.markdown("""
        请上传包含图书信息的CSV文件。CSV文件应包含以下列:
        - `name`: 书名
        - `author`: 作者
        - `publisher`: 出版社
        - `category`: 类别
        - `publish_year`: 出版年份
        - `summary`: 简介
        - `keywords`: 关键词(用分号分隔)
        """)

        # 显示文件上传控件,限定仅可上传csv类型文件,并显示帮助提示
        uploaded_file = st.file_uploader(
            "选择CSV文件",
            type=["csv"],
            help="上传包含图书信息的CSV文件"
        )
        # 如果用户已经上传了文件
        if uploaded_file is not None:
            try:
                # 读取上传的文件内容,并解码为utf-8格式的字符串
                csv_content = uploaded_file.read().decode("utf-8")
                # 调用服务将CSV内容解析为DataFrame
                df = import_service.parse_csv(csv_content)

                # 验证上传的DataFrame是否符合CSV格式要求
                is_valid, error_msg = import_service.validate_csv(df)

                # 如果格式校验未通过,显示错误
                if not is_valid:
                    st.error(f"CSV格式错误: {error_msg}")
                # 如果校验通过,显示成功消息,并告知记录数量
                else:
                    st.success(f"CSV文件验证通过!共 {len(df)} 条记录")
                    # 显示预览
+                   with st.expander("数据预览", expanded=True):
+                       st.dataframe(df)
                    # 导入选项
+                   clear_existing = st.checkbox(
+                       "清空现有数据",
+                       value=False,
+                       help="导入前清空数据库中的所有数据"
+                   )
                    # 当用户点击"开始导入"按钮时执行
+                   if st.button("开始导入", type="primary", width='stretch'):
                        # 创建一个显示日志信息的空容器
+                       log_container = st.empty()
                        # 用于存放日志消息的列表
+                       log_messages = []

                        # 定义进度回调函数
+                       def progress_callback(message, current, total):
                            # 添加当前消息到日志列表(包含进度信息)
+                           log_messages.append(f"- [{current}/{total}] {message}")
                            # 刷新日志容器内容
+                           log_container.text("\n".join(log_messages))

+                       try:
                            # 调用导入服务执行数据导入
+                           stats = import_service.import_books(
+                               df,#数据源
+                               clear_existing=clear_existing,# 是否清空现有数据
+                               progress_callback=progress_callback# 进度回调函数
+                           )
                            # 显示最终统计信息
+                           if stats:
+                               st.info(
+                                   f"完成: 图书 {stats.get('books', 0)} | "
+                                   f"作者 {stats.get('authors', 0)} | "
+                                   f"出版社 {stats.get('publishers', 0)} | "
+                                   f"类别 {stats.get('categories', 0)} | "
+                                   f"关键词 {stats.get('keywords', 0)} | "
+                                   f"关系 {stats.get('relationships', 0)}"
+                               )
                                # 如果有错误信息,显示出来
+                               if stats.get("errors") and len(stats["errors"]) > 0:
+                                   with st.expander("⚠️ 导入过程中的错误", expanded=False):
+                                       for error in stats["errors"]:
+                                           st.error(error)
+                       except Exception as e:
                            # 导入过程中如有异常则显示错误信息
+                           st.error(f"导入失败: {str(e)}")
                            # 导入traceback模块以显示详细错误
+                           import traceback
                            # 展开详细错误信息
+                           with st.expander("详细错误信息"):
                                # 格式化并显示完整的异常调用栈
+                               st.code(traceback.format_exc())    



            # 捕获任何异常,显示读取文件失败的错误信息
            except Exception as e:
                st.error(f"读取文件失败: {str(e)}")



# 判断是否为主程序入口
if __name__ == "__main__":
    # 调用主函数
    main()    

4.8. import_service.py #

services/import_service.py

# 数据导入服务说明
"""数据导入服务"""

# 导入io模块,用于内存数据流操作
import io
# 导入pandas进行数据处理
import pandas as pd
# 从typing模块导入类型提示
from typing import List, Dict, Any, Optional, Callable
# 从database.connection导入数据库实例graph
+from database.connection import graph
# 从constants模块导入各类常量
+from constants import (
+   REQUIRED_CSV_COLUMNS, # 必需的CSV列名
+   NODE_LABEL_BOOK, # 图书节点标签
+   NODE_LABEL_AUTHOR, # 作者节点标签
+   NODE_LABEL_PUBLISHER, # 出版社节点标签
+   NODE_LABEL_CATEGORY, # 类别节点标签
+   NODE_LABEL_KEYWORD, # 关键词节点标签
+   RELATIONSHIP_TYPE_WRITTEN_BY, # 书-作者关系类型
+   RELATIONSHIP_TYPE_PUBLISHED_BY, # 书-出版社关系类型
+   RELATIONSHIP_TYPE_HAS_CATEGORY, # 书-类别关系类型
+   RELATIONSHIP_TYPE_HAS_KEYWORD # 书-关键词关系类型
+)
# 导入py2neo中的Node、Relationship、NodeMatcher
+from py2neo import Node, Relationship, NodeMatcher
# 定义数据导入服务类,采用单例模式
class ImportService:
    # 类说明:数据导入服务(单例)
    """数据导入服务(单例模式)"""

    # 初始化方法
+   def __init__(self):
        # 创建NodeMatcher对象用于节点查找
+       self.node_matcher = NodeMatcher(graph)

    # 解析CSV内容的函数
    def parse_csv(self, csv_content: str) -> pd.DataFrame:
        # 方法说明:解析CSV内容为DataFrame
        """解析CSV内容"""
        # 利用pandas读取内存中的CSV字符串,返回DataFrame
        return pd.read_csv(io.StringIO(csv_content))

    # 验证CSV格式合法性的方法
    def validate_csv(self, df: pd.DataFrame) -> tuple[bool, Optional[str]]:
        # 方法说明:验证CSV格式
        """验证CSV格式"""
        # 判断DataFrame是否为空,为空则返回错误消息
        if df.empty:
            return False, "CSV文件为空"
        # 检查必需CSV列是否全部存在
        missing = [col for col in REQUIRED_CSV_COLUMNS if col not in df.columns]
        # 若无漏列返回True,否则返回缺失列名
        return (True, None) if not missing else (False, f"缺少列: {', '.join(missing)}")

    # 清空全部数据库的方法
+   def clear_database(self) -> None:
        # 方法说明:清空数据库,删除所有节点和关系
+       """清空数据库"""
        # 执行Cypher语句,删除所有节点及其关系
+       graph.run("MATCH (n) DETACH DELETE n")

    # 内部方法:创建节点
+   def _create_nodes(
+       self, label, node_list, stats, stat_key,
+       progress_callback, current, total
+   ) -> int:
        # 方法说明:创建指定类型的节点并更新进度
+       """创建节点

+       Args:
+           label: 节点标签
+           node_list: 节点列表,每项为字典
+           stats: 统计计数与错误信息字典
+           stat_key: 统计键名
+           progress_callback: 进度回调函数
+           current: 当前进度值
+           total: 总进度步数
+       """
        # 有进度回调时,通知正在创建节点
+       if progress_callback:
+           progress_callback(f"正在创建{label}节点...", current, total)

        # 初始化进度步数
+       step = current
        # 遍历所有节点数据逐一创建
+       for node in node_list:
+           try:
                # 取节点名称用于反馈与出错描述,默认“未知”
+               node_name = node.get("name", "未知")
                # 构建节点对象
+               node = Node(label, **node)
                # 写入数据库
+               graph.create(node)
                # 节点计数器+1
+               stats[stat_key] += 1
                # 进度步数+1
+               step += 1
                # 回调进度反馈
+               if progress_callback:
+                   progress_callback(f"创建{label}: {node_name}", step, total)
+           except Exception as e:
                # 异常情况将出错信息记录到stats["errors"]
+               node_name = node.get("name", "未知")
+               stats["errors"].append(f"创建{label} {node_name} 失败: {str(e)}")
+               step += 1
        # 返回当前进度步数
+       return step

    # 内部方法:创建节点间关系
+   def _create_relationships(
+       self, start_label, end_label, edges, rel_type,
+       progress_callback, current, total, stats=None
+   ) -> int:
+       """
+       创建关系
+       Args:
+           start_label: 起始节点标签
+           end_label: 终止节点标签
+           edges: 关系对列表,每项为[起始节点名称, 终止节点名称]
+           rel_type: 关系类型
+           progress_callback: 进度回调函数
+           current: 当前进度值
+           total: 总进度步数
+           stats: 统计计数与错误信息字典
+       """
        # 初始化关系计数
+       count = 0
        # 初始化进度值
+       step = current
        # 遍历所有关系对
+       for edge in edges:
+           try:
                # 查找起始节点对象
+               start_node = self.node_matcher.match(start_label, name=str(edge[0])).first()
                # 查找终止节点对象
+               end_node = self.node_matcher.match(end_label, name=str(edge[1])).first()
                # 两端节点均存在后创建关系
+               if start_node and end_node:
+                   rel = Relationship(start_node, rel_type, end_node)#创建关系对象
+                   graph.create(rel)#写入数据库
+                   count += 1#关系计数器+1
+                   step += 1#进度步数+1
                    # 调用进度回调反馈
+                   if progress_callback:# 有进度回调时,通知正在创建关系
+                       progress_callback(f"创建关系: {edge[0]}-{rel_type}->{edge[1]}", step, total)
+               else:
                    # 如有节点不存在,将详细信息加到stats["errors"]
+                   if stats:
+                       missing = []#缺失节点列表
+                       if not start_node:
+                           missing.append(f"{start_label}:{edge[0]}")#添加起始节点
+                       if not end_node:
+                           missing.append(f"{end_label}:{edge[1]}")#添加终止节点
+                       stats["errors"].append(f"关系创建失败: 找不到节点 {', '.join(missing)}")#添加错误信息
+                   step += 1#进度步数+1
+           except Exception as e:
                # 其他异常也写入stats["errors"]
+               if stats:# 有统计信息时,添加错误信息
+                   stats["errors"].append(f"创建关系 {edge[0]}-{rel_type}->{edge[1]} 时出错: {str(e)}")
+               step += 1
        # 返回成功创建的关系数
+       return count    

    # 图书数据导入的主方法
+   def import_books(
+       self, df: pd.DataFrame, clear_existing: bool = False,
+       progress_callback: Optional[Callable] = None
+   ) -> Dict[str, Any]:
        # 方法说明:将DataFrame批量导入为图数据库节点与关系
+       """导入图书数据"""
        # 初始化统计信息字典
+       stats = {
+           "books": 0, "authors": 0, "publishers": 0,
+           "categories": 0, "keywords": 0, "relationships": 0, "errors": []
+       }

+       try:
            # 是否需要清空库
+           if clear_existing:
                # 有进度回调时先通知
+               if progress_callback:
+                   progress_callback("正在清空数据库...", 0, 0)
                # 调用清库
+               self.clear_database()

            # 生成图书节点数据列表
+           books = []
+           for _, row in df.iterrows():
+               if pd.notna(row["name"]):
                    # 创建基本字典(书名)
+                   book = {"name": row["name"]}
                    # 补充出版年份
+                   if pd.notna(row.get("publish_year")):
+                       book["publish_year"] = int(row["publish_year"])
                    # 补充简介
+                   if pd.notna(row.get("summary")):
+                       book["summary"] = str(row["summary"])
                    # 处理关键字列表
+                   if pd.notna(row.get("keywords")):
+                       book["keywords"] = [kw.strip() for kw in str(row["keywords"]).split(";") if kw.strip()]
+                   books.append(book)

            # authors节点只创建唯一且非空的作者名
+           authors = [{"name": name} for name in set(df["author"].dropna())]
            # publishers节点只创建唯一且非空的出版社名
+           publishers = [{"name": name} for name in set(df["publisher"].dropna())]
            # categories节点只创建唯一且非空的类别名
+           categories = [{"name": name} for name in set(df["category"].dropna())]

            # 提取所有keywords列的分号分割项,并去重空
+           keywords_set = set()
            # 遍历所有关键词,并去重空格
+           for keyword in df["keywords"].dropna():
                # 将关键词按分号分割,并去重空格
+               keywords_set.update([kw.strip() for kw in keyword.split(";") if kw.strip()])
            # 将关键词列表转换为字典列表
+           keywords = [{"name": keyword.strip()} for keyword in keywords_set if keyword.strip()]

            # 构建关系对列表
+           rels_written_by = []       # 书-作者
+           rels_published_by = []     # 书-出版社
+           rels_has_category = []     # 书-类别
+           rels_has_keyword = []      # 书-关键词

            # 遍历每行收集上述所有关系
+           for _, row in df.iterrows():
                # 书-作者
+               if pd.notna(row["name"]) and pd.notna(row["author"]):
+                   rels_written_by.append([row["name"], row["author"]])
                # 书-出版社
+               if pd.notna(row["name"]) and pd.notna(row["publisher"]):
+                   rels_published_by.append([row["name"], row["publisher"]])
                # 书-类别
+               if pd.notna(row["name"]) and pd.notna(row["category"]):
+                   rels_has_category.append([row["name"], row["category"]])
                # 书-关键词(支持多关键词)
+               if pd.notna(row["name"]) and pd.notna(row["keywords"]):
+                   for kw in row["keywords"].split(";"):
+                       rels_has_keyword.append([row["name"], kw.strip()])

            # 统计全部节点和关系数,便于进度反馈
+           total_steps = (
+               len(books) + len(authors) + len(publishers) +
+               len(categories) + len(keywords) +
+               len(rels_written_by) + len(rels_published_by) +
+               len(rels_has_category) + len(rels_has_keyword)
+           )
            # 初始化进度指针
+           current = 0

            # 创建所有图书节点
+           current = self._create_nodes(
+               NODE_LABEL_BOOK, books, stats, "books",
+               progress_callback, current, total_steps
+           )
            # 创建作者节点
+           current = self._create_nodes(
+               NODE_LABEL_AUTHOR, authors, stats, "authors",
+               progress_callback, current, total_steps
+           )
            # 创建出版社节点
+           current = self._create_nodes(
+               NODE_LABEL_PUBLISHER, publishers, stats, "publishers",
+               progress_callback, current, total_steps
+           )
            # 创建类别节点
+           current = self._create_nodes(
+               NODE_LABEL_CATEGORY, categories, stats, "categories",
+               progress_callback, current, total_steps
+           )
            # 创建关键词节点
+           current = self._create_nodes(
+               NODE_LABEL_KEYWORD, keywords, stats, "keywords",
+               progress_callback, current, total_steps
+           )
            # 所有节点创建后,反馈将创建关系
+           if progress_callback:
+               progress_callback("正在创建关系...", current, total_steps)

            # 创建书-作者关系并累计关系数
+           stats["relationships"] += self._create_relationships(
+               NODE_LABEL_BOOK, NODE_LABEL_AUTHOR, rels_written_by,
+               RELATIONSHIP_TYPE_WRITTEN_BY, progress_callback, current, total_steps, stats
+           )
            # 更新进度
+           current += len(rels_written_by)
            # 创建书-出版社关系
+           stats["relationships"] += self._create_relationships(
+               NODE_LABEL_BOOK, NODE_LABEL_PUBLISHER, rels_published_by,
+               RELATIONSHIP_TYPE_PUBLISHED_BY, progress_callback, current, total_steps, stats
+           )
+           current += len(rels_published_by)
            # 创建书-类别关系
+           stats["relationships"] += self._create_relationships(
+               NODE_LABEL_BOOK, NODE_LABEL_CATEGORY, rels_has_category,
+               RELATIONSHIP_TYPE_HAS_CATEGORY, progress_callback, current, total_steps, stats
+           )
+           current += len(rels_has_category)
            # 创建书-关键词关系
+           stats["relationships"] += self._create_relationships(
+               NODE_LABEL_BOOK, NODE_LABEL_KEYWORD, rels_has_keyword,
+               RELATIONSHIP_TYPE_HAS_KEYWORD, progress_callback, current, total_steps, stats
+           )
        # 如果整个导入过程中出现异常,记录到stats["errors"]
+       except Exception as e:
+           stats["errors"].append(f"导入过程出错: {str(e)}")

        # 返回统计字典
+       return stats    

# 创建ImportService的单例实例
+import_service =  ImportService()

5. 向量嵌入 #

5.1. embedding_service.py #

services/embedding_service.py

"""向量初始化服务"""
# 导入类型提示所需的类型List、Optional、Callable
from typing import List, Optional, Callable
# 导入HTTP请求库requests
import requests
# 导入py2neo包中的NodeMatcher用于节点查询
from py2neo import NodeMatcher
# 导入常量:节点标签(图书、作者)
from constants import NODE_LABEL_BOOK, NODE_LABEL_AUTHOR
# 导入Neo4j数据库连接
from database.connection import graph
# 导入配置对象
from config import config

# 定义向量初始化服务类
class EmbeddingInitService:
    """向量初始化服务"""
    # 构造函数,初始化必要的成员变量
    def __init__(self):
        # 创建NodeMatcher对象,用于查找节点
        self.node_matcher = NodeMatcher(graph)
        # 嵌入API的URL,从配置获取
        self.api_url = config.embedding.api_url
        # API密钥,从配置获取
        self.api_key = config.embedding.api_key
        # 嵌入模型名,从配置获取
        self.model = config.embedding.model

    # 为节点批量生成和更新嵌入向量
    def update_embeddings(
        self,
        progress_callback: Optional[Callable] = None
    ) -> dict:
        """为节点生成嵌入向量"""
        # 初始化统计信息字典
        stats = {"total_nodes": 0, "processed": 0, "failed": 0, "errors": []}
        # 要处理的节点标签列表(图书和作者)
        node_labels = [NODE_LABEL_BOOK, NODE_LABEL_AUTHOR]
        # 统计所有目标节点的总数
        total_nodes = 0
        for label in node_labels:
            # 查找指定标签的所有节点
            nodes = list(self.node_matcher.match(label))
            # 累加节点数量
            total_nodes += len(nodes)
        # 将总节点数写入统计信息
        stats["total_nodes"] = total_nodes
        # 当前已处理节点数
        current = 0
        # 遍历每个标签
        for node_label in node_labels:
            try:
                # 获取该标签下所有节点
                nodes = list(self.node_matcher.match(node_label))
                # 遍历每个节点
                for node in nodes:
                    # 获取节点name属性
                    name = node.get("name")
                    try:
                        # 获取文本的向量
                        embedding = self.get_embedding(name)
                        # 将嵌入向量写入节点属性
                        node["embedding"] = embedding
                        # 更新节点到数据库
                        graph.push(node)
                        # 成功处理节点数量加一
                        stats["processed"] += 1
                        # 当前计数加一
                        current += 1
                        # 调用进度回调函数(已成功处理该节点)
                        progress_callback(f" 已为 {name} 设置嵌入向量", current, total_nodes)
                    except Exception as e:
                        # 处理该节点失败时,记录错误信息
                        stats["errors"].append(f"处理 {name} 时出错: {str(e)}")
                        # 失败计数加一
                        stats["failed"] += 1
                        # 当前计数加一
                        current += 1
                        # 调用进度回调函数(处理失败)
                        progress_callback(f"处理 {name} 失败", current, total_nodes)
            except Exception as e:
                # 标签下所有节点处理发生异常时,记录错误
                stats["errors"].append(f"处理{node_label}节点时出错: {str(e)}")
                # 调用进度回调函数(标签整体失败)
                progress_callback(f"处理{node_label}节点时出错: {str(e)}", current, total_nodes)
        # 返回统计信息
        return stats

    # 获取单条文本的嵌入向量
    def get_embedding(self, text: str) -> List[float]:
        """获取文本的嵌入向量"""
        # 构造HTTP请求头
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}",
        }
        # 构造请求体
        payload = {"model": self.model, "input": text}
        # 输出请求体到控制台(调试用)
        print(payload)
        try:
            # 发送post请求到嵌入API
            response = requests.post(
                self.api_url,
                json=payload,
                headers=headers
            )
            # 若请求成功,且返回状态码200
            if response.status_code == 200:
                # 获取响应体中的数据内容
                data = response.json()
                # 从响应体获取嵌入向量
                embedding = data["data"][0].get("embedding")
                # 若未获得向量内容,则抛出异常
                if not embedding:
                    raise Exception("API返回数据格式错误")
                # 返回嵌入向量
                return embedding
            else:
                # 若HTTP响应非200,抛出异常包含错误状态码
                error_msg = f"API错误: HTTP {response.status_code}"
                raise Exception(error_msg)
        # 网络异常处理
        except requests.exceptions.RequestException as e:
            # 抛出网络相关异常
            raise Exception(f"网络错误: {e}") from e

# 创建EmbeddingInitService实例并赋值给embedding_service变量,供外部调用
embedding_service = EmbeddingInitService()

5.2. .env #

.env

NEO4J_URI="bolt://localhost:7687"
NEO4J_USER="neo4j"
NEO4J_PASSWORD="12345678"

+VOLC_EMBEDDINGS_API_URL=https://ark.cn-beijing.volces.com/api/v3/embeddings
+VOLC_EMBEDDING_MODEL=doubao-embedding-text-240715
+VOLC_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6

5.3. config.py #

config.py

"""配置管理模块"""
# 导入os模块,用于获取环境变量
import os
# 导入dotenv模块,用于加载.env文件中的环境变量
import dotenv
# 从dataclasses模块导入dataclass,用于简化数据类定义
from dataclasses import dataclass

# 加载.env文件中的所有环境变量
dotenv.load_dotenv()

# 定义Neo4j数据库配置的数据类
@dataclass
class Neo4jConfig:
    """Neo4j数据库配置"""
    # 数据库URI
    uri: str
    # 用户名
    user: str
    # 密码
    password: str

+@dataclass
+class EmbeddingConfig:
+   """火山方舟嵌入API配置"""
+   api_url: str
+   api_key: str
+   model: str

# 定义应用整体配置的数据类
@dataclass
class AppConfig:
    """应用配置"""
    # Neo4j数据库配置属性
    neo4j: Neo4jConfig
    # 火山方舟嵌入API配置属性
+   embedding: EmbeddingConfig

# 创建一个AppConfig对象,保存应用的整体配置
+config = AppConfig(
    # Neo4j数据库配置,使用Neo4jConfig类初始化
+   neo4j=Neo4jConfig(
        # 从环境变量获取数据库URI,如果没有则使用默认值"bolt://localhost:7687"
+       uri=os.environ.get("NEO4J_URI", "bolt://localhost:7687"),
        # 从环境变量获取数据库用户名,如果没有则使用默认值"neo4j"
+       user=os.environ.get("NEO4J_USER", "neo4j"),
        # 从环境变量获取数据库密码,如果没有则使用默认值"12345678"
+       password=os.environ.get("NEO4J_PASSWORD", "12345678"),
+   ),
    # 嵌入API配置,使用EmbeddingConfig类初始化
+   embedding=EmbeddingConfig(
        # 从环境变量获取嵌入API的URL,没有则使用默认值
+       api_url=os.environ.get(
+           "VOLC_EMBEDDINGS_API_URL",
+           "https://ark.cn-beijing.volces.com/api/v3/embeddings"
+       ),
        # 从环境变量获取API密钥,如果没有则为空字符串
+       api_key=os.environ.get("VOLC_API_KEY", ""),
        # 从环境变量获取嵌入模型名,没有则使用默认值"doubao-embedding-text-240715"
+       model=os.environ.get("VOLC_EMBEDDING_MODEL", "doubao-embedding-text-240715"),
    )
+)

5.4. connection.py #

database/connection.py

# 这是Neo4j数据库连接管理的模块说明
"""Neo4j数据库连接管理"""

# 从py2neo库导入Graph类,用于连接和操作Neo4j数据库
from py2neo import Graph

# 从config模块导入get_config函数,用于获取配置信息
+from config import config

# 使用配置信息创建Neo4j数据库连接实例
graph = Graph(config.neo4j.uri, auth=(config.neo4j.user, config.neo4j.password))

5.5. 0_数据导入.py #

pages/0_数据导入.py

# 数据导入页面模块注释
"""数据导入页面"""

# 导入streamlit库并简写为st
import streamlit as st
from services.import_service import import_service
+from services.embedding_service import embedding_service

# 定义主函数
def main():
    # 数据导入页面主函数的文档字符串
    """数据导入页面主函数"""
    # 配置Streamlit页面,设置标题和布局
    st.set_page_config(
        page_title="数据导入 - 图书知识图谱",
        layout="wide"
    )

    # 显示页面主标题
    st.title("数据导入管理")
    # 插入分割线
    st.markdown("---")
    # 创建一个标签页,只包含“CSV数据导入”标签,并将其赋值给tab1变量
+   tab1, tab2 = st.tabs(["CSV数据导入","向量初始化"])

    # 在tab1标签页下进行后续操作
    with tab1:
        # 显示“CSV数据导入”二级标题
        st.header("CSV数据导入")
        # 展示有关上传CSV文件的说明及字段要求
        st.markdown("""
        请上传包含图书信息的CSV文件。CSV文件应包含以下列:
        - `name`: 书名
        - `author`: 作者
        - `publisher`: 出版社
        - `category`: 类别
        - `publish_year`: 出版年份
        - `summary`: 简介
        - `keywords`: 关键词(用分号分隔)
        """)

        # 显示文件上传控件,限定仅可上传csv类型文件,并显示帮助提示
        uploaded_file = st.file_uploader(
            "选择CSV文件",
            type=["csv"],
            help="上传包含图书信息的CSV文件"
        )
        # 如果用户已经上传了文件
        if uploaded_file is not None:
            try:
                # 读取上传的文件内容,并解码为utf-8格式的字符串
                csv_content = uploaded_file.read().decode("utf-8")
                # 调用服务将CSV内容解析为DataFrame
                df = import_service.parse_csv(csv_content)

                # 验证上传的DataFrame是否符合CSV格式要求
                is_valid, error_msg = import_service.validate_csv(df)

                # 如果格式校验未通过,显示错误
                if not is_valid:
                    st.error(f"CSV格式错误: {error_msg}")
                # 如果校验通过,显示成功消息,并告知记录数量
                else:
                    st.success(f"CSV文件验证通过!共 {len(df)} 条记录")
                    # 显示预览
                    with st.expander("数据预览", expanded=True):
                        st.dataframe(df)
                    # 导入选项
                    clear_existing = st.checkbox(
                        "清空现有数据",
                        value=False,
                        help="导入前清空数据库中的所有数据"
                    )
                    # 当用户点击"开始导入"按钮时执行
                    if st.button("开始导入", type="primary", width='stretch'):
                        # 创建一个显示日志信息的空容器
                        log_container = st.empty()
                        # 用于存放日志消息的列表
                        log_messages = []

                        # 定义进度回调函数
                        def progress_callback(message, current, total):
                            # 添加当前消息到日志列表(包含进度信息)
                            log_messages.append(f"- [{current}/{total}] {message}")
                            # 刷新日志容器内容
                            log_container.text("\n".join(log_messages))

                        try:
                            # 调用导入服务执行数据导入
                            stats = import_service.import_books(
                                df,#数据源
                                clear_existing=clear_existing,# 是否清空现有数据
                                progress_callback=progress_callback# 进度回调函数
                            )
                            # 显示最终统计信息
                            if stats:
                                st.info(
                                    f"完成: 图书 {stats.get('books', 0)} | "
                                    f"作者 {stats.get('authors', 0)} | "
                                    f"出版社 {stats.get('publishers', 0)} | "
                                    f"类别 {stats.get('categories', 0)} | "
                                    f"关键词 {stats.get('keywords', 0)} | "
                                    f"关系 {stats.get('relationships', 0)}"
                                )
                                # 如果有错误信息,显示出来
                                if stats.get("errors") and len(stats["errors"]) > 0:
                                    with st.expander("⚠️ 导入过程中的错误", expanded=False):
                                        for error in stats["errors"]:
                                            st.error(error)
                        except Exception as e:
                            # 导入过程中如有异常则显示错误信息
                            st.error(f"导入失败: {str(e)}")
                            # 导入traceback模块以显示详细错误
                            import traceback
                            # 展开详细错误信息
                            with st.expander("详细错误信息"):
                                # 格式化并显示完整的异常调用栈
                                st.code(traceback.format_exc())    



            # 捕获任何异常,显示读取文件失败的错误信息
            except Exception as e:
                st.error(f"读取文件失败: {str(e)}")
    # 使用tab2选项卡
+   with tab2:
        # 设置页面标题为“向量初始化”
+       st.header("向量初始化")
        # 显示多行文字介绍向量初始化的功能
+       st.markdown("""
+       为数据库中的节点生成嵌入向量。此操作会为所有Book和Author节点生成向量,
+       用于后续的向量检索功能。
+       """)

        # 如果用户点击“开始生成向量”按钮
+       if st.button("开始生成向量", type="primary", width='stretch'):
            # 创建一个空的日志容器用于显示进度
+           log_container = st.empty()
            # 用于存放日志消息的列表
+           log_messages = []

            # 定义进度回调函数
+           def progress_callback(message,current, total):
                # 向日志列表追加当前进度消息
+               log_messages.append(f"- [{current}/{total}] {message}")
                # 在日志容器中刷新显示所有日志消息
+               log_container.text("\n".join(log_messages))

+           try:
                # 调用embedding_service的update_embeddings方法,开始进行向量初始化
+               stats = embedding_service.update_embeddings(progress_callback=progress_callback)

                # 在页面上显示向量生成成功的提示
+               st.success("向量生成完成!")

                # 显示统计信息标题
+               st.markdown("### 生成统计")
                # 将页面分为三列用于展示不同统计项
+               col1, col2, col3 = st.columns(3)

                # 第一列显示总节点数
+               with col1:
+                   st.metric("总节点数", stats["total_nodes"])
                # 第二列显示已处理成功的数量
+               with col2:
+                   st.metric("成功", stats["processed"])
                # 第三列显示处理失败的数量
+               with col3:
+                   st.metric("失败", stats["failed"])

                # 如果存在错误信息,则在可展开区域中展示所有错误
+               if stats["errors"]:
+                   with st.expander("错误信息", expanded=False):
+                       for error in stats["errors"]:
+                           st.error(error) 
            # 捕获异常并显示错误信息
+           except Exception as e:
                # 在页面上显示向量生成失败的错误信息
+               st.error(f"向量生成失败: {str(e)}")
                # 导入traceback模块用于获取详细出错信息
+               import traceback
                # 可展开区域显示完整Traceback
+               with st.expander("详细错误信息"):
+                   st.code(traceback.format_exc())



# 判断是否为主程序入口
if __name__ == "__main__":
    # 调用主函数
    main()    

6. 提问 #

6.1. 1_问答系统.py #

pages/1_问答系统.py


"""Streamlit图书知识图谱问答系统主应用"""
import streamlit as st
from config import config

def main():
    """主函数"""
    st.set_page_config(
        page_title="问答系统 - 图书知识图谱",
        layout="wide"
    )
    if "messages" not in st.session_state:
        st.session_state.messages = []
    with st.sidebar:
        st.markdown("### 参数设置")
    st.markdown(
        "<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
        unsafe_allow_html=True,
    )    
    for message in st.session_state.messages:
        st.chat_message(message["role"]).write(message["content"])
    if query := st.chat_input("输入图书相关问题", key="query_input"):
        st.session_state.messages.append({"role": "user", "content": query})
        st.chat_message("user").write(query)




if __name__ == "__main__":
    main()

7. 向量化查询 #

7.1. queries.py #

database/queries.py

"""图数据库查询"""
# 导入类型提示
from typing import List, Dict, Any, Optional
# 导入py2neo的Node类型
from py2neo import Node
# 导入Neo4j连接对象
from database.connection import graph
# 导入常量:向量索引名
from constants import VECTOR_INDEX_BOOK, VECTOR_INDEX_AUTHOR

# 定义图书查询类
class BookQuery:
    """图书查询"""

    # 基于嵌入向量检索图书的方法
    def query_by_embedding(
        self, query_embedding: List[float], top_k: int = 3
    ) -> List[Dict[str, Any]]:
        """基于向量检索图书"""
        # 定义Cypher向量检索查询语句
        query = f"""
        CALL db.index.vector.queryNodes('{VECTOR_INDEX_BOOK}', $top_k, $query_embedding)
        YIELD node, score
        MATCH (node:Book)
        OPTIONAL MATCH (node)-[:written_by]->(author:Author)
        OPTIONAL MATCH (node)-[:published_by]->(publisher:Publisher)
        OPTIONAL MATCH (node)-[:has_category]->(category:Category)
        OPTIONAL MATCH (node)-[:has_keyword]->(keyword:Keyword)
        WITH node, score, author, publisher, category, COLLECT(DISTINCT keyword.name) AS keyword_names
        RETURN node, score, author, publisher, category, keyword_names
        ORDER BY score DESC
        """

        try:
            # 执行Cypher查询,传入top_k和query_embedding参数
            results = graph.run(query, top_k=top_k, query_embedding=query_embedding)
            # 初始化图书结果列表
            books = []
            # 遍历查询结果
            for record in results:
                # 获取图书节点
                node: Node = record["node"]
                # 获取相似度分数
                score = float(record["score"])
                # 获取作者节点
                author_node: Optional[Node] = record.get("author")
                # 获取出版社节点
                publisher_node: Optional[Node] = record.get("publisher")
                # 获取类别节点
                category_node: Optional[Node] = record.get("category")
                # 获取关键词名称列表
                keyword_names = record.get("keyword_names", [])

                # 如果关系中存在关键词则使用,否则优先节点属性keywords
                keywords = keyword_names if keyword_names else (node.get("keywords", []) or [])

                # 构建图书结果字典
                books.append({
                    "name": node.get("name", ""),
                    "similarity": score,
                    "作者": author_node.get("name") if author_node else None,
                    "出版社": publisher_node.get("name") if publisher_node else None,
                    "类别": category_node.get("name") if category_node else None,
                    "出版年份": node.get("publish_year"),
                    "简介": node.get("summary"),
                    "关键词": keywords,
                })
            # 返回图书结果列表
            return books
        except Exception as e:
            # 查询出错时抛出异常并提示
            raise Exception(f"图书查询失败: {e}") from e


# 定义作者查询类
class AuthorQuery:
    """作者查询"""

    # 基于嵌入向量检索作者的方法
    def query_by_embedding(
        self, query_embedding: List[float], top_k: int = 3
    ) -> List[Dict[str, Any]]:
        """基于向量检索作者"""
        # 定义Cypher向量检索查询语句
        query = f"""
        CALL db.index.vector.queryNodes('{VECTOR_INDEX_AUTHOR}', $top_k, $query_embedding)
        YIELD node, score
        MATCH (node:Author)
        OPTIONAL MATCH (node)<-[:written_by]-(book:Book)
        RETURN node, score, COLLECT(DISTINCT book.name) AS book_names
        ORDER BY score DESC
        """

        try:
            # 执行Cypher查询,传入top_k和query_embedding参数
            results = graph.run(query, top_k=top_k, query_embedding=query_embedding)
            # 初始化作者结果列表
            authors = []
            # 遍历查询结果
            for record in results:
                # 获取作者节点
                node: Node = record["node"]
                # 获取相似度分数
                score = float(record["score"])
                # 获取相关图书名称列表
                book_names = record.get("book_names", [])

                # 构建作者结果字典
                authors.append({
                    "name": node.get("name", ""),
                    "similarity": score,
                    "相关图书": book_names,
                })
            # 返回作者结果列表
            return authors
        except Exception as e:
            # 查询出错时抛出异常并提示
            raise Exception(f"作者查询失败: {e}") from e

7.2. retrieval_service.py #

services/retrieval_service.py

"""检索服务"""
# 导入类型提示
from typing import List, Dict, Any
# 导入图书和作者的查询类
from database.queries import BookQuery, AuthorQuery
# 导入嵌入服务(虽然此文件中未直接用到)
from services.embedding_service import embedding_service
# 导入用于区分查询类型和默认top_k值的常量
from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR, DEFAULT_TOP_K


class RetrievalService:
    """检索服务"""

    # 初始化方法,分别实例化图书和作者查询对象
    def __init__(self):
        self.book_query = BookQuery()
        self.author_query = AuthorQuery()

    # 基于嵌入向量查询图书或作者的主方法
    def query_by_embedding(
        self,
        query_embedding: List[float],   # 问题的嵌入向量
        query_type: str,                # 查询类型(图书或作者)
        top_k: int = DEFAULT_TOP_K,     # 返回结果数量,默认值为常量
    ) -> List[Dict[str, Any]]:
        """使用嵌入向量查询"""
        # 若查询类型为“图书”,则使用book_query执行查询
        if query_type == QUERY_TYPE_BOOK:
            return self.book_query.query_by_embedding(query_embedding, top_k)
        # 若查询类型为“作者”,则使用author_query执行查询
        elif query_type == QUERY_TYPE_AUTHOR:
            return self.author_query.query_by_embedding(query_embedding, top_k)
        # 其他类型(不支持的类型)抛出异常
        else:
            raise ValueError(f"不支持的查询类型: {query_type}")


# 创建RetrievalService的单例,便于外部调用
retrieval_service = RetrievalService()

7.3. constants.py #

constants.py

# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
    # 书名
    "name", 
    # 作者
    "author", 
    # 出版社
    "publisher", 
    # 类别
    "category",
    # 出版年份
    "publish_year", 
    # 简介
    "summary", 
    # 关键词(用分号分隔)
    "keywords"
]

# 定义图书节点标签
NODE_LABEL_BOOK = "Book"
# 定义作者节点标签
NODE_LABEL_AUTHOR = "Author"
# 定义出版社节点标签
NODE_LABEL_PUBLISHER = "Publisher"
# 定义类别节点标签
NODE_LABEL_CATEGORY = "Category"
# 定义关键词节点标签
NODE_LABEL_KEYWORD = "Keyword"

# 定义“书籍-作者”关系类型
RELATIONSHIP_TYPE_WRITTEN_BY = "written_by"
# 定义“书籍-出版社”关系类型
RELATIONSHIP_TYPE_PUBLISHED_BY = "published_by"
# 定义“书籍-类别”关系类型
RELATIONSHIP_TYPE_HAS_CATEGORY = "has_category"
# 定义“书籍-关键词”关系类型
RELATIONSHIP_TYPE_HAS_KEYWORD = "has_keyword"

# 查询类型
# 图书查询类型
+QUERY_TYPE_BOOK = "图书"
# 作者查询类型
+QUERY_TYPE_AUTHOR = "作者"
# 默认Top K值
+DEFAULT_TOP_K = 3

# 向量索引名称
# 图书向量索引名称
+VECTOR_INDEX_BOOK = "book_embeddings"
# 作者向量索引名称
+VECTOR_INDEX_AUTHOR = "author_embeddings"

7.4. 1_问答系统.py #

pages/1_问答系统.py

# Streamlit图书知识图谱问答系统主应用
"""Streamlit图书知识图谱问答系统主应用"""

# 导入Streamlit库
import streamlit as st
# 导入配置文件
from config import config
# 导入嵌入向量服务
+from services.embedding_service import embedding_service
# 导入检索服务
+from services.retrieval_service import retrieval_service
# 导入查询类型常量
+from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR

# 定义主函数
def main():
    """主函数"""
    # 设置页面配置(标题、布局)
    st.set_page_config(
        page_title="问答系统 - 图书知识图谱",
        layout="wide"
    )
    # 如果session_state中还没有'messages',则初始化为空列表
    if "messages" not in st.session_state:
        st.session_state.messages = []
    # 侧边栏参数设置
    with st.sidebar:
        # 显示参数设置的标题
        st.markdown("### 参数设置")
        # 单选框选择查询类型(默认选中第一个,即图书)
+       query_type = st.radio("选择查询类型", [QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR], index=0)
        # 滑块选择返回结果数量(Top K),范围1-10,默认3
+       top_k = st.slider("返回结果数量 (Top K)", 1, 10, 3, 1)
    # 在主界面居中显示系统标题
    st.markdown(
        "<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
        unsafe_allow_html=True,
    )    
    # 按顺序显示历史消息
    for message in st.session_state.messages:
        st.chat_message(message["role"]).write(message["content"])
    # 获取用户输入的问题(如果有输入)
    if query := st.chat_input("输入图书相关问题", key="query_input"):
        # 将用户输入添加到消息历史记录
        st.session_state.messages.append({"role": "user", "content": query})
        # 显示用户输入的消息
        st.chat_message("user").write(query)
        # 显示查询中loading动画
+       with st.spinner("正在查询中..."):
+           try:
                # 获取问题的嵌入向量
+               query_embedding = embedding_service.get_embedding(query)
                # 使用嵌入向量进行检索,获取结果
+               results = retrieval_service.query_by_embedding(query_embedding, query_type, top_k)
                # 显示“查询结果”标题
+               st.markdown("#### 查询结果")
                # 如果有返回结果,逐条显示
+               if results:
+                   for idx, item in enumerate(results, 1):
                        # 用可展开板块显示每一条结果,默认展开第一个
+                       with st.expander(f"结果 {idx}", expanded=True if idx == 1 else False):
                            # 按字典项逐个展示
+                           for key, value in item.items():
+                               st.write(f"**{key}**: {value}")
+               else:
                    # 若无结果,显示提示
+                   st.info("未找到相关结果。")
+           except Exception as err:
                # 捕获异常时的处理
+               error_msg = f"查询过程中出错: {str(err)}"
+               st.session_state.messages.append({
+                   "role": "assistant",
+                   "content": error_msg
+               })
+               st.chat_message("assistant").write(error_msg)
+               st.error(error_msg)  


# 判断是否作为主程序运行
if __name__ == "__main__":
    # 调用主函数启动应用
    main()

8. 提问 #

8.1. init.py #

llm/init.py

# 从当前包导入BaseLLM基类
from .base import BaseLLM
# 从当前包导入DeepSeekLLM类
from .deepseek import DeepSeekLLM
# 从当前包导入VolcengineLLM类
from .volcengine import VolcengineLLM
# 定义__all__变量,指定包导出的公有接口
__all__ = ["BaseLLM", "DeepSeekLLM", "VolcengineLLM"]

8.2. base.py #

llm/base.py

# LLM抽象基类的文档字符串
"""LLM抽象基类"""

# 导入ABC和abstractmethod,用于创建抽象基类
from abc import ABC, abstractmethod


# 定义LLM的抽象基类,继承自ABC
class BaseLLM(ABC):
    # 类的文档字符串,说明是LLM抽象基类
    """LLM抽象基类"""

    # 抽象方法generate,子类必须实现
    @abstractmethod
    def generate(self, prompt: str, **kwargs) -> str:
        # 生成回复的文档字符串
        """生成回复"""
        # 抽象方法内使用pass,占位
        pass

8.3. deepseek.py #

llm/deepseek.py

"""DeepSeek LLM实现"""

# 导入可选类型提示
from typing import Optional
# 导入 DeepSeek 的 ChatDeepSeek 类
from langchain_deepseek import ChatDeepSeek
# 导入自定义的基础LLM类
from .base import BaseLLM
# 导入全局配置对象
from config import config
# 导入默认温度和最大token常量
from constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS

# 定义 DeepSeekLLM 类,继承自 BaseLLM
class DeepSeekLLM(BaseLLM):
    """DeepSeek LLM"""

    # 构造函数,支持自定义模型名、api_key、温度和最大输出tokens数量
    def __init__(
        self,
        model_name: Optional[str] = None,
        api_key: Optional[str] = None,
        temperature: float = DEFAULT_TEMPERATURE,
        max_tokens: int = DEFAULT_MAX_TOKENS,
    ):
        # 设置模型名:优先用传入参数,否则用配置里的默认
        self.model_name = model_name or config.deepseek.model
        # 设置api_key:优先用传入参数,否则用配置里的默认
        self.api_key = api_key or config.deepseek.api_key
        # 设置温度:优先用传入参数,否则用默认
        self.temperature = temperature or DEFAULT_TEMPERATURE
        # 设置最大返回token数:优先用传入参数,否则用默认
        self.max_tokens = max_tokens or DEFAULT_MAX_TOKENS
        # 初始化 DeepSeek 的 LLM 实例
        self.llm = ChatDeepSeek(
            api_key=self.api_key,
            model=self.model_name,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
        )

    # 生成方法,实现对LLM的调用
    def generate(self, prompt: str) -> str:
        """生成回复"""
        try:
            # 调用 DeepSeek LLM 获取生成内容
            response = self.llm.invoke(prompt)
            # 检查是否有返回内容
            if not response or not response.content:
                # 返回内容为空时抛出异常
                raise Exception("API返回空响应")
            # 正常时返回生成内容
            return response.content
        except Exception as e:
            # 捕获并重新抛出异常,加上友好的报错说明
            raise Exception(f"DeepSeek API调用失败: {e}") from e

8.4. volcengine.py #

llm/volcengine.py

"""火山引擎LLM实现"""

# 导入Optional类型用于参数类型标注
from typing import Optional
# 导入OpenAI库用于与火山引擎API交互
from openai import OpenAI
# 导入自定义的BaseLLM基类
from .base import BaseLLM
# 导入全局配置对象
from config import config
# 导入默认温度和最大tokens常量
from constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS


# 定义VolcengineLLM类,继承自BaseLLM
class VolcengineLLM(BaseLLM):
    """火山引擎LLM"""

    # 构造函数,初始化各项参数
    def __init__(self, model_name: Optional[str] = None, api_key: Optional[str] = None,
                 temperature: float = DEFAULT_TEMPERATURE, max_tokens: int = DEFAULT_MAX_TOKENS):
        # 获取API基础URL
        self.base_url = config.volcengine.base_url
        # 获取模型名称,优先使用传入值,否则取配置中的默认值
        self.model_name = model_name or config.volcengine.model
        # 获取API密钥,优先使用传入值,否则取配置中的默认值
        self.api_key = api_key or config.volcengine.api_key
        # 设置温度参数,控制生成内容的多样性
        self.temperature = temperature or DEFAULT_TEMPERATURE
        # 设置最大token数量
        self.max_tokens = max_tokens or DEFAULT_MAX_TOKENS
        # 初始化OpenAI客户端,用于与火山引擎API交互
        self.client = OpenAI(
            base_url=self.base_url,
            api_key=self.api_key,
        )

    # 定义generate方法用于生成模型回复
    def generate(self, prompt: str) -> str:
        """生成回复"""
        try:
            # 向火山引擎API发送请求,生成回复
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[{"role": "user", "content": prompt}],
                temperature=self.temperature,
                max_tokens=self.max_tokens,
            )
            # 判断响应内容是否有效
            if not response.choices or not response.choices[0].message.content:
                # 如果返回内容为空,抛出异常
                raise Exception("API返回空响应")
            # 返回生成的回复内容
            return response.choices[0].message.content
        except Exception as e:
            # 捕获异常并抛出自定义异常信息
            raise Exception(f"Volcengine API调用失败: {e}") from e

8.5. init.py #

utils/init.py

"""工具函数模块"""

8.6. formatters.py #

utils/formatters.py

"""格式化工具"""
# 导入类型提示List, Dict, Any
from typing import List, Dict, Any

# 定义用于格式化查询结果上下文的函数
def format_context(query_type: str, results: List[Dict[str, Any]]) -> str:
    """格式化查询结果上下文"""

    # 定义内部函数:格式化一本图书的方法
    def format_book(result: Dict[str, Any]) -> str:
        # 准备字段及相应的值,元组形式依次为(字段名, 字段值)
        fields = [
            ("作者", result.get("作者")),
            ("出版社", result.get("出版社")),
            ("类别", result.get("类别")),
            ("出版年份", result.get("出版年份")),
            ("简介", result.get("简介")),
            # 如果关键词存在,拼接成字符串,否则为None
            ("关键词", ", ".join(result["关键词"]) if result.get("关键词") else None),
        ]
        # 对所有有值的字段拼接为多行字符串,格式为“- 字段名: 字段值”
        return "\n".join([f"   - {k}: {v}" for k, v in fields if v])

    # 定义内部函数:格式化作者(只显示相关图书)
    def format_author(result: Dict[str, Any]) -> str:
        # 如果结果中有“相关图书”字段,则拼接显示
        if result.get("相关图书"):
            return f"   - 相关图书: {', '.join(result['相关图书'])}"
        # 否则返回空串
        return ""

    # 用于存放所有格式化后的每条结果内容
    lines = []
    # 遍历全部结果数据
    for idx, result in enumerate(results, 1):
        # 构建当前项的标题,包括序号、名称及相似度
        header = f"{idx}. {result['name']} (相似度: {result['similarity']:.4f})"
        # 根据查询类型决定调用哪个格式化方法
        details = format_book(result) if query_type == "图书" else format_author(result)
        # 若有详情,拼接标题、详情及换行;否则仅标题加换行
        info = f"{header}\n{details}\n" if details else f"{header}\n"
        # 将本次内容添加到lines列表
        lines.append(info)

    # 用两个换行符分隔拼接所有结果,返回最终字符串
    return "\n\n".join(lines)

8.7. .env #

.env

NEO4J_URI="bolt://localhost:7687"
NEO4J_USER="neo4j"
NEO4J_PASSWORD="12345678"

VOLC_EMBEDDINGS_API_URL=https://ark.cn-beijing.volces.com/api/v3/embeddings
VOLC_EMBEDDING_MODEL=doubao-embedding-text-240715
VOLC_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6


+VOLCENGINE_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
+VOLCENGINE_MODEL=doubao-seed-1-6-250615
+VOLCENGINE_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6

+DeepSeek_BASE_URL="https://api.deepseek.com/v1"
+DeepSeek_API_KEY="sk-278496d471bc4f4cb0ccb8c389a15018"
+DeepSeek_MODEL="deepseek-chat"

8.8. config.py #

config.py

"""配置管理模块"""
# 导入os模块,用于获取环境变量
import os
# 导入dotenv模块,用于加载.env文件中的环境变量
import dotenv
# 从dataclasses模块导入dataclass,用于简化数据类定义
from dataclasses import dataclass
+from typing import Optional

# 加载.env文件中的所有环境变量
dotenv.load_dotenv()

# 定义Neo4j数据库配置的数据类
@dataclass
class Neo4jConfig:
    """Neo4j数据库配置"""
    # 数据库URI
    uri: str
    # 用户名
    user: str
    # 密码
    password: str

@dataclass
class EmbeddingConfig:
    """火山方舟嵌入API配置"""
    api_url: str
    api_key: str
    model: str
+@dataclass
+class VolcengineConfig:
+   """火山引擎API配置"""
+   base_url: Optional[str]
+   api_key: Optional[str]
+   model: Optional[str]


+@dataclass
+class DeepSeekConfig:
+   """DeepSeek API配置"""
+   base_url: str
+   api_key: str
+   model: str
# 定义应用整体配置的数据类
@dataclass
class AppConfig:
    """应用配置"""
    # Neo4j数据库配置属性
    neo4j: Neo4jConfig
    # 火山方舟嵌入API配置属性
    embedding: EmbeddingConfig
    # 火山引擎API配置属性
+   volcengine: VolcengineConfig
    # DeepSeek API配置属性
+   deepseek: DeepSeekConfig

# 创建一个AppConfig对象,保存应用的整体配置
config = AppConfig(
    # Neo4j数据库配置,使用Neo4jConfig类初始化
    neo4j=Neo4jConfig(
        # 从环境变量获取数据库URI,如果没有则使用默认值"bolt://localhost:7687"
        uri=os.environ.get("NEO4J_URI", "bolt://localhost:7687"),
        # 从环境变量获取数据库用户名,如果没有则使用默认值"neo4j"
        user=os.environ.get("NEO4J_USER", "neo4j"),
        # 从环境变量获取数据库密码,如果没有则使用默认值"12345678"
        password=os.environ.get("NEO4J_PASSWORD", "12345678"),
    ),
    # 嵌入API配置,使用EmbeddingConfig类初始化
    embedding=EmbeddingConfig(
        # 从环境变量获取嵌入API的URL,没有则使用默认值
        api_url=os.environ.get(
            "VOLC_EMBEDDINGS_API_URL",
            "https://ark.cn-beijing.volces.com/api/v3/embeddings"
        ),
        # 从环境变量获取API密钥,如果没有则为空字符串
        api_key=os.environ.get("VOLC_API_KEY", ""),
        # 从环境变量获取嵌入模型名,没有则使用默认值"doubao-embedding-text-240715"
        model=os.environ.get("VOLC_EMBEDDING_MODEL", "doubao-embedding-text-240715"),
+   ),
    # 初始化火山引擎API配置,从环境变量获取相关参数
+   volcengine=VolcengineConfig(
        # 从环境变量获取火山引擎API的基础URL
+       base_url=os.environ.get("VOLCENGINE_BASE_URL"),
        # 从环境变量获取火山引擎API的密钥
+       api_key=os.environ.get("VOLCENGINE_API_KEY"),
        # 从环境变量获取火山引擎API的模型名称
+       model=os.environ.get("VOLCENGINE_MODEL"),
+   ),
    # 初始化DeepSeek API配置,支持从环境变量读取参数,并提供默认值
+   deepseek=DeepSeekConfig(
        # 从环境变量获取DeepSeek基础URL,若未设置则使用默认值"https://api.deepseek.com"
+       base_url=os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com"),
        # 从环境变量获取DeepSeek API密钥,默认值为空字符串
+       api_key=os.environ.get("DEEPSEEK_API_KEY", ""),
        # 从环境变量获取DeepSeek模型名,若未设置则使用默认值"deepseek-chat"
+       model=os.environ.get("DEEPSEEK_MODEL", "deepseek-chat"),
+   ),
)

8.9. constants.py #

constants.py

# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
    # 书名
    "name", 
    # 作者
    "author", 
    # 出版社
    "publisher", 
    # 类别
    "category",
    # 出版年份
    "publish_year", 
    # 简介
    "summary", 
    # 关键词(用分号分隔)
    "keywords"
]

# 定义图书节点标签
NODE_LABEL_BOOK = "Book"
# 定义作者节点标签
NODE_LABEL_AUTHOR = "Author"
# 定义出版社节点标签
NODE_LABEL_PUBLISHER = "Publisher"
# 定义类别节点标签
NODE_LABEL_CATEGORY = "Category"
# 定义关键词节点标签
NODE_LABEL_KEYWORD = "Keyword"

# 定义“书籍-作者”关系类型
RELATIONSHIP_TYPE_WRITTEN_BY = "written_by"
# 定义“书籍-出版社”关系类型
RELATIONSHIP_TYPE_PUBLISHED_BY = "published_by"
# 定义“书籍-类别”关系类型
RELATIONSHIP_TYPE_HAS_CATEGORY = "has_category"
# 定义“书籍-关键词”关系类型
RELATIONSHIP_TYPE_HAS_KEYWORD = "has_keyword"

# 查询类型
# 图书查询类型
QUERY_TYPE_BOOK = "图书"
# 作者查询类型
QUERY_TYPE_AUTHOR = "作者"
# 默认Top K值
DEFAULT_TOP_K = 3
# 默认最大Tokens值
+DEFAULT_MAX_TOKENS = 4096
# 默认温度值
+DEFAULT_TEMPERATURE = 0.7

# 向量索引名称
# 图书向量索引名称
VECTOR_INDEX_BOOK = "book_embeddings"
# 作者向量索引名称
VECTOR_INDEX_AUTHOR = "author_embeddings"

8.10. 1_问答系统.py #

pages/1_问答系统.py

# Streamlit图书知识图谱问答系统主应用
"""Streamlit图书知识图谱问答系统主应用"""

# 导入Streamlit库
import streamlit as st
# 导入PromptTemplate
+from langchain_core.prompts import PromptTemplate
# 导入配置文件
from config import config
# 导入嵌入向量服务
from services.embedding_service import embedding_service
# 导入检索服务
from services.retrieval_service import retrieval_service
# 导入查询类型常量
from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR
# 导入大模型服务
+from llm import VolcengineLLM, DeepSeekLLM
+from utils.formatters import format_context
+prompt_template = PromptTemplate(
+   input_variables=["question", "context"],
+   template="""你是一名图书知识助手,需要根据提供的图书信息回答用户的提问。
+   请直接回答问题,如果信息不足,请回答"根据现有信息无法确定"。
+   问题:{question}
+   图书信息:\n{context}
+   回答:""",
+)
# 定义主函数
def main():
    """主函数"""
    # 设置页面配置(标题、布局)
    st.set_page_config(
        page_title="问答系统 - 图书知识图谱",
        layout="wide"
    )
    # 如果session_state中还没有'messages',则初始化为空列表
    if "messages" not in st.session_state:
        st.session_state.messages = []
    # 侧边栏参数设置
    with st.sidebar:
        # 显示参数设置的标题
        st.markdown("### 参数设置")
        # 单选框选择查询类型(默认选中第一个,即图书)
        query_type = st.radio("选择查询类型", [QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR], index=0)
        # 滑块选择返回结果数量(Top K),范围1-10,默认3
        top_k = st.slider("返回结果数量 (Top K)", 1, 10, 3, 1)
        # 滑块选择温度值,范围0.0-1.0,默认0.3
+       temperature = st.slider("温度 (Temperature)", 0.0, 1.0, 0.3, 0.1)
        # 选择大模型服务商,默认火山引擎
+       llm_provider = st.selectbox("选择大模型服务商", ["volcengine", "deepseek"], index=0)
        # 如果选择火山引擎,则显示火山引擎API Key和模型名
+       if llm_provider == "volcengine":
+           api_key = st.text_input(
+               "火山引擎 API Key",
+               value=config.volcengine.api_key or "",
+               type="password",
+               help="如留空则使用服务器默认配置",
+           )
+           model_name = st.text_input(
+               "火山引擎模型名",
+               value=config.volcengine.model or "",
+               help="如留空则使用服务器默认配置"
+           )
+       else:
+           api_key = st.text_input(
+               "DeepSeek API Key",
+               value=config.deepseek.api_key or "",
+               type="password",
+               help="如留空则使用服务器默认配置",
+           )
+           model_name = st.text_input(
+               "DeepSeek模型名",
+               value=config.deepseek.model or "",
+               help="如留空则使用服务器默认配置"
+           )
    # 在主界面居中显示系统标题
    st.markdown(
        "<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
        unsafe_allow_html=True,
    )    
    # 按顺序显示历史消息
    for message in st.session_state.messages:
        st.chat_message(message["role"]).write(message["content"])
+       if message["role"] == "assistant":  
+           # 展开可查看详细检索结果
+           with st.expander("查看详细结果"):
+               # 以 json 格式显示检索的类型和结果
+               st.json({"type": message['query_type'], "results": message['results']})
    # 获取用户输入的问题(如果有输入)
    if query := st.chat_input("输入图书相关问题", key="query_input"):
        # 将用户输入添加到消息历史记录
        st.session_state.messages.append({"role": "user", "content": query})
        # 显示用户输入的消息
        st.chat_message("user").write(query)
        # 显示查询中的 loading 动画
        with st.spinner("正在查询中..."):
            try:
                # 获取用户输入 query 的嵌入向量
                query_embedding = embedding_service.get_embedding(query)
                # 用嵌入向量进行检索,查询对应的结果
                results = retrieval_service.query_by_embedding(query_embedding, query_type, top_k)
                # 如果没有检索到结果
+               if not results or len(results) == 0:
                    # 设置回复内容为没有找到相关信息
+                   answer = "抱歉,没有找到相关的信息。"
                else:
                    # 如果选择的服务商是火山引擎
+                   if llm_provider == "volcengine":
                        # 实例化火山引擎大模型
+                       llm = VolcengineLLM(
+                           model_name=model_name if model_name else None,
+                           api_key=api_key if api_key else None,
+                       )
+                   else:
                        # 否则实例化 DeepSeek 大模型
+                       llm = DeepSeekLLM(
+                           model_name=model_name if model_name else None,
+                           api_key=api_key if api_key else None,
+                       )
                    # 根据检索结果格式化上下文字符串
+                   context_str = format_context(query_type, results)
                    # 打印上下文字符串,便于调试
+                   print("context_str: ", context_str)
                    # 用模板格式化最终 prompt
+                   final_prompt = prompt_template.format(
+                       question=query,
+                       context=context_str
+                   )
                    # 打印最终 prompt,便于调试
+                   print("final_prompt: ", final_prompt)
                    # 调用大模型生成答案
+                   answer = llm.generate(final_prompt)

                    # 展开可查看详细检索结果
+                   with st.expander("查看详细结果"):
                        # 以 json 格式显示检索的类型和结果
+                       st.json({"type": query_type, "results": results})
                    # 把大模型回复添加到会话历史
+                   st.session_state.messages.append({
+                       "role": "assistant",
+                       "content": answer,
+                       "query_type": query_type,
+                       "results": results          # 检索结果
+                   })
                    # 用 chat_message 组件显示大模型回复
+                   st.chat_message("assistant").write(answer)
                    # 重新运行页面,用于强制刷新
+                   st.rerun()
            except Exception as err:
                # 打印异常信息,便于调试
+               print(err)
                # 捕获异常时组织错误提示内容
                error_msg = f"查询过程中出错: {str(err)}"
                # 将错误信息添加到会话历史
                st.session_state.messages.append({
                    "role": "assistant",
                    "content": error_msg
+                   "query_type": query_type,   # 查询类型
+                   "results": results          # 检索结果
                })
                # 显示错误信息在对话界面
                st.chat_message("assistant").write(error_msg)
                # 在页面上弹出错误提示
                st.error(error_msg)  


# 判断是否作为主程序运行
if __name__ == "__main__":
    # 调用主函数启动应用
    main()

9. 历史记录 #

9.1. 1_问答系统.py #

pages/1_问答系统.py

# Streamlit图书知识图谱问答系统主应用
"""Streamlit图书知识图谱问答系统主应用"""

# 导入Streamlit库
import streamlit as st
# 导入PromptTemplate
from langchain_core.prompts import PromptTemplate
# 导入配置文件
from config import config
# 导入嵌入向量服务
from services.embedding_service import embedding_service
# 导入检索服务
from services.retrieval_service import retrieval_service
# 导入查询类型常量
from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR
# 导入大模型服务
from llm import VolcengineLLM, DeepSeekLLM
from utils.formatters import format_context
prompt_template = PromptTemplate(
    input_variables=["question", "context"],
    template="""你是一名图书知识助手,需要根据提供的图书信息回答用户的提问。
    请直接回答问题,如果信息不足,请回答"根据现有信息无法确定"。
    问题:{question}
    图书信息:\n{context}
    回答:""",
)
# 定义主函数
def main():
    """主函数"""
    # 设置页面配置(标题、布局)
    st.set_page_config(
        page_title="问答系统 - 图书知识图谱",
        layout="wide"
    )
    # 如果session_state中还没有'messages',则初始化为空列表
    if "messages" not in st.session_state:
        st.session_state.messages = []
    # 如果session_state中还没有'history',则初始化为空列表
+   if "history" not in st.session_state:
+       st.session_state.history = []    
    # 侧边栏参数设置
    with st.sidebar:
        # 显示参数设置的标题
        st.markdown("### 参数设置")
        # 单选框选择查询类型(默认选中第一个,即图书)
        query_type = st.radio("选择查询类型", [QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR], index=0)
        # 滑块选择返回结果数量(Top K),范围1-10,默认3
        top_k = st.slider("返回结果数量 (Top K)", 1, 10, 3, 1)
        # 滑块选择温度值,范围0.0-1.0,默认0.3
        temperature = st.slider("温度 (Temperature)", 0.0, 1.0, 0.3, 0.1)
        # 选择大模型服务商,默认火山引擎
        llm_provider = st.selectbox("选择大模型服务商", ["volcengine", "deepseek"], index=0)
        # 如果选择火山引擎,则显示火山引擎API Key和模型名
        if llm_provider == "volcengine":
            api_key = st.text_input(
                "火山引擎 API Key",
                value=config.volcengine.api_key or "",
                type="password",
                help="如留空则使用服务器默认配置",
            )
            model_name = st.text_input(
                "火山引擎模型名",
                value=config.volcengine.model or "",
                help="如留空则使用服务器默认配置"
            )
        else:
            api_key = st.text_input(
                "DeepSeek API Key",
                value=config.deepseek.api_key or "",
                type="password",
                help="如留空则使用服务器默认配置",
            )
            model_name = st.text_input(
                "DeepSeek模型名",
                value=config.deepseek.model or "",
                help="如留空则使用服务器默认配置"
            )
+       st.markdown("### 历史查询")
+       if st.session_state.history:
+           for i, item in enumerate(st.session_state.history):
+               with st.expander(f"查询 {i+1}: {item['question']}"):
+                   st.json(item)
+       else:
+           st.info("暂无历史查询记录")    
    # 在主界面居中显示系统标题
    st.markdown(
        "<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
        unsafe_allow_html=True,
    )    
    # 按顺序显示历史消息
    for message in st.session_state.messages:
        st.chat_message(message["role"]).write(message["content"])
    # 获取用户输入的问题(如果有输入)
    if query := st.chat_input("输入图书相关问题", key="query_input"):
        # 将用户输入添加到消息历史记录
        st.session_state.messages.append({"role": "user", "content": query})
        # 显示用户输入的消息
        st.chat_message("user").write(query)
        # 显示查询中的 loading 动画
        with st.spinner("正在查询中..."):
            try:
                # 获取用户输入 query 的嵌入向量
                query_embedding = embedding_service.get_embedding(query)
                # 用嵌入向量进行检索,查询对应的结果
                results = retrieval_service.query_by_embedding(query_embedding, query_type, top_k)
                # 如果没有检索到结果
                if not results or len(results) == 0:
                    # 设置回复内容为没有找到相关信息
                    answer = "抱歉,没有找到相关的信息。"
                else:
                    # 如果选择的服务商是火山引擎
                    if llm_provider == "volcengine":
                        # 实例化火山引擎大模型
                        llm = VolcengineLLM(
                            model_name=model_name if model_name else None,
                            api_key=api_key if api_key else None,
                        )
                    else:
                        # 否则实例化 DeepSeek 大模型
                        llm = DeepSeekLLM(
                            model_name=model_name if model_name else None,
                            api_key=api_key if api_key else None,
                        )
                    # 根据检索结果格式化上下文字符串
                    context_str = format_context(query_type, results)
                    # 打印上下文字符串,便于调试
                    print("context_str: ", context_str)
                    # 用模板格式化最终 prompt
                    final_prompt = prompt_template.format(
                        question=query,
                        context=context_str
                    )
                    # 打印最终 prompt,便于调试
                    print("final_prompt: ", final_prompt)
                    # 调用大模型生成答案
                    answer = llm.generate(final_prompt)
                    # 将本次对话历史添加到 session_state.history,也包含了用户问题、检索类型、上下文、模型回复及温度参数
+                   st.session_state.history.append({
+                       "question": query,          # 用户输入的问题
+                       "query_type": query_type,   # 查询类型(如“图书”/“作者”)
+                       "context": context_str,     # 格式化后的上下文字符串
+                       "answer": answer,           # 大模型生成的答案
+                       "temperature": temperature, # 当前使用的温度参数
+                   })
                    # 展开可查看详细检索结果
                    with st.expander("查看详细结果"):
                        # 以 json 格式显示检索的类型和结果
                        st.json({"type": query_type, "results": results})
                    # 把大模型回复添加到会话历史
                    st.session_state.messages.append({
                        "role": "assistant",
                        "content": answer,
                        "query_type": query_type,
                        "results": results          # 检索结果

                    })
                    # 用 chat_message 组件显示大模型回复
                    st.chat_message("assistant").write(answer)
                    # 重新运行页面,用于强制刷新
                    st.rerun()
            except Exception as err:
                # 打印异常信息,便于调试
                print(err)
                # 捕获异常时组织错误提示内容
                error_msg = f"查询过程中出错: {str(err)}"
                # 将错误信息添加到会话历史
                st.session_state.messages.append({
                    "role": "assistant",
                    "content": error_msg
                })
                # 显示错误信息在对话界面
                st.chat_message("assistant").write(error_msg)
                # 在页面上弹出错误提示
                st.error(error_msg)  


# 判断是否作为主程序运行
if __name__ == "__main__":
    # 调用主函数启动应用
    main()
← 上一节 49.Pandas 下一节 51.deepdoc →

访问验证

请输入访问令牌

Token不正确,请重新输入