ai_member_xiaoyan/skills/interactive-component-json/scripts/llm_client.py

250 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
LLM 调用客户端
参考 /root/.openclaw/reference/llm_shared/common_llm_api.py 封装
使用 requests 直接调用 Volcengine Ark OpenAI-compatible API
"""
import os
import json
import time
import logging
import traceback
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
CURRENT_PATH = os.path.dirname(os.path.abspath(__file__))
# ============ 日志 ============
logger = logging.getLogger("llm_client")
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(
"%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s"
))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
# ============ 模型配置 ============
MODEL_REGISTRY = {
"doubao-seed-2.0-pro": {
"provider": "volcano",
"endpoint": "ep-m-20260301164317-vmmj4",
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
"disable_thinking": True,
},
"doubao-1.8-volcano": {
"provider": "volcano",
"endpoint": "ep-20260106175024-6stxn",
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
},
"doubao-1.6-volcano": {
"provider": "volcano",
"endpoint": "ep-20250729144911-s9hwc",
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
},
"doubao-1.5-pro-volcano": {
"provider": "volcano",
"endpoint": "ep-20250206151029-tm7xl",
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
},
"deepseek-v3-volcano": {
"provider": "volcano",
"endpoint": "ep-20250206151202-5jh9t",
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
},
}
# API Key 来源优先级:环境变量 > 配置文件
DEFAULT_API_KEY = "32994652-505c-492b-b6da-616ec5c5733c"
DEFAULT_MODEL = "doubao-seed-2.0-pro"
LLM_RETRY_TIMES = 2
def _get_api_key(provider="volcano"):
"""获取 API Key"""
env_map = {
"volcano": "VOLCANO_API_KEY",
}
env_var = env_map.get(provider, "VOLCANO_API_KEY")
key = os.getenv(env_var)
if key:
return key
return DEFAULT_API_KEY
class LLMClient:
"""通用 LLM 客户端,使用 requests 调用 OpenAI-compatible API"""
def __init__(self, model_name=None):
self.model_name = model_name or DEFAULT_MODEL
if self.model_name not in MODEL_REGISTRY:
raise ValueError(f"未注册的模型: {self.model_name},可用: {list(MODEL_REGISTRY.keys())}")
self.model_config = MODEL_REGISTRY[self.model_name]
self.api_key = _get_api_key(self.model_config["provider"])
self.base_url = self.model_config["base_url"]
self.endpoint = self.model_config["endpoint"]
self.total_err_cnt = 0
self._init_session()
def _init_session(self):
"""初始化 requests session带连接池和重试"""
self.session = requests.Session()
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
)
adapter = HTTPAdapter(
pool_connections=5,
pool_maxsize=10,
max_retries=retry_strategy,
)
self.session.mount("http://", adapter)
self.session.mount("https://", adapter)
def call(self, system_prompt, user_prompt="", max_tokens=4096, temperature=0.3, timeout=120):
"""
调用 LLM
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词如果为空system_prompt 将作为唯一 message
max_tokens: 最大输出 token 数
temperature: 温度参数
timeout: 超时时间(秒)
Returns:
tuple: (content_str, {"prompt_tokens": int, "completion_tokens": int})
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if user_prompt:
messages.append({"role": "user", "content": user_prompt})
elif not system_prompt:
raise ValueError("system_prompt 和 user_prompt 不能同时为空")
url = f"{self.base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.endpoint,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": False,
}
# Disable thinking for models that support it
if self.model_config.get("disable_thinking"):
payload["thinking"] = {"type": "disabled"}
cnt = 0
last_err = ""
while cnt < LLM_RETRY_TIMES:
try:
t1 = time.time()
resp = self.session.post(url, headers=headers, json=payload, timeout=timeout)
t2 = time.time()
resp.raise_for_status()
data = resp.json()
if "error" in data:
raise Exception(f"API error: {data['error']}")
content = data["choices"][0]["message"]["content"]
usage = {
"prompt_tokens": data.get("usage", {}).get("prompt_tokens", 0),
"completion_tokens": data.get("usage", {}).get("completion_tokens", 0),
}
logger.info(
f"LLM [{self.model_name}] 耗时={t2-t1:.2f}s "
f"tokens={usage['prompt_tokens']}+{usage['completion_tokens']}"
)
return content, usage
except Exception:
cnt += 1
self.total_err_cnt += 1
last_err = traceback.format_exc()
logger.warning(f"LLM [{self.model_name}] 第{cnt}次调用失败: {last_err[:200]}")
if cnt < LLM_RETRY_TIMES:
time.sleep(2 * cnt) # 指数退避
raise Exception(f"LLM [{self.model_name}] 调用{LLM_RETRY_TIMES}次均失败! 最后错误:\n{last_err}")
def call_for_json(self, system_prompt, user_prompt="", max_tokens=4096, temperature=0.1, timeout=120):
"""
调用 LLM 并解析返回的 JSON
Returns:
tuple: (parsed_json_obj, usage_dict)
"""
content, usage = self.call(system_prompt, user_prompt, max_tokens, temperature, timeout)
# 尝试从返回内容中提取 JSON
json_str = content.strip()
# 容错LLM 在输入为空时可能返回"无",视为空结构
if json_str == "":
logger.warning(f"LLM返回'',视为空结构(输入可能为空)")
# 根据 prompt 中期望的输出格式推断空结构
# 如果 prompt 要求以 "[" 开始,返回空列表;否则返回空字典
if "直接以\"[\"开始" in user_prompt or "\"[\"开始输出" in user_prompt:
return [], usage
return {}, usage
# 处理 markdown code block 包裹
if json_str.startswith("```"):
# 去掉 ```json 或 ``` 开头和 ``` 结尾
lines = json_str.split("\n")
start = 1 if lines[0].startswith("```") else 0
end = -1 if lines[-1].strip() == "```" else len(lines)
json_str = "\n".join(lines[start:end]).strip()
try:
parsed = json.loads(json_str)
return parsed, usage
except json.JSONDecodeError as e:
# 详细日志:记录完整 prompt 和 LLM 返回内容
logger.error(
f"JSON解析失败!\n"
f"── system_prompt ({len(system_prompt)} chars) ──\n{system_prompt[:2000]}\n"
f"── user_prompt ({len(user_prompt)} chars) ──\n{user_prompt[:3000]}\n"
f"── LLM原始返回 ({len(content)} chars) ──\n{content[:1000]}\n"
f"── 解析错误 ──\n{e}"
)
raise ValueError(f"LLM返回内容不是合法JSON: {e}\n内容: {json_str[:300]}") from e
# ============ 便捷单例 ============
_default_client = None
def get_client(model_name=None):
"""获取 LLM 客户端单例(默认 doubao-seed-2.0-pro"""
global _default_client
target = model_name or DEFAULT_MODEL
if _default_client is None or _default_client.model_name != target:
_default_client = LLMClient(target)
return _default_client
# ============ CLI 测试 ============
if __name__ == "__main__":
import sys
client = get_client()
prompt = sys.argv[1] if len(sys.argv) > 1 else "你好,请用一句话介绍你自己。"
print(f"模型: {client.model_name} ({client.endpoint})")
print(f"提示: {prompt}")
content, usage = client.call("你是一个有用的助手。", prompt)
print(f"回复: {content}")
print(f"用量: {usage}")