Skip to content

BaseLLM 基类设计

BaseLLM 是所有 LLM 封装的抽象基类,定义了统一的核心接口和通用功能。理解 BaseLLM 的设计是扩展自定义 LLM 或深入使用现有 LLM 的关键。

设计理念

职责分离

BaseLLM 遵循单一职责原则,专注于以下核心功能:

  1. 统一接口定义:定义所有 LLM 实现必须遵循的接口契约
  2. 通用功能实现:提供跨厂商共享的功能(token 计算、上下文压缩、遥测)
  3. 生命周期管理:处理 LLM 对象的初始化、连接、断开等生命周期
  4. 错误处理框架:定义统一的错误处理策略

抽象层设计

BaseLLM 在以下层次提供了抽象:

┌─────────────────────────────────────────┐
│           应用层 (Chain/Drive)           │
├─────────────────────────────────────────┤
│          BaseLLM (抽象基类)               │
│  - 完整的 token 计算逻辑                  │
│  - 智能上下文压缩                         │
│  - 遥测追踪                              │
│  - 媒体缓存管理                           │
├─────────────────────────────────────────┤
│  ChatLLM / GenerationLLM (中间抽象)       │
│  - Prompt 位置系统                       │
│  - 消息格式化                            │
├─────────────────────────────────────────┤
│  具体实现 (GPT/Claude/Gemini...)         │
│  - 厂商 API 调用                         │
│  - 厂商格式转换                          │
└─────────────────────────────────────────┘

核心抽象方法

子类必须实现以下抽象方法:

complete / async_complete

def complete(
    self,
    current_input: BaseMessage,
    conversation: Optional[list[BaseMessage]] = None,
    elements: Optional[list[DocElement]] = None,
    knowledge: Optional[str] = None,
    tools: Optional[list[BaseTool]] = None,
    intermediate_msgs: Optional[list[BaseMessage]] = None,
    response_format: Optional[LLMResponseFormat] = None,
) -> LLMResult:
    """生成聊天结果(同步版本)"""

关键设计点

  • 当模型 API 返回上下文超长错误时,不应在 LLM 层进行折叠重试,而应转换为 ContextTooLargeError 异常抛出
  • 由 Chain 层捕获异常并触发 compact 流程
  • 这样设计是为了职责分离:LLM 负责调用,Chain 负责上下文管理

construct_request_params

def construct_request_params(
    self,
    current_input: BaseMessage,
    conversation: Optional[list[BaseMessage]] = None,
    ...
) -> dict:
    """构造聊天请求参数"""

职责:将 TFRobot 内部的数据结构转换为厂商 API 所需的请求参数。

construct_llm_result

def construct_llm_result(
    self,
    response: Any,
    prompt: dict,
) -> LLMResult:
    """构造 LLMResult 对象"""

职责:将厂商 API 的响应转换为统一的 LLMResult 对象。

ChatLLM 专属抽象方法

reformat_request_msg_to_api

def reformat_request_msg_to_api(
    self,
    msg: BaseLLMMessage
) -> ProviderMessageType | list[ProviderMessageType]:
    """将 TFRobot 内部消息类型转换为厂商格式"""

设计原因

  • 不同厂商的消息格式差异巨大(如 Anthropic 不支持 system 角色)
  • 一个 TFRobot 消息可能对应多个厂商消息(如多模态拆分)
  • 使用泛型 ProviderMessageType 允许子类定义自己的消息类型

extract_context_size_from_error

def _extract_context_size_from_error(
    self,
    exception: Exception
) -> tuple[Optional[int], Optional[int]]:
    """从厂商异常中提取上下文大小信息"""

用途:当上下文超长时,从异常中提取当前大小和最大大小,用于智能压缩。

核心功能实现

1. Token 计算系统

BaseLLM 提供了完整的 token 计算功能,支持多模态内容:

def calculate_token_count(
    self,
    content: str | BaseMessage | DocElement
) -> int:
    """计算任意内容的 token 数量"""

计算流程

content
    ↓
类型分发 (match-case)
    ├─→ str        → tokenize_str
    ├─→ BaseMessage → _count_message_tokens
    └─→ DocElement  → _count_element_tokens

多模态支持

  • 图片:使用 tile_image 计算瓦片数量,应用厂商计费规则
  • 视频:根据时长估算(Gemini: 258 + 258 × 秒数)
  • 音频:根据时长估算(GPT-4o: 32 × 秒数)
  • PDF:按页数估算(默认 500 tokens/页)

2. 智能上下文压缩

当上下文超出限制时,collapse_context 方法会执行智能压缩:

def collapse_context(
    self,
    current_input: BaseMessage,
    conversation: Optional[list[BaseMessage]] = None,
    ...
    to_size: Optional[int] = None,
    current_size: Optional[int] = None,
) -> tuple[
    Optional[list[BaseMessage]],
    Optional[list[DocElement]],
    Optional[str],
    Optional[list[BaseMessage]],
    list[LLMResult],
]:

压缩策略

  1. 预压缩阶段(当 current_size > to_size):
  2. 使用 Splitter 快速机械压缩
  3. 优先压缩 conversationelementsknowledge
  4. 尽可能保留 intermediate_msgs

  5. LLM 智能压缩阶段

  6. 构建压缩提示(支持多语言)
  7. 调用 LLM 生成摘要
  8. 清空所有上下文,只保留压缩结果

  9. 递归处理

  10. 如果 LLM 调用仍然超出限制
  11. 目标大小衰减为 to_size × 0.9
  12. 最多递归 5 次

3. 媒体缓存系统

使用 TTLCache 缓存媒体数据和元数据:

class MediaCacheEntry(TypedDict):
    media_type: str           # "image" | "video" | "audio" | "pdf"
    bytes_data: Optional[tuple[bytes, str]]  # (data, mime_type)
    dimensions: Optional[tuple[int, int]]    # (width, height)
    tiles: Optional[tuple[int, int, int]]    # (tiles_w, tiles_h, total)
    tokens: Optional[int]     # token 数量
    cached_at: float          # 缓存时间戳

缓存 Key 格式"{media_type}:{uri}:{param_hash}"

缓存策略: - 每个子类独立的缓存实例 - 最大 5000 条目 - 24 小时过期

4. 遥测追踪

自动通过 OpenTelemetry 记录 LLM 调用事件:

@span_decorator(
    span_attr,
    SpanEvent.BEFORE_LLM_GENERATE,
    SpanEvent.AFTER_LLM_GENERATE,
    SpanEvent.LLM_GENERATE_RAISE,
    SpanEvent.LLM_GENERATE_ABORT,
    ...
)
def complete(self, ...):
    ...

追踪的事件: - BEFORE_LLM_GENERATE:调用开始 - AFTER_LLM_GENERATE:调用成功 - LLM_GENERATE_RAISE:抛出异常 - LLM_GENERATE_ABORT:调用中止

5. 工具调用优化

optimize_map_reduce_messages 方法优化工具调用的中间消息:

def optimize_map_reduce_messages(
    self,
    intermediate_msgs: list[BaseMessage]
) -> tuple[list[BaseMessage], bool, bool]:
    """优化 Map-Reduce 模式的工具调用"""

优化逻辑

  1. Map 阶段:仅保留最后一个 Map 消息,折叠中间的 Map 消息
  2. Reduce 阶段:将所有 Map 的返回结果融入到 Reduce 消息中
  3. 工具调用树:支持嵌套的工具调用(树型结构)

配置参数详解

核心参数

参数 类型 说明 默认值
name str 模型名称(如 gpt-4o -
input_price float 输入价格(元/kTokens) 0.0
output_price float 输出价格(元/kTokens) 0.0
context_window int 上下文窗口大小(tokens) 自动获取
response_format LLMResponseFormat 响应格式 text
tool_filter str 工具过滤表达式 None
locale Locale 内部提示语言 DEFAULT

响应格式配置

response_format: LLMResponseFormat = {
    "type": "json_schema",  # "text" | "json_object" | "json_schema"
    "json_schema": {...},   # JSON Schema 定义(仅 json_schema 需要)
    "examples": [...]       # 示例(可选,用于 few-shot)
}

扩展 BaseLLM

实现新的 LLM 提供商

  1. 继承 ChatLLM 或 GenerationLLM
from tfrobot.brain.chain.llms.chat_llm import ChatLLM

class MyLLM(ChatLLM[dict]):
    name: str = "my-model"
    input_price: float = 0.01
    output_price: float = 0.02
  1. 实现抽象方法
def model_post_init(self, __context: Any) -> None:
    """初始化客户端"""
    self._client = MyLLMClient(api_key=self.api_key)

def construct_request_params(self, ...) -> dict:
    """构造请求参数"""
    # 调用父类方法获取消息列表
    req_msgs = self.format_to_request_msgs(current_input, prompt_ctx)
    # 转换为厂商格式
    messages = [
        self.reformat_request_msg_to_api(msg)
        for msg in req_msgs.to_list()
    ]
    return {"messages": messages, ...}

def complete(self, ...) -> LLMResult:
    """调用 API"""
    params = self.construct_request_params(...)
    try:
        response = self._client.chat(**params)
        return self.construct_llm_result(response, params)
    except MyContextTooLargeError as e:
        current_size, target_size = self._extract_context_size_from_error(e)
        raise ContextTooLargeError(
            current_size=current_size,
            target_size=target_size
        ) from e

def reformat_request_msg_to_api(self, msg: BaseLLMMessage) -> dict:
    """转换消息格式"""
    if isinstance(msg, LLMSystemMessage):
        return {"role": "system", "content": msg.content}
    elif isinstance(msg, LLMUserMessage):
        return {"role": "user", "content": msg.content}
    ...

def _extract_context_size_from_error(self, e: Exception):
    """提取上下文大小"""
    # 从异常中解析当前大小和最大大小
    # 返回 (current_size, max_size)
    ...
  1. 处理多模态内容

如果模型支持多模态,重写 _estimate_*_tokens 方法:

def _estimate_image_tokens(self, uri: str, detail: Optional[str] = None) -> int:
    """计算图片 token(自定义规则)"""
    # 使用 tile_image 获取瓦片信息
    tiles_w, tiles_h, n = self.tile_image(uri, 512, 768)
    return 100 + 200 * n  # 自定义计费规则

关键实现细节

错误识别

正确识别上下文超长错误是关键:

try:
    response = self._client.chat(**params)
except MyBadRequestError as e:
    if "context_length_exceeded" in str(e):
        current_size = parse_current_size(e)
        max_size = parse_max_size(e)
        raise ContextTooLargeError(
            current_size=current_size,
            target_size=max_size
        ) from e
    raise

重试策略

使用 tenacity 实现重试:

from tenacity import (
    retry,
    stop_after_attempt,
    wait_exponential,
    retry_if_exception_type,
)

@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=4, max=10),
    retry=retry_if_exception_type((MyTimeoutError, MyConnectionError)),
)
def complete(self, ...):
    ...

异步支持

同步和异步方法应共享实现逻辑:

def complete(self, ...) -> LLMResult:
    params = self.construct_request_params(...)
    response = self._client.chat(**params)
    return self._process_response(response)

async def async_complete(self, ...) -> LLMResult:
    params = self.construct_request_params(...)
    response = await self._async_client.chat(**params)
    return self._process_response(response)

最佳实践

  1. 使用 whosellm 获取模型元数据:自动获取上下文窗口、支持的功能等
  2. 复用父类方法:如 calculate_token_counttile_image
  3. 正确处理异常:将厂商异常转换为 ContextTooLargeError
  4. 支持异步:同时实现同步和异步方法
  5. 使用缓存:媒体数据、瓦片信息等都应缓存
  6. 记录遥测:使用 @tracer.start_as_current_span 装饰器

相关文档