导航菜单

  • 1.langchain.intro
  • 2.langchain.chat_models
  • 3.langchain.prompts
  • 4.langchain.example_selectors
  • 5.output_parsers
  • 6.Runnable
  • 7.memory
  • 8.document_loaders
  • 9.text_splitters
  • 10.embeddings
  • 11.tool
  • 12.retrievers
  • 13.optimize
  • 14.项目介绍
  • 15.启动HTTP
  • 16.数据与模型
  • 17.权限管理
  • 18.知识库管理
  • 19.设置
  • 20.文档管理
  • 21.聊天
  • 22.API文档
  • 23.RAG优化
  • 24.索引时优化
  • 25.检索前优化
  • 26.检索后优化
  • 27.系统优化
  • 28.GraphRAG
  • 29.图
  • 30.为什么选择图数据库
  • 31.什么是 Neo4j
  • 32.安装和连接 Neo4j
  • 33.Neo4j核心概念
  • 34.Cypher基础
  • 35.模式匹配
  • 36.数据CRUD操作
  • 37.Python操作Neo4j
  • 38.GraphRAG
  • 39.查询和过滤
  • 40.结果处理和聚合
  • 41.语句组合
  • 42.子查询
  • 43.模式和约束
  • 44.日期时间处理
  • 45.Cypher内置函数
  • 46.py2neo
  • 47.Streamlit
  • 48.Pandas
  • 49.graphRAG
  • 50.deepdoc
  • 51.deepdoc
  • 52.deepdoc
  • 53.deepdoc
  • 54.deepdoc
  • Pillow
  • 1. 项目准备
    • 1.1 项目介绍
    • 1.2 项目架构
    • 1.3 核心功能
      • 1.3.1. 向量检索
      • 1.3.2. RAG问答
      • 1.3.3. 多LLM支持
      • 1.3.4. 数据导入管理
    • 1.4 技术栈
    • 1.5 启动Neo4j
    • 1.6 创建项目目录
    • 1.7 首页
      • 1.7.1 main.py
      • 1.7.2 首页.py
    • 1.8 建立索引
  • 2. 数据导入
    • 2.1. 0_数据导入.py
  • 3. 读取csv文件
    • 3.1. constants.py
    • 3.2. import_service.py
    • 3.3. 0_数据导入.py
  • 4. 数据导入
    • 4.1. .env
    • 4.2. config.py
    • 4.3. init.py
    • 4.4. connection.py
    • 4.5. constants.py
    • 4.6. 0_数据导入.py
    • 4.8. import_service.py
  • 5. 向量嵌入
    • 5.1. embedding_service.py
    • 5.2. .env
    • 5.3. config.py
    • 5.4. connection.py
    • 5.5. 0_数据导入.py
  • 6. 提问
    • 6.1. 1_问答系统.py
  • 7. 向量化查询
    • 7.1. queries.py
    • 7.2. retrieval_service.py
    • 7.3. constants.py
    • 7.4. 1_问答系统.py
  • 8. 提问
    • 8.1. init.py
    • 8.2. base.py
    • 8.3. deepseek.py
    • 8.4. volcengine.py
    • 8.5. init.py
    • 8.6. formatters.py
    • 8.7. .env
    • 8.8. config.py
    • 8.9. constants.py
    • 8.10. 1_问答系统.py
  • 9. 历史记录
    • 9.1. 1_问答系统.py
  • 10. 项目整体架构
  • 11 类图
  • 12 页面执行过程
    • 12.1. 首页
    • 12.2 数据导入页
      • 12.2.1 时序图:CSV 数据导入流程
      • 12.2.2 向量初始化流程
    • 12.3. 问答页
      • 12.3.1 问答流程
  • 13. 数据流总结图

1. 项目准备 #

1.1 项目介绍 #

基于Neo4j图数据库和向量检索的智能图书问答系统,支持RAG(检索增强生成)架构。

1.2 项目架构 #

graphrag/                          # 项目根目录
├── database/                      # 数据库相关模块
│   ├── __init__.py                # 数据库包初始化
│   ├── connection.py              # Neo4j连接相关代码
│   └── queries.py                 # 数据库查询实现
├── llm/                           # 大模型相关模块
│   ├── __init__.py                # LLM包初始化
│   ├── base.py                    # LLM基类定义
│   ├── deepseek.py                # DeepSeek模型相关代码
│   └── volcengine.py              # 火山引擎(豆包)模型相关代码
├── pages/                         # Streamlit页面目录
│   ├── 0_数据导入.py              # 数据导入页面
│   └── 1_问答系统.py              # 问答系统页面
├── services/                      # 服务层
│   ├── embedding_service.py       # 嵌入服务
│   ├── import_service.py          # 导入服务
│   └── retrieval_service.py       # 检索服务
├── utils/                         # 工具类目录
│   ├── __init__.py                # 工具包初始化
│   └── formatters.py              # 格式化相关工具函数
├── config.py                      # 配置文件
├── constants.py                   # 常量定义
├── main.py                        # 项目入口,运行主程序
├── pyproject.toml                 # Python项目依赖与配置
└── 首页.py                         # Streamlit首页入口

1.3 核心功能 #

1.3.1. 向量检索 #

  • 使用火山方舟嵌入API生成文本向量
  • 基于Neo4j向量索引进行相似度检索
  • 支持图书和作者两种查询类型

1.3.2. RAG问答 #

  • 向量检索获取相关上下文
  • 使用LLM(豆包/DeepSeek)生成答案
  • 支持温度参数调整

1.3.3. 多LLM支持 #

  • 豆包(Doubao)LLM
  • DeepSeek LLM
  • 统一的抽象接口,易于扩展

1.3.4. 数据导入管理 #

  • CSV文件上传和解析
  • 自动创建节点和关系
  • 批量生成嵌入向量
  • 实时进度显示和错误处理

1.4 技术栈 #

  • 前端框架: Streamlit
  • 图数据库: Neo4j (使用py2neo)
  • 向量嵌入: 火山方舟 (doubao-embedding-text-240715)
  • LLM框架: LangChain
  • LLM服务: 豆包 / DeepSeek

1.5 启动Neo4j #

  • 下载地址:https://neo4j.com/download/
  • 或使用 Docker:
    docker run -d -p 7474:7474 -p 7687:7687 -e NEO4J_AUTH=neo4j/12345678 neo4j:latest

1.6 创建项目目录 #

mkdir graphrag
cd graphrag
uv init
uv add  python-dotenv langchain langchain-deepseek  pandas py2neo streamlit requests  openai 

1.7 首页 #

1.7.1 main.py #

# 应用入口说明
"""应用入口"""
# 导入 subprocess 模块,用于运行子进程
import subprocess
# 导入 sys 模块,用于访问解释器相关变量
import sys

# 判断当前脚本是否为主程序执行
if __name__ == "__main__":
    # 使用 subprocess 运行 streamlit 应用,指定脚本为“首页.py”
    subprocess.run([sys.executable, "-m", "streamlit", "run", "首页.py"])

1.7.2 首页.py #

"""Streamlit应用主入口"""
import streamlit as st

st.set_page_config(
    page_title="图书知识图谱系统",
    layout="wide",
    initial_sidebar_state="expanded"
)

st.title("图书知识图谱系统")
st.markdown("---")
st.markdown("""
欢迎使用图书知识图谱系统!

### 功能模块

1. **数据导入** - 导入CSV数据并生成向量
2. **问答系统** - 基于向量检索的智能问答
""")

1.8 建立索引 #

# 清空整个Neo4j数据库,删除所有节点和关系
MATCH (n) DETACH DELETE n

# 显示当前数据库中的所有索引信息
SHOW INDEXES

# 删除名为“book_embeddings”的向量索引
DROP INDEX book_embeddings

# 删除名为“author_embeddings”的向量索引
DROP INDEX author_embeddings

# 为Book节点的embedding属性创建名为“book_embeddings”的向量索引,若不存在则创建
CREATE VECTOR INDEX book_embeddings IF NOT EXISTS
FOR (m:Book) ON m.embedding
OPTIONS { indexConfig: {
  `vector.dimensions`: 2560,  # 设置向量维数为2560
  `vector.similarity_function`: 'cosine'  # 使用余弦相似度作为索引匹配方式
}}

# 为Author节点的embedding属性创建名为“author_embeddings”的向量索引,若不存在则创建
CREATE VECTOR INDEX author_embeddings IF NOT EXISTS
FOR (m:Author) ON m.embedding
OPTIONS { indexConfig: {
  `vector.dimensions`: 2560,  # 设置向量维数为2560
  `vector.similarity_function`: 'cosine'  # 使用余弦相似度作为索引匹配方式
}}

2. 数据导入 #

2.1. 0_数据导入.py #

pages/0_数据导入.py

# 数据导入页面模块注释
"""数据导入页面"""

# 导入streamlit库并简写为st
import streamlit as st

# 定义主函数
def main():
    # 数据导入页面主函数的文档字符串
    """数据导入页面主函数"""
    # 配置Streamlit页面,设置标题和布局
    st.set_page_config(
        page_title="数据导入 - 图书知识图谱",
        layout="wide"
    )

    # 显示页面主标题
    st.title("数据导入管理")
    # 插入分割线
    st.markdown("---")

# 判断是否为主程序入口
if __name__ == "__main__":
    # 调用主函数
    main()    

3. 读取csv文件 #

3.1. constants.py #

constants.py

# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
    # 书名
    "name", 
    # 作者
    "author", 
    # 出版社
    "publisher", 
    # 类别
    "category",
    # 出版年份
    "publish_year", 
    # 简介
    "summary", 
    # 关键词(用分号分隔)
    "keywords"
]

3.2. import_service.py #

services/import_service.py

"""数据导入服务"""

# 导入io模块用于处理内存中的流数据
import io
# 导入pandas库用于数据分析处理
import pandas as pd
# 导入类型提示
from typing import List, Dict, Any, Optional, Callable
# 从constants模块导入必需的CSV列名常量
from constants import (REQUIRED_CSV_COLUMNS)

# 定义数据导入服务类(单例模式)
class ImportService:
    """数据导入服务(单例模式)"""

    # 类变量:存储单例实例
    _instance = None

    # 重写__new__方法实现单例模式
    def __new__(cls):
        """创建单例实例"""
        # 如果实例不存在,则创建新实例
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        # 返回单例实例
        return cls._instance

    # 解析CSV内容为DataFrame的方法
    def parse_csv(self, csv_content: str) -> pd.DataFrame:
        """解析CSV内容"""
        # 使用pandas读取内存中的CSV内容并返回DataFrame
        return pd.read_csv(io.StringIO(csv_content))

    # 验证DataFrame格式的方法
    def validate_csv(self, df: pd.DataFrame) -> tuple[bool, Optional[str]]:
        """验证CSV格式"""
        # 如果DataFrame为空,返回False和错误信息
        if df.empty:
            return False, "CSV文件为空"
        # 检查是否有缺失的必需字段
        missing = [col for col in REQUIRED_CSV_COLUMNS if col not in df.columns]
        # 如果没有缺失则返回True,否则返回缺失列信息
        return (True, None) if not missing else (False, f"缺少列: {', '.join(missing)}")


# 获取导入服务实例的便捷方法
def get_import_service() -> ImportService:
    """获取导入服务单例实例"""
    # 返回单例实例(无论调用多少次都返回同一个实例)
    return ImportService()

3.3. 0_数据导入.py #

pages/0_数据导入.py

# 数据导入页面模块注释
"""数据导入页面"""

# 导入streamlit库并简写为st
import streamlit as st
+from services.import_service import get_import_service
# 定义主函数
def main():
    # 数据导入页面主函数的文档字符串
    """数据导入页面主函数"""
    # 配置Streamlit页面,设置标题和布局
    st.set_page_config(
        page_title="数据导入 - 图书知识图谱",
        layout="wide"
    )

    # 显示页面主标题
    st.title("数据导入管理")
    # 插入分割线
    st.markdown("---")
     # 获取服务实例
+   import_service = get_import_service()

    # 创建一个标签页,只包含“CSV数据导入”标签,并将其赋值给tab1变量
+   tab1, = st.tabs(["CSV数据导入"])

    # 在tab1标签页下进行后续操作
+   with tab1:
        # 显示“CSV数据导入”二级标题
+       st.header("CSV数据导入")
        # 展示有关上传CSV文件的说明及字段要求
+       st.markdown("""
+       请上传包含图书信息的CSV文件。CSV文件应包含以下列:
+       - `name`: 书名
+       - `author`: 作者
+       - `publisher`: 出版社
+       - `category`: 类别
+       - `publish_year`: 出版年份
+       - `summary`: 简介
+       - `keywords`: 关键词(用分号分隔)
+       """)

        # 显示文件上传控件,限定仅可上传csv类型文件,并显示帮助提示
+       uploaded_file = st.file_uploader(
+           "选择CSV文件",
+           type=["csv"],
+           help="上传包含图书信息的CSV文件"
+       )
        # 如果用户已经上传了文件
+       if uploaded_file is not None:
+           try:
                # 读取上传的文件内容,并解码为utf-8格式的字符串
+               csv_content = uploaded_file.read().decode("utf-8")
                # 调用服务将CSV内容解析为DataFrame
+               df = import_service.parse_csv(csv_content)

                # 验证上传的DataFrame是否符合CSV格式要求
+               is_valid, error_msg = import_service.validate_csv(df)

                # 如果格式校验未通过,显示错误
+               if not is_valid:
+                   st.error(f"CSV格式错误: {error_msg}")
                # 如果校验通过,显示成功消息,并告知记录数量
+               else:
+                   st.success(f"CSV文件验证通过!共 {len(df)} 条记录")
            # 捕获任何异常,显示读取文件失败的错误信息
+           except Exception as e:
+               st.error(f"读取文件失败: {str(e)}")



# 判断是否为主程序入口
if __name__ == "__main__":
    # 调用主函数
    main()    

4. 数据导入 #

4.1. .env #

.env

NEO4J_URI="bolt://localhost:7687"
NEO4J_USER="neo4j"
NEO4J_PASSWORD="12345678"

VOLC_EMBEDDINGS_API_URL=https://ark.cn-beijing.volces.com/api/v3/embeddings
VOLC_EMBEDDING_MODEL=doubao-embedding-text-240715
VOLC_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6


DeepSeek_BASE_URL="https://api.deepseek.com/v1"
DeepSeek_API_KEY="sk-ae59c1b6731e4d5d8f1f3fd0ad340b39"
DeepSeek_MODEL="deepseek-chat"


VOLCENGINE_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
VOLCENGINE_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6
VOLCENGINE_MODEL=doubao-seed-1-6-250615

4.2. config.py #

config.py

"""配置管理模块"""
# 导入os模块,用于获取环境变量
import os
# 导入dotenv模块,用于加载.env文件中的环境变量
import dotenv
# 从dataclasses模块导入dataclass,用于简化数据类定义
from dataclasses import dataclass

# 加载.env文件中的所有环境变量
dotenv.load_dotenv()

# 定义Neo4j数据库配置的数据类
@dataclass
class Neo4jConfig:
    """Neo4j数据库配置"""
    # 数据库URI
    uri: str
    # 用户名
    user: str
    # 密码
    password: str

# 定义应用整体配置的数据类
@dataclass
class AppConfig:
    """应用配置"""
    # Neo4j数据库配置属性
    neo4j: Neo4jConfig

config=AppConfig(
        neo4j=Neo4jConfig(
            uri=os.environ.get("NEO4J_URI", "bolt://localhost:7687"),
            user=os.environ.get("NEO4J_USER", "neo4j"),
            password=os.environ.get("NEO4J_PASSWORD", "12345678"),
        )
)

4.3. init.py #

database/init.py

"""数据库模块"""

4.4. connection.py #

database/connection.py

# 这是Neo4j数据库连接管理的模块说明
"""Neo4j数据库连接管理"""

# 从py2neo库导入Graph类,用于连接和操作Neo4j数据库
from py2neo import Graph

# 从config模块导入config对象,用于获取配置信息
from config import config

# 使用配置信息创建Neo4j数据库连接实例
graph = Graph(config.neo4j.uri, auth=(config.neo4j.user, config.neo4j.password))

4.5. constants.py #

constants.py

# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
    # 书名
    "name", 
    # 作者
    "author", 
    # 出版社
    "publisher", 
    # 类别
    "category",
    # 出版年份
    "publish_year", 
    # 简介
    "summary", 
    # 关键词(用分号分隔)
    "keywords"
]

# 定义图书节点标签
+NODE_LABEL_BOOK = "Book"
# 定义作者节点标签
+NODE_LABEL_AUTHOR = "Author"
# 定义出版社节点标签
+NODE_LABEL_PUBLISHER = "Publisher"
# 定义类别节点标签
+NODE_LABEL_CATEGORY = "Category"
# 定义关键词节点标签
+NODE_LABEL_KEYWORD = "Keyword"

# 定义“书籍-作者”关系类型
+RELATIONSHIP_TYPE_WRITTEN_BY = "written_by"
# 定义“书籍-出版社”关系类型
+RELATIONSHIP_TYPE_PUBLISHED_BY = "published_by"
# 定义“书籍-类别”关系类型
+RELATIONSHIP_TYPE_HAS_CATEGORY = "has_category"
# 定义“书籍-关键词”关系类型
+RELATIONSHIP_TYPE_HAS_KEYWORD = "has_keyword"

4.6. 0_数据导入.py #

pages/0_数据导入.py

# 数据导入页面模块注释
"""数据导入页面"""

# 导入streamlit库并简写为st
import streamlit as st
+from services.import_service import import_service
# 定义主函数
def main():
    # 数据导入页面主函数的文档字符串
    """数据导入页面主函数"""
    # 配置Streamlit页面,设置标题和布局
    st.set_page_config(
        page_title="数据导入 - 图书知识图谱",
        layout="wide"
    )

    # 显示页面主标题
    st.title("数据导入管理")
    # 插入分割线
    st.markdown("---")
    # 创建一个标签页,只包含“CSV数据导入”标签,并将其赋值给tab1变量
    tab1, = st.tabs(["CSV数据导入"])

    # 在tab1标签页下进行后续操作
    with tab1:
        # 显示“CSV数据导入”二级标题
        st.header("CSV数据导入")
        # 展示有关上传CSV文件的说明及字段要求
        st.markdown("""
        请上传包含图书信息的CSV文件。CSV文件应包含以下列:
        - `name`: 书名
        - `author`: 作者
        - `publisher`: 出版社
        - `category`: 类别
        - `publish_year`: 出版年份
        - `summary`: 简介
        - `keywords`: 关键词(用分号分隔)
        """)

        # 显示文件上传控件,限定仅可上传csv类型文件,并显示帮助提示
        uploaded_file = st.file_uploader(
            "选择CSV文件",
            type=["csv"],
            help="上传包含图书信息的CSV文件"
        )
        # 如果用户已经上传了文件
        if uploaded_file is not None:
            try:
                # 读取上传的文件内容,并解码为utf-8格式的字符串
                csv_content = uploaded_file.read().decode("utf-8")
                # 调用服务将CSV内容解析为DataFrame
                df = import_service.parse_csv(csv_content)

                # 验证上传的DataFrame是否符合CSV格式要求
                is_valid, error_msg = import_service.validate_csv(df)

                # 如果格式校验未通过,显示错误
                if not is_valid:
                    st.error(f"CSV格式错误: {error_msg}")
                # 如果校验通过,显示成功消息,并告知记录数量
                else:
                    st.success(f"CSV文件验证通过!共 {len(df)} 条记录")
                    # 显示预览
+                   with st.expander("数据预览", expanded=True):
+                       st.dataframe(df)
                    # 导入选项
+                   clear_existing = st.checkbox(
+                       "清空现有数据",
+                       value=False,
+                       help="导入前清空数据库中的所有数据"
+                   )
                    # 当用户点击"开始导入"按钮时执行
+                   if st.button("开始导入", type="primary", width='stretch'):
                        # 创建一个显示日志信息的空容器
+                       log_container = st.empty()
                        # 用于存放日志消息的列表
+                       log_messages = []

                        # 定义进度回调函数
+                       def progress_callback(message, current, total):
                            # 添加当前消息到日志列表(包含进度信息)
+                           log_messages.append(f"- [{current}/{total}] {message}")
                            # 刷新日志容器内容
+                           log_container.text("\n".join(log_messages))

+                       try:
                            # 调用导入服务执行数据导入
+                           stats = import_service.import_books(
+                               df,#数据源
+                               clear_existing=clear_existing,# 是否清空现有数据
+                               progress_callback=progress_callback# 进度回调函数
+                           )
                            # 显示最终统计信息
+                           if stats:
+                               st.info(
+                                   f"完成: 图书 {stats.get('books', 0)} | "
+                                   f"作者 {stats.get('authors', 0)} | "
+                                   f"出版社 {stats.get('publishers', 0)} | "
+                                   f"类别 {stats.get('categories', 0)} | "
+                                   f"关键词 {stats.get('keywords', 0)} | "
+                                   f"关系 {stats.get('relationships', 0)}"
+                               )
                                # 如果有错误信息,显示出来
+                               if stats.get("errors") and len(stats["errors"]) > 0:
+                                   with st.expander("⚠️ 导入过程中的错误", expanded=False):
+                                       for error in stats["errors"]:
+                                           st.error(error)
+                       except Exception as e:
                            # 导入过程中如有异常则显示错误信息
+                           st.error(f"导入失败: {str(e)}")
                            # 导入traceback模块以显示详细错误
+                           import traceback
                            # 展开详细错误信息
+                           with st.expander("详细错误信息"):
                                # 格式化并显示完整的异常调用栈
+                               st.code(traceback.format_exc())    



            # 捕获任何异常,显示读取文件失败的错误信息
            except Exception as e:
                st.error(f"读取文件失败: {str(e)}")



# 判断是否为主程序入口
if __name__ == "__main__":
    # 调用主函数
    main()    

4.8. import_service.py #

services/import_service.py

# 数据导入服务说明
"""数据导入服务"""

# 导入io模块,用于内存数据流操作
import io
# 导入pandas进行数据处理
import pandas as pd
# 从typing模块导入类型提示
from typing import List, Dict, Any, Optional, Callable
# 从database.connection导入数据库实例graph
+from database.connection import graph
# 从constants模块导入各类常量
+from constants import (
+   REQUIRED_CSV_COLUMNS, # 必需的CSV列名
+   NODE_LABEL_BOOK, # 图书节点标签
+   NODE_LABEL_AUTHOR, # 作者节点标签
+   NODE_LABEL_PUBLISHER, # 出版社节点标签
+   NODE_LABEL_CATEGORY, # 类别节点标签
+   NODE_LABEL_KEYWORD, # 关键词节点标签
+   RELATIONSHIP_TYPE_WRITTEN_BY, # 书-作者关系类型
+   RELATIONSHIP_TYPE_PUBLISHED_BY, # 书-出版社关系类型
+   RELATIONSHIP_TYPE_HAS_CATEGORY, # 书-类别关系类型
+   RELATIONSHIP_TYPE_HAS_KEYWORD # 书-关键词关系类型
+)
# 导入py2neo中的Node、Relationship、NodeMatcher
+from py2neo import Node, Relationship, NodeMatcher
# 定义数据导入服务类,采用单例模式
class ImportService:
    # 类说明:数据导入服务(单例)
    """数据导入服务(单例模式)"""

    # 初始化方法
+   def __init__(self):
        # 创建NodeMatcher对象用于节点查找
+       self.node_matcher = NodeMatcher(graph)

    # 解析CSV内容的函数
    def parse_csv(self, csv_content: str) -> pd.DataFrame:
        # 方法说明:解析CSV内容为DataFrame
        """解析CSV内容"""
        # 利用pandas读取内存中的CSV字符串,返回DataFrame
        return pd.read_csv(io.StringIO(csv_content))

    # 验证CSV格式合法性的方法
    def validate_csv(self, df: pd.DataFrame) -> tuple[bool, Optional[str]]:
        # 方法说明:验证CSV格式
        """验证CSV格式"""
        # 判断DataFrame是否为空,为空则返回错误消息
        if df.empty:
            return False, "CSV文件为空"
        # 检查必需CSV列是否全部存在
        missing = [col for col in REQUIRED_CSV_COLUMNS if col not in df.columns]
        # 若无漏列返回True,否则返回缺失列名
        return (True, None) if not missing else (False, f"缺少列: {', '.join(missing)}")

    # 清空全部数据库的方法
+   def clear_database(self) -> None:
        # 方法说明:清空数据库,删除所有节点和关系
+       """清空数据库"""
        # 执行Cypher语句,删除所有节点及其关系
+       graph.run("MATCH (n) DETACH DELETE n")

    # 内部方法:创建节点
+   def _create_nodes(
+       self, label, node_list, stats, stat_key,
+       progress_callback, current, total
+   ) -> int:
        # 方法说明:创建指定类型的节点并更新进度
+       """创建节点

+       Args:
+           label: 节点标签
+           node_list: 节点列表,每项为字典
+           stats: 统计计数与错误信息字典
+           stat_key: 统计键名
+           progress_callback: 进度回调函数
+           current: 当前进度值
+           total: 总进度步数
+       """
        # 有进度回调时,通知正在创建节点
+       if progress_callback:
+           progress_callback(f"正在创建{label}节点...", current, total)

        # 初始化进度步数
+       step = current
        # 遍历所有节点数据逐一创建
+       for node in node_list:
+           try:
                # 取节点名称用于反馈与出错描述,默认“未知”
+               node_name = node.get("name", "未知")
                # 构建节点对象
+               node = Node(label, **node)
                # 写入数据库
+               graph.create(node)
                # 节点计数器+1
+               stats[stat_key] += 1
                # 进度步数+1
+               step += 1
                # 回调进度反馈
+               if progress_callback:
+                   progress_callback(f"创建{label}: {node_name}", step, total)
+           except Exception as e:
                # 异常情况将出错信息记录到stats["errors"]
+               node_name = node.get("name", "未知")
+               stats["errors"].append(f"创建{label} {node_name} 失败: {str(e)}")
+               step += 1
        # 返回当前进度步数
+       return step

    # 内部方法:创建节点间关系
+   def _create_relationships(
+       self, start_label, end_label, edges, rel_type,
+       progress_callback, current, total, stats=None
+   ) -> int:
+       """
+       创建关系
+       Args:
+           start_label: 起始节点标签
+           end_label: 终止节点标签
+           edges: 关系对列表,每项为[起始节点名称, 终止节点名称]
+           rel_type: 关系类型
+           progress_callback: 进度回调函数
+           current: 当前进度值
+           total: 总进度步数
+           stats: 统计计数与错误信息字典
+       """
        # 初始化关系计数
+       count = 0
        # 初始化进度值
+       step = current
        # 遍历所有关系对
+       for edge in edges:
+           try:
                # 查找起始节点对象
+               start_node = self.node_matcher.match(start_label, name=str(edge[0])).first()
                # 查找终止节点对象
+               end_node = self.node_matcher.match(end_label, name=str(edge[1])).first()
                # 两端节点均存在后创建关系
+               if start_node and end_node:
+                   rel = Relationship(start_node, rel_type, end_node)#创建关系对象
+                   graph.create(rel)#写入数据库
+                   count += 1#关系计数器+1
+                   step += 1#进度步数+1
                    # 调用进度回调反馈
+                   if progress_callback:# 有进度回调时,通知正在创建关系
+                       progress_callback(f"创建关系: {edge[0]}-{rel_type}->{edge[1]}", step, total)
+               else:
                    # 如有节点不存在,将详细信息加到stats["errors"]
+                   if stats:
+                       missing = []#缺失节点列表
+                       if not start_node:
+                           missing.append(f"{start_label}:{edge[0]}")#添加起始节点
+                       if not end_node:
+                           missing.append(f"{end_label}:{edge[1]}")#添加终止节点
+                       stats["errors"].append(f"关系创建失败: 找不到节点 {', '.join(missing)}")#添加错误信息
+                   step += 1#进度步数+1
+           except Exception as e:
                # 其他异常也写入stats["errors"]
+               if stats:# 有统计信息时,添加错误信息
+                   stats["errors"].append(f"创建关系 {edge[0]}-{rel_type}->{edge[1]} 时出错: {str(e)}")
+               step += 1
        # 返回成功创建的关系数
+       return count    

    # 图书数据导入的主方法
+   def import_books(
+       self, df: pd.DataFrame, clear_existing: bool = False,
+       progress_callback: Optional[Callable] = None
+   ) -> Dict[str, Any]:
        # 方法说明:将DataFrame批量导入为图数据库节点与关系
+       """导入图书数据"""
        # 初始化统计信息字典
+       stats = {
+           "books": 0, "authors": 0, "publishers": 0,
+           "categories": 0, "keywords": 0, "relationships": 0, "errors": []
+       }

+       try:
            # 是否需要清空库
+           if clear_existing:
                # 有进度回调时先通知
+               if progress_callback:
+                   progress_callback("正在清空数据库...", 0, 0)
                # 调用清库
+               self.clear_database()

            # 生成图书节点数据列表
+           books = []
+           for _, row in df.iterrows():
+               if pd.notna(row["name"]):
                    # 创建基本字典(书名)
+                   book = {"name": row["name"]}
                    # 补充出版年份
+                   if pd.notna(row.get("publish_year")):
+                       book["publish_year"] = int(row["publish_year"])
                    # 补充简介
+                   if pd.notna(row.get("summary")):
+                       book["summary"] = str(row["summary"])
                    # 处理关键字列表
+                   if pd.notna(row.get("keywords")):
+                       book["keywords"] = [kw.strip() for kw in str(row["keywords"]).split(";") if kw.strip()]
+                   books.append(book)

            # authors节点只创建唯一且非空的作者名
+           authors = [{"name": name} for name in set(df["author"].dropna())]
            # publishers节点只创建唯一且非空的出版社名
+           publishers = [{"name": name} for name in set(df["publisher"].dropna())]
            # categories节点只创建唯一且非空的类别名
+           categories = [{"name": name} for name in set(df["category"].dropna())]

            # 提取所有keywords列的分号分割项,并去重空
+           keywords_set = set()
            # 遍历所有关键词,并去重空格
+           for keyword in df["keywords"].dropna():
                # 将关键词按分号分割,并去重空格
+               keywords_set.update([kw.strip() for kw in keyword.split(";") if kw.strip()])
            # 将关键词列表转换为字典列表
+           keywords = [{"name": keyword.strip()} for keyword in keywords_set if keyword.strip()]

            # 构建关系对列表
+           rels_written_by = []       # 书-作者
+           rels_published_by = []     # 书-出版社
+           rels_has_category = []     # 书-类别
+           rels_has_keyword = []      # 书-关键词

            # 遍历每行收集上述所有关系
+           for _, row in df.iterrows():
                # 书-作者
+               if pd.notna(row["name"]) and pd.notna(row["author"]):
+                   rels_written_by.append([row["name"], row["author"]])
                # 书-出版社
+               if pd.notna(row["name"]) and pd.notna(row["publisher"]):
+                   rels_published_by.append([row["name"], row["publisher"]])
                # 书-类别
+               if pd.notna(row["name"]) and pd.notna(row["category"]):
+                   rels_has_category.append([row["name"], row["category"]])
                # 书-关键词(支持多关键词)
+               if pd.notna(row["name"]) and pd.notna(row["keywords"]):
+                   for kw in row["keywords"].split(";"):
+                       rels_has_keyword.append([row["name"], kw.strip()])

            # 统计全部节点和关系数,便于进度反馈
+           total_steps = (
+               len(books) + len(authors) + len(publishers) +
+               len(categories) + len(keywords) +
+               len(rels_written_by) + len(rels_published_by) +
+               len(rels_has_category) + len(rels_has_keyword)
+           )
            # 初始化进度指针
+           current = 0

            # 创建所有图书节点
+           current = self._create_nodes(
+               NODE_LABEL_BOOK, books, stats, "books",
+               progress_callback, current, total_steps
+           )
            # 创建作者节点
+           current = self._create_nodes(
+               NODE_LABEL_AUTHOR, authors, stats, "authors",
+               progress_callback, current, total_steps
+           )
            # 创建出版社节点
+           current = self._create_nodes(
+               NODE_LABEL_PUBLISHER, publishers, stats, "publishers",
+               progress_callback, current, total_steps
+           )
            # 创建类别节点
+           current = self._create_nodes(
+               NODE_LABEL_CATEGORY, categories, stats, "categories",
+               progress_callback, current, total_steps
+           )
            # 创建关键词节点
+           current = self._create_nodes(
+               NODE_LABEL_KEYWORD, keywords, stats, "keywords",
+               progress_callback, current, total_steps
+           )
            # 所有节点创建后,反馈将创建关系
+           if progress_callback:
+               progress_callback("正在创建关系...", current, total_steps)

            # 创建书-作者关系并累计关系数
+           stats["relationships"] += self._create_relationships(
+               NODE_LABEL_BOOK, NODE_LABEL_AUTHOR, rels_written_by,
+               RELATIONSHIP_TYPE_WRITTEN_BY, progress_callback, current, total_steps, stats
+           )
            # 更新进度
+           current += len(rels_written_by)
            # 创建书-出版社关系
+           stats["relationships"] += self._create_relationships(
+               NODE_LABEL_BOOK, NODE_LABEL_PUBLISHER, rels_published_by,
+               RELATIONSHIP_TYPE_PUBLISHED_BY, progress_callback, current, total_steps, stats
+           )
+           current += len(rels_published_by)
            # 创建书-类别关系
+           stats["relationships"] += self._create_relationships(
+               NODE_LABEL_BOOK, NODE_LABEL_CATEGORY, rels_has_category,
+               RELATIONSHIP_TYPE_HAS_CATEGORY, progress_callback, current, total_steps, stats
+           )
+           current += len(rels_has_category)
            # 创建书-关键词关系
+           stats["relationships"] += self._create_relationships(
+               NODE_LABEL_BOOK, NODE_LABEL_KEYWORD, rels_has_keyword,
+               RELATIONSHIP_TYPE_HAS_KEYWORD, progress_callback, current, total_steps, stats
+           )
        # 如果整个导入过程中出现异常,记录到stats["errors"]
+       except Exception as e:
+           stats["errors"].append(f"导入过程出错: {str(e)}")

        # 返回统计字典
+       return stats    

# 创建ImportService的单例实例
+import_service =  ImportService()

5. 向量嵌入 #

5.1. embedding_service.py #

services/embedding_service.py

"""向量初始化服务"""
# 导入类型提示所需的类型List、Optional、Callable
from typing import List, Optional, Callable
# 导入HTTP请求库requests
import requests
# 导入py2neo包中的NodeMatcher用于节点查询
from py2neo import NodeMatcher
# 导入常量:节点标签(图书、作者)
from constants import NODE_LABEL_BOOK, NODE_LABEL_AUTHOR
# 导入Neo4j数据库连接
from database.connection import graph
# 导入配置对象
from config import config

# 定义向量初始化服务类
class EmbeddingInitService:
    """向量初始化服务"""
    # 构造函数,初始化必要的成员变量
    def __init__(self):
        # 创建NodeMatcher对象,用于查找节点
        self.node_matcher = NodeMatcher(graph)
        # 嵌入API的URL,从配置获取
        self.api_url = config.embedding.api_url
        # API密钥,从配置获取
        self.api_key = config.embedding.api_key
        # 嵌入模型名,从配置获取
        self.model = config.embedding.model

    # 为节点批量生成和更新嵌入向量
    def update_embeddings(
        self,
        progress_callback: Optional[Callable] = None
    ) -> dict:
        """为节点生成嵌入向量"""
        # 初始化统计信息字典
        stats = {"total_nodes": 0, "processed": 0, "failed": 0, "errors": []}
        # 要处理的节点标签列表(图书和作者)
        node_labels = [NODE_LABEL_BOOK, NODE_LABEL_AUTHOR]
        # 统计所有目标节点的总数
        total_nodes = 0
        for label in node_labels:
            # 查找指定标签的所有节点
            nodes = list(self.node_matcher.match(label))
            # 累加节点数量
            total_nodes += len(nodes)
        # 将总节点数写入统计信息
        stats["total_nodes"] = total_nodes
        # 当前已处理节点数
        current = 0
        # 遍历每个标签
        for node_label in node_labels:
            try:
                # 获取该标签下所有节点
                nodes = list(self.node_matcher.match(node_label))
                # 遍历每个节点
                for node in nodes:
                    # 获取节点name属性
                    name = node.get("name")
                    try:
                        # 获取文本的向量
                        embedding = self.get_embedding(name)
                        # 将嵌入向量写入节点属性
                        node["embedding"] = embedding
                        # 更新节点到数据库
                        graph.push(node)
                        # 成功处理节点数量加一
                        stats["processed"] += 1
                        # 当前计数加一
                        current += 1
                        # 调用进度回调函数(已成功处理该节点)
                        progress_callback(f" 已为 {name} 设置嵌入向量", current, total_nodes)
                    except Exception as e:
                        # 处理该节点失败时,记录错误信息
                        stats["errors"].append(f"处理 {name} 时出错: {str(e)}")
                        # 失败计数加一
                        stats["failed"] += 1
                        # 当前计数加一
                        current += 1
                        # 调用进度回调函数(处理失败)
                        progress_callback(f"处理 {name} 失败", current, total_nodes)
            except Exception as e:
                # 标签下所有节点处理发生异常时,记录错误
                stats["errors"].append(f"处理{node_label}节点时出错: {str(e)}")
                # 调用进度回调函数(标签整体失败)
                progress_callback(f"处理{node_label}节点时出错: {str(e)}", current, total_nodes)
        # 返回统计信息
        return stats

    # 获取单条文本的嵌入向量
    def get_embedding(self, text: str) -> List[float]:
        """获取文本的嵌入向量"""
        # 构造HTTP请求头
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}",
        }
        # 构造请求体
        payload = {"model": self.model, "input": text}
        # 输出请求体到控制台(调试用)
        print(payload)
        try:
            # 发送post请求到嵌入API
            response = requests.post(
                self.api_url,
                json=payload,
                headers=headers
            )
            # 若请求成功,且返回状态码200
            if response.status_code == 200:
                # 获取响应体中的数据内容
                data = response.json()
                # 从响应体获取嵌入向量
                embedding = data["data"][0].get("embedding")
                # 若未获得向量内容,则抛出异常
                if not embedding:
                    raise Exception("API返回数据格式错误")
                # 返回嵌入向量
                return embedding
            else:
                # 若HTTP响应非200,抛出异常包含错误状态码
                error_msg = f"API错误: HTTP {response.status_code}"
                raise Exception(error_msg)
        # 网络异常处理
        except requests.exceptions.RequestException as e:
            # 抛出网络相关异常
            raise Exception(f"网络错误: {e}") from e

# 创建EmbeddingInitService实例并赋值给embedding_service变量,供外部调用
embedding_service = EmbeddingInitService()

5.2. .env #

.env

NEO4J_URI="bolt://localhost:7687"
NEO4J_USER="neo4j"
NEO4J_PASSWORD="12345678"

+VOLC_EMBEDDINGS_API_URL=https://ark.cn-beijing.volces.com/api/v3/embeddings
+VOLC_EMBEDDING_MODEL=doubao-embedding-text-240715
+VOLC_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6

5.3. config.py #

config.py

"""配置管理模块"""
# 导入os模块,用于获取环境变量
import os
# 导入dotenv模块,用于加载.env文件中的环境变量
import dotenv
# 从dataclasses模块导入dataclass,用于简化数据类定义
from dataclasses import dataclass

# 加载.env文件中的所有环境变量
dotenv.load_dotenv()

# 定义Neo4j数据库配置的数据类
@dataclass
class Neo4jConfig:
    """Neo4j数据库配置"""
    # 数据库URI
    uri: str
    # 用户名
    user: str
    # 密码
    password: str

+@dataclass
+class EmbeddingConfig:
+   """火山方舟嵌入API配置"""
+   api_url: str
+   api_key: str
+   model: str

# 定义应用整体配置的数据类
@dataclass
class AppConfig:
    """应用配置"""
    # Neo4j数据库配置属性
    neo4j: Neo4jConfig
    # 火山方舟嵌入API配置属性
+   embedding: EmbeddingConfig

# 创建一个AppConfig对象,保存应用的整体配置
+config = AppConfig(
    # Neo4j数据库配置,使用Neo4jConfig类初始化
+   neo4j=Neo4jConfig(
        # 从环境变量获取数据库URI,如果没有则使用默认值"bolt://localhost:7687"
+       uri=os.environ.get("NEO4J_URI", "bolt://localhost:7687"),
        # 从环境变量获取数据库用户名,如果没有则使用默认值"neo4j"
+       user=os.environ.get("NEO4J_USER", "neo4j"),
        # 从环境变量获取数据库密码,如果没有则使用默认值"12345678"
+       password=os.environ.get("NEO4J_PASSWORD", "12345678"),
+   ),
    # 嵌入API配置,使用EmbeddingConfig类初始化
+   embedding=EmbeddingConfig(
        # 从环境变量获取嵌入API的URL,没有则使用默认值
+       api_url=os.environ.get(
+           "VOLC_EMBEDDINGS_API_URL",
+           "https://ark.cn-beijing.volces.com/api/v3/embeddings"
+       ),
        # 从环境变量获取API密钥,如果没有则为空字符串
+       api_key=os.environ.get("VOLC_API_KEY", ""),
        # 从环境变量获取嵌入模型名,没有则使用默认值"doubao-embedding-text-240715"
+       model=os.environ.get("VOLC_EMBEDDING_MODEL", "doubao-embedding-text-240715"),
    )
+)

5.4. connection.py #

database/connection.py

# 这是Neo4j数据库连接管理的模块说明
"""Neo4j数据库连接管理"""

# 从py2neo库导入Graph类,用于连接和操作Neo4j数据库
from py2neo import Graph

# 从config模块导入get_config函数,用于获取配置信息
+from config import config

# 使用配置信息创建Neo4j数据库连接实例
graph = Graph(config.neo4j.uri, auth=(config.neo4j.user, config.neo4j.password))

5.5. 0_数据导入.py #

pages/0_数据导入.py

# 数据导入页面模块注释
"""数据导入页面"""

# 导入streamlit库并简写为st
import streamlit as st
from services.import_service import import_service
+from services.embedding_service import embedding_service

# 定义主函数
def main():
    # 数据导入页面主函数的文档字符串
    """数据导入页面主函数"""
    # 配置Streamlit页面,设置标题和布局
    st.set_page_config(
        page_title="数据导入 - 图书知识图谱",
        layout="wide"
    )

    # 显示页面主标题
    st.title("数据导入管理")
    # 插入分割线
    st.markdown("---")
    # 创建一个标签页,只包含“CSV数据导入”标签,并将其赋值给tab1变量
+   tab1, tab2 = st.tabs(["CSV数据导入","向量初始化"])

    # 在tab1标签页下进行后续操作
    with tab1:
        # 显示“CSV数据导入”二级标题
        st.header("CSV数据导入")
        # 展示有关上传CSV文件的说明及字段要求
        st.markdown("""
        请上传包含图书信息的CSV文件。CSV文件应包含以下列:
        - `name`: 书名
        - `author`: 作者
        - `publisher`: 出版社
        - `category`: 类别
        - `publish_year`: 出版年份
        - `summary`: 简介
        - `keywords`: 关键词(用分号分隔)
        """)

        # 显示文件上传控件,限定仅可上传csv类型文件,并显示帮助提示
        uploaded_file = st.file_uploader(
            "选择CSV文件",
            type=["csv"],
            help="上传包含图书信息的CSV文件"
        )
        # 如果用户已经上传了文件
        if uploaded_file is not None:
            try:
                # 读取上传的文件内容,并解码为utf-8格式的字符串
                csv_content = uploaded_file.read().decode("utf-8")
                # 调用服务将CSV内容解析为DataFrame
                df = import_service.parse_csv(csv_content)

                # 验证上传的DataFrame是否符合CSV格式要求
                is_valid, error_msg = import_service.validate_csv(df)

                # 如果格式校验未通过,显示错误
                if not is_valid:
                    st.error(f"CSV格式错误: {error_msg}")
                # 如果校验通过,显示成功消息,并告知记录数量
                else:
                    st.success(f"CSV文件验证通过!共 {len(df)} 条记录")
                    # 显示预览
                    with st.expander("数据预览", expanded=True):
                        st.dataframe(df)
                    # 导入选项
                    clear_existing = st.checkbox(
                        "清空现有数据",
                        value=False,
                        help="导入前清空数据库中的所有数据"
                    )
                    # 当用户点击"开始导入"按钮时执行
                    if st.button("开始导入", type="primary", width='stretch'):
                        # 创建一个显示日志信息的空容器
                        log_container = st.empty()
                        # 用于存放日志消息的列表
                        log_messages = []

                        # 定义进度回调函数
                        def progress_callback(message, current, total):
                            # 添加当前消息到日志列表(包含进度信息)
                            log_messages.append(f"- [{current}/{total}] {message}")
                            # 刷新日志容器内容
                            log_container.text("\n".join(log_messages))

                        try:
                            # 调用导入服务执行数据导入
                            stats = import_service.import_books(
                                df,#数据源
                                clear_existing=clear_existing,# 是否清空现有数据
                                progress_callback=progress_callback# 进度回调函数
                            )
                            # 显示最终统计信息
                            if stats:
                                st.info(
                                    f"完成: 图书 {stats.get('books', 0)} | "
                                    f"作者 {stats.get('authors', 0)} | "
                                    f"出版社 {stats.get('publishers', 0)} | "
                                    f"类别 {stats.get('categories', 0)} | "
                                    f"关键词 {stats.get('keywords', 0)} | "
                                    f"关系 {stats.get('relationships', 0)}"
                                )
                                # 如果有错误信息,显示出来
                                if stats.get("errors") and len(stats["errors"]) > 0:
                                    with st.expander("⚠️ 导入过程中的错误", expanded=False):
                                        for error in stats["errors"]:
                                            st.error(error)
                        except Exception as e:
                            # 导入过程中如有异常则显示错误信息
                            st.error(f"导入失败: {str(e)}")
                            # 导入traceback模块以显示详细错误
                            import traceback
                            # 展开详细错误信息
                            with st.expander("详细错误信息"):
                                # 格式化并显示完整的异常调用栈
                                st.code(traceback.format_exc())    



            # 捕获任何异常,显示读取文件失败的错误信息
            except Exception as e:
                st.error(f"读取文件失败: {str(e)}")
    # 使用tab2选项卡
+   with tab2:
        # 设置页面标题为“向量初始化”
+       st.header("向量初始化")
        # 显示多行文字介绍向量初始化的功能
+       st.markdown("""
+       为数据库中的节点生成嵌入向量。此操作会为所有Book和Author节点生成向量,
+       用于后续的向量检索功能。
+       """)

        # 如果用户点击“开始生成向量”按钮
+       if st.button("开始生成向量", type="primary", width='stretch'):
            # 创建一个空的日志容器用于显示进度
+           log_container = st.empty()
            # 用于存放日志消息的列表
+           log_messages = []

            # 定义进度回调函数
+           def progress_callback(message,current, total):
                # 向日志列表追加当前进度消息
+               log_messages.append(f"- [{current}/{total}] {message}")
                # 在日志容器中刷新显示所有日志消息
+               log_container.text("\n".join(log_messages))

+           try:
                # 调用embedding_service的update_embeddings方法,开始进行向量初始化
+               stats = embedding_service.update_embeddings(progress_callback=progress_callback)

                # 在页面上显示向量生成成功的提示
+               st.success("向量生成完成!")

                # 显示统计信息标题
+               st.markdown("### 生成统计")
                # 将页面分为三列用于展示不同统计项
+               col1, col2, col3 = st.columns(3)

                # 第一列显示总节点数
+               with col1:
+                   st.metric("总节点数", stats["total_nodes"])
                # 第二列显示已处理成功的数量
+               with col2:
+                   st.metric("成功", stats["processed"])
                # 第三列显示处理失败的数量
+               with col3:
+                   st.metric("失败", stats["failed"])

                # 如果存在错误信息,则在可展开区域中展示所有错误
+               if stats["errors"]:
+                   with st.expander("错误信息", expanded=False):
+                       for error in stats["errors"]:
+                           st.error(error) 
            # 捕获异常并显示错误信息
+           except Exception as e:
                # 在页面上显示向量生成失败的错误信息
+               st.error(f"向量生成失败: {str(e)}")
                # 导入traceback模块用于获取详细出错信息
+               import traceback
                # 可展开区域显示完整Traceback
+               with st.expander("详细错误信息"):
+                   st.code(traceback.format_exc())



# 判断是否为主程序入口
if __name__ == "__main__":
    # 调用主函数
    main()    

6. 提问 #

6.1. 1_问答系统.py #

pages/1_问答系统.py


"""Streamlit图书知识图谱问答系统主应用"""
import streamlit as st
from config import config

def main():
    """主函数"""
    st.set_page_config(
        page_title="问答系统 - 图书知识图谱",
        layout="wide"
    )
    if "messages" not in st.session_state:
        st.session_state.messages = []
    with st.sidebar:
        st.markdown("### 参数设置")
    st.markdown(
        "<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
        unsafe_allow_html=True,
    )    
    for message in st.session_state.messages:
        st.chat_message(message["role"]).write(message["content"])
    if query := st.chat_input("输入图书相关问题", key="query_input"):
        st.session_state.messages.append({"role": "user", "content": query})
        st.chat_message("user").write(query)




if __name__ == "__main__":
    main()

7. 向量化查询 #

7.1. queries.py #

database/queries.py

"""图数据库查询"""
# 导入类型提示
from typing import List, Dict, Any, Optional
# 导入py2neo的Node类型
from py2neo import Node
# 导入Neo4j连接对象
from database.connection import graph
# 导入常量:向量索引名
from constants import VECTOR_INDEX_BOOK, VECTOR_INDEX_AUTHOR

# 定义图书查询类
class BookQuery:
    """图书查询"""

    # 基于嵌入向量检索图书的方法
    def query_by_embedding(
        self, query_embedding: List[float], top_k: int = 3
    ) -> List[Dict[str, Any]]:
        """基于向量检索图书"""
        # 定义Cypher向量检索查询语句
        query = f"""
        CALL db.index.vector.queryNodes('{VECTOR_INDEX_BOOK}', $top_k, $query_embedding)
        YIELD node, score
        MATCH (node:Book)
        OPTIONAL MATCH (node)-[:written_by]->(author:Author)
        OPTIONAL MATCH (node)-[:published_by]->(publisher:Publisher)
        OPTIONAL MATCH (node)-[:has_category]->(category:Category)
        OPTIONAL MATCH (node)-[:has_keyword]->(keyword:Keyword)
        WITH node, score, author, publisher, category, COLLECT(DISTINCT keyword.name) AS keyword_names
        RETURN node, score, author, publisher, category, keyword_names
        ORDER BY score DESC
        """

        try:
            # 执行Cypher查询,传入top_k和query_embedding参数
            results = graph.run(query, top_k=top_k, query_embedding=query_embedding)
            # 初始化图书结果列表
            books = []
            # 遍历查询结果
            for record in results:
                # 获取图书节点
                node: Node = record["node"]
                # 获取相似度分数
                score = float(record["score"])
                # 获取作者节点
                author_node: Optional[Node] = record.get("author")
                # 获取出版社节点
                publisher_node: Optional[Node] = record.get("publisher")
                # 获取类别节点
                category_node: Optional[Node] = record.get("category")
                # 获取关键词名称列表
                keyword_names = record.get("keyword_names", [])

                # 如果关系中存在关键词则使用,否则优先节点属性keywords
                keywords = keyword_names if keyword_names else (node.get("keywords", []) or [])

                # 构建图书结果字典
                books.append({
                    "name": node.get("name", ""),
                    "similarity": score,
                    "作者": author_node.get("name") if author_node else None,
                    "出版社": publisher_node.get("name") if publisher_node else None,
                    "类别": category_node.get("name") if category_node else None,
                    "出版年份": node.get("publish_year"),
                    "简介": node.get("summary"),
                    "关键词": keywords,
                })
            # 返回图书结果列表
            return books
        except Exception as e:
            # 查询出错时抛出异常并提示
            raise Exception(f"图书查询失败: {e}") from e


# 定义作者查询类
class AuthorQuery:
    """作者查询"""

    # 基于嵌入向量检索作者的方法
    def query_by_embedding(
        self, query_embedding: List[float], top_k: int = 3
    ) -> List[Dict[str, Any]]:
        """基于向量检索作者"""
        # 定义Cypher向量检索查询语句
        query = f"""
        CALL db.index.vector.queryNodes('{VECTOR_INDEX_AUTHOR}', $top_k, $query_embedding)
        YIELD node, score
        MATCH (node:Author)
        OPTIONAL MATCH (node)<-[:written_by]-(book:Book)
        RETURN node, score, COLLECT(DISTINCT book.name) AS book_names
        ORDER BY score DESC
        """

        try:
            # 执行Cypher查询,传入top_k和query_embedding参数
            results = graph.run(query, top_k=top_k, query_embedding=query_embedding)
            # 初始化作者结果列表
            authors = []
            # 遍历查询结果
            for record in results:
                # 获取作者节点
                node: Node = record["node"]
                # 获取相似度分数
                score = float(record["score"])
                # 获取相关图书名称列表
                book_names = record.get("book_names", [])

                # 构建作者结果字典
                authors.append({
                    "name": node.get("name", ""),
                    "similarity": score,
                    "相关图书": book_names,
                })
            # 返回作者结果列表
            return authors
        except Exception as e:
            # 查询出错时抛出异常并提示
            raise Exception(f"作者查询失败: {e}") from e

7.2. retrieval_service.py #

services/retrieval_service.py

"""检索服务"""
# 导入类型提示
from typing import List, Dict, Any
# 导入图书和作者的查询类
from database.queries import BookQuery, AuthorQuery
# 导入嵌入服务(虽然此文件中未直接用到)
from services.embedding_service import embedding_service
# 导入用于区分查询类型和默认top_k值的常量
from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR, DEFAULT_TOP_K


class RetrievalService:
    """检索服务"""

    # 初始化方法,分别实例化图书和作者查询对象
    def __init__(self):
        self.book_query = BookQuery()
        self.author_query = AuthorQuery()

    # 基于嵌入向量查询图书或作者的主方法
    def query_by_embedding(
        self,
        query_embedding: List[float],   # 问题的嵌入向量
        query_type: str,                # 查询类型(图书或作者)
        top_k: int = DEFAULT_TOP_K,     # 返回结果数量,默认值为常量
    ) -> List[Dict[str, Any]]:
        """使用嵌入向量查询"""
        # 若查询类型为“图书”,则使用book_query执行查询
        if query_type == QUERY_TYPE_BOOK:
            return self.book_query.query_by_embedding(query_embedding, top_k)
        # 若查询类型为“作者”,则使用author_query执行查询
        elif query_type == QUERY_TYPE_AUTHOR:
            return self.author_query.query_by_embedding(query_embedding, top_k)
        # 其他类型(不支持的类型)抛出异常
        else:
            raise ValueError(f"不支持的查询类型: {query_type}")


# 创建RetrievalService的单例,便于外部调用
retrieval_service = RetrievalService()

7.3. constants.py #

constants.py

# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
    # 书名
    "name", 
    # 作者
    "author", 
    # 出版社
    "publisher", 
    # 类别
    "category",
    # 出版年份
    "publish_year", 
    # 简介
    "summary", 
    # 关键词(用分号分隔)
    "keywords"
]

# 定义图书节点标签
NODE_LABEL_BOOK = "Book"
# 定义作者节点标签
NODE_LABEL_AUTHOR = "Author"
# 定义出版社节点标签
NODE_LABEL_PUBLISHER = "Publisher"
# 定义类别节点标签
NODE_LABEL_CATEGORY = "Category"
# 定义关键词节点标签
NODE_LABEL_KEYWORD = "Keyword"

# 定义“书籍-作者”关系类型
RELATIONSHIP_TYPE_WRITTEN_BY = "written_by"
# 定义“书籍-出版社”关系类型
RELATIONSHIP_TYPE_PUBLISHED_BY = "published_by"
# 定义“书籍-类别”关系类型
RELATIONSHIP_TYPE_HAS_CATEGORY = "has_category"
# 定义“书籍-关键词”关系类型
RELATIONSHIP_TYPE_HAS_KEYWORD = "has_keyword"

# 查询类型
# 图书查询类型
+QUERY_TYPE_BOOK = "图书"
# 作者查询类型
+QUERY_TYPE_AUTHOR = "作者"
# 默认Top K值
+DEFAULT_TOP_K = 3

# 向量索引名称
# 图书向量索引名称
+VECTOR_INDEX_BOOK = "book_embeddings"
# 作者向量索引名称
+VECTOR_INDEX_AUTHOR = "author_embeddings"

7.4. 1_问答系统.py #

pages/1_问答系统.py

# Streamlit图书知识图谱问答系统主应用
"""Streamlit图书知识图谱问答系统主应用"""

# 导入Streamlit库
import streamlit as st
# 导入配置文件
from config import config
# 导入嵌入向量服务
+from services.embedding_service import embedding_service
# 导入检索服务
+from services.retrieval_service import retrieval_service
# 导入查询类型常量
+from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR

# 定义主函数
def main():
    """主函数"""
    # 设置页面配置(标题、布局)
    st.set_page_config(
        page_title="问答系统 - 图书知识图谱",
        layout="wide"
    )
    # 如果session_state中还没有'messages',则初始化为空列表
    if "messages" not in st.session_state:
        st.session_state.messages = []
    # 侧边栏参数设置
    with st.sidebar:
        # 显示参数设置的标题
        st.markdown("### 参数设置")
        # 单选框选择查询类型(默认选中第一个,即图书)
+       query_type = st.radio("选择查询类型", [QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR], index=0)
        # 滑块选择返回结果数量(Top K),范围1-10,默认3
+       top_k = st.slider("返回结果数量 (Top K)", 1, 10, 3, 1)
    # 在主界面居中显示系统标题
    st.markdown(
        "<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
        unsafe_allow_html=True,
    )    
    # 按顺序显示历史消息
    for message in st.session_state.messages:
        st.chat_message(message["role"]).write(message["content"])
    # 获取用户输入的问题(如果有输入)
    if query := st.chat_input("输入图书相关问题", key="query_input"):
        # 将用户输入添加到消息历史记录
        st.session_state.messages.append({"role": "user", "content": query})
        # 显示用户输入的消息
        st.chat_message("user").write(query)
        # 显示查询中loading动画
+       with st.spinner("正在查询中..."):
+           try:
                # 获取问题的嵌入向量
+               query_embedding = embedding_service.get_embedding(query)
                # 使用嵌入向量进行检索,获取结果
+               results = retrieval_service.query_by_embedding(query_embedding, query_type, top_k)
                # 显示“查询结果”标题
+               st.markdown("#### 查询结果")
                # 如果有返回结果,逐条显示
+               if results:
+                   for idx, item in enumerate(results, 1):
                        # 用可展开板块显示每一条结果,默认展开第一个
+                       with st.expander(f"结果 {idx}", expanded=True if idx == 1 else False):
                            # 按字典项逐个展示
+                           for key, value in item.items():
+                               st.write(f"**{key}**: {value}")
+               else:
                    # 若无结果,显示提示
+                   st.info("未找到相关结果。")
+           except Exception as err:
                # 捕获异常时的处理
+               error_msg = f"查询过程中出错: {str(err)}"
+               st.session_state.messages.append({
+                   "role": "assistant",
+                   "content": error_msg
+               })
+               st.chat_message("assistant").write(error_msg)
+               st.error(error_msg)  


# 判断是否作为主程序运行
if __name__ == "__main__":
    # 调用主函数启动应用
    main()

8. 提问 #

8.1. init.py #

llm/init.py

# 从当前包导入BaseLLM基类
from .base import BaseLLM
# 从当前包导入DeepSeekLLM类
from .deepseek import DeepSeekLLM
# 从当前包导入VolcengineLLM类
from .volcengine import VolcengineLLM
# 定义__all__变量,指定包导出的公有接口
__all__ = ["BaseLLM", "DeepSeekLLM", "VolcengineLLM"]

8.2. base.py #

llm/base.py

# LLM抽象基类的文档字符串
"""LLM抽象基类"""

# 导入ABC和abstractmethod,用于创建抽象基类
from abc import ABC, abstractmethod


# 定义LLM的抽象基类,继承自ABC
class BaseLLM(ABC):
    # 类的文档字符串,说明是LLM抽象基类
    """LLM抽象基类"""

    # 抽象方法generate,子类必须实现
    @abstractmethod
    def generate(self, prompt: str, **kwargs) -> str:
        # 生成回复的文档字符串
        """生成回复"""
        # 抽象方法内使用pass,占位
        pass

8.3. deepseek.py #

llm/deepseek.py

"""DeepSeek LLM实现"""

# 导入可选类型提示
from typing import Optional
# 导入 DeepSeek 的 ChatDeepSeek 类
from langchain_deepseek import ChatDeepSeek
# 导入自定义的基础LLM类
from .base import BaseLLM
# 导入全局配置对象
from config import config
# 导入默认温度和最大token常量
from constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS

# 定义 DeepSeekLLM 类,继承自 BaseLLM
class DeepSeekLLM(BaseLLM):
    """DeepSeek LLM"""

    # 构造函数,支持自定义模型名、api_key、温度和最大输出tokens数量
    def __init__(
        self,
        model_name: Optional[str] = None,
        api_key: Optional[str] = None,
        temperature: float = DEFAULT_TEMPERATURE,
        max_tokens: int = DEFAULT_MAX_TOKENS,
    ):
        # 设置模型名:优先用传入参数,否则用配置里的默认
        self.model_name = model_name or config.deepseek.model
        # 设置api_key:优先用传入参数,否则用配置里的默认
        self.api_key = api_key or config.deepseek.api_key
        # 设置温度:优先用传入参数,否则用默认
        self.temperature = temperature or DEFAULT_TEMPERATURE
        # 设置最大返回token数:优先用传入参数,否则用默认
        self.max_tokens = max_tokens or DEFAULT_MAX_TOKENS
        # 初始化 DeepSeek 的 LLM 实例
        self.llm = ChatDeepSeek(
            api_key=self.api_key,
            model=self.model_name,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
        )

    # 生成方法,实现对LLM的调用
    def generate(self, prompt: str) -> str:
        """生成回复"""
        try:
            # 调用 DeepSeek LLM 获取生成内容
            response = self.llm.invoke(prompt)
            # 检查是否有返回内容
            if not response or not response.content:
                # 返回内容为空时抛出异常
                raise Exception("API返回空响应")
            # 正常时返回生成内容
            return response.content
        except Exception as e:
            # 捕获并重新抛出异常,加上友好的报错说明
            raise Exception(f"DeepSeek API调用失败: {e}") from e

8.4. volcengine.py #

llm/volcengine.py

"""火山引擎LLM实现"""

# 导入Optional类型用于参数类型标注
from typing import Optional
# 导入OpenAI库用于与火山引擎API交互
from openai import OpenAI
# 导入自定义的BaseLLM基类
from .base import BaseLLM
# 导入全局配置对象
from config import config
# 导入默认温度和最大tokens常量
from constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS


# 定义VolcengineLLM类,继承自BaseLLM
class VolcengineLLM(BaseLLM):
    """火山引擎LLM"""

    # 构造函数,初始化各项参数
    def __init__(self, model_name: Optional[str] = None, api_key: Optional[str] = None,
                 temperature: float = DEFAULT_TEMPERATURE, max_tokens: int = DEFAULT_MAX_TOKENS):
        # 获取API基础URL
        self.base_url = config.volcengine.base_url
        # 获取模型名称,优先使用传入值,否则取配置中的默认值
        self.model_name = model_name or config.volcengine.model
        # 获取API密钥,优先使用传入值,否则取配置中的默认值
        self.api_key = api_key or config.volcengine.api_key
        # 设置温度参数,控制生成内容的多样性
        self.temperature = temperature or DEFAULT_TEMPERATURE
        # 设置最大token数量
        self.max_tokens = max_tokens or DEFAULT_MAX_TOKENS
        # 初始化OpenAI客户端,用于与火山引擎API交互
        self.client = OpenAI(
            base_url=self.base_url,
            api_key=self.api_key,
        )

    # 定义generate方法用于生成模型回复
    def generate(self, prompt: str) -> str:
        """生成回复"""
        try:
            # 向火山引擎API发送请求,生成回复
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[{"role": "user", "content": prompt}],
                temperature=self.temperature,
                max_tokens=self.max_tokens,
            )
            # 判断响应内容是否有效
            if not response.choices or not response.choices[0].message.content:
                # 如果返回内容为空,抛出异常
                raise Exception("API返回空响应")
            # 返回生成的回复内容
            return response.choices[0].message.content
        except Exception as e:
            # 捕获异常并抛出自定义异常信息
            raise Exception(f"Volcengine API调用失败: {e}") from e

8.5. init.py #

utils/init.py

"""工具函数模块"""

8.6. formatters.py #

utils/formatters.py

"""格式化工具"""
# 导入类型提示List, Dict, Any
from typing import List, Dict, Any

# 定义用于格式化查询结果上下文的函数
def format_context(query_type: str, results: List[Dict[str, Any]]) -> str:
    """格式化查询结果上下文"""

    # 定义内部函数:格式化一本图书的方法
    def format_book(result: Dict[str, Any]) -> str:
        # 准备字段及相应的值,元组形式依次为(字段名, 字段值)
        fields = [
            ("作者", result.get("作者")),
            ("出版社", result.get("出版社")),
            ("类别", result.get("类别")),
            ("出版年份", result.get("出版年份")),
            ("简介", result.get("简介")),
            # 如果关键词存在,拼接成字符串,否则为None
            ("关键词", ", ".join(result["关键词"]) if result.get("关键词") else None),
        ]
        # 对所有有值的字段拼接为多行字符串,格式为“- 字段名: 字段值”
        return "\n".join([f"   - {k}: {v}" for k, v in fields if v])

    # 定义内部函数:格式化作者(只显示相关图书)
    def format_author(result: Dict[str, Any]) -> str:
        # 如果结果中有“相关图书”字段,则拼接显示
        if result.get("相关图书"):
            return f"   - 相关图书: {', '.join(result['相关图书'])}"
        # 否则返回空串
        return ""

    # 用于存放所有格式化后的每条结果内容
    lines = []
    # 遍历全部结果数据
    for idx, result in enumerate(results, 1):
        # 构建当前项的标题,包括序号、名称及相似度
        header = f"{idx}. {result['name']} (相似度: {result['similarity']:.4f})"
        # 根据查询类型决定调用哪个格式化方法
        details = format_book(result) if query_type == "图书" else format_author(result)
        # 若有详情,拼接标题、详情及换行;否则仅标题加换行
        info = f"{header}\n{details}\n" if details else f"{header}\n"
        # 将本次内容添加到lines列表
        lines.append(info)

    # 用两个换行符分隔拼接所有结果,返回最终字符串
    return "\n\n".join(lines)

8.7. .env #

.env

NEO4J_URI="bolt://localhost:7687"
NEO4J_USER="neo4j"
NEO4J_PASSWORD="12345678"

VOLC_EMBEDDINGS_API_URL=https://ark.cn-beijing.volces.com/api/v3/embeddings
VOLC_EMBEDDING_MODEL=doubao-embedding-text-240715
VOLC_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6


+VOLCENGINE_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
+VOLCENGINE_MODEL=doubao-seed-1-6-250615
+VOLCENGINE_API_KEY=d52e49a1-36ea-44bb-bc6e-65ce789a72f6

+DeepSeek_BASE_URL="https://api.deepseek.com/v1"
+DeepSeek_API_KEY="sk-278496d471bc4f4cb0ccb8c389a15018"
+DeepSeek_MODEL="deepseek-chat"

8.8. config.py #

config.py

"""配置管理模块"""
# 导入os模块,用于获取环境变量
import os
# 导入dotenv模块,用于加载.env文件中的环境变量
import dotenv
# 从dataclasses模块导入dataclass,用于简化数据类定义
from dataclasses import dataclass
+from typing import Optional

# 加载.env文件中的所有环境变量
dotenv.load_dotenv()

# 定义Neo4j数据库配置的数据类
@dataclass
class Neo4jConfig:
    """Neo4j数据库配置"""
    # 数据库URI
    uri: str
    # 用户名
    user: str
    # 密码
    password: str

@dataclass
class EmbeddingConfig:
    """火山方舟嵌入API配置"""
    api_url: str
    api_key: str
    model: str
+@dataclass
+class VolcengineConfig:
+   """火山引擎API配置"""
+   base_url: Optional[str]
+   api_key: Optional[str]
+   model: Optional[str]


+@dataclass
+class DeepSeekConfig:
+   """DeepSeek API配置"""
+   base_url: str
+   api_key: str
+   model: str
# 定义应用整体配置的数据类
@dataclass
class AppConfig:
    """应用配置"""
    # Neo4j数据库配置属性
    neo4j: Neo4jConfig
    # 火山方舟嵌入API配置属性
    embedding: EmbeddingConfig
    # 火山引擎API配置属性
+   volcengine: VolcengineConfig
    # DeepSeek API配置属性
+   deepseek: DeepSeekConfig

# 创建一个AppConfig对象,保存应用的整体配置
config = AppConfig(
    # Neo4j数据库配置,使用Neo4jConfig类初始化
    neo4j=Neo4jConfig(
        # 从环境变量获取数据库URI,如果没有则使用默认值"bolt://localhost:7687"
        uri=os.environ.get("NEO4J_URI", "bolt://localhost:7687"),
        # 从环境变量获取数据库用户名,如果没有则使用默认值"neo4j"
        user=os.environ.get("NEO4J_USER", "neo4j"),
        # 从环境变量获取数据库密码,如果没有则使用默认值"12345678"
        password=os.environ.get("NEO4J_PASSWORD", "12345678"),
    ),
    # 嵌入API配置,使用EmbeddingConfig类初始化
    embedding=EmbeddingConfig(
        # 从环境变量获取嵌入API的URL,没有则使用默认值
        api_url=os.environ.get(
            "VOLC_EMBEDDINGS_API_URL",
            "https://ark.cn-beijing.volces.com/api/v3/embeddings"
        ),
        # 从环境变量获取API密钥,如果没有则为空字符串
        api_key=os.environ.get("VOLC_API_KEY", ""),
        # 从环境变量获取嵌入模型名,没有则使用默认值"doubao-embedding-text-240715"
        model=os.environ.get("VOLC_EMBEDDING_MODEL", "doubao-embedding-text-240715"),
    ),
   # 初始化火山引擎API配置,从环境变量获取相关参数
   volcengine=VolcengineConfig(
        # 从环境变量获取火山引擎API的基础URL
       base_url=os.environ.get("VOLCENGINE_BASE_URL","https://ark.cn-beijing.volces.com/api/v3"),
        # 从环境变量获取火山引擎API的密钥
       api_key=os.environ.get("VOLCENGINE_API_KEY","d52e49a1-36ea-44bb-bc6e-65ce789a72f6"),
        # 从环境变量获取火山引擎API的模型名称
       model=os.environ.get("VOLCENGINE_MODEL","doubao-seed-1-6-250615"),
   ),
    # 初始化DeepSeek API配置,支持从环境变量读取参数,并提供默认值
   deepseek=DeepSeekConfig(
        # 从环境变量获取DeepSeek基础URL,若未设置则使用默认值"https://api.deepseek.com"
       base_url=os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com"),
        # 从环境变量获取DeepSeek API密钥,默认值为空字符串
       api_key=os.environ.get("DEEPSEEK_API_KEY", "ae59c1b6731e4d5d8f1f3fd0ad340b39"),
        # 从环境变量获取DeepSeek模型名,若未设置则使用默认值"deepseek-chat"
       model=os.environ.get("DEEPSEEK_MODEL", "deepseek-chat"),
   ),
)

8.9. constants.py #

constants.py

# 定义一个列表,包含CSV文件导入时必需的列名
REQUIRED_CSV_COLUMNS = [
    # 书名
    "name", 
    # 作者
    "author", 
    # 出版社
    "publisher", 
    # 类别
    "category",
    # 出版年份
    "publish_year", 
    # 简介
    "summary", 
    # 关键词(用分号分隔)
    "keywords"
]

# 定义图书节点标签
NODE_LABEL_BOOK = "Book"
# 定义作者节点标签
NODE_LABEL_AUTHOR = "Author"
# 定义出版社节点标签
NODE_LABEL_PUBLISHER = "Publisher"
# 定义类别节点标签
NODE_LABEL_CATEGORY = "Category"
# 定义关键词节点标签
NODE_LABEL_KEYWORD = "Keyword"

# 定义“书籍-作者”关系类型
RELATIONSHIP_TYPE_WRITTEN_BY = "written_by"
# 定义“书籍-出版社”关系类型
RELATIONSHIP_TYPE_PUBLISHED_BY = "published_by"
# 定义“书籍-类别”关系类型
RELATIONSHIP_TYPE_HAS_CATEGORY = "has_category"
# 定义“书籍-关键词”关系类型
RELATIONSHIP_TYPE_HAS_KEYWORD = "has_keyword"

# 查询类型
# 图书查询类型
QUERY_TYPE_BOOK = "图书"
# 作者查询类型
QUERY_TYPE_AUTHOR = "作者"
# 默认Top K值
DEFAULT_TOP_K = 3
# 默认最大Tokens值
+DEFAULT_MAX_TOKENS = 4096
# 默认温度值
+DEFAULT_TEMPERATURE = 0.7

# 向量索引名称
# 图书向量索引名称
VECTOR_INDEX_BOOK = "book_embeddings"
# 作者向量索引名称
VECTOR_INDEX_AUTHOR = "author_embeddings"

8.10. 1_问答系统.py #

pages/1_问答系统.py

# Streamlit图书知识图谱问答系统主应用
"""Streamlit图书知识图谱问答系统主应用"""

# 导入Streamlit库
import streamlit as st
# 导入PromptTemplate
+from langchain_core.prompts import PromptTemplate
# 导入配置文件
from config import config
# 导入嵌入向量服务
from services.embedding_service import embedding_service
# 导入检索服务
from services.retrieval_service import retrieval_service
# 导入查询类型常量
from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR
# 导入大模型服务
+from llm import VolcengineLLM, DeepSeekLLM
+from utils.formatters import format_context
+prompt_template = PromptTemplate(
+   input_variables=["question", "context"],
+   template="""你是一名图书知识助手,需要根据提供的图书信息回答用户的提问。
+   请直接回答问题,如果信息不足,请回答"根据现有信息无法确定"。
+   问题:{question}
+   图书信息:\n{context}
+   回答:""",
+)
# 定义主函数
def main():
    """主函数"""
    # 设置页面配置(标题、布局)
    st.set_page_config(
        page_title="问答系统 - 图书知识图谱",
        layout="wide"
    )
    # 如果session_state中还没有'messages',则初始化为空列表
    if "messages" not in st.session_state:
        st.session_state.messages = []
    # 侧边栏参数设置
    with st.sidebar:
        # 显示参数设置的标题
        st.markdown("### 参数设置")
        # 单选框选择查询类型(默认选中第一个,即图书)
        query_type = st.radio("选择查询类型", [QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR], index=0)
        # 滑块选择返回结果数量(Top K),范围1-10,默认3
        top_k = st.slider("返回结果数量 (Top K)", 1, 10, 3, 1)
        # 滑块选择温度值,范围0.0-1.0,默认0.3
+       temperature = st.slider("温度 (Temperature)", 0.0, 1.0, 0.3, 0.1)
        # 选择大模型服务商,默认火山引擎
+       llm_provider = st.selectbox("选择大模型服务商", ["volcengine", "deepseek"], index=0)
        # 如果选择火山引擎,则显示火山引擎API Key和模型名
+       if llm_provider == "volcengine":
+           api_key = st.text_input(
+               "火山引擎 API Key",
+               value=config.volcengine.api_key or "",
+               type="password",
+               help="如留空则使用服务器默认配置",
+           )
+           model_name = st.text_input(
+               "火山引擎模型名",
+               value=config.volcengine.model or "",
+               help="如留空则使用服务器默认配置"
+           )
+       else:
+           api_key = st.text_input(
+               "DeepSeek API Key",
+               value=config.deepseek.api_key or "",
+               type="password",
+               help="如留空则使用服务器默认配置",
+           )
+           model_name = st.text_input(
+               "DeepSeek模型名",
+               value=config.deepseek.model or "",
+               help="如留空则使用服务器默认配置"
+           )
    # 在主界面居中显示系统标题
    st.markdown(
        "<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
        unsafe_allow_html=True,
    )    
    # 按顺序显示历史消息
    for message in st.session_state.messages:
        st.chat_message(message["role"]).write(message["content"])
+       if message["role"] == "assistant":  
+           # 展开可查看详细检索结果
+           with st.expander("查看详细结果"):
+               # 以 json 格式显示检索的类型和结果
+               st.json({"type": message['query_type'], "results": message['results']})
    # 获取用户输入的问题(如果有输入)
    if query := st.chat_input("输入图书相关问题", key="query_input"):
        # 将用户输入添加到消息历史记录
        st.session_state.messages.append({"role": "user", "content": query})
        # 显示用户输入的消息
        st.chat_message("user").write(query)
        # 显示查询中的 loading 动画
        with st.spinner("正在查询中..."):
            try:
                # 获取用户输入 query 的嵌入向量
                query_embedding = embedding_service.get_embedding(query)
                # 用嵌入向量进行检索,查询对应的结果
                results = retrieval_service.query_by_embedding(query_embedding, query_type, top_k)
                # 如果没有检索到结果
+               if not results or len(results) == 0:
                    # 设置回复内容为没有找到相关信息
+                   answer = "抱歉,没有找到相关的信息。"
                else:
                    # 如果选择的服务商是火山引擎
+                   if llm_provider == "volcengine":
                        # 实例化火山引擎大模型
+                       llm = VolcengineLLM(
+                           model_name=model_name if model_name else None,
+                           api_key=api_key if api_key else None,
+                       )
+                   else:
                        # 否则实例化 DeepSeek 大模型
+                       llm = DeepSeekLLM(
+                           model_name=model_name if model_name else None,
+                           api_key=api_key if api_key else None,
+                       )
                    # 根据检索结果格式化上下文字符串
+                   context_str = format_context(query_type, results)
                    # 打印上下文字符串,便于调试
+                   print("context_str: ", context_str)
                    # 用模板格式化最终 prompt
+                   final_prompt = prompt_template.format(
+                       question=query,
+                       context=context_str
+                   )
                    # 打印最终 prompt,便于调试
+                   print("final_prompt: ", final_prompt)
                    # 调用大模型生成答案
+                   answer = llm.generate(final_prompt)

                    # 展开可查看详细检索结果
+                   with st.expander("查看详细结果"):
                        # 以 json 格式显示检索的类型和结果
+                       st.json({"type": query_type, "results": results})
                    # 把大模型回复添加到会话历史
+                   st.session_state.messages.append({
+                       "role": "assistant",
+                       "content": answer,
+                       "query_type": query_type,
+                       "results": results          # 检索结果
+                   })
                    # 用 chat_message 组件显示大模型回复
+                   st.chat_message("assistant").write(answer)
                    # 重新运行页面,用于强制刷新
+                   st.rerun()
            except Exception as err:
                # 打印异常信息,便于调试
+               print(err)
                # 捕获异常时组织错误提示内容
                error_msg = f"查询过程中出错: {str(err)}"
                # 将错误信息添加到会话历史
                st.session_state.messages.append({
                    "role": "assistant",
                    "content": error_msg
+                   "query_type": query_type,   # 查询类型
+                   "results": results          # 检索结果
                })
                # 显示错误信息在对话界面
                st.chat_message("assistant").write(error_msg)
                # 在页面上弹出错误提示
                st.error(error_msg)  


# 判断是否作为主程序运行
if __name__ == "__main__":
    # 调用主函数启动应用
    main()

9. 历史记录 #

9.1. 1_问答系统.py #

pages/1_问答系统.py

# Streamlit图书知识图谱问答系统主应用
"""Streamlit图书知识图谱问答系统主应用"""

# 导入Streamlit库
import streamlit as st
# 导入PromptTemplate
from langchain_core.prompts import PromptTemplate
# 导入配置文件
from config import config
# 导入嵌入向量服务
from services.embedding_service import embedding_service
# 导入检索服务
from services.retrieval_service import retrieval_service
# 导入查询类型常量
from constants import QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR
# 导入大模型服务
from llm import VolcengineLLM, DeepSeekLLM
from utils.formatters import format_context
prompt_template = PromptTemplate(
    input_variables=["question", "context"],
    template="""你是一名图书知识助手,需要根据提供的图书信息回答用户的提问。
    请直接回答问题,如果信息不足,请回答"根据现有信息无法确定"。
    问题:{question}
    图书信息:\n{context}
    回答:""",
)
# 定义主函数
def main():
    """主函数"""
    # 设置页面配置(标题、布局)
    st.set_page_config(
        page_title="问答系统 - 图书知识图谱",
        layout="wide"
    )
    # 如果session_state中还没有'messages',则初始化为空列表
    if "messages" not in st.session_state:
        st.session_state.messages = []
    # 如果session_state中还没有'history',则初始化为空列表
+   if "history" not in st.session_state:
+       st.session_state.history = []    
    # 侧边栏参数设置
    with st.sidebar:
        # 显示参数设置的标题
        st.markdown("### 参数设置")
        # 单选框选择查询类型(默认选中第一个,即图书)
        query_type = st.radio("选择查询类型", [QUERY_TYPE_BOOK, QUERY_TYPE_AUTHOR], index=0)
        # 滑块选择返回结果数量(Top K),范围1-10,默认3
        top_k = st.slider("返回结果数量 (Top K)", 1, 10, 3, 1)
        # 滑块选择温度值,范围0.0-1.0,默认0.3
        temperature = st.slider("温度 (Temperature)", 0.0, 1.0, 0.3, 0.1)
        # 选择大模型服务商,默认火山引擎
        llm_provider = st.selectbox("选择大模型服务商", ["volcengine", "deepseek"], index=0)
        # 如果选择火山引擎,则显示火山引擎API Key和模型名
        if llm_provider == "volcengine":
            api_key = st.text_input(
                "火山引擎 API Key",
                value=config.volcengine.api_key or "",
                type="password",
                help="如留空则使用服务器默认配置",
            )
            model_name = st.text_input(
                "火山引擎模型名",
                value=config.volcengine.model or "",
                help="如留空则使用服务器默认配置"
            )
        else:
            api_key = st.text_input(
                "DeepSeek API Key",
                value=config.deepseek.api_key or "",
                type="password",
                help="如留空则使用服务器默认配置",
            )
            model_name = st.text_input(
                "DeepSeek模型名",
                value=config.deepseek.model or "",
                help="如留空则使用服务器默认配置"
            )
+       st.markdown("### 历史查询")
+       if st.session_state.history:
+           for i, item in enumerate(st.session_state.history):
+               with st.expander(f"查询 {i+1}: {item['question']}"):
+                   st.json(item)
+       else:
+           st.info("暂无历史查询记录")    
    # 在主界面居中显示系统标题
    st.markdown(
        "<h2 style='text-align: center; color: blue;'>图书知识图谱查询系统</h1>",
        unsafe_allow_html=True,
    )    
    # 按顺序显示历史消息
    for message in st.session_state.messages:
        st.chat_message(message["role"]).write(message["content"])
    # 获取用户输入的问题(如果有输入)
    if query := st.chat_input("输入图书相关问题", key="query_input"):
        # 将用户输入添加到消息历史记录
        st.session_state.messages.append({"role": "user", "content": query})
        # 显示用户输入的消息
        st.chat_message("user").write(query)
        # 显示查询中的 loading 动画
        with st.spinner("正在查询中..."):
            try:
                # 获取用户输入 query 的嵌入向量
                query_embedding = embedding_service.get_embedding(query)
                # 用嵌入向量进行检索,查询对应的结果
                results = retrieval_service.query_by_embedding(query_embedding, query_type, top_k)
                # 如果没有检索到结果
                if not results or len(results) == 0:
                    # 设置回复内容为没有找到相关信息
                    answer = "抱歉,没有找到相关的信息。"
                else:
                    # 如果选择的服务商是火山引擎
                    if llm_provider == "volcengine":
                        # 实例化火山引擎大模型
                        llm = VolcengineLLM(
                            model_name=model_name if model_name else None,
                            api_key=api_key if api_key else None,
                        )
                    else:
                        # 否则实例化 DeepSeek 大模型
                        llm = DeepSeekLLM(
                            model_name=model_name if model_name else None,
                            api_key=api_key if api_key else None,
                        )
                    # 根据检索结果格式化上下文字符串
                    context_str = format_context(query_type, results)
                    # 打印上下文字符串,便于调试
                    print("context_str: ", context_str)
                    # 用模板格式化最终 prompt
                    final_prompt = prompt_template.format(
                        question=query,
                        context=context_str
                    )
                    # 打印最终 prompt,便于调试
                    print("final_prompt: ", final_prompt)
                    # 调用大模型生成答案
                    answer = llm.generate(final_prompt)
                    # 将本次对话历史添加到 session_state.history,也包含了用户问题、检索类型、上下文、模型回复及温度参数
+                   st.session_state.history.append({
+                       "question": query,          # 用户输入的问题
+                       "query_type": query_type,   # 查询类型(如“图书”/“作者”)
+                       "context": context_str,     # 格式化后的上下文字符串
+                       "answer": answer,           # 大模型生成的答案
+                       "temperature": temperature, # 当前使用的温度参数
+                   })
                    # 展开可查看详细检索结果
                    with st.expander("查看详细结果"):
                        # 以 json 格式显示检索的类型和结果
                        st.json({"type": query_type, "results": results})
                    # 把大模型回复添加到会话历史
                    st.session_state.messages.append({
                        "role": "assistant",
                        "content": answer,
                        "query_type": query_type,
                        "results": results          # 检索结果

                    })
                    # 用 chat_message 组件显示大模型回复
                    st.chat_message("assistant").write(answer)
                    # 重新运行页面,用于强制刷新
                    st.rerun()
            except Exception as err:
                # 打印异常信息,便于调试
                print(err)
                # 捕获异常时组织错误提示内容
                error_msg = f"查询过程中出错: {str(err)}"
                # 将错误信息添加到会话历史
                st.session_state.messages.append({
                    "role": "assistant",
                    "content": error_msg
                })
                # 显示错误信息在对话界面
                st.chat_message("assistant").write(error_msg)
                # 在页面上弹出错误提示
                st.error(error_msg)  


# 判断是否作为主程序运行
if __name__ == "__main__":
    # 调用主函数启动应用
    main()

10. 项目整体架构 #

本项目是一个基于 Streamlit 多页面应用 的图书知识图谱问答系统,采用分层架构:

用户浏览器
    ↓
┌─────────────────────────────────────┐
│  Streamlit UI 层 (pages/)           │
│  ├─ 首页.py          (欢迎页)       │
│  ├─ 0_数据导入.py    (导入+向量化)   │
│  └─ 1_问答系统.py    (RAG问答)      │
├─────────────────────────────────────┤
│  服务层 (services/)                  │
│  ├─ ImportService        (CSV导入)   │
│  ├─ EmbeddingInitService (向量初始化) │
│  ├─ EmbeddingService     (嵌入API)   │
│  └─ RetrievalService     (检索编排)  │
├─────────────────────────────────────┤
│  LLM层 (llm/)                       │
│  ├─ BaseLLM              (抽象基类)  │
│  ├─ VolcengineLLM        (火山引擎)  │
│  └─ DeepSeekLLM          (DeepSeek)  │
├─────────────────────────────────────┤
│  数据库层 (database/)                │
│  ├─ GraphManager          (连接管理) │
│  ├─ BookQuery             (图书查询) │
│  └─ AuthorQuery           (作者查询) │
├─────────────────────────────────────┤
│  基础设施                            │
│  ├─ config.py     (配置管理)         │
│  ├─ constants.py  (常量定义)         │
│  └─ exceptions.py (异常定义)         │
└─────────────────────────────────────┘
    ↓                    ↓
  Neo4j             火山引擎API
  图数据库          (Embedding/LLM)

11 类图 #

classDiagram direction TB %% ========== 配置层 ========== class AppConfig { +neo4j: Neo4jConfig +volcengine: VolcengineConfig +deepseek: DeepSeekConfig +embedding: EmbeddingConfig } class Neo4jConfig { +uri: str +user: str +password: str } class VolcengineConfig { +base_url: Optional~str~ +api_key: Optional~str~ +model: Optional~str~ } class DeepSeekConfig { +base_url: str +api_key: str +model: str } class EmbeddingConfig { +api_url: str +api_key: str +model: str } AppConfig *-- Neo4jConfig AppConfig *-- VolcengineConfig AppConfig *-- DeepSeekConfig AppConfig *-- EmbeddingConfig %% ========== 数据库层 ========== class GraphManager { -_instance: GraphManager$ -_graph: Graph$ +graph: Graph +__new__() GraphManager +__init__() } class BookQuery { -graph: Graph +query_by_embedding(embedding, top_k) List~Dict~ } class AuthorQuery { -graph: Graph +query_by_embedding(embedding, top_k) List~Dict~ } GraphManager ..> AppConfig : 读取配置 BookQuery --> GraphManager : get_graph() AuthorQuery --> GraphManager : get_graph() %% ========== LLM层 ========== class BaseLLM { <<abstract>> +generate(prompt, **kwargs)* str } class VolcengineLLM { -model_name: str -base_url: str -api_key: str -client: OpenAI +generate(prompt, temperature) str } class DeepSeekLLM { -model_name: str -llm: ChatDeepSeek +generate(prompt, temperature) str } BaseLLM <|-- VolcengineLLM BaseLLM <|-- DeepSeekLLM VolcengineLLM ..> AppConfig : 读取配置 DeepSeekLLM ..> AppConfig : 读取配置 %% ========== 服务层 ========== class EmbeddingService { -api_url: str -api_key: str -model: str -max_retries: int +get_embedding(text) List~float~ } class RetrievalService { -book_query: BookQuery -author_query: AuthorQuery -embedding_service: EmbeddingService +query_by_embedding(embedding, type, top_k) List~Dict~ } class ImportService { -graph: Graph -node_matcher: NodeMatcher +parse_csv(content) DataFrame +validate_csv(df) tuple +clear_database() +import_books(df, clear, callback) Dict -_create_nodes(...) int -_create_relationships(...) int } class EmbeddingInitService { -graph: Graph -node_matcher: NodeMatcher -embedding_service: EmbeddingService +init_embeddings(labels, callback) dict } EmbeddingService ..> AppConfig : 读取配置 RetrievalService --> BookQuery RetrievalService --> AuthorQuery RetrievalService --> EmbeddingService ImportService --> GraphManager : get_graph() EmbeddingInitService --> GraphManager : get_graph() EmbeddingInitService --> EmbeddingService %% ========== 异常 ========== class ConfigurationError class DatabaseError class EmbeddingAPIError class LLMAPIError Exception <|-- ConfigurationError Exception <|-- DatabaseError Exception <|-- EmbeddingAPIError Exception <|-- LLMAPIError

12 页面执行过程 #

12.1. 首页 #

"""Streamlit应用主入口"""
import streamlit as st

st.set_page_config(
    page_title="图书知识图谱系统",
    layout="wide",
    initial_sidebar_state="expanded"
)

st.title("图书知识图谱系统")
st.markdown("---")
st.markdown("""
欢迎使用图书知识图谱系统!

**功能模块**

1. **问答系统** - 基于向量检索的智能问答
2. **数据导入** - 导入CSV数据并生成向量

请使用左侧导航栏访问各个功能模块。
""")

执行过程很简单:

  1. Streamlit 启动时加载 首页.py 作为主入口
  2. set_page_config() 设置浏览器标题和布局
  3. 渲染标题和功能说明的 Markdown 文本
  4. Streamlit 自动扫描 pages/ 目录,在左侧栏生成导航链接(0_数据导入、1_问答系统)

12.2 数据导入页 #

这是最复杂的页面,包含两个功能标签页:CSV导入 和 向量初始化。

12.2.1 时序图:CSV 数据导入流程 #

sequenceDiagram actor User as 用户 participant UI as 0_数据导入.py participant IS as ImportService participant GM as GraphManager participant Neo4j as Neo4j数据库 User->>UI: 上传 CSV 文件 UI->>IS: parse_csv(csv_content) IS-->>UI: 返回 DataFrame UI->>IS: validate_csv(df) IS-->>UI: (True, None) 或 (False, 错误信息) alt 验证失败 UI-->>User: 显示错误 ❌ end UI-->>User: 显示数据预览表格 User->>UI: 点击「开始导入」 UI->>IS: import_books(df, clear_existing, callback) alt 清空现有数据 IS->>GM: get_graph() GM-->>IS: Graph 实例 IS->>Neo4j: MATCH (n) DETACH DELETE n end Note over IS: 提取唯一值:books, authors,<br/>publishers, categories, keywords Note over IS: 遍历 DataFrame<br/>准备关系边列表 loop 创建 Book 节点 IS->>Neo4j: graph.create(Node("Book", ...)) IS-->>UI: callback("创建Book: 三体", 1, N) UI-->>User: 更新进度条 end loop 创建 Author/Publisher/Category/Keyword 节点 IS->>Neo4j: graph.create(Node(label, name=...)) IS-->>UI: callback(...) end loop 创建关系(written_by, published_by, ...) IS->>Neo4j: node_matcher.match("Book", name=...) Neo4j-->>IS: 起始节点 IS->>Neo4j: node_matcher.match("Author", name=...) Neo4j-->>IS: 目标节点 IS->>Neo4j: graph.create(Relationship(start, "written_by", end)) IS-->>UI: callback(...) end IS-->>UI: 返回 stats 统计 UI-->>User: 显示导入统计(图书/作者/关系数等)

关键执行步骤详解:

步骤 代码位置 说明
① 文件上传 0_数据导入.py:37-41 st.file_uploader 接收 CSV 文件
② 解析CSV ImportService.parse_csv() 用 pd.read_csv 解析为 DataFrame
③ 验证格式 ImportService.validate_csv() 检查是否包含7个必需列
④ 预览 0_数据导入.py:58-59 st.dataframe 显示前10行
⑤ 导入 ImportService.import_books() 核心逻辑:先提取唯一值,再批量创建节点和关系
⑥ 节点创建 _create_nodes() 用 py2neo.Node + graph.create() 创建
⑦ 关系创建 _create_relationships() 用 NodeMatcher.match() 查找节点,Relationship() 创建关系

12.2.2 向量初始化流程 #

sequenceDiagram actor User as 用户 participant UI as 0_数据导入.py (Tab2) participant EIS as EmbeddingInitService participant ES as EmbeddingService participant GM as GraphManager participant Neo4j as Neo4j数据库 participant API as 火山方舟 Embedding API User->>UI: 选择节点类型 [Book, Author] User->>UI: 点击「开始生成向量」 UI->>EIS: init_embeddings(["Book","Author"], callback) loop 对每个标签(Book, Author) EIS->>Neo4j: node_matcher.match("Book") Neo4j-->>EIS: 返回所有 Book 节点列表 loop 对每个节点 EIS->>ES: get_embedding(node.name) ES->>API: POST /embeddings {model, input: "三体"} API-->>ES: {data: [{embedding: [0.1, 0.2, ...]}]} ES-->>EIS: 返回向量 List[float] EIS->>Neo4j: node["embedding"] = 向量 EIS->>Neo4j: graph.push(node) EIS-->>UI: callback("✓ 已为 三体 设置嵌入向量", n, total) UI-->>User: 更新进度条 end end EIS-->>UI: 返回 stats UI-->>User: 显示统计(总节点/成功/失败)

关键点:

  • EmbeddingService.get_embedding() 调用火山方舟的 REST API 获取文本的嵌入向量(2048维浮点数组)
  • 内置重试机制:服务器错误(5xx)或网络异常时最多重试3次
  • 向量通过 graph.push(node) 写回 Neo4j 节点的 embedding 属性

12.3. 问答页 #

这是系统的核心功能页面,实现了 RAG(检索增强生成) 流程。

12.3.1 问答流程 #

sequenceDiagram actor User as 用户 participant UI as 1_问答系统.py participant RS as RetrievalService participant ES as EmbeddingService participant BQ as BookQuery participant Neo4j as Neo4j数据库 participant FMT as format_context() participant LLM as VolcengineLLM / DeepSeekLLM participant API as LLM API Note over UI: 侧边栏:设置查询类型、Top K、<br/>温度、模型服务商、API Key User->>UI: 输入问题:"活着的作者是谁?" UI->>UI: session_state.messages.append(user msg) rect rgb(240, 248, 255) Note over UI,API: RAG 核心流程 %% Step 1: 生成问题的嵌入向量 UI->>ES: get_embedding("活着的作者是谁?") ES->>API: POST 火山方舟 /embeddings API-->>ES: 返回 query_embedding [0.12, 0.34, ...] ES-->>UI: query_embedding %% Step 2: 创建 LLM 实例 UI->>LLM: new VolcengineLLM(api_key, model_name) LLM->>LLM: 读取 AppConfig, 创建 OpenAI client %% Step 3: 向量检索 UI->>RS: query_by_embedding(embedding, "图书", top_k=3) RS->>BQ: query_by_embedding(embedding, 3) BQ->>Neo4j: CALL db.index.vector.queryNodes(...)<br/>+ OPTIONAL MATCH 作者/出版社/类别/关键词 Neo4j-->>BQ: 返回匹配的节点 + 相似度分数 BQ-->>RS: [{name:"活着", similarity:0.93, 作者:"余华", ...}, ...] RS-->>UI: results %% Step 4: 格式化上下文 UI->>FMT: format_context("图书", results) FMT-->>UI: "1. 活着 (相似度: 0.9342)\n - 作者: 余华\n..." %% Step 5: 构造 Prompt UI->>UI: prompt_template.format(question, context) Note over UI: "你是一名图书知识助手...<br/>问题:活着的作者是谁?<br/>图书信息:1. 活着...<br/>回答:" %% Step 6: 调用 LLM 生成回答 UI->>LLM: generate(final_prompt, temperature=0.3) LLM->>API: POST chat/completions {model, messages, temperature} API-->>LLM: "活着的作者是余华。" LLM-->>UI: answer end UI->>UI: session_state.messages.append(assistant msg) UI->>UI: session_state.history.append(记录) UI-->>User: 显示回答 + 可展开的详细结果 UI->>UI: st.rerun() 刷新页面

RAG 核心 6 步详解:

步骤 作用 代码位置
① 问题向量化 将用户问题转为嵌入向量 get_embedding(query) → EmbeddingService
② 创建 LLM 根据用户选择的服务商创建对应 LLM 实例 VolcengineLLM(...) 或 DeepSeekLLM(...)
③ 向量检索 用问题向量在 Neo4j 中搜索最相似的图书/作者 RetrievalService → BookQuery/AuthorQuery
④ 格式化上下文 将检索结果格式化为 LLM 可读的文本 format_context()
⑤ 构造 Prompt 用 LangChain PromptTemplate 拼接最终提示词 prompt_template.format(...)
⑥ LLM 生成 调用 LLM API 生成最终回答 llm.generate(final_prompt)

13. 数据流总结图 #

graph LR subgraph RAG问答流程 J[用户问题] -->|get_embedding| H2[火山方舟API] H2 -->|query_embedding| K[向量检索] K -->|Cypher| E2[(Neo4j)] E2 -->|相似结果| L[format_context] L -->|上下文| M[PromptTemplate] J -->|问题| M M -->|final_prompt| N[LLM API] N -->|回答| O[显示给用户] end subgraph 数据导入流程 A[CSV文件] -->|parse_csv| B[DataFrame] B -->|validate_csv| C{验证} C -->|通过| D[创建节点] D -->|py2neo| E[(Neo4j)] D --> F[创建关系] F -->|py2neo| E E -->|init_embeddings| G[读取节点] G -->|get_embedding| H[火山方舟API] H -->|向量| I[写回节点] I -->|graph.push| E end

整个系统的核心思路就是:先导入数据建图 → 再向量化 → 最后用 RAG 检索+生成 回答用户问题。

← 上一节 48.Pandas 下一节 50.deepdoc →

访问验证

请输入访问令牌

Token不正确,请重新输入