250 lines
8.8 KiB
Python
250 lines
8.8 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
LLM 调用客户端
|
||
参考 /root/.openclaw/reference/llm_shared/common_llm_api.py 封装
|
||
使用 requests 直接调用 Volcengine Ark OpenAI-compatible API
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import time
|
||
import logging
|
||
import traceback
|
||
import requests
|
||
from requests.adapters import HTTPAdapter
|
||
from urllib3.util.retry import Retry
|
||
|
||
CURRENT_PATH = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
# ============ 日志 ============
|
||
logger = logging.getLogger("llm_client")
|
||
if not logger.handlers:
|
||
handler = logging.StreamHandler()
|
||
handler.setFormatter(logging.Formatter(
|
||
"%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s"
|
||
))
|
||
logger.addHandler(handler)
|
||
logger.setLevel(logging.INFO)
|
||
|
||
# ============ 模型配置 ============
|
||
MODEL_REGISTRY = {
|
||
"doubao-seed-2.0-pro": {
|
||
"provider": "volcano",
|
||
"endpoint": "ep-m-20260301164317-vmmj4",
|
||
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||
"disable_thinking": True,
|
||
},
|
||
"doubao-1.8-volcano": {
|
||
"provider": "volcano",
|
||
"endpoint": "ep-20260106175024-6stxn",
|
||
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||
},
|
||
"doubao-1.6-volcano": {
|
||
"provider": "volcano",
|
||
"endpoint": "ep-20250729144911-s9hwc",
|
||
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||
},
|
||
"doubao-1.5-pro-volcano": {
|
||
"provider": "volcano",
|
||
"endpoint": "ep-20250206151029-tm7xl",
|
||
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||
},
|
||
"deepseek-v3-volcano": {
|
||
"provider": "volcano",
|
||
"endpoint": "ep-20250206151202-5jh9t",
|
||
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||
},
|
||
}
|
||
|
||
# API Key 来源优先级:环境变量 > 配置文件
|
||
DEFAULT_API_KEY = "32994652-505c-492b-b6da-616ec5c5733c"
|
||
DEFAULT_MODEL = "doubao-seed-2.0-pro"
|
||
LLM_RETRY_TIMES = 2
|
||
|
||
|
||
def _get_api_key(provider="volcano"):
|
||
"""获取 API Key"""
|
||
env_map = {
|
||
"volcano": "VOLCANO_API_KEY",
|
||
}
|
||
env_var = env_map.get(provider, "VOLCANO_API_KEY")
|
||
key = os.getenv(env_var)
|
||
if key:
|
||
return key
|
||
return DEFAULT_API_KEY
|
||
|
||
|
||
class LLMClient:
|
||
"""通用 LLM 客户端,使用 requests 调用 OpenAI-compatible API"""
|
||
|
||
def __init__(self, model_name=None):
|
||
self.model_name = model_name or DEFAULT_MODEL
|
||
if self.model_name not in MODEL_REGISTRY:
|
||
raise ValueError(f"未注册的模型: {self.model_name},可用: {list(MODEL_REGISTRY.keys())}")
|
||
|
||
self.model_config = MODEL_REGISTRY[self.model_name]
|
||
self.api_key = _get_api_key(self.model_config["provider"])
|
||
self.base_url = self.model_config["base_url"]
|
||
self.endpoint = self.model_config["endpoint"]
|
||
self.total_err_cnt = 0
|
||
self._init_session()
|
||
|
||
def _init_session(self):
|
||
"""初始化 requests session(带连接池和重试)"""
|
||
self.session = requests.Session()
|
||
retry_strategy = Retry(
|
||
total=3,
|
||
backoff_factor=1,
|
||
status_forcelist=[429, 500, 502, 503, 504],
|
||
)
|
||
adapter = HTTPAdapter(
|
||
pool_connections=5,
|
||
pool_maxsize=10,
|
||
max_retries=retry_strategy,
|
||
)
|
||
self.session.mount("http://", adapter)
|
||
self.session.mount("https://", adapter)
|
||
|
||
def call(self, system_prompt, user_prompt="", max_tokens=4096, temperature=0.3, timeout=120):
|
||
"""
|
||
调用 LLM
|
||
|
||
Args:
|
||
system_prompt: 系统提示词
|
||
user_prompt: 用户提示词(如果为空,system_prompt 将作为唯一 message)
|
||
max_tokens: 最大输出 token 数
|
||
temperature: 温度参数
|
||
timeout: 超时时间(秒)
|
||
|
||
Returns:
|
||
tuple: (content_str, {"prompt_tokens": int, "completion_tokens": int})
|
||
"""
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
if user_prompt:
|
||
messages.append({"role": "user", "content": user_prompt})
|
||
elif not system_prompt:
|
||
raise ValueError("system_prompt 和 user_prompt 不能同时为空")
|
||
|
||
url = f"{self.base_url}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
payload = {
|
||
"model": self.endpoint,
|
||
"messages": messages,
|
||
"max_tokens": max_tokens,
|
||
"temperature": temperature,
|
||
"stream": False,
|
||
}
|
||
|
||
# Disable thinking for models that support it
|
||
if self.model_config.get("disable_thinking"):
|
||
payload["thinking"] = {"type": "disabled"}
|
||
|
||
cnt = 0
|
||
last_err = ""
|
||
while cnt < LLM_RETRY_TIMES:
|
||
try:
|
||
t1 = time.time()
|
||
resp = self.session.post(url, headers=headers, json=payload, timeout=timeout)
|
||
t2 = time.time()
|
||
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
|
||
if "error" in data:
|
||
raise Exception(f"API error: {data['error']}")
|
||
|
||
content = data["choices"][0]["message"]["content"]
|
||
usage = {
|
||
"prompt_tokens": data.get("usage", {}).get("prompt_tokens", 0),
|
||
"completion_tokens": data.get("usage", {}).get("completion_tokens", 0),
|
||
}
|
||
|
||
logger.info(
|
||
f"LLM [{self.model_name}] 耗时={t2-t1:.2f}s "
|
||
f"tokens={usage['prompt_tokens']}+{usage['completion_tokens']}"
|
||
)
|
||
return content, usage
|
||
|
||
except Exception:
|
||
cnt += 1
|
||
self.total_err_cnt += 1
|
||
last_err = traceback.format_exc()
|
||
logger.warning(f"LLM [{self.model_name}] 第{cnt}次调用失败: {last_err[:200]}")
|
||
if cnt < LLM_RETRY_TIMES:
|
||
time.sleep(2 * cnt) # 指数退避
|
||
|
||
raise Exception(f"LLM [{self.model_name}] 调用{LLM_RETRY_TIMES}次均失败! 最后错误:\n{last_err}")
|
||
|
||
def call_for_json(self, system_prompt, user_prompt="", max_tokens=4096, temperature=0.1, timeout=120):
|
||
"""
|
||
调用 LLM 并解析返回的 JSON
|
||
|
||
Returns:
|
||
tuple: (parsed_json_obj, usage_dict)
|
||
"""
|
||
content, usage = self.call(system_prompt, user_prompt, max_tokens, temperature, timeout)
|
||
|
||
# 尝试从返回内容中提取 JSON
|
||
json_str = content.strip()
|
||
|
||
# 容错:LLM 在输入为空时可能返回"无",视为空结构
|
||
if json_str == "无":
|
||
logger.warning(f"LLM返回'无',视为空结构(输入可能为空)")
|
||
# 根据 prompt 中期望的输出格式推断空结构
|
||
# 如果 prompt 要求以 "[" 开始,返回空列表;否则返回空字典
|
||
if "直接以\"[\"开始" in user_prompt or "以\"[\"开始输出" in user_prompt:
|
||
return [], usage
|
||
return {}, usage
|
||
|
||
# 处理 markdown code block 包裹
|
||
if json_str.startswith("```"):
|
||
# 去掉 ```json 或 ``` 开头和 ``` 结尾
|
||
lines = json_str.split("\n")
|
||
start = 1 if lines[0].startswith("```") else 0
|
||
end = -1 if lines[-1].strip() == "```" else len(lines)
|
||
json_str = "\n".join(lines[start:end]).strip()
|
||
|
||
try:
|
||
parsed = json.loads(json_str)
|
||
return parsed, usage
|
||
except json.JSONDecodeError as e:
|
||
# 详细日志:记录完整 prompt 和 LLM 返回内容
|
||
logger.error(
|
||
f"JSON解析失败!\n"
|
||
f"── system_prompt ({len(system_prompt)} chars) ──\n{system_prompt[:2000]}\n"
|
||
f"── user_prompt ({len(user_prompt)} chars) ──\n{user_prompt[:3000]}\n"
|
||
f"── LLM原始返回 ({len(content)} chars) ──\n{content[:1000]}\n"
|
||
f"── 解析错误 ──\n{e}"
|
||
)
|
||
raise ValueError(f"LLM返回内容不是合法JSON: {e}\n内容: {json_str[:300]}") from e
|
||
|
||
|
||
# ============ 便捷单例 ============
|
||
_default_client = None
|
||
|
||
|
||
def get_client(model_name=None):
|
||
"""获取 LLM 客户端单例(默认 doubao-seed-2.0-pro)"""
|
||
global _default_client
|
||
target = model_name or DEFAULT_MODEL
|
||
if _default_client is None or _default_client.model_name != target:
|
||
_default_client = LLMClient(target)
|
||
return _default_client
|
||
|
||
|
||
# ============ CLI 测试 ============
|
||
if __name__ == "__main__":
|
||
import sys
|
||
client = get_client()
|
||
prompt = sys.argv[1] if len(sys.argv) > 1 else "你好,请用一句话介绍你自己。"
|
||
print(f"模型: {client.model_name} ({client.endpoint})")
|
||
print(f"提示: {prompt}")
|
||
content, usage = client.call("你是一个有用的助手。", prompt)
|
||
print(f"回复: {content}")
|
||
print(f"用量: {usage}")
|