46.BaseTool #
如何基于BaseTool快速实现一个结构化参数的自定义工具类。
核心知识点:
- 基类 BaseTool 提供统一接口(如 name, description, args_schema, run/invoke),方便各种工具标准化实现。
- 使用 Pydantic 模型(如 AddInput)进行参数类型声明和校验,让输入更安全、易用也便于自动文档化。
- 工具通过实现
_run(同步)和可选的_arun(异步)接口,定义实际的业务逻辑。 - 典型调用方式为
invoke或run,支持直接以字典或参数形式输入,也能在主程序中便捷测试。
应用场景: 适合于需要标准化工具调用、接口自描述、参数自动校验的场景,尤其对接大模型 Agent、自动工具链等需求。
关键优势:
- 解耦参数校验与业务逻辑,易于扩展与集成;
- 支持多种调用方式(同步/异步,dict/参数);
- 保持类型安全、结构清晰。
此范式实际应用时,只需根据业务需求自定义 args_schema 和 _run 方法即可,非常适合构建各类 Agent 工具组件。
46.1. 46.BaseTool.py #
46.BaseTool.py
# 从pydantic库导入BaseModel和Field,用于数据模型和字段定义
from pydantic import BaseModel, Field
# from langchain_core.tools import BaseTool # 这一行是被注释掉的替代导入
# 从smartchain.tools模块导入BaseTool作为基类
from smartchain.tools import BaseTool
# 定义加法输入的数据模型,继承自BaseModel
class AddInput(BaseModel):
"""加法输入参数"""
# 被加数,类型为int,带有描述信息
a: int = Field(..., description="被加数")
# 加数,类型为int,带有描述信息
b: int = Field(..., description="加数")
# 定义加法工具类,继承自BaseTool
class AddTool(BaseTool):
"""继承 BaseTool,声明 name/description/args_schema 并实现 _run/_arun"""
# 工具名称,字符串类型
name: str = "add"
# 工具描述,字符串类型
description: str = "计算两个数的和"
# 参数模式,类型为BaseModel,此处为AddInput
args_schema: type[BaseModel] = AddInput
# 实现同步运行方法,计算两个整数的和
def _run(self, a: int, b: int, **kwargs) -> int:
return a + b
# 实现异步运行方法,调用同步方法完成计算
async def _arun(self, a: int, b: int, **kwargs) -> int:
return self._run(a, b, **kwargs)
# 主程序入口
if __name__ == "__main__":
# 实例化加法工具
tool = AddTool()
# 打印工具名称
print("工具名称:", tool.name)
# 打印工具描述
print("工具描述:", tool.description)
# 打印参数模式的JSON属性
print("参数模式:", tool.args_schema.model_json_schema()["properties"])
# 同步调用invoke方法,传入字典参数,打印调用结果
print("调用结果:", tool.invoke({"a": 3, "b": 5}))
46.2. tools.py #
smartchain/tools.py
# 导入抽象基类ABC和抽象方法装饰器abstractmethod
from abc import ABC, abstractmethod
# 从pydantic导入BaseModel,用于参数校验模型
from pydantic import BaseModel
# 定义通用工具基类,继承自ABC
class BaseTool(ABC):
"""
工具基类:
- name / description:工具元信息
- args_schema:可选的 Pydantic 输入模型,用于参数校验
- invoke/run:统一调用入口
"""
# 工具名称,默认为空字符串
name: str = ""
# 工具描述信息,默认为空字符串
description: str = ""
# 参数校验模型类型,默认为None,可选Pydantic模型
args_schema: type[BaseModel] | None = None
# 初始化方法,支持通过参数覆盖name/description/args_schema及扩展属性
def __init__(
self,
*,
name: str | None = None,
description: str | None = None,
args_schema: type[BaseModel] | None = None,
**kwargs,
) -> None:
# 如果传入了name,则覆盖类的默认name属性
if name is not None:
self.name = name
# 如果传入了description,则覆盖类的默认description属性
if description is not None:
self.description = description
# 如果传入了args_schema,则覆盖类的默认args_schema属性
if args_schema is not None:
self.args_schema = args_schema
# 循环设置其它自定义扩展属性
for k, v in kwargs.items():
setattr(self, k, v)
# 必须由子类实现的同步运行方法,未实现时报错
@abstractmethod
def _run(self, *args, **kwargs):
raise NotImplementedError
# 可选的异步运行方法,默认抛出未实现异常
async def _arun(self, *args, **kwargs):
raise NotImplementedError("该工具未实现异步调用")
# 工具统一同步调用入口
def invoke(self, input, **kwargs):
# 如果输入是字典,则使用解包传递参数
if isinstance(input, dict):
return self._run(**input, **kwargs)
# 否则直接传递单一参数
return self._run(input, **kwargs)
# run方法与invoke等价,也是同步调用入口
def run(self, *args, **kwargs):
return self._run(*args, **kwargs)
47.StructuredTool #
StructuredTool 是一个结构化工具类,用于简化工具(函数)与参数 Pydantic 模型的绑定,并自动生成参数验证逻辑。它通过 from_function 工厂方法,可以将任何具名函数与参数 schema(基于 Pydantic BaseModel)结合,自动完成工具说明与参数定义的提取和标准化。
主要特性:
- 结构化参数校验:通过 Pydantic BaseModel 定义输入参数类型和描述,自动校验输入参数的合法性。
- 灵活工具封装:支持将已有的 Python 函数快速封装为可被统一调度和调用的工具对象,工具参数和返回值类型清晰。
- 元数据管理:每个 StructuredTool 除了 name 和 description 外,还能统一管理参数模型(args_schema)结构,便于文档、交互、自动补全等场景使用。
- 同步调用接口:提供 invoke/run/_run 等方法,兼容同步参数、命名参数两种方式调用底层函数。
一般用法如下:
- 继承 Pydantic BaseModel 定义输入数据模型,包括每个参数的类型与说明。
- 编写实现具体功能的普通 Python 函数。
- 使用 StructuredTool.from_function 函数,将函数、名称、描述、参数模型等信息绑定为 StructuredTool 工具实例。
- 通过工具实例的 invoke 等方法传入参数字典即可方便、安全地调用底层函数,并得到校验后的结果。
这极大方便了工具链的封装、调用、参数自动校验与文档生成,是构建智能体/自动化链路中的推荐模式。
47.1. 47.StructuredTool.py #
47.StructuredTool.py
# 示例:使用 smartchain.tools.StructuredTool 定义加法工具
from pydantic import BaseModel, Field
#from langchain_core.tools import StructuredTool
from smartchain.tools import StructuredTool
class AddInput(BaseModel):
"""加法输入参数"""
a: int = Field(..., description="被加数")
b: int = Field(..., description="加数")
def add(a: int, b: int) -> int:
"""计算两个数的和"""
return a + b
if __name__ == "__main__":
# 通过 from_function 创建结构化工具,绑定参数模式
add_tool = StructuredTool.from_function(
func=add,
name="add",
description="计算两个数的和",
args_schema=AddInput,
)
print("工具名称:", add_tool.name)
print("工具描述:", add_tool.description)
print("参数模式:", add_tool.args_schema.model_json_schema()["properties"])
print("调用结果:", add_tool.invoke({"a": 3, "b": 5}))
47.2. tools.py #
smartchain/tools.py
# 导入抽象基类ABC和抽象方法装饰器abstractmethod
from abc import ABC, abstractmethod
+from inspect import signature, Parameter
# 从pydantic导入BaseModel,用于参数校验模型
+from pydantic import BaseModel, Field, create_model
# 定义通用工具基类,继承自ABC
class BaseTool(ABC):
"""
工具基类:
- name / description:工具元信息
- args_schema:可选的 Pydantic 输入模型,用于参数校验
- invoke/run:统一调用入口
"""
# 工具名称,默认为空字符串
name: str = ""
# 工具描述信息,默认为空字符串
description: str = ""
# 参数校验模型类型,默认为None,可选Pydantic模型
+ args_schema = None
# 初始化方法,支持通过参数覆盖name/description/args_schema及扩展属性
def __init__(
self,
*,
+ name = None,
+ description = None,
+ args_schema = None,
**kwargs,
) -> None:
# 如果传入了name,则覆盖类的默认name属性
if name is not None:
self.name = name
# 如果传入了description,则覆盖类的默认description属性
if description is not None:
self.description = description
# 如果传入了args_schema,则覆盖类的默认args_schema属性
if args_schema is not None:
self.args_schema = args_schema
# 循环设置其它自定义扩展属性
for k, v in kwargs.items():
setattr(self, k, v)
# 必须由子类实现的同步运行方法,未实现时报错
@abstractmethod
def _run(self, *args, **kwargs):
raise NotImplementedError
# 可选的异步运行方法,默认抛出未实现异常
async def _arun(self, *args, **kwargs):
raise NotImplementedError("该工具未实现异步调用")
# 工具统一同步调用入口
def invoke(self, input, **kwargs):
# 如果输入是字典,则使用解包传递参数
if isinstance(input, dict):
return self._run(**input, **kwargs)
# 否则直接传递单一参数
return self._run(input, **kwargs)
# run方法与invoke等价,也是同步调用入口
def run(self, *args, **kwargs):
return self._run(*args, **kwargs)
# 定义结构化工具类,继承自BaseTool
+class StructuredTool(BaseTool):
+ """
+ 结构化工具:
+ - 可绑定函数并使用 Pydantic 模型做参数校验
+ - 支持 from_function 创建
+ """
# 存储要绑定的函数,初始为None
+ func = None
# 初始化方法
+ def __init__(
+ self,
+ func, # 要绑定的函数
+ *,
+ name = None, # 工具名称
+ description = None, # 工具描述
+ args_schema = None, # 参数校验模型
+ **kwargs, # 其他扩展参数
+ ) -> None:
# 如果未提供name,则使用函数名
+ name = name or func.__name__
# 如果未提供description,则用函数的docstring
+ description = description or (func.__doc__ or "")
# 如果未提供参数模型,则根据函数签名自动推断模型
+ if args_schema is None:
+ args_schema = self._infer_schema_from_function(func, name)
# 调用父类构造方法,初始化相关属性
+ super().__init__(
+ name=name,#工具名称
+ description=description,#工具描述
+ args_schema=args_schema,#参数校验模型
+ **kwargs,#其他扩展参数
+ )
# 绑定要调用的函数
+ self.func = func#要绑定的函数
# 静态方法:通过函数签名推断Pydantic模型
+ @staticmethod
+ def _infer_schema_from_function(func, model_name):
+ """
+ 根据函数签名推断 Pydantic 模型。
+ - 无注解默认 str
+ - 有默认值则设置 Field(default=...)
+ - 无默认值则必填 Field(...)
+ """
# 获取函数的签名
+ sig = signature(func)
# 存储参数及其类型定义
+ fields = {}
# 遍历所有参数
+ for param_name, param in sig.parameters.items():
# 排除self和cls参数
+ if param_name in ("self", "cls"):
+ continue
# 获取参数注解类型,如果没有注解默认用str
+ param_type = param.annotation
+ if param_type == Parameter.empty:
+ param_type = str
# 如果有默认值,则用Field设置默认值,否则为必填
+ if param.default != Parameter.empty:
+ fields[param_name] = (param_type, Field(default=param.default))
+ else:
+ fields[param_name] = (param_type, Field(...))
# 如果没有参数,返回None
+ if not fields:
+ return None
# 动态生成Pydantic模型类型
+ return create_model(f"{model_name}Input", **fields)
# 实现同步运行工具的方法
+ def _run(self, *args, **kwargs):
+ """执行工具函数并校验输入"""
# 如果存在参数校验模型且是BaseModel的子类
+ if self.args_schema and isinstance(self.args_schema, type) and issubclass(self.args_schema, BaseModel):
# 如果传递了args参数
+ if args:
# 如果仅一个参数并且是字典,合并dict和kwargs
+ if len(args) == 1 and isinstance(args[0], dict):
+ input_dict = {**args[0], **kwargs}
+ else:
# 否则将args顺序匹配到schema字段名,再合并kwargs
+ field_names = list(self.args_schema.model_fields.keys())
+ input_dict = {field_names[i]: args[i] for i in range(min(len(args), len(field_names)))}
+ input_dict.update(kwargs)
+ else:
# 只用kwargs
+ input_dict = kwargs
# 用schema做参数校验
+ validated = self.args_schema(**input_dict)
# 调用绑定函数,并以模型校验后的数据作为参数
+ return self.func(**validated.model_dump())
# 如果没有schema校验,直接普通调用
+ return self.func(*args, **kwargs)
# 实现异步运行方法,默认调用同步方法
+ async def _arun(self, *args, **kwargs):
+ return self._run(*args, **kwargs)
# 类方法:通过函数快速生成StructuredTool实例
+ @classmethod
+ def from_function(
+ cls,
+ func, # 绑定的函数
+ *,
+ name = None, # 可选工具名
+ description = None, # 可选描述
+ args_schema = None, # 可选参数模型
+ **kwargs, # 其它参数
+ ):
# 创建并返回StructuredTool实例
+ return cls(
+ func=func,
+ name=name,
+ description=description,
+ args_schema=args_schema,
+ **kwargs,
+ )
48.tool #
48.tool 提供了一个便捷型 @tool 装饰器,它可以将普通 Python 函数一键包装为结构化工具对象(StructuredTool),自动绑定参数校验模型(Pydantic BaseModel),统一整理工具的名称、描述、参数和调用接口。
核心优势:
- 只需用
@tool或@tool(...)修饰函数,无需手动实例化 StructuredTool。 - 支持函数签名自动推断参数模型,也能通过
args_schema指定精细的 Pydantic 参数校验。 - 统一后的工具对象具有
.name,.description,.args_schema,.invoke()等标准属性,便于后续自动化处理、构建工具链和接口适配。 - 装饰器用法灵活,可
@tool、@tool(name=..., args_schema=...)多种模式。
应用场景:
- 快速把各类函数(如数学运算、接口调用、复杂业务逻辑)转为具备结构化接口的 AI 工具,方便 LLM/智能体编排与自动参数校验。
- 支持工具注册、自动文档生成、参数校验、前后端动态表单自动推断等全流程自动化。
tool极大简化了工具函数的标准化和集成,是现代 Python 智能体/自动化框架的重要基础设施。
48.tool
48.1. 48.tool.py #
48.tool.py
# 使用 tool 装饰器定义加法工具
from pydantic import BaseModel, Field
#from langchain_core.tools import tool
from smartchain.tools import tool
class AddInput(BaseModel):
"""加法输入参数"""
a: int = Field(..., description="被加数")
b: int = Field(..., description="加数")
@tool(args_schema=AddInput, description="计算两个数的和")
def add(a: int, b: int) -> int:
"""计算两个数的和"""
return a + b
if __name__ == "__main__":
print("工具名称:", add.name)
print("工具描述:", add.description)
print("参数模式:", add.args_schema.model_json_schema()["properties"])
print("调用结果:", add.invoke({"a": 3, "b": 5}))
48.2. tools.py #
smartchain/tools.py
# 导入抽象基类ABC和抽象方法装饰器abstractmethod
from abc import ABC, abstractmethod
from inspect import signature, Parameter
+from typing import Any, Callable
# 从pydantic导入BaseModel,用于参数校验模型
from pydantic import BaseModel, Field, create_model
# 定义通用工具基类,继承自ABC
class BaseTool(ABC):
"""
工具基类:
- name / description:工具元信息
- args_schema:可选的 Pydantic 输入模型,用于参数校验
- invoke/run:统一调用入口
"""
# 工具名称,默认为空字符串
name: str = ""
# 工具描述信息,默认为空字符串
description: str = ""
# 参数校验模型类型,默认为None,可选Pydantic模型
args_schema = None
# 初始化方法,支持通过参数覆盖name/description/args_schema及扩展属性
def __init__(
self,
*,
name = None,
description = None,
args_schema = None,
**kwargs,
) -> None:
# 如果传入了name,则覆盖类的默认name属性
if name is not None:
self.name = name
# 如果传入了description,则覆盖类的默认description属性
if description is not None:
self.description = description
# 如果传入了args_schema,则覆盖类的默认args_schema属性
if args_schema is not None:
self.args_schema = args_schema
# 循环设置其它自定义扩展属性
for k, v in kwargs.items():
setattr(self, k, v)
# 必须由子类实现的同步运行方法,未实现时报错
@abstractmethod
def _run(self, *args, **kwargs):
raise NotImplementedError
# 可选的异步运行方法,默认抛出未实现异常
async def _arun(self, *args, **kwargs):
raise NotImplementedError("该工具未实现异步调用")
# 工具统一同步调用入口
def invoke(self, input, **kwargs):
# 如果输入是字典,则使用解包传递参数
if isinstance(input, dict):
return self._run(**input, **kwargs)
# 否则直接传递单一参数
return self._run(input, **kwargs)
# run方法与invoke等价,也是同步调用入口
def run(self, *args, **kwargs):
return self._run(*args, **kwargs)
# 定义结构化工具类,继承自BaseTool
class StructuredTool(BaseTool):
"""
结构化工具:
- 可绑定函数并使用 Pydantic 模型做参数校验
- 支持 from_function 创建
"""
# 存储要绑定的函数,初始为None
func = None
# 初始化方法
def __init__(
self,
func, # 要绑定的函数
*,
name = None, # 工具名称
description = None, # 工具描述
args_schema = None, # 参数校验模型
**kwargs, # 其他扩展参数
) -> None:
# 如果未提供name,则使用函数名
name = name or func.__name__
# 如果未提供description,则用函数的docstring
description = description or (func.__doc__ or "")
# 如果未提供参数模型,则根据函数签名自动推断模型
if args_schema is None:
args_schema = self._infer_schema_from_function(func, name)
# 调用父类构造方法,初始化相关属性
super().__init__(
name=name,#工具名称
description=description,#工具描述
args_schema=args_schema,#参数校验模型
**kwargs,#其他扩展参数
)
# 绑定要调用的函数
self.func = func#要绑定的函数
# 静态方法:通过函数签名推断Pydantic模型
@staticmethod
def _infer_schema_from_function(func, model_name):
"""
根据函数签名推断 Pydantic 模型。
- 无注解默认 str
- 有默认值则设置 Field(default=...)
- 无默认值则必填 Field(...)
"""
# 获取函数的签名
sig = signature(func)
# 存储参数及其类型定义
fields = {}
# 遍历所有参数
for param_name, param in sig.parameters.items():
# 排除self和cls参数
if param_name in ("self", "cls"):
continue
# 获取参数注解类型,如果没有注解默认用str
param_type = param.annotation
if param_type == Parameter.empty:
param_type = str
# 如果有默认值,则用Field设置默认值,否则为必填
if param.default != Parameter.empty:
fields[param_name] = (param_type, Field(default=param.default))
else:
fields[param_name] = (param_type, Field(...))
# 如果没有参数,返回None
if not fields:
return None
# 动态生成Pydantic模型类型
return create_model(f"{model_name}Input", **fields)
# 实现同步运行工具的方法
def _run(self, *args, **kwargs):
"""执行工具函数并校验输入"""
# 如果存在参数校验模型且是BaseModel的子类
if self.args_schema and isinstance(self.args_schema, type) and issubclass(self.args_schema, BaseModel):
# 如果传递了args参数
if args:
# 如果仅一个参数并且是字典,合并dict和kwargs
if len(args) == 1 and isinstance(args[0], dict):
input_dict = {**args[0], **kwargs}
else:
# 否则将args顺序匹配到schema字段名,再合并kwargs
field_names = list(self.args_schema.model_fields.keys())
input_dict = {field_names[i]: args[i] for i in range(min(len(args), len(field_names)))}
input_dict.update(kwargs)
else:
# 只用kwargs
input_dict = kwargs
# 用schema做参数校验
validated = self.args_schema(**input_dict)
# 调用绑定函数,并以模型校验后的数据作为参数
return self.func(**validated.model_dump())
# 如果没有schema校验,直接普通调用
return self.func(*args, **kwargs)
# 实现异步运行方法,默认调用同步方法
async def _arun(self, *args, **kwargs):
return self._run(*args, **kwargs)
# 类方法:通过函数快速生成StructuredTool实例
@classmethod
def from_function(
cls,
func, # 绑定的函数
*,
name = None, # 可选工具名
description = None, # 可选描述
args_schema = None, # 可选参数模型
**kwargs, # 其它参数
):
# 创建并返回StructuredTool实例
return cls(
func=func,
name=name,
description=description,
args_schema=args_schema,
**kwargs,
)
# tool 装饰器:将函数快速转为 StructuredTool
+def tool(
+ name_or_callable=None,#工具名称或函数
+ *,
+ description=None,#工具描述
+ args_schema=None,#参数校验模型
+ **kwargs,#其他扩展参数
+):
+ """
+ 用法:
+ @tool
+ def f(x: int) -> int: ...
+ @tool("add", args_schema=Schema)
+ def add(a: int, b: int) -> int: ...
+ """
# 定义一个内部函数,用于实际创建 StructuredTool 实例
+ def _create_tool(func, override_name=None):
# 优先使用override_name,若无则用函数自身的名称
+ tool_name = override_name or func.__name__
# 优先使用外部description,否则使用函数文档字符串
+ tool_desc = description or (func.__doc__ or "")
# 调用 StructuredTool.from_function 创建工具实例
+ return StructuredTool.from_function(
+ func=func, # 要绑定的函数
+ name=tool_name, # 工具名称
+ description=tool_desc, # 工具描述
+ args_schema=args_schema, # 参数校验模型
+ **kwargs, # 其他扩展参数
+ )
# 装饰器无参数形式:@tool
+ if name_or_callable is None:
# 返回一个装饰器函数
+ def decorator(func):
# 装饰后直接创建 StructuredTool
+ return _create_tool(func)
+ return decorator
# 直接装饰函数且无参数:@tool 用于函数本身
+ if callable(name_or_callable) and hasattr(name_or_callable, "__name__"):
# 直接返回 StructuredTool 实例
+ return _create_tool(name_or_callable)
# 字符串形式:@tool("name", ...)
+ if isinstance(name_or_callable, str):
# 返回一个装饰器,设置工具名为指定字符串
+ def decorator(func):
+ return _create_tool(func, override_name=name_or_callable)
+ return decorator
# 不支持的参数类型,抛出异常
+ raise ValueError(f"tool 装饰器的参数类型不正确: {type(name_or_callable)}")
49.bind_tools #
bind_tools 方法和工具绑定机制允许你为大模型(LLM)绑定一组可调用的工具函数(如自定义 API、数据库查询等),使得模型能够在推理和对话中动态调用这些工具、执行操作并将结果反馈进对话流,实现“智能助理”般的增强能力。
一、bind_tools 方法是什么?
bind_tools 是 LLM 类(如 ChatDeepSeek、ChatOpenAI 等)的一个实例方法,其作用是将工具列表(可以是自定义的函数、继承自 BaseTool 的类等)与模型对象绑定,返回一个新的包装对象,随后你只需通过 invoke 或 stream 方法,即可实现带工具调用能力的对话。
用法示例:
from smartchain.chat_models import ChatDeepSeek
from smartchain.tools import tool
# 1. 使用装饰器注册一个工具
@tool
def say_hello(name: str) -> str:
return f"你好,{name}!"
# 2. 创建 LLM 实例并绑定工具
llm = ChatDeepSeek()
llm_tools = llm.bind_tools([say_hello])
# 3. 现在模型即可在对话中自动选择调用 say_hello 工具
response = llm_tools.invoke("请帮我向小明打个招呼")
print(response.content) # 可能输出:"你好,小明!"二、bind_tools 返回的是什么?
bind_tools返回一个 _BoundToolsLLM 对象,它包装了原始的 LLM 和工具列表。你可以像操作普通模型一样调用 .invoke(...) 或 .stream(...),模型会根据上下文自动决定是否调用某个(或多个)工具。
这个机制底层会把你的工具信息(名称、描述、参数类型等)打包成“大语言模型”能够理解的格式,模型需要时会以结构化的方式发起工具调用(常见于 OpenAI 的 Function Calling/Tools API、DeepSeek 的 Function call 等),然后框架替你路由到合适的 Python 函数或对象,并把结果反馈给模型,支持多轮、连续工具调用。
三、工具如何定义(与 tool 装饰器配合)?
推荐结合 @tool 装饰器来开发结构化工具。你可以为函数声明参数类型或 Pydantic 模型,工具会自动生成对应的参数说明表和校验逻辑,模型更易于正确调用。
例子如下:
from pydantic import BaseModel, Field
from smartchain.tools import tool
class AddInput(BaseModel):
a: int = Field(..., description="加数1")
b: int = Field(..., description="加数2")
@tool(args_schema=AddInput, description="计算两数之和")
def add(a: int, b: int) -> int:
return a + b你还可以直接用装饰器自动推断参数,无需声明 schema,只要参数类型清晰即可。
四、bind_tools 适合什么场景?
- LLM 需要调用外部工具、API、数据库或执行环境指令的场景
- 希望对话流程自动发现问题、动态调用一个或多个工具、并能把调用结果反馈给 LLM 剧本/链
- 支持复杂、多轮多工具连续调用(如先查天气再发提醒)
49.1. 49.bind_tools.py #
49.bind_tools.py
# 导入pydantic中的BaseModel和Field用于参数校验与注释
from pydantic import BaseModel, Field
# 从smartchain.tools导入tool装饰器
from smartchain.tools import tool
# 从smartchain.chat_models导入ChatOpenAI,用于实例化模型
from smartchain.chat_models import ChatOpenAI
# 从smartchain.messages导入ToolMessage,表示工具调用的消息
from smartchain.messages import ToolMessage
import os
os.environ["DEEPSEEK_API_KEY"] = "sk-c4e682d07ed643e0bce7bb66f24c5720"
# 定义加法输入参数的数据结构,继承自BaseModel
class AddInput(BaseModel):
"""加法输入参数"""
# 第一个被加数,使用pydantic的Field并加上描述说明
a: int = Field(..., description="被加数")
# 第二个加数,Field用于参数描述
b: int = Field(..., description="加数")
# 利用tool装饰器注册add函数为工具,指定输入模式以及描述
@tool(args_schema=AddInput, description="计算两个数的和")
def add(a: int, b: int) -> int:
"""计算两个数的和"""
# 打印加法操作的日志
print(f"计算两个数的和: {a} + {b}")
# 返回两个数的和
return a + b
# 定义一个支持多轮工具调用的对话函数
def chat_with_tools(llm_with_tools, initial_messages):
"""
支持多轮/多次连续工具调用,必要时二次/多次循环。
"""
# 构建工具名称到对象的映射字典
tool_map = {add.name: add}
# 把初始消息转为列表
messages = list(initial_messages)
# 开始循环直到模型不再要求调用工具
while True:
# 调用大模型接口并传入消息
resp = llm_with_tools.invoke(messages)
# 打印模型的响应内容
print("模型响应:", resp.content)
# 检查模型返回内容中是否包含tool_calls
if not getattr(resp, "tool_calls", None):
# 如果没有工具调用,表示已完成
print("最终回答:", resp.content)
return resp
# 工具调用的结果列表
tool_results = []
# 遍历每个工具调用请求
for tc in resp.tool_calls:
# 取出工具名
tool_name = tc["name"]
# 根据名称查找工具对象
tool_obj = tool_map.get(tool_name)
# 如果工具不存在,打印警告并跳过
if tool_obj is None:
print(f"未知工具: {tool_name}")
continue
# 调用工具并传入参数,获得结果
result = tool_obj.invoke(tc["args"])
# 打印工具调用及其结果
print(f"工具[{tool_name}]执行({tc['args']}): {result}")
# 用ToolMessage封装工具调用结果,并记下tool_call_id
tool_results.append(ToolMessage(str(result), tool_call_id=tc["id"]))
# 将原消息、最新模型响应、和工具结果一起组成下一轮消息
messages = messages + [resp] + tool_results
# 初始化ChatOpenAI大语言模型实例,指定使用的模型名
llm = ChatOpenAI(model="gpt-4o")
# 将add工具绑定到模型上,获得支持工具调用的llm对象
llm_with_tools = llm.bind_tools([add])
# 构造一轮初始人类消息,格式为元组(角色, 内容)
messages = [("human", "请帮我计算 3 和 5 的和,并调用合适的工具")]
# 调用对话函数,启动与模型的多轮交互
chat_with_tools(llm_with_tools, messages)49.2. chat_models.py #
smartchain/chat_models.py
# 导入操作系统相关模块
import os
# 导入 openai 模块
import openai
# 从 .messages 模块导入 AIMessage、HumanMessage 和 SystemMessage 类
+from .messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from .prompts import ChatPromptValue
# 定义与 OpenAI 聊天模型交互的类
class ChatOpenAI:
# 初始化方法
def __init__(self, model: str = "gpt-4o", **kwargs):
# 初始化 ChatOpenAI 类
"""
初始化 ChatOpenAI
Args:
model: 模型名称,如 "gpt-4o"
**kwargs: 其他参数(如 temperature, max_tokens 等)
"""
# 设置模型名称
self.model = model
# 获取 api_key,优先从参数获取,否则从环境变量获取
self.api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
# 如果没有提供 api_key,则抛出异常
if not self.api_key:
raise ValueError("需要提供 api_key 或设置 OPENAI_API_KEY 环境变量")
# 保存除 api_key 之外的其他参数,用于 API 调用
self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
# 创建 OpenAI 客户端实例
self.client = openai.OpenAI(api_key=self.api_key)
# 调用模型生成回复的方法
def invoke(self, input, **kwargs):
# 调用模型生成回复
"""
调用模型生成回复
Args:
input: 输入内容,可以是字符串或消息列表
**kwargs: 额外的 API 参数
Returns:
AIMessage: AI 的回复消息
"""
# 将输入数据转换为消息格式
messages = self._convert_input(input)
# 构建 API 请求参数字典
params = {
"model": self.model,
"messages": messages,
**self.model_kwargs,
**kwargs
}
# 使用 OpenAI 客户端发起 chat.completions.create 调用获取回复
response = self.client.chat.completions.create(**params)
# 取出返回结果中的第一个选项
choice = response.choices[0]
# 获取消息内容
content = choice.message.content or ""
# 创建 AIMessage 对象
+ msg = AIMessage(content=content)
# 如果有工具调用,解析并添加到消息中
+ if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls:
+ import json
+ tool_calls = []
+ for tc in choice.message.tool_calls:
+ try:
+ args = json.loads(tc.function.arguments) if tc.function.arguments else {}
+ except:
+ args = {}
+ tool_calls.append({
+ "id": tc.id,
+ "name": tc.function.name,
+ "args": args,
+ "type": "tool_call"
+ })
+ msg.tool_calls = tool_calls
# 返回 AIMessage 对象
+ return msg
# 流式调用模型生成回复的方法
def stream(self, input, **kwargs):
# 流式调用模型生成回复
"""
流式调用模型生成回复
Args:
input: 输入内容,可以是字符串或消息列表
**kwargs: 额外的 API 参数
Yields:
AIMessage: AI 的回复消息块(每次产生部分内容)
"""
# 将输入数据转换为消息格式
messages = self._convert_input(input)
# 构建 API 请求参数字典,启用流式输出
params = {
"model": self.model,
"messages": messages,
"stream": True, # 启用流式输出
**self.model_kwargs,
**kwargs
}
# 使用 OpenAI 客户端发起流式调用
stream = self.client.chat.completions.create(**params)
# 迭代流式响应
for chunk in stream:
# 检查是否有内容增量
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
# 检查 delta 中是否有 content,如果有则发送
if hasattr(delta, 'content') and delta.content:
# 产生包含部分内容的 AIMessage
yield AIMessage(content=delta.content)
# 内部方法,将输入转换为 OpenAI API 需要的消息格式
def _convert_input(self, input):
# 将输入转换为 OpenAI API 需要的消息格式
"""
将输入转换为 OpenAI API 需要的消息格式
Args:
input: 字符串、消息列表或 ChatPromptValue
Returns:
list[dict]: OpenAI API 格式的消息列表
"""
if isinstance(input, ChatPromptValue):
input = input.to_messages()
# 输入为字符串时,直接封装为用户角色消息
if isinstance(input, str):
return [{"role": "user", "content": input}]
# 如果输入是列表类型
elif isinstance(input, list):
# 新建一个空的消息列表
messages = []
# 遍历输入列表中的每一个元素
for msg in input:
# 判断是否为字符串,是则作为用户消息加入
if isinstance(msg, str):
messages.append({"role": "user", "content": msg})
# 判断是否为 HumanMessage、AIMessage、SystemMessage 或 ToolMessage 实例
+ elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage, ToolMessage)):
# 如果是 HumanMessage,将角色设为 user
if isinstance(msg, HumanMessage):
role = "user"
# 如果是 AIMessage,将角色设为 assistant
elif isinstance(msg, AIMessage):
role = "assistant"
# 如果 AIMessage 包含 tool_calls,需要转换为 OpenAI 格式
+ if hasattr(msg, "tool_calls") and msg.tool_calls:
# 构建包含 tool_calls 的 assistant 消息
+ tool_calls = []
+ for tc in msg.tool_calls:
+ tool_calls.append({
+ "id": tc.get("id", ""),
+ "type": "function",
+ "function": {
+ "name": tc.get("name", ""),
+ "arguments": str(tc.get("args", {}))
+ }
+ })
+ msg_dict = {
+ "role": role,
+ "content": msg.content if hasattr(msg, "content") else str(msg),
+ "tool_calls": tool_calls
+ }
+ messages.append(msg_dict)
+ continue
# 如果是 SystemMessage,将角色设为 system
elif isinstance(msg, SystemMessage):
role = "system"
# 如果是 ToolMessage,将角色设为 tool
+ elif isinstance(msg, ToolMessage):
+ role = "tool"
# 获取消息内容(有 content 属性则取 content,否则转为字符串)
content = msg.content if hasattr(msg, "content") else str(msg)
# 构建消息字典
+ msg_dict = {"role": role, "content": content}
# 如果是 ToolMessage,添加 tool_call_id
+ if isinstance(msg, ToolMessage) and hasattr(msg, "tool_call_id") and msg.tool_call_id:
+ msg_dict["tool_call_id"] = msg.tool_call_id
# 将消息添加到消息列表
+ messages.append(msg_dict)
# 如果元素本身为字典,直接添加进消息列表
elif isinstance(msg, dict):
# 直接添加字典类型的消息
messages.append(msg)
# 如果元素为长度为 2 的元组,将其解包为 role 和 content
elif isinstance(msg, tuple) and len(msg) == 2:
# 将元组解包为 role 和 content
role, content = msg
# 将 "human" 映射为 "user"(OpenAI API 要求)
+ if role == "human":
+ role = "user"
# 将角色和内容添加到消息列表
messages.append({"role": role, "content": content})
# 返回构建好的消息列表
return messages
else:
# 其他输入类型,转为字符串作为 user 消息
return [{"role": "user", "content": str(input)}]
# 定义可配置字段的方法,用于包装当前实例,支持部分参数运行时动态调整
def configurable_fields(self, **fields):
# 配置可动态调整的字段
"""
配置可动态调整的字段
Args:
**fields: 可配置字段的字典,键为字段名,值为 ConfigurableField 实例
Returns:
RunnableConfigurableFields: 包装后的 Runnable 实例
示例:
``python
from smartchain.runnables import ConfigurableField
llm = ChatDeepSeek(temperature=0).configurable_fields(
temperature=ConfigurableField(
id="temperature",
name="温度值",
description="LLM 的采样温度参数"
)
)
# 使用默认 temperature=0
result1 = llm.invoke("你好")
# 使用 temperature=1.0
result2 = llm.invoke("你好", config={"configurable": {"temperature": 1.0}})
``
"""
# 从当前目录导入 RunnableConfigurableFields 类
from .runnables import RunnableConfigurableFields
# 返回将当前实例 self 及配置字段 fields 包装后的新对象
return RunnableConfigurableFields(default=self, fields=fields)
# 定义方法,用于根据 selector_field 和 config 动态切换不同 runnable 分支
def configurable_alternatives(self, selector_field, *, default_key, **alternatives):
# """
# 配置可替代的 Runnable 选项,根据 config["configurable"] 动态切换
#
# Args:
# selector_field: ConfigurableField,定义选择键的 id/name/description
# default_key: 默认使用的分支 key(必须存在于 alternatives 中)
# **alternatives: key -> runnable 或具有 invoke 方法的对象
#
# Returns:
# RunnableConfigurableAlternatives: 包装后的 Runnable 实例
#
# 示例:
# from smartchain.runnables import ConfigurableField
# from smartchain.chat_models import ChatDeepSeek, ChatDeepSeek
#
# llm = ChatDeepSeek().configurable_alternatives(
# ConfigurableField(
# id="provider",
# name="LLM 提供方",
# description="在 OpenAI 与 DeepSeek 之间切换"
# ),
# default_key="openai",
# openai=ChatDeepSeek(temperature=0),
# deepseek=ChatDeepSeek(temperature=0),
# )
#
# # 默认使用 openai
# result1 = llm.invoke("你好")
#
# # 切换为 deepseek
# result2 = llm.invoke("你好", config={"configurable": {"provider": "deepseek"}})
# """
# 从 .runnables 模块导入 RunnableConfigurableAlternatives 类
from .runnables import RunnableConfigurableAlternatives
# 返回 RunnableConfigurableAlternatives 实例,实现 runtime 动态分支切换
return RunnableConfigurableAlternatives(
selector_field=selector_field, # 用于选择分支的字段信息
default_key=default_key, # 默认分支 key
alternatives=alternatives, # 所有可供切换的分支字典
)
# 绑定工具到模型的方法
+ def bind_tools(self, tools, **kwargs):
+ """
+ 绑定工具到模型,返回一个支持工具调用的包装对象
+ Args:
+ tools: 工具列表,可以是 BaseTool 实例或函数
+ **kwargs: 其他参数
+ Returns:
+ _BoundToolsLLM: 包装后的 LLM 对象,支持工具调用
+ """
+ return _BoundToolsLLM(self, tools, **kwargs)
# 绑定工具的 LLM 包装类
+class _BoundToolsLLM:
# 绑定工具后的 LLM 包装类文档字符串
+ """绑定工具后的 LLM 包装类"""
# 构造方法,初始化绑定工具的 LLM
+ def __init__(self, llm, tools, **kwargs):
+ """
+ 初始化绑定工具的 LLM
+ Args:
+ llm: 原始 LLM 实例
+ tools: 工具列表
+ **kwargs: 其他参数
+ """
# 保存原始 LLM 实例
+ self.llm = llm
# 保存工具列表
+ self.tools = tools
# 保存可能的其他参数
+ self.kwargs = kwargs
# 将工具转换为 OpenAI 格式的工具描述列表
+ self.openai_tools = [self._tool_to_openai(tool) for tool in tools]
# 私有方法:将工具对象转换成 OpenAI 格式
+ def _tool_to_openai(self, tool):
+ """
+ 将工具转换为 OpenAI 格式
+ Args:
+ tool: BaseTool 实例或函数
+ Returns:
+ dict: OpenAI 格式的工具定义
+ """
# 判断是否为 BaseTool 实例(通常有 name、description、args_schema 属性)
+ if hasattr(tool, 'name') and hasattr(tool, 'description') and hasattr(tool, 'args_schema'):
# 有参数模式时,优先尝试 model_json_schema 方法
+ if tool.args_schema:
+ if hasattr(tool.args_schema, 'model_json_schema'):
# 使用 pydantic 的 model_json_schema 生成 schema
+ schema = tool.args_schema.model_json_schema()
+ else:
# 若没有此方法直接使用 args_schema 本身
+ schema = tool.args_schema
+ else:
# 无参数时返回空参数对象
+ schema = {"type": "object", "properties": {}}
# 按 OpenAI 工具格式规范组织
+ return {
+ "type": "function",
+ "function": {
+ "name": tool.name,
+ "description": tool.description,
+ "parameters": schema
+ }
+ }
+ else:
# 如果不是 BaseTool 实例,则尝试处理为已装饰的函数
+ if hasattr(tool, 'name'):
# 获取工具名
+ name = tool.name
# 获取工具描述,如果无则给空字符串
+ description = getattr(tool, 'description', '')
# 获取参数模式
+ args_schema = getattr(tool, 'args_schema', None)
# 若可用则调用 model_json_schema 方法
+ if args_schema and hasattr(args_schema, 'model_json_schema'):
+ schema = args_schema.model_json_schema()
+ else:
# 否则返回空对象参数
+ schema = {"type": "object", "properties": {}}
# 返回 OpenAI 格式定义
+ return {
+ "type": "function",
+ "function": {
+ "name": name,
+ "description": description,
+ "parameters": schema
+ }
+ }
+ else:
# 无法处理未知的工具类型时报错
+ raise ValueError(f"无法转换工具: {tool}")
# 调用模型,支持工具调用
+ def invoke(self, input, **kwargs):
+ """
+ 调用模型,支持工具调用
+ Args:
+ input: 输入内容
+ **kwargs: 其他参数
+ Returns:
+ AIMessage: AI 的回复消息,可能包含 tool_calls
+ """
# 合并初始化时和调用时传入的参数
+ merged_kwargs = {**self.kwargs, **kwargs}
# 指定 tools 参数为已转换的 openai_tools 列表
+ merged_kwargs["tools"] = self.openai_tools
# 执行原始 LLM 的 invoke 方法,并返回结果
+ return self.llm.invoke(input, **merged_kwargs)
# 定义与 DeepSeek 聊天模型交互的类
class ChatDeepSeek:
# 初始化方法
# model: 模型名称,默认为 "deepseek-chat"
# **kwargs: 其他可选参数(如 temperature, max_tokens 等)
def __init__(self, model: str = "deepseek-chat", **kwargs):
"""
初始化 ChatDeepSeek
Args:
model: 模型名称,如 "deepseek-chat"
**kwargs: 其他参数(如 temperature, max_tokens 等)
"""
# 设置模型名称
self.model = model
# 获取 api_key,优先从参数获取,否则从环境变量获取
self.api_key = kwargs.get("api_key") or os.getenv("DEEPSEEK_API_KEY")
# 如果没有提供 api_key,则抛出异常
if not self.api_key:
raise ValueError("需要提供 api_key 或设置 DEEPSEEK_API_KEY 环境变量")
# 保存除 api_key 之外的其他参数,用于 API 调用
self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
# 获取 DeepSeek 的 base_url,默认为官方地址
base_url = kwargs.get("base_url", "https://api.deepseek.com/v1")
# 创建 OpenAI 兼容的客户端实例(DeepSeek 使用 OpenAI 兼容的 API)
self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)
# 调用模型生成回复的方法
# input: 输入内容,可以是字符串或消息列表
# **kwargs: 额外的 API 参数
def invoke(self, input, **kwargs):
"""
调用模型生成回复
Args:
input: 输入内容,可以是字符串或消息列表
**kwargs: 额外的 API 参数
Returns:
AIMessage: AI 的回复消息
"""
# 将输入数据转换为消息格式
messages = self._convert_input(input)
# 构建 API 请求参数字典
params = {
"model": self.model,
"messages": messages,
**self.model_kwargs,
**kwargs
}
# 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
response = self.client.chat.completions.create(**params)
# 取出返回结果中的第一个选项
choice = response.choices[0]
# 获取消息内容
content = choice.message.content or ""
# 返回一个 AIMessage 对象
return AIMessage(content=content)
# 内部方法,将输入转换为 API 需要的消息格式
# input: 字符串、消息列表或 ChatPromptValue
def _convert_input(self, input):
"""
将输入转换为 API 需要的消息格式
Args:
input: 字符串、消息列表或 ChatPromptValue
Returns:
list[dict]: API 格式的消息列表
"""
# 如果输入是字符串,直接作为用户消息
if isinstance(input, str):
return [{"role": "user", "content": input}]
else:
# 其他输入类型,转为字符串作为 user 消息
return [{"role": "user", "content": str(input)}]
# 定义与通义千问(Tongyi)聊天模型交互的类
class ChatTongyi:
# 初始化方法
# 初始化方法,设置模型名称和 API 相关参数
def __init__(self, model: str = "qwen-max", **kwargs):
"""
初始化 ChatTongyi
Args:
model: 模型名称,如 "qwen-max"
**kwargs: 其他参数(如 temperature, max_tokens 等)
"""
# 设置模型名称
self.model = model
# 获取 api_key,优先从参数获取,否则从环境变量获取
self.api_key = kwargs.get("api_key") or os.getenv("DASHSCOPE_API_KEY")
# 如果没有提供 api_key,则抛出异常
if not self.api_key:
raise ValueError("需要提供 api_key 或设置 DASHSCOPE_API_KEY 环境变量")
# 保存除 api_key 之外的其他参数,用于 API 调用
self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
# 获取通义千问的 API base URL(使用 OpenAI 兼容模式),如果未指定则使用默认值
base_url = kwargs.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
# 创建 OpenAI 兼容的客户端实例(通义千问使用 OpenAI 兼容的 API)
self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)
# 调用模型生成回复的方法
# 调用模型生成回复,返回 AIMessage 对象
def invoke(self, input, **kwargs):
"""
调用模型生成回复
Args:
input: 输入内容,可以是字符串或消息列表
**kwargs: 额外的 API 参数
Returns:
AIMessage: AI 的回复消息
"""
# 将输入数据转换为消息格式
messages = self._convert_input(input)
# 构建 API 请求参数字典,包含模型名、消息内容和其他参数
params = {
"model": self.model,
"messages": messages,
**self.model_kwargs,
**kwargs
}
# 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用以获取回复
response = self.client.chat.completions.create(**params)
# 取出返回结果中的第一个回复选项
choice = response.choices[0]
# 获取回复的消息内容,如果内容不存在则返回空字符串
content = choice.message.content or ""
# 构建并返回一个 AIMessage 对象
return AIMessage(content=content)
# 内部方法,将输入转换为 API 需要的消息格式
# 支持字符串、消息列表等输入,统一包装为 OpenAI API 格式
def _convert_input(self, input):
"""
将输入转换为 API 需要的消息格式
Args:
input: 字符串、消息列表或 ChatPromptValue
Returns:
list[dict]: API 格式的消息列表
"""
# 如果输入是字符串,直接包装为“用户”角色的消息
if isinstance(input, str):
return [{"role": "user", "content": input}]
else:
# 其他输入类型,转换为字符串作为“用户”消息内容
return [{"role": "user", "content": str(input)}] 49.3. messages.py #
smartchain/messages.py
# 定义基础消息类
class BaseMessage:
# 基础消息类的文档字符串
"""基础消息类"""
# 初始化方法,content为消息内容,其余为可选参数
def __init__(self, content: str, **kwargs):
# 初始化消息
# content: 消息内容
# **kwargs: 其他可选参数
self.content = content # 保存消息内容
self.type = kwargs.get("type", "base") # 获取消息类型,默认为"base"
for key, value in kwargs.items(): # 遍历所有其他参数
if key != "type": # 排除type参数
setattr(self, key, value) # 设置为成员变量
# 定义当对象被str()或print时的输出内容
def __str__(self):
return self.content # 返回消息内容
# 定义对象的官方字符串表示,用于debug
def __repr__(self):
return f"{self.__class__.__name__}(content={self.content!r})" # 返回类名和内容
# 定义用户消息类,继承自BaseMessage
class HumanMessage(BaseMessage):
# 用户消息类的文档字符串
"""用户消息"""
# 初始化方法,调用父类构造方法,并指定type为"human"
def __init__(self, content: str, **kwargs):
# 调用父类构造方法,并固定type为"human"
super().__init__(content, type="human", **kwargs)
# 定义AI消息类,继承自BaseMessage
class AIMessage(BaseMessage):
# AI消息类的文档字符串
"""AI 消息"""
# 初始化方法,调用父类构造方法,并指定type为"ai"
def __init__(self, content: str, **kwargs):
# 调用父类构造方法,并固定type为"ai"
super().__init__(content, type="ai", **kwargs)
# 定义一个继承自BaseMessage的系统消息类
class SystemMessage(BaseMessage):
# 说明这是系统消息的类
"""系统消息"""
# 初始化方法,content是消息内容,**kwargs为其他可选参数
def __init__(self, content: str, **kwargs):
# 调用父类BaseMessage的初始化方法,并将type参数固定为"system"
super().__init__(content, type="system", **kwargs)
# 定义一个工具消息类,继承自BaseMessage,用于表示工具执行结果的消息
+class ToolMessage(BaseMessage):
# 设置类的文档字符串,说明此类用于表示工具执行结果消息
+ """工具执行结果消息"""
# 初始化方法,接收内容、工具调用ID和其他可选参数
+ def __init__(self, content: str, tool_call_id: str = None, **kwargs):
# 调用父类BaseMessage的初始化方法,同时固定type为"tool"
+ super().__init__(content, type="tool", **kwargs)
# 将传入的工具调用ID保存为实例属性
+ self.tool_call_id = tool_call_id
49.4. runnables.py #
smartchain/runnables.py
# 导入抽象基类 (ABC: 抽象基类基类,abstractmethod: 用于定义抽象方法)
from abc import ABC, abstractmethod
import time
import random
import inspect
import uuid as uuid_module
from .config import ensure_config
+from .messages import ToolMessage, AIMessage
# 定义 Runnable 抽象基类,所有可运行单元必须继承它
class Runnable(ABC):
"""
Runnable 抽象基类
所有可运行组件的基础接口,定义了统一的调用方法。
"""
# 抽象方法,子类必须实现,用于同步调用
@abstractmethod
def invoke(self, input, config = None, **kwargs):
"""
同步调用 Runnable
Args:
input: 输入值
config: 可选的配置字典
**kwargs: 额外的关键字参数
Returns:
输出值
"""
pass # 仅做接口规范,子类务必实现
def stream(self, input, config = None, **kwargs):
"""
流式调用 Runnable
默认实现:先调用 invoke,若返回可迭代且不是字符串/字节/字典,则逐项 yield;
否则直接 yield 单值。
"""
result = self.invoke(input, config=config, **kwargs)
# 字符串/字节/字典不视为流式可迭代,直接返回单值
if hasattr(result, "__iter__") and not isinstance(result, (str, bytes, dict)):
for item in result:
yield item
else:
yield result
# 定义可配置替代分支选择器方法,通过 config["configurable"][field.id] 动态切换分支
def configurable_alternatives(self, selector_field, *, default_key, **alternatives):
"""
根据 config["configurable"] 中的选择键,动态切换不同的 Runnable/对象。
Args:
selector_field: ConfigurableField,定义选择键的 id/name/description
default_key: 默认使用的分支 key(必须存在于 alternatives 中)
**alternatives: key -> runnable 或具有 invoke 方法的对象
Returns:
RunnableConfigurableAlternatives 包装对象
"""
# 从当前模块导入 ConfigurableField 和 RunnableConfigurableAlternatives
from .runnables import ConfigurableField, RunnableConfigurableAlternatives
# 判断 selector_field 是否为 ConfigurableField 的实例
if not isinstance(selector_field, ConfigurableField):
# 如果不是则抛出类型错误
raise TypeError("selector_field 必须是 ConfigurableField 实例")
# 检查默认分支 key 是否包含在 alternatives 中
if default_key not in alternatives:
# 如果不包含则抛出值错误
raise ValueError("default_key 必须存在于 alternatives 中")
# 返回一个 RunnableConfigurableAlternatives 实例,实现动态分支选择
return RunnableConfigurableAlternatives(
selector_field=selector_field,
default_key=default_key,
alternatives=alternatives,
)
# 管道操作符,便于链式拼接
def __or__(self, other):
if not isinstance(other, Runnable):
raise TypeError("管道右侧必须是 Runnable 实例")
return RunnableSequence([self, other])
# 定义批量调用方法,默认实现为遍历输入逐个调用 invoke
def batch(self, inputs, config = None, **kwargs):
"""
批量调用 Runnable
Args:
inputs: 输入值列表
config: 可选的配置字典
**kwargs: 额外的关键字参数
Returns:
输出值列表
"""
# 对每个输入项都调用 invoke,并收集结果
return [self.invoke(input_item, config=config, **kwargs) for input_item in inputs]
# 添加重试功能,返回包装了重试逻辑的 Runnable
# 定义 with_retry 方法,为当前 Runnable 添加重试机制
def with_retry(
self,
*,
retry_if_exception_type=(Exception,), # 指定需要重试的异常类型,默认所有 Exception
stop_after_attempt=3, # 最大尝试次数,默认3次
wait_exponential_jitter=True, # 是否启用指数退避抖动
exponential_jitter_params=None, # 抖动参数字典,支持 initial/max/exp_base/jitter
):
"""
创建带重试功能的 Runnable 包装器
Args:
retry_if_exception_type: 需要重试的异常类型元组
stop_after_attempt: 最大尝试次数
wait_exponential_jitter: 是否启用指数回退抖动
exponential_jitter_params: 抖动参数,支持 initial/max/exp_base/jitter
Returns:
包装了重试逻辑的 RunnableRetry 实例
"""
# 返回带重试功能的 RunnableRetry 实例,绑定当前 runnable 和重试参数
return RunnableRetry(
bound=self,
retry_if_exception_type=retry_if_exception_type,
stop_after_attempt=stop_after_attempt,
wait_exponential_jitter=wait_exponential_jitter,
exponential_jitter_params=exponential_jitter_params,
)
def with_config(self, config=None, **kwargs):
"""
绑定配置到 Runnable,返回一个新的 Runnable
Args:
config: 要绑定的配置字典
**kwargs: 额外的关键字参数,会合并到 config 中
Returns:
一个新的 RunnableBinding 实例,包含绑定的配置
"""
# 合并 config 和 kwargs
merged_config = {}
if config:
merged_config.update(config)
if kwargs:
merged_config.update(kwargs)
# 返回 RunnableBinding 实例
return RunnableBinding(bound=self, config=merged_config)
# 定义 RunnableLambda 类,用于将普通 Python 函数封装为 Runnable 对象
class RunnableLambda(Runnable):
"""
RunnableLambda 将普通 Python 函数包装成 Runnable
这使得普通函数可以在链式调用中使用,并支持统一的 invoke 接口。
示例:
``python
def add_one(x: int) -> int:
return x + 1
runnable = RunnableLambda(add_one)
result = runnable.invoke(5) # 返回 6
results = runnable.batch([1, 2, 3]) # 返回 [2, 3, 4]
``
"""
# 初始化方法,接收一个函数和可选的名称
def __init__(self, func, name=None):
"""
初始化 RunnableLambda
Args:
func: 要包装的函数
name: Runnable 的名称(可选,默认使用函数名)
"""
# 检查传入的 func 是否为可调用对象
if not callable(func):
raise TypeError(f"func 必须是可调用对象,但得到了 {type(func)}")
# 保存待封装的函数
self.func = func
# 如果 name 明确传入则使用,否则合理推断
if name is not None:
self.name = name
else:
try:
# 尽量用函数原名,如果是 lambda 就命名为 "lambda"
self.name = func.__name__ if func.__name__ != "<lambda>" else "lambda"
except AttributeError:
# 对于匿名对象无法获取 __name__ 时兜底
self.name = "runnable"
# 实现 invoke 方法,对被封装的底层函数进行同步调用
def invoke(self, input, config = None, **kwargs):
"""
调用包装的函数
Args:
input: 输入值
config: 可选的配置字典
**kwargs: 额外的关键字参数(会传递给函数)
Returns:
函数的返回值
"""
# 保证 config 不为 None,如为 None 则转为空字典
config = ensure_config(config)
# 从配置字典中获取回调对象 callbacks
callbacks = config.get("callbacks")
# 初始化回调对象列表
callback_list = []
# 获取当前调用的唯一 ID(run_id)
run_id = config.get("run_id")
# 如果没有传入 run_id,则自动生成一个新的 uuid
if run_id is None:
run_id = uuid_module.uuid4()
# 如果 callbacks 不为空
if callbacks:
# 如果 callbacks 已经是列表,则直接用,否则转为单元素列表
if isinstance(callbacks, list):
callback_list = callbacks
else:
callback_list = [callbacks]
# 构造序列化信息,用于回调上报链条标识
serialized = {"name": self.name, "type": "RunnableLambda"}
# 遍历每个回调对象,触发其 on_chain_start 方法
for callback in callback_list:
# 只有回调对象有 on_chain_start 属性才调用
if hasattr(callback, "on_chain_start"):
try:
# 调用回调的 on_chain_start 方法,传入相关参数
callback.on_chain_start(
serialized=serialized,
inputs={"input": input},
run_id=run_id,
parent_run_id=None,
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs
)
except Exception:
# 回调过程中如出现异常则忽略,确保主流程不会终止
pass
# 检查被包装的函数是否接收 config 参数
if _accept_config(self.func):
# 如果接收 config,则将 config 传递下去
kwargs["config"] = config
# 尝试执行被包装的核心函数
try:
# 正常调用被包装的函数,将 input 作为第一个参数,kwargs作为关键字参数
output = self.func(input, **kwargs)
except Exception as e:
# 若捕获到异常,则对所有回调触发 on_chain_error 并继续抛出异常
if callback_list:
for callback in callback_list:
if hasattr(callback, "on_chain_error"):
try:
callback.on_chain_error(
error=e,
run_id=run_id,
parent_run_id=None,
**kwargs
)
except Exception:
# 回调异常不影响主异常继续抛出
pass
# 重新抛出主流程中的异常
raise
else:
# 如果没有异常执行,顺序触发所有回调的 on_chain_end 方法
if callback_list:
for callback in callback_list:
if hasattr(callback, "on_chain_end"):
try:
callback.on_chain_end(
outputs={"output": output},
run_id=run_id,
parent_run_id=None,
**kwargs
)
except Exception:
# 回调异常不影响主逻辑输出
pass
# 返回包装函数的输出结果
return output
# 批量调用内部依然调用 invoke,保证与 Runnable 基类一致
def batch(self, inputs, config = None, **kwargs):
"""
批量调用包装的函数
Args:
inputs: 输入值列表
config: 可选的配置字典
**kwargs: 额外的关键字参数
Returns:
输出值列表
"""
# 调用 invoke 实现批量处理
return [self.invoke(input_item, config=config, **kwargs) for input_item in inputs]
# 流式调用:直接复用基类的流式封装
def stream(self, input, **kwargs):
"""
流式调用包装的函数
对返回生成器/可迭代对象逐项 yield;若是单值则直接 yield。
"""
yield from super().stream(input, **kwargs)
# 返回对象自身的字符串表达,便于排查与日志
def __repr__(self):
"""返回 RunnableLambda 的字符串表示"""
return f"RunnableLambda(func={self.name})"
# 定义 RunnableParallel,继承自 Runnable
class RunnableParallel(Runnable):
"""
并行执行多个 Runnable,返回字典结果。
使用示例:
parallel = RunnableParallel(a=r1, b=r2)
result = parallel.invoke(input) # {"a": ..., "b": ...}
"""
# 构造方法,接收若干个可运行对象作为关键字参数
def __init__(self, **runnables):
# 如果未传递任何 runnable,则报错
if not runnables:
raise ValueError("至少需要一个 runnable")
# 检查每个传入的值是否为 Runnable 实例
for name, r in runnables.items():
if not isinstance(r, Runnable):
raise TypeError(f"键 {name} 的值必须是 Runnable 实例")
# 保存所有传入的 runnable 到实例属性
self.runnables = runnables
# 同步调用,将相同输入传递给所有子 runnable,并收集结果为字典
def invoke(self, input, config = None, **kwargs):
"""
同一输入传给所有子 runnable,收集结果为字典。
"""
# 遍历每个 runnable,调用其 invoke,结果收集为 {name: 返回值}
return {name: r.invoke(input, config=config, **kwargs) for name, r in self.runnables.items()}
# 批量调用,对输入列表每一项都运行 invoke,返回结果字典的列表
def batch(self, inputs, config = None, **kwargs):
"""
对输入列表逐项并行处理,返回字典列表。
"""
# 对每个输入元素调用 invoke,收集所有结果
return [self.invoke(item, config=config, **kwargs) for item in inputs]
# 流式调用,直接调用父类的流式实现
def stream(self, input, **kwargs):
"""
对单次输入执行并返回一个字典,流式单次产出。
"""
# 复用基类的 stream 方法
yield from super().stream(input, **kwargs)
# 返回对象的字符串表示(列出包含的所有子 runnable 的键名)
def __repr__(self):
# 拼接所有 runnable 的键名
keys = ", ".join(self.runnables.keys())
# 返回格式化字符串
return f"RunnableParallel({keys})"
# 定义RunnableBranch类,继承自Runnable,用于条件分支执行不同runnable
class RunnableBranch(Runnable):
"""
条件分支执行:按顺序检查条件,匹配则运行对应 runnable,若都不匹配则走默认分支。
"""
# 构造方法,接受若干分支参数
def __init__(self, *branches):
"""
支持“默认分支作为最后一个位置参数”的用法:
RunnableBranch((cond1, r1), (cond2, r2), default_runnable)
"""
# 分支数量必须至少2(至少一个条件+一个默认)
if len(branches) < 2:
raise ValueError("至少需要一个条件分支和一个默认分支")
# 将分支参数转为列表
branches_list = list(branches)
# 最后一个参数视为默认分支
default = branches_list.pop() # 最后一个位置参数为默认分支
# 校验每个分支
validated_branches = []
for item in branches_list:
# 每个分支需为二元组或二元列表
if not (isinstance(item, (tuple, list)) and len(item) == 2):
raise TypeError("分支必须是 (condition, runnable) 形式的二元组")
# 解包条件函数和runnable
cond, runnable = item
# 条件必须为可调用对象
if not callable(cond):
raise TypeError("分支条件必须是可调用对象")
# runnable必须是Runnable实例
if not isinstance(runnable, Runnable):
raise TypeError("分支 runnable 必须是 Runnable 实例")
# 校验通过则加入分支列表
validated_branches.append((cond, runnable))
# 校验默认分支必须为Runnable实例
if not isinstance(default, Runnable):
raise TypeError("默认分支必须是 Runnable 实例")
# 保存所有条件分支
self.branches = validated_branches
# 保存默认分支
self.default = default
# 单个输入同步调用方法
def invoke(self, input, config = None, **kwargs):
"""
按顺序匹配条件,命中即执行对应 runnable;否则走默认分支。
"""
# 遍历所有分支,遇到条件命中则执行对应runnable
for cond, runnable in self.branches:
if cond(input):
return runnable.invoke(input, config=config, **kwargs)
# 如果有默认分支则执行默认runnable
if self.default is not None:
return self.default.invoke(input, config=config, **kwargs)
# 无匹配分支时报错
raise ValueError("未匹配到任何分支,且未提供默认分支")
# 批量调用,遍历输入批量执行invoke
def batch(self, inputs, config = None, **kwargs):
# 对输入列表逐一执行invoke
return [self.invoke(item, config=config, **kwargs) for item in inputs]
# 流式调用,直接调用父类的stream方法
def stream(self, input, **kwargs):
# 复用父类的流式实现
yield from super().stream(input, **kwargs)
# 返回对象简洁字符串表示
def __repr__(self):
# 拼接分支编号
parts = [f"branch{idx}" for idx, _ in enumerate(self.branches)]
# 若有默认分支则拼接default字符串
if self.default:
parts.append("default")
# 格式化输出
return f"RunnableBranch({', '.join(parts)})"
class RunnablePassthrough(Runnable):
"""
直通型 Runnable:原样返回输入,不做任何处理。
可用于调试或需要保留原始输入的场景。
"""
def invoke(self, input, config = None, **kwargs):
return input
def batch(self, inputs, config = None, **kwargs):
return list(inputs)
def stream(self, input, **kwargs):
# 复用基类流式封装(对单值直接 yield)
yield from super().stream(input, **kwargs)
def __repr__(self):
return "RunnablePassthrough()"
# 定义 RunnableSequence 类,用于实现可运行对象的链式组合(A | B | C 的效果)
class RunnableSequence(Runnable):
"""
Runnable 组合序列,用于支持 A | B | C 的链式拼接。
"""
# 初始化方法,接收一个 Runnable 对象的列表
def __init__(self, runnables):
# 检查传入的 runnables 列表不能为空
if not runnables:
raise ValueError("runnables 不能为空")
# 校验每一个元素都必须是 Runnable 实例
for r in runnables:
if not isinstance(r, Runnable):
raise TypeError("runnables 需全部为 Runnable 实例")
# 保存连成链的 runnable 组件
self.runnables = runnables
# 实现管道操作符 |,使链式拼接成立
def __or__(self, other):
# 右侧对象必须也是 Runnable 实例
if not isinstance(other, Runnable):
raise TypeError("管道右侧必须是 Runnable 实例")
# 返回新的组合链(原有链 + 新加的 runnable)
return RunnableSequence(self.runnables + [other])
# 调用链的同步调用,将输入依次传过所有组件
def invoke(self, input, config = None, **kwargs):
"""
逐个执行链条:上一步输出作为下一步输入。
"""
# 确保 config 存在
config = ensure_config(config)
# 处理回调:如果有 callbacks,则触发链的开始回调
callbacks = config.get("callbacks")
# 初始化回调列表
callback_list = []
# 获取 run_id
run_id = config.get("run_id")
# 如果 run_id 为 None,则生成一个新的 uuid
if run_id is None:
run_id = uuid_module.uuid4()
# 如果 callbacks 不为空
if callbacks:
# 如果 callbacks 是列表,则直接赋值给 callback_list
if isinstance(callbacks, list):
callback_list = callbacks
# 如果 callbacks 不是列表,则转换为单元素列表
else:
callback_list = [callbacks]
# 序列化信息,用于回调上报链条标识
serialized = {"name": "RunnableSequence", "type": "chain"}
# 遍历每个回调对象,触发其 on_chain_start 方法
for callback in callback_list:
# 只有回调对象有 on_chain_start 属性才调用
if hasattr(callback, "on_chain_start"):
# 调用回调的 on_chain_start 方法,传入相关参数
try:
# 调用回调的 on_chain_start 方法,传入相关参数
callback.on_chain_start(serialized, {"input": input}, run_id=run_id, parent_run_id=None, tags=config.get("tags"), metadata=config.get("metadata"), **kwargs)
except Exception:
# 回调过程中如出现异常则忽略,确保主流程不会终止
pass
# 初始 value 为输入 input
value = input
try:
# 依次调用每个 runnable 的 invoke,并传递最新的 value
for runnable in self.runnables:
value = runnable.invoke(value, config=config, **kwargs)
except Exception as e:
# 若捕获到异常,则对所有回调触发 on_chain_error 并继续抛出异常
if callback_list:
for callback in callback_list:
# 只有回调对象有 on_chain_error 属性才调用
if hasattr(callback, "on_chain_error"):
try:
# 调用回调的 on_chain_error 方法,传入相关参数
callback.on_chain_error(e, run_id=run_id, parent_run_id=None, **kwargs)
except Exception:
# 回调过程中如出现异常则忽略,确保主流程不会终止
pass
raise
else:
# 如果没有异常执行,顺序触发所有回调的 on_chain_end 方法
if callback_list:
for callback in callback_list:
# 只有回调对象有 on_chain_end 属性才调用
if hasattr(callback, "on_chain_end"):
try:
# 调用回调的 on_chain_end 方法,传入相关参数
callback.on_chain_end(outputs={"output": value}, run_id=run_id, parent_run_id=None, **kwargs)
except Exception:
# 回调过程中如出现异常则忽略,确保主流程不会终止
pass
# 返回最后一步的输出值
return value
# 批量调用,输入为多个 input,结果为每个 input 执行完整链条的输出
def batch(self, inputs, config = None, **kwargs):
"""
对输入列表逐项执行同一条链。
"""
# 逐项调用 invoke,收集所有输出
return [self.invoke(item, config=config, **kwargs) for item in inputs]
# 流式调用,默认复用基类逻辑(只对链最终结果流式分发)
def stream(self, input, **kwargs):
"""
流式执行:沿用基类逻辑,对最终结果做流式分发。
"""
# 使用基类 stream
yield from super().stream(input, **kwargs)
# 定义字符串表示,便于调试,输出链路结构
def __repr__(self):
# 获取每个 runnable 的名字,用"|"拼接成描述
names = " | ".join(getattr(r, "name", r.__class__.__name__) for r in self.runnables)
# 返回自定义格式
return f"RunnableSequence({names})"
# 定义 RunnableRetry 类,用于包装 Runnable 并添加重试逻辑
class RunnableRetry(Runnable):
"""
带重试功能的 Runnable 包装器
当底层 runnable 抛出指定异常时,会自动重试指定次数。
"""
# 初始化方法,接受被包装的 runnable 以及重试参数
def __init__(
self,
bound,
retry_if_exception_type=(Exception,),
stop_after_attempt=3,
wait_exponential_jitter=True,
exponential_jitter_params=None,
):
"""
初始化 RunnableRetry
Args:
bound: 被包装的 Runnable 对象
retry_if_exception_type: 需要重试的异常类型元组
stop_after_attempt: 最大尝试次数
wait_exponential_jitter: 是否启用指数回退抖动
exponential_jitter_params: 抖动参数 initial/max/exp_base/jitter
"""
# 保存底层被包装的 Runnable
self.bound = bound
# 保存需要重试的异常类型
self.retry_if_exception_type = retry_if_exception_type
# 保存最大尝试次数
self.stop_after_attempt = stop_after_attempt
# 保存是否启用指数回退抖动
self.wait_exponential_jitter = wait_exponential_jitter
# 保存指数回退相关参数(若为 None 则用空字典兜底)
self.exponential_jitter_params = exponential_jitter_params or {}
# 实现同步调用(自动重试机制)
def invoke(self, input, config = None, **kwargs):
"""
调用底层 runnable,失败时自动重试
"""
# 用于记录最后一次抛出的异常
last_exception = None
# 解析重试等待的各项参数
initial = self.exponential_jitter_params.get("initial", 0.1) # 初始延迟
max_wait = self.exponential_jitter_params.get("max", 10.0) # 最大延迟
exp_base = self.exponential_jitter_params.get("exp_base", 2.0) # 幂指数基数
jitter = self.exponential_jitter_params.get("jitter", 0.0) # 抖动范围
# 尝试多次调用,直到最大次数
for attempt in range(1, self.stop_after_attempt + 1):
try:
# 调用底层的 invoke 方法
return self.bound.invoke(input, config=config, **kwargs)
# 捕获需要重试的异常类型
except self.retry_if_exception_type as e:
# 保存本次捕获的异常
last_exception = e
# 若还没到最大次数,可以重试
if attempt < self.stop_after_attempt:
# 判断是否使用指数回退
if self.wait_exponential_jitter:
# 计算当前次的延迟
delay = min(max_wait, initial * (exp_base ** (attempt - 1)))
# 如果配置了 jitter,叠加一个随机抖动
if jitter > 0:
delay += random.uniform(0, jitter)
else:
# 不指数回退则用 initial 固定延迟
delay = initial
# 等待指定时间再重试
time.sleep(delay)
else:
# 达到最大次数仍然失败则抛出最后一次异常
raise last_exception
except Exception:
# 如果是完全不在重试范围的异常,直接抛出
raise
# 如果所有尝试都失败,最终抛出异常
raise last_exception
# 实现批量调用,每个输入独立重试
def batch(self, inputs, config = None, **kwargs):
"""
批量调用,每个输入独立重试
"""
# 对每个输入都单独执行 invoke,收集结果为列表
return [self.invoke(item, config=config, **kwargs) for item in inputs]
# 实现流式调用,直接复用基类逻辑
def stream(self, input, **kwargs):
"""
流式调用,复用基类实现
"""
# 使用父类的 stream,yield 结果
yield from super().stream(input, **kwargs)
# 返回自身字符串表示,便于调试查看 retry 配置与绑定对象
def __repr__(self):
return f"RunnableRetry(bound={self.bound}, max_attempts={self.stop_after_attempt})"
# 工具函数:检查函数是否接受 config 参数
def _accept_config(func) -> bool:
"""
检查函数是否接受 config 参数
Args:
func: 要检查的函数
Returns:
如果函数接受 config 参数则返回 True,否则返回 False
"""
try:
sig = inspect.signature(func)
return "config" in sig.parameters
except (ValueError, TypeError):
return False
# 工具函数:合并配置字典
def _merge_configs(*configs):
"""
合并多个配置字典
Args:
*configs: 要合并的配置字典列表
Returns:
合并后的配置字典
"""
result = {}
for config in configs:
if config:
# 对于嵌套字典(如 metadata),需要深度合并
for key, value in config.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = {**result[key], **value}
else:
result[key] = value
return result
# 定义 RunnableBinding 类,用于包装 Runnable 并绑定配置
class RunnableBinding(Runnable):
"""
Runnable 绑定包装器
用于将配置绑定到 Runnable,返回一个新的 Runnable 实例。
当调用绑定的 Runnable 时,会自动合并绑定的配置和传入的配置。
"""
def __init__(self, bound, config=None, kwargs=None):
"""
初始化 RunnableBinding
Args:
bound: 要绑定的底层 Runnable 实例
config: 要绑定的配置字典
kwargs: 要绑定的额外关键字参数(暂未使用)
"""
if not isinstance(bound, Runnable):
raise TypeError("bound 必须是 Runnable 实例")
self.bound = bound
self.config = ensure_config(config) or {}
self.kwargs = kwargs or {}
def invoke(self, input, config=None, **kwargs):
"""
调用绑定的 Runnable,合并配置
Args:
input: 输入值
config: 可选的配置字典,会与绑定的配置合并
**kwargs: 额外的关键字参数
Returns:
底层 Runnable 的返回值
"""
# 合并绑定的配置和传入的配置
merged_config = _merge_configs(self.config, config)
# 合并关键字参数
merged_kwargs = {**self.kwargs, **kwargs}
# 调用底层 Runnable
return self.bound.invoke(input, config=merged_config, **merged_kwargs)
def batch(self, inputs, config=None, **kwargs):
"""
批量调用绑定的 Runnable,合并配置
Args:
inputs: 输入值列表
config: 可选的配置字典,会与绑定的配置合并
**kwargs: 额外的关键字参数
Returns:
输出值列表
"""
# 合并绑定的配置和传入的配置
merged_config = _merge_configs(self.config, config)
# 合并关键字参数
merged_kwargs = {**self.kwargs, **kwargs}
# 调用底层 Runnable
return self.bound.batch(inputs, config=merged_config, **merged_kwargs)
def stream(self, input, config=None, **kwargs):
"""
流式调用绑定的 Runnable,合并配置
Args:
input: 输入值
config: 可选的配置字典,会与绑定的配置合并
**kwargs: 额外的关键字参数
Yields:
底层 Runnable 的流式输出
"""
# 合并绑定的配置和传入的配置
merged_config = _merge_configs(self.config, config)
# 合并关键字参数
merged_kwargs = {**self.kwargs, **kwargs}
# 调用底层 Runnable
yield from self.bound.stream(input, config=merged_config, **merged_kwargs)
def __repr__(self):
"""返回对象的字符串表示"""
return f"RunnableBinding(bound={self.bound}, config={self.config})"
# 定义 ConfigurableField 类,用于配置可动态调整的字段
from collections import namedtuple
ConfigurableField = namedtuple(
"ConfigurableField",
["id", "name", "description", "annotation", "is_shared"],
defaults=(None, None, None, False)
)
"""可配置字段的定义
Args:
id: 字段的唯一标识符,在 config["configurable"] 中使用
name: 字段的显示名称(可选)
description: 字段的描述(可选)
annotation: 字段的类型注解(可选)
is_shared: 字段是否共享(可选,默认 False)
"""
# 定义 RunnableConfigurableFields 类,用于包装 Runnable 并支持动态配置字段
class RunnableConfigurableFields(Runnable):
"""
Runnable 可配置字段包装器
用于将 Runnable 的某些字段配置为可在运行时动态调整。
当调用时,会从 config["configurable"] 中读取配置值,并创建新的实例。
示例:
``python
from smartchain.runnables import ConfigurableField
llm = ChatDeepSeek(temperature=0).configurable_fields(
temperature=ConfigurableField(
id="temperature",
name="温度值",
description="LLM 的采样温度参数"
)
)
# 使用默认 temperature=0
result1 = llm.invoke("你好")
# 使用 temperature=1.0
result2 = llm.invoke("你好", config={"configurable": {"temperature": 1.0}})
``
"""
# 构造函数:接收默认可执行对象和字段描述字典
def __init__(self, default, fields):
"""
初始化 RunnableConfigurableFields
Args:
default: 默认的 Runnable 实例或具有 invoke 方法的对象
fields: 可配置字段的字典,键为字段名,值为 ConfigurableField 实例
"""
# 检查 default 是否为 Runnable 实例或者拥有 invoke 方法
if not (isinstance(default, Runnable) or (hasattr(default, 'invoke') and callable(getattr(default, 'invoke')))):
raise TypeError("default 必须是 Runnable 实例或具有 invoke 方法的对象")
# 保存默认实例
self.default = default
# 保存字段配置(如果未传入则设为{})
self.fields = fields or {}
# 内部方法,根据 config 动态生成实例,应用动态配置
def _prepare(self, config=None):
"""
准备 Runnable 实例和配置
从 config["configurable"] 中读取配置值,并创建新的实例。
Args:
config: 配置字典
Returns:
tuple: (Runnable 实例, 配置字典)
"""
# 规范化 config(保证为字典)
config = ensure_config(config)
# 从 config 取出 configurable 配置
configurable = config.get("configurable", {})
# 收集需要修改的字段和值
updates = {}
for field_name, field_spec in self.fields.items():
# 检查字段是否为 ConfigurableField
if isinstance(field_spec, ConfigurableField):
# 从 config 找对应 id 的值
config_value = configurable.get(field_spec.id)
if config_value is not None:
updates[field_name] = config_value
# 有更新内容则创建新实例
if updates:
# 获取默认实例的类型
default_class = type(self.default)
# 获取类型名
class_name = default_class.__name__
# 对于特定聊天模型需要特殊参数处理
if class_name in ('ChatDeepSeek', 'ChatDeepSeek', 'ChatTongyi'):
# 构造初始化参数 dict,必须包含 model
init_params = {
'model': self.default.model,
}
# 如果有 model_kwargs 就复制
if hasattr(self.default, 'model_kwargs'):
init_params.update(self.default.model_kwargs.copy())
# 增加本次需更新的参数
init_params.update(updates)
# 保持 api_key(如有)
if hasattr(self.default, 'api_key'):
init_params['api_key'] = self.default.api_key
# 保持 base_url(如有)
if hasattr(self.default, 'base_url'):
init_params['base_url'] = getattr(self.default, 'base_url', None)
# 构造新实例
new_instance = default_class(**init_params)
return (new_instance, config)
else:
# 对于其他类型的实例采用通用方法
if hasattr(self.default, '__dict__'):
# 使用对象字段构建参数(忽略以 _ 开头的字段)
init_params = {k: v for k, v in self.default.__dict__.items()
if not k.startswith('_')}
else:
# 无法获取 __dict__ 则用空参数
init_params = {}
# 更新参数
init_params.update(updates)
try:
# 尝试直接用参数构造新实例
new_instance = default_class(**init_params)
return (new_instance, config)
except Exception:
# 构造失败则深拷贝实例并赋值
import copy
new_instance = copy.deepcopy(self.default)
for key, value in updates.items():
# 优先直接设置属性
if hasattr(new_instance, key):
setattr(new_instance, key, value)
# 对于 ChatDeepSeek 还要更新 model_kwargs 字典
elif hasattr(new_instance, 'model_kwargs'):
new_instance.model_kwargs[key] = value
return (new_instance, config)
# 未指定可配置参数,直接返回默认实例和 config
return (self.default, config)
# 单条输入调用方法,支持动态配置
def invoke(self, input, config=None, **kwargs):
"""
调用 Runnable,支持动态配置
Args:
input: 输入值
config: 配置字典,可以包含 configurable 字段
**kwargs: 额外的关键字参数
Returns:
底层 Runnable 的返回值
"""
# 获取动态配置后的 runnable 实例和配置
runnable, merged_config = self._prepare(config)
# 若为 Runnable 实例则传递 config 参数
if isinstance(runnable, Runnable):
return runnable.invoke(input, config=merged_config, **kwargs)
else:
# 非 Runnable 实例直接调用(初始化时参数已生效)
return runnable.invoke(input, **kwargs)
# 批量输入调用方法,支持动态配置
def batch(self, inputs, config=None, **kwargs):
"""
批量调用 Runnable,支持动态配置
Args:
inputs: 输入值列表
config: 配置字典,可以包含 configurable 字段
**kwargs: 额外的关键字参数
Returns:
输出值列表
"""
# 获取动态配置后的 runnable 实例和配置
runnable, merged_config = self._prepare(config)
# 若为 Runnable 实例则传递 config 参数
if isinstance(runnable, Runnable):
return runnable.batch(inputs, config=merged_config, **kwargs)
else:
# 有 batch 方法就直接调用
if hasattr(runnable, 'batch'):
return runnable.batch(inputs, **kwargs)
else:
# 没有 batch 方法,逐个调用 invoke 实现
return [runnable.invoke(input_item, **kwargs) for input_item in inputs]
# 流式输入调用,支持动态配置
def stream(self, input, config=None, **kwargs):
"""
流式调用 Runnable,支持动态配置
Args:
input: 输入值
config: 配置字典,可以包含 configurable 字段
**kwargs: 额外的关键字参数
Yields:
底层 Runnable 的流式输出
"""
# 获取动态配置后的 runnable 实例和配置
runnable, merged_config = self._prepare(config)
# 若为 Runnable 实例则传递 config 参数
if isinstance(runnable, Runnable):
yield from runnable.stream(input, config=merged_config, **kwargs)
else:
# 有 stream 方法就直接调用
if hasattr(runnable, 'stream'):
yield from runnable.stream(input, **kwargs)
else:
# 没有流式方法则调用 invoke 并 yield 单值
result = runnable.invoke(input, **kwargs)
yield result
# 字符串表示方法,便于调试
def __repr__(self):
"""返回对象的字符串表示"""
return f"RunnableConfigurableFields(default={self.default}, fields={self.fields})"
# 定义用于根据 config["configurable"] 动态选择分支的类
class RunnableConfigurableAlternatives(Runnable):
"""
根据配置动态选择不同分支的 Runnable/对象。
示例:
selector = ConfigurableField(id="provider", name="LLM 提供方")
chain = some_runnable.configurable_alternatives(
selector,
default_key="openai",
openai=ChatDeepSeek(...),
deepseek=ChatDeepSeek(...),
)
# 默认使用 openai
chain.invoke("hi")
# 切换为 deepseek
chain.invoke("hi", config={"configurable": {"provider": "deepseek"}})
"""
# 初始化方法,接收选择字段、默认 key、和所有可选分支
def __init__(self, selector_field, default_key, alternatives):
"""
初始化
Args:
selector_field: ConfigurableField,用于从 config["configurable"] 取值的字段
default_key: 默认分支 key,必须存在于 alternatives
alternatives: dict,key -> runnable 或具有 invoke 方法的对象
"""
# 检查 selector_field 是否为 ConfigurableField 实例
if not isinstance(selector_field, ConfigurableField):
raise TypeError("selector_field 必须是 ConfigurableField 实例")
# 检查默认 key 是否在 alternatives 里
if default_key not in alternatives:
raise ValueError("default_key 必须存在于 alternatives 中")
# 检查 alternatives 是否为非空字典
if not isinstance(alternatives, dict) or not alternatives:
raise ValueError("alternatives 必须是非空字典")
# 保存选择器字段
self.selector_field = selector_field
# 保存默认分支 key
self.default_key = default_key
# 保存所有分支
self.alternatives = alternatives
# 内部方法:按照 config 动态选择分支
def _select(self, config=None):
# 标准化配置,补全可选项结构
config = ensure_config(config)
# 获取 configurable 字段(可能为空)
configurable = config.get("configurable", {}) or {}
# 根据 selector_field.id 查询分支 key,如果没指定则使用默认 key
key = configurable.get(self.selector_field.id, self.default_key)
# 找不到分支则报错
if key not in self.alternatives:
raise ValueError(f"未找到可用分支: {key}")
# 返回被选中的分支和合并后的配置
return self.alternatives[key], config
# 单条输入调用,根据当前 config 路由到对应分支
def invoke(self, input, config=None, **kwargs):
# 动态选择分支和合并后的配置
selected, merged_config = self._select(config)
# 如果是 Runnable,则传递 config
if isinstance(selected, Runnable):
return selected.invoke(input, config=merged_config, **kwargs)
else:
# 否则只调用普通 invoke
return selected.invoke(input, **kwargs)
# 批量调用,根据当前 config 调用子分支
def batch(self, inputs, config=None, **kwargs):
# 选择分支和合并 config
selected, merged_config = self._select(config)
# 如果是 Runnable,传递 config 下批量调用
if isinstance(selected, Runnable):
return selected.batch(inputs, config=merged_config, **kwargs)
else:
# 有 batch 方法直接用
if hasattr(selected, "batch"):
return selected.batch(inputs, **kwargs)
# 否则逐条调用 invoke
return [selected.invoke(item, **kwargs) for item in inputs]
# 流式输出,根据 config 路由
def stream(self, input, config=None, **kwargs):
# 动态选择分支
selected, merged_config = self._select(config)
# 如果支持 stream 且是 Runnable,传递 config
if isinstance(selected, Runnable):
yield from selected.stream(input, config=merged_config, **kwargs)
else:
# 有 stream 方法直接用
if hasattr(selected, "stream"):
yield from selected.stream(input, **kwargs)
else:
# 没有流式方法则调用普通 invoke
yield selected.invoke(input, **kwargs)
# 字符串表示方法,便于调试打印分支
def __repr__(self):
return (
f"RunnableConfigurableAlternatives("
f"selector_field={self.selector_field}, "
f"default_key={self.default_key}, "
f"alternatives={list(self.alternatives.keys())}"
f")"
)
# 定义一个带有消息历史管理功能的 Runnable 包装器类
class RunnableWithMessageHistory(Runnable):
"""
管理聊天消息历史的 Runnable 包装器
自动处理历史消息的读取和更新,支持多会话管理。
示例见文档字符串内容。
"""
# 初始化方法,接收底层runnable、会话历史获取方法和相关key配置
def __init__(
self,
runnable,
get_session_history,
*,
input_messages_key=None,
output_messages_key=None,
history_messages_key=None,
):
"""
初始化 RunnableWithMessageHistory
Args:
runnable: 要包装的 Runnable 实例
get_session_history: 用于获取会话历史的函数,需接受 session_id 参数
input_messages_key: 输入字典中的消息键名
output_messages_key: 输出字典中的消息键名
history_messages_key: 历史消息在输入字典中的键名
"""
# 检查 runnable 是否为 Runnable 的实例
if not isinstance(runnable, Runnable):
raise TypeError("runnable 必须是 Runnable 实例")
# 检查 get_session_history 是否是可调用对象
if not callable(get_session_history):
raise TypeError("get_session_history 必须是可调用对象")
# 保存 runnable 对象
self.runnable = runnable
# 保存用于获取会话历史的方法
self.get_session_history = get_session_history
# 输入消息的键
self.input_messages_key = input_messages_key
# 输出消息的键
self.output_messages_key = output_messages_key
# 历史消息的键
self.history_messages_key = history_messages_key
# 从输入中提取消息对象列表(只提取当前输入内容,不含历史)
def _get_input_messages(self, input_val):
"""
从输入中提取消息列表
Args:
input_val: 输入值,可以是字符串、消息对象、消息列表或字典
Returns:
消息列表
"""
from .messages import HumanMessage, BaseMessage
# 如果输入是字典,提取目标key,对应输入内容
if isinstance(input_val, dict):
if self.input_messages_key:
key = self.input_messages_key
elif len(input_val) == 1:
key = next(iter(input_val.keys()))
else:
key = "input"
input_val = input_val[key]
# 如果输入是字符串,转为HumanMessage
if isinstance(input_val, str):
return [HumanMessage(content=input_val)]
# 如果是基础消息对象,转为单元素列表
if isinstance(input_val, BaseMessage):
return [input_val]
# 如果已经是列表或元组,直接返回为列表
if isinstance(input_val, (list, tuple)):
return list(input_val)
# 其它情况抛出异常
raise ValueError(f"无法从输入中提取消息: {input_val}")
# 从输出中解析AI消息列表
def _get_output_messages(self, output_val):
"""
从输出中提取消息列表
Args:
output_val: 输出值,可以是字符串、消息对象、消息列表或字典
Returns:
消息列表
"""
from .messages import AIMessage, BaseMessage
# 如果输出是字典,根据配置或默认key提取
if isinstance(output_val, dict):
if self.output_messages_key:
key = self.output_messages_key
elif len(output_val) == 1:
key = next(iter(output_val.keys()))
else:
key = "output"
output_val = output_val[key]
# 若为字符串,转为AIMessage
if isinstance(output_val, str):
return [AIMessage(content=output_val)]
# 若为基础消息对象或含content/type属性(兼容AIMessage)
if isinstance(output_val, BaseMessage) or (hasattr(output_val, 'content') and hasattr(output_val, 'type')):
return [output_val]
# 若为列表或元组
if isinstance(output_val, (list, tuple)):
return list(output_val)
# 其它情况抛异常
raise ValueError(f"无法从输出中提取消息: {output_val}")
# 核心:带历史的invoke调用
def invoke(self, input, config=None, **kwargs):
"""
调用 Runnable,自动管理历史消息
Args:
input: 输入值
config: 配置字典,需包含 configurable.session_id
**kwargs: 其余关键参数
Returns:
底层 Runnable 的返回值
"""
# 确保 config 存在和格式标准化
config = ensure_config(config)
# 获取自定义配置部分
configurable = config.get("configurable", {})
# 获取当前会话ID,必须提供
session_id = configurable.get("session_id")
if not session_id:
raise ValueError("config['configurable']['session_id'] 必须提供")
# 调用 get_session_history 拉取(或新建)指定会话的历史对象
history = self.get_session_history(session_id)
# 获取历史消息列表(copy)
history_messages = history.messages
# 保存原始输入,用于后续提取当前输入消息
original_input = input
# ------------ 准备带历史的输入 ----------
# 如果需要指定历史键,插入到输入字典
if self.history_messages_key:
# 必须保证原输入是字典
if not isinstance(input, dict):
raise ValueError(f"当使用 history_messages_key 时,输入必须是字典,但得到了 {type(input)}")
# 复制输入,避免副作用
input = input.copy()
# 加入历史消息到指定键
input[self.history_messages_key] = history_messages
# 如果只指定了输入消息key,加入历史到"history"键
elif self.input_messages_key:
if not isinstance(input, dict):
raise ValueError(f"当使用 input_messages_key 时,输入必须是字典,但得到了 {type(input)}")
# 复制输入
input = input.copy()
# 始终加入"history",供Prompt模板中的MessagesPlaceholder用
input["history"] = history_messages
else:
# 如果输入不是字典,直接在头部追加历史消息
if isinstance(input, (list, tuple)):
input = list(history_messages) + list(input)
elif isinstance(input, str):
from .messages import HumanMessage
input = list(history_messages) + [HumanMessage(content=input)]
# ---------- 调用底层Runnable ----------
output = self.runnable.invoke(input, config=config, **kwargs)
# ----------- 解析输入输出消息 -----------
# 针对不同key配置,从原始输入提取“本轮”的人类消息
if self.history_messages_key:
if isinstance(original_input, dict):
input_messages = self._get_input_messages(
original_input.get(self.input_messages_key, original_input)
)
else:
input_messages = self._get_input_messages(original_input)
elif self.input_messages_key:
if isinstance(original_input, dict):
input_messages = self._get_input_messages(
original_input.get(self.input_messages_key, original_input)
)
else:
input_messages = self._get_input_messages(original_input)
else:
# 没有指定key,直接解析全部的原始输入
input_messages = self._get_input_messages(original_input)
# 解析AI消息
output_messages = self._get_output_messages(output)
# ---------- 更新历史(当前输入+本轮AI回复) ----------
history.add_messages(input_messages + output_messages)
# 返回最终的模型输出
return output
# 批量输入的处理,每条输入单独处理一次并串行更新历史
def batch(self, inputs, config=None, **kwargs):
"""
批量调用 Runnable,自动管理历史消息
Args:
inputs: 输入值组成的列表
config: 配置字典
**kwargs: 其它参数
Returns:
输出值组成的列表
"""
# 针对每个输入,依次调用invoke
return [self.invoke(input_item, config=config, **kwargs) for input_item in inputs]
# 支持流式输出,更新历史后返回流式结果(这里每次只yield一次)
def stream(self, input, config=None, **kwargs):
"""
流式调用 Runnable,自动管理历史消息
Args:
input: 单条输入
config: 配置
**kwargs: 其它参数
Yields:
底层 Runnable 的流式输出结果
"""
# 简单实现:先完整invoke一轮更新历史,然后yield一次输出
output = self.invoke(input, config=config, **kwargs)
yield output
# 字符串显示方法,方便调试
def __repr__(self):
"""返回对象的字符串表示"""
return f"RunnableWithMessageHistory(runnable={self.runnable}, input_messages_key={self.input_messages_key}, history_messages_key={self.history_messages_key})"
# 工具执行器 Runnable,自动处理工具调用流程
+class RunnableToolExecutor(Runnable):
+ """
+ 工具执行器 Runnable,自动处理工具调用流程
+ 封装了工具调用的完整流程:
+ 1. 调用 LLM 获取响应(可能包含 tool_calls)
+ 2. 检测 tool_calls 并执行对应工具
+ 3. 将工具结果回传给 LLM
+ 4. 返回最终答案
+ 支持链式调用,可以和其他 Runnable 组合使用。
+ """
+ def __init__(self, llm_with_tools, tools):
+ """
+ 初始化工具执行器
+ Args:
+ llm_with_tools: 绑定了工具的 LLM(通过 bind_tools 创建)
+ tools: 工具列表,用于执行工具调用
+ """
+ self.llm_with_tools = llm_with_tools
# 创建工具名称到工具实例的映射
+ self.tool_map = {tool.name: tool for tool in tools}
+ def invoke(self, input, config=None, **kwargs):
+ """
+ 执行工具调用流程
+ Args:
+ input: 输入消息(可以是字符串、消息列表等)
+ config: 配置字典
+ **kwargs: 其他参数
+ Returns:
+ AIMessage: 最终的回答消息
+ """
# 确保 input 是列表格式(用于消息历史)
+ if isinstance(input, str):
+ messages = [("human", input)]
+ elif isinstance(input, list):
+ messages = input
+ else:
+ messages = [input]
# 第一次调用 LLM
+ resp = self.llm_with_tools.invoke(messages, **kwargs)
# 如果响应包含工具调用,执行工具并再次调用 LLM
+ if hasattr(resp, "tool_calls") and resp.tool_calls:
+ tool_results = []
# 执行每个工具调用
+ for tc in resp.tool_calls:
+ tool_name = tc["name"]
+ if tool_name in self.tool_map:
# 执行工具
+ result = self.tool_map[tool_name].invoke(tc["args"])
# 创建 ToolMessage
+ tool_results.append(
+ ToolMessage(str(result), tool_call_id=tc["id"])
+ )
# 将工具结果回传给 LLM,获取最终答案
+ final_messages = messages + [resp] + tool_results
+ final = self.llm_with_tools.invoke(final_messages, **kwargs)
+ return final
+ else:
# 没有工具调用,直接返回响应
+ return resp