ai_member_xiaoyan/skills/interactive-component-json/scripts/proxy_server.py

243 lines
8.8 KiB
Python

#!/usr/bin/env python3
"""轻量 CORS 代理服务 - 转发推送请求到目标 API + 单组件重新生成"""
import json
import sys
import time
import logging
import threading
import traceback
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
from logging.handlers import TimedRotatingFileHandler
from pathlib import Path
import requests
PROJECT_ROOT = Path(__file__).resolve().parent.parent
SCRIPTS_DIR = Path(__file__).resolve().parent
LOG_DIR = PROJECT_ROOT / 'logs'
LOG_DIR.mkdir(exist_ok=True)
file_handler = TimedRotatingFileHandler(LOG_DIR / 'proxy.log', when='midnight', backupCount=30, encoding='utf-8')
file_handler.suffix = '%Y-%m-%d'
file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
logging.basicConfig(level=logging.INFO, handlers=[file_handler, stream_handler])
logger = logging.getLogger(__name__)
CONFIG_PATH = PROJECT_ROOT / 'config.json'
def load_config():
with open(CONFIG_PATH, 'r') as f:
return json.load(f)['proxy']
CONFIG = load_config()
VALID_BODY_KEYS = {'componentData', 'pushRelationKp', 'pushType'}
# --- Rate Limiter ---
class RateLimiter:
def __init__(self, max_per_minute):
self.max_per_minute = max_per_minute
self.records = {}
self.lock = threading.Lock()
def is_allowed(self, ip):
now = time.time()
with self.lock:
timestamps = self.records.get(ip, [])
timestamps = [t for t in timestamps if now - t < 60]
if len(timestamps) >= self.max_per_minute:
self.records[ip] = timestamps
return False
timestamps.append(now)
self.records[ip] = timestamps
return True
rate_limiter = RateLimiter(CONFIG['rate_limit_per_minute'])
# --- Regenerate Handler ---
def _do_regenerate(body_data):
"""调用 generate_component 重新生成单个组件"""
if str(SCRIPTS_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPTS_DIR))
from generate_json import generate_component
from llm_client import LLMClient
teaching_config = body_data.get("teaching_config", "")
type_name = body_data.get("type_name", "")
cId = body_data.get("cId", "")
character_map = body_data.get("character_map") or {}
level = body_data.get("level") or None
if not teaching_config or not type_name:
return 400, {"error": "缺少 teaching_config 或 type_name"}
component = {
"type_name": type_name,
"cId": cId,
"teaching_config": teaching_config,
"has_image": body_data.get("has_image", False),
"knowledge_text": body_data.get("knowledge_text", ""),
"config_info": body_data.get("config_info", ""),
}
try:
llm_client = LLMClient()
result = generate_component(component, character_map=character_map, llm_client=llm_client, level=level)
return 200, result
except Exception as e:
logger.error(f"Regenerate failed: {traceback.format_exc()}")
return 500, {"error": str(e)}
class ProxyHandler(BaseHTTPRequestHandler):
def _send_json(self, code, data, extra_headers=None):
body = json.dumps(data, ensure_ascii=False).encode('utf-8')
self.send_response(code)
self.send_header('Content-Type', 'application/json; charset=utf-8')
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS, GET')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
if extra_headers:
for k, v in extra_headers.items():
self.send_header(k, v)
self.send_header('Content-Length', str(len(body)))
self.end_headers()
self.wfile.write(body)
def _get_client_ip(self):
forwarded = self.headers.get('X-Forwarded-For')
if forwarded:
return forwarded.split(',')[0].strip()
return self.client_address[0]
def do_OPTIONS(self):
if self.path in ('/api/push', '/api/regenerate'):
self.send_response(204)
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
self.send_header('Access-Control-Max-Age', '86400')
self.end_headers()
else:
self.send_error(404)
def do_GET(self):
if self.path == '/health':
self._send_json(200, {'status': 'ok'})
else:
self.send_error(404)
def do_POST(self):
if self.path == '/api/push':
self._handle_push()
elif self.path == '/api/regenerate':
self._handle_regenerate()
else:
self.send_error(404)
def _read_body(self, max_size=2 * 1024 * 1024):
content_length = int(self.headers.get('Content-Length', 0))
if content_length == 0:
return None, (400, {'error': 'empty body'})
if content_length > max_size:
return None, (413, {'error': 'body too large'})
raw_body = self.rfile.read(content_length)
try:
return json.loads(raw_body), None
except json.JSONDecodeError:
return None, (400, {'error': 'invalid JSON'})
def _handle_push(self):
client_ip = self._get_client_ip()
if not rate_limiter.is_allowed(client_ip):
self._send_json(429, {'error': 'rate limit exceeded'})
logger.warning(f'Rate limit hit: {client_ip}')
return
content_length = int(self.headers.get('Content-Length', 0))
if content_length == 0:
self._send_json(400, {'error': 'empty body'})
return
if content_length > 1024 * 1024:
self._send_json(413, {'error': 'body too large'})
return
raw_body = self.rfile.read(content_length)
try:
body_data = json.loads(raw_body)
except json.JSONDecodeError:
self._send_json(400, {'error': 'invalid JSON'})
logger.warning(f'[{client_ip}] Invalid JSON body: {raw_body[:200]}')
return
if not isinstance(body_data, dict) or not VALID_BODY_KEYS.intersection(body_data.keys()):
self._send_json(400, {'error': f'body must contain one of: {", ".join(VALID_BODY_KEYS)}'})
logger.warning(f'[{client_ip}] Invalid body keys: {list(body_data.keys()) if isinstance(body_data, dict) else type(body_data)}')
return
logger.info(f'[{client_ip}] POST /api/push | keys={list(body_data.keys())} | size={content_length}B')
logger.info(f'[{client_ip}] >>> body: {raw_body.decode("utf-8", errors="replace")[:2000]}')
try:
resp = requests.post(
CONFIG['target_url'],
json=body_data,
headers={
'Content-Type': 'application/json',
'Origin': CONFIG['target_origin'],
},
timeout=15,
)
resp_body = resp.text[:1000]
self._send_json(resp.status_code, resp.json() if resp.headers.get('content-type', '').startswith('application/json') else {'raw': resp_body})
logger.info(f'[{client_ip}] -> upstream {resp.status_code} | resp={resp_body[:300]}')
except requests.RequestException as e:
self._send_json(502, {'error': f'upstream error: {str(e)}'})
logger.error(f'[{client_ip}] Forward failed: {e}')
def _handle_regenerate(self):
client_ip = self._get_client_ip()
logger.info(f'[{client_ip}] POST /api/regenerate')
body_data, err = self._read_body()
if err:
self._send_json(err[0], err[1])
return
if not isinstance(body_data, dict) or not body_data.get('teaching_config'):
self._send_json(400, {'error': 'body must contain teaching_config'})
return
logger.info(f'[{client_ip}] Regenerating: type_name={body_data.get("type_name")}, cId={body_data.get("cId")}')
code, result = _do_regenerate(body_data)
self._send_json(code, result)
logger.info(f'[{client_ip}] Regenerate done: status={code}')
def log_message(self, format, *args):
logger.debug(f'{self.client_address[0]} - {format % args}')
def main():
port = CONFIG['port']
server = ThreadingHTTPServer(('0.0.0.0', port), ProxyHandler)
logger.info(f'Proxy server started on 0.0.0.0:{port}')
logger.info(f'Target: {CONFIG["target_url"]}')
logger.info(f'Rate limit: {CONFIG["rate_limit_per_minute"]} req/min per IP')
try:
server.serve_forever()
except KeyboardInterrupt:
logger.info('Shutting down...')
server.shutdown()
if __name__ == '__main__':
main()