BaseLLM 基类设计¶
BaseLLM 是所有 LLM 封装的抽象基类,定义了统一的核心接口和通用功能。理解 BaseLLM 的设计是扩展自定义 LLM 或深入使用现有 LLM 的关键。
设计理念¶
职责分离¶
BaseLLM 遵循单一职责原则,专注于以下核心功能:
- 统一接口定义:定义所有 LLM 实现必须遵循的接口契约
- 通用功能实现:提供跨厂商共享的功能(token 计算、上下文压缩、遥测)
- 生命周期管理:处理 LLM 对象的初始化、连接、断开等生命周期
- 错误处理框架:定义统一的错误处理策略
抽象层设计¶
BaseLLM 在以下层次提供了抽象:
┌─────────────────────────────────────────┐
│ 应用层 (Chain/Drive) │
├─────────────────────────────────────────┤
│ BaseLLM (抽象基类) │
│ - 完整的 token 计算逻辑 │
│ - 智能上下文压缩 │
│ - 遥测追踪 │
│ - 媒体缓存管理 │
├─────────────────────────────────────────┤
│ ChatLLM / GenerationLLM (中间抽象) │
│ - Prompt 位置系统 │
│ - 消息格式化 │
├─────────────────────────────────────────┤
│ 具体实现 (GPT/Claude/Gemini...) │
│ - 厂商 API 调用 │
│ - 厂商格式转换 │
└─────────────────────────────────────────┘
核心抽象方法¶
子类必须实现以下抽象方法:
complete / async_complete¶
def complete(
self,
current_input: BaseMessage,
conversation: Optional[list[BaseMessage]] = None,
elements: Optional[list[DocElement]] = None,
knowledge: Optional[str] = None,
tools: Optional[list[BaseTool]] = None,
intermediate_msgs: Optional[list[BaseMessage]] = None,
response_format: Optional[LLMResponseFormat] = None,
) -> LLMResult:
"""生成聊天结果(同步版本)"""
关键设计点:
- 当模型 API 返回上下文超长错误时,不应在 LLM 层进行折叠重试,而应转换为
ContextTooLargeError异常抛出 - 由 Chain 层捕获异常并触发 compact 流程
- 这样设计是为了职责分离:LLM 负责调用,Chain 负责上下文管理
construct_request_params¶
def construct_request_params(
self,
current_input: BaseMessage,
conversation: Optional[list[BaseMessage]] = None,
...
) -> dict:
"""构造聊天请求参数"""
职责:将 TFRobot 内部的数据结构转换为厂商 API 所需的请求参数。
construct_llm_result¶
def construct_llm_result(
self,
response: Any,
prompt: dict,
) -> LLMResult:
"""构造 LLMResult 对象"""
职责:将厂商 API 的响应转换为统一的 LLMResult 对象。
ChatLLM 专属抽象方法¶
reformat_request_msg_to_api¶
def reformat_request_msg_to_api(
self,
msg: BaseLLMMessage
) -> ProviderMessageType | list[ProviderMessageType]:
"""将 TFRobot 内部消息类型转换为厂商格式"""
设计原因:
- 不同厂商的消息格式差异巨大(如 Anthropic 不支持 system 角色)
- 一个 TFRobot 消息可能对应多个厂商消息(如多模态拆分)
- 使用泛型
ProviderMessageType允许子类定义自己的消息类型
extract_context_size_from_error¶
def _extract_context_size_from_error(
self,
exception: Exception
) -> tuple[Optional[int], Optional[int]]:
"""从厂商异常中提取上下文大小信息"""
用途:当上下文超长时,从异常中提取当前大小和最大大小,用于智能压缩。
核心功能实现¶
1. Token 计算系统¶
BaseLLM 提供了完整的 token 计算功能,支持多模态内容:
def calculate_token_count(
self,
content: str | BaseMessage | DocElement
) -> int:
"""计算任意内容的 token 数量"""
计算流程:
content
↓
类型分发 (match-case)
├─→ str → tokenize_str
├─→ BaseMessage → _count_message_tokens
└─→ DocElement → _count_element_tokens
多模态支持:
- 图片:使用
tile_image计算瓦片数量,应用厂商计费规则 - 视频:根据时长估算(Gemini: 258 + 258 × 秒数)
- 音频:根据时长估算(GPT-4o: 32 × 秒数)
- PDF:按页数估算(默认 500 tokens/页)
2. 智能上下文压缩¶
当上下文超出限制时,collapse_context 方法会执行智能压缩:
def collapse_context(
self,
current_input: BaseMessage,
conversation: Optional[list[BaseMessage]] = None,
...
to_size: Optional[int] = None,
current_size: Optional[int] = None,
) -> tuple[
Optional[list[BaseMessage]],
Optional[list[DocElement]],
Optional[str],
Optional[list[BaseMessage]],
list[LLMResult],
]:
压缩策略:
- 预压缩阶段(当
current_size > to_size): - 使用 Splitter 快速机械压缩
- 优先压缩
conversation、elements、knowledge -
尽可能保留
intermediate_msgs -
LLM 智能压缩阶段:
- 构建压缩提示(支持多语言)
- 调用 LLM 生成摘要
-
清空所有上下文,只保留压缩结果
-
递归处理:
- 如果 LLM 调用仍然超出限制
- 目标大小衰减为
to_size × 0.9 - 最多递归 5 次
3. 媒体缓存系统¶
使用 TTLCache 缓存媒体数据和元数据:
class MediaCacheEntry(TypedDict):
media_type: str # "image" | "video" | "audio" | "pdf"
bytes_data: Optional[tuple[bytes, str]] # (data, mime_type)
dimensions: Optional[tuple[int, int]] # (width, height)
tiles: Optional[tuple[int, int, int]] # (tiles_w, tiles_h, total)
tokens: Optional[int] # token 数量
cached_at: float # 缓存时间戳
缓存 Key 格式:"{media_type}:{uri}:{param_hash}"
缓存策略: - 每个子类独立的缓存实例 - 最大 5000 条目 - 24 小时过期
4. 遥测追踪¶
自动通过 OpenTelemetry 记录 LLM 调用事件:
@span_decorator(
span_attr,
SpanEvent.BEFORE_LLM_GENERATE,
SpanEvent.AFTER_LLM_GENERATE,
SpanEvent.LLM_GENERATE_RAISE,
SpanEvent.LLM_GENERATE_ABORT,
...
)
def complete(self, ...):
...
追踪的事件:
- BEFORE_LLM_GENERATE:调用开始
- AFTER_LLM_GENERATE:调用成功
- LLM_GENERATE_RAISE:抛出异常
- LLM_GENERATE_ABORT:调用中止
5. 工具调用优化¶
optimize_map_reduce_messages 方法优化工具调用的中间消息:
def optimize_map_reduce_messages(
self,
intermediate_msgs: list[BaseMessage]
) -> tuple[list[BaseMessage], bool, bool]:
"""优化 Map-Reduce 模式的工具调用"""
优化逻辑:
- Map 阶段:仅保留最后一个 Map 消息,折叠中间的 Map 消息
- Reduce 阶段:将所有 Map 的返回结果融入到 Reduce 消息中
- 工具调用树:支持嵌套的工具调用(树型结构)
配置参数详解¶
核心参数¶
| 参数 | 类型 | 说明 | 默认值 |
|---|---|---|---|
name |
str |
模型名称(如 gpt-4o) |
- |
input_price |
float |
输入价格(元/kTokens) | 0.0 |
output_price |
float |
输出价格(元/kTokens) | 0.0 |
context_window |
int |
上下文窗口大小(tokens) | 自动获取 |
response_format |
LLMResponseFormat |
响应格式 | text |
tool_filter |
str |
工具过滤表达式 | None |
locale |
Locale |
内部提示语言 | DEFAULT |
响应格式配置¶
response_format: LLMResponseFormat = {
"type": "json_schema", # "text" | "json_object" | "json_schema"
"json_schema": {...}, # JSON Schema 定义(仅 json_schema 需要)
"examples": [...] # 示例(可选,用于 few-shot)
}
扩展 BaseLLM¶
实现新的 LLM 提供商¶
- 继承 ChatLLM 或 GenerationLLM
from tfrobot.brain.chain.llms.chat_llm import ChatLLM
class MyLLM(ChatLLM[dict]):
name: str = "my-model"
input_price: float = 0.01
output_price: float = 0.02
- 实现抽象方法
def model_post_init(self, __context: Any) -> None:
"""初始化客户端"""
self._client = MyLLMClient(api_key=self.api_key)
def construct_request_params(self, ...) -> dict:
"""构造请求参数"""
# 调用父类方法获取消息列表
req_msgs = self.format_to_request_msgs(current_input, prompt_ctx)
# 转换为厂商格式
messages = [
self.reformat_request_msg_to_api(msg)
for msg in req_msgs.to_list()
]
return {"messages": messages, ...}
def complete(self, ...) -> LLMResult:
"""调用 API"""
params = self.construct_request_params(...)
try:
response = self._client.chat(**params)
return self.construct_llm_result(response, params)
except MyContextTooLargeError as e:
current_size, target_size = self._extract_context_size_from_error(e)
raise ContextTooLargeError(
current_size=current_size,
target_size=target_size
) from e
def reformat_request_msg_to_api(self, msg: BaseLLMMessage) -> dict:
"""转换消息格式"""
if isinstance(msg, LLMSystemMessage):
return {"role": "system", "content": msg.content}
elif isinstance(msg, LLMUserMessage):
return {"role": "user", "content": msg.content}
...
def _extract_context_size_from_error(self, e: Exception):
"""提取上下文大小"""
# 从异常中解析当前大小和最大大小
# 返回 (current_size, max_size)
...
- 处理多模态内容
如果模型支持多模态,重写 _estimate_*_tokens 方法:
def _estimate_image_tokens(self, uri: str, detail: Optional[str] = None) -> int:
"""计算图片 token(自定义规则)"""
# 使用 tile_image 获取瓦片信息
tiles_w, tiles_h, n = self.tile_image(uri, 512, 768)
return 100 + 200 * n # 自定义计费规则
关键实现细节¶
错误识别¶
正确识别上下文超长错误是关键:
try:
response = self._client.chat(**params)
except MyBadRequestError as e:
if "context_length_exceeded" in str(e):
current_size = parse_current_size(e)
max_size = parse_max_size(e)
raise ContextTooLargeError(
current_size=current_size,
target_size=max_size
) from e
raise
重试策略¶
使用 tenacity 实现重试:
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((MyTimeoutError, MyConnectionError)),
)
def complete(self, ...):
...
异步支持¶
同步和异步方法应共享实现逻辑:
def complete(self, ...) -> LLMResult:
params = self.construct_request_params(...)
response = self._client.chat(**params)
return self._process_response(response)
async def async_complete(self, ...) -> LLMResult:
params = self.construct_request_params(...)
response = await self._async_client.chat(**params)
return self._process_response(response)
最佳实践¶
- 使用
whosellm获取模型元数据:自动获取上下文窗口、支持的功能等 - 复用父类方法:如
calculate_token_count、tile_image等 - 正确处理异常:将厂商异常转换为
ContextTooLargeError - 支持异步:同时实现同步和异步方法
- 使用缓存:媒体数据、瓦片信息等都应缓存
- 记录遥测:使用
@tracer.start_as_current_span装饰器