#!/usr/bin/env python3 """轻量 CORS 代理服务 - 转发推送请求到目标 API + 单组件重新生成""" import json import sys import time import logging import threading import traceback from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler from logging.handlers import TimedRotatingFileHandler from pathlib import Path import requests PROJECT_ROOT = Path(__file__).resolve().parent.parent SCRIPTS_DIR = Path(__file__).resolve().parent LOG_DIR = PROJECT_ROOT / 'logs' LOG_DIR.mkdir(exist_ok=True) file_handler = TimedRotatingFileHandler(LOG_DIR / 'proxy.log', when='midnight', backupCount=30, encoding='utf-8') file_handler.suffix = '%Y-%m-%d' file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')) stream_handler = logging.StreamHandler() stream_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')) logging.basicConfig(level=logging.INFO, handlers=[file_handler, stream_handler]) logger = logging.getLogger(__name__) CONFIG_PATH = PROJECT_ROOT / 'config.json' def load_config(): with open(CONFIG_PATH, 'r') as f: return json.load(f)['proxy'] CONFIG = load_config() VALID_BODY_KEYS = {'componentData', 'pushRelationKp', 'pushType'} # --- Rate Limiter --- class RateLimiter: def __init__(self, max_per_minute): self.max_per_minute = max_per_minute self.records = {} self.lock = threading.Lock() def is_allowed(self, ip): now = time.time() with self.lock: timestamps = self.records.get(ip, []) timestamps = [t for t in timestamps if now - t < 60] if len(timestamps) >= self.max_per_minute: self.records[ip] = timestamps return False timestamps.append(now) self.records[ip] = timestamps return True rate_limiter = RateLimiter(CONFIG['rate_limit_per_minute']) # --- Regenerate Handler --- def _do_regenerate(body_data): """调用 generate_component 重新生成单个组件""" if str(SCRIPTS_DIR) not in sys.path: sys.path.insert(0, str(SCRIPTS_DIR)) from generate_json import generate_component from llm_client import LLMClient teaching_config = body_data.get("teaching_config", "") type_name = body_data.get("type_name", "") cId = body_data.get("cId", "") character_map = body_data.get("character_map") or {} level = body_data.get("level") or None if not teaching_config or not type_name: return 400, {"error": "缺少 teaching_config 或 type_name"} component = { "type_name": type_name, "cId": cId, "teaching_config": teaching_config, "has_image": body_data.get("has_image", False), "knowledge_text": body_data.get("knowledge_text", ""), "config_info": body_data.get("config_info", ""), } try: llm_client = LLMClient() result = generate_component(component, character_map=character_map, llm_client=llm_client, level=level) return 200, result except Exception as e: logger.error(f"Regenerate failed: {traceback.format_exc()}") return 500, {"error": str(e)} class ProxyHandler(BaseHTTPRequestHandler): def _send_json(self, code, data, extra_headers=None): body = json.dumps(data, ensure_ascii=False).encode('utf-8') self.send_response(code) self.send_header('Content-Type', 'application/json; charset=utf-8') self.send_header('Access-Control-Allow-Origin', '*') self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS, GET') self.send_header('Access-Control-Allow-Headers', 'Content-Type') if extra_headers: for k, v in extra_headers.items(): self.send_header(k, v) self.send_header('Content-Length', str(len(body))) self.end_headers() self.wfile.write(body) def _get_client_ip(self): forwarded = self.headers.get('X-Forwarded-For') if forwarded: return forwarded.split(',')[0].strip() return self.client_address[0] def do_OPTIONS(self): if self.path in ('/api/push', '/api/regenerate'): self.send_response(204) self.send_header('Access-Control-Allow-Origin', '*') self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS') self.send_header('Access-Control-Allow-Headers', 'Content-Type') self.send_header('Access-Control-Max-Age', '86400') self.end_headers() else: self.send_error(404) def do_GET(self): if self.path == '/health': self._send_json(200, {'status': 'ok'}) else: self.send_error(404) def do_POST(self): if self.path == '/api/push': self._handle_push() elif self.path == '/api/regenerate': self._handle_regenerate() else: self.send_error(404) def _read_body(self, max_size=2 * 1024 * 1024): content_length = int(self.headers.get('Content-Length', 0)) if content_length == 0: return None, (400, {'error': 'empty body'}) if content_length > max_size: return None, (413, {'error': 'body too large'}) raw_body = self.rfile.read(content_length) try: return json.loads(raw_body), None except json.JSONDecodeError: return None, (400, {'error': 'invalid JSON'}) def _handle_push(self): client_ip = self._get_client_ip() if not rate_limiter.is_allowed(client_ip): self._send_json(429, {'error': 'rate limit exceeded'}) logger.warning(f'Rate limit hit: {client_ip}') return content_length = int(self.headers.get('Content-Length', 0)) if content_length == 0: self._send_json(400, {'error': 'empty body'}) return if content_length > 1024 * 1024: self._send_json(413, {'error': 'body too large'}) return raw_body = self.rfile.read(content_length) try: body_data = json.loads(raw_body) except json.JSONDecodeError: self._send_json(400, {'error': 'invalid JSON'}) logger.warning(f'[{client_ip}] Invalid JSON body: {raw_body[:200]}') return if not isinstance(body_data, dict) or not VALID_BODY_KEYS.intersection(body_data.keys()): self._send_json(400, {'error': f'body must contain one of: {", ".join(VALID_BODY_KEYS)}'}) logger.warning(f'[{client_ip}] Invalid body keys: {list(body_data.keys()) if isinstance(body_data, dict) else type(body_data)}') return logger.info(f'[{client_ip}] POST /api/push | keys={list(body_data.keys())} | size={content_length}B') logger.info(f'[{client_ip}] >>> body: {raw_body.decode("utf-8", errors="replace")[:2000]}') try: resp = requests.post( CONFIG['target_url'], json=body_data, headers={ 'Content-Type': 'application/json', 'Origin': CONFIG['target_origin'], }, timeout=15, ) resp_body = resp.text[:1000] self._send_json(resp.status_code, resp.json() if resp.headers.get('content-type', '').startswith('application/json') else {'raw': resp_body}) logger.info(f'[{client_ip}] -> upstream {resp.status_code} | resp={resp_body[:300]}') except requests.RequestException as e: self._send_json(502, {'error': f'upstream error: {str(e)}'}) logger.error(f'[{client_ip}] Forward failed: {e}') def _handle_regenerate(self): client_ip = self._get_client_ip() logger.info(f'[{client_ip}] POST /api/regenerate') body_data, err = self._read_body() if err: self._send_json(err[0], err[1]) return if not isinstance(body_data, dict) or not body_data.get('teaching_config'): self._send_json(400, {'error': 'body must contain teaching_config'}) return logger.info(f'[{client_ip}] Regenerating: type_name={body_data.get("type_name")}, cId={body_data.get("cId")}') code, result = _do_regenerate(body_data) self._send_json(code, result) logger.info(f'[{client_ip}] Regenerate done: status={code}') def log_message(self, format, *args): logger.debug(f'{self.client_address[0]} - {format % args}') def main(): port = CONFIG['port'] server = ThreadingHTTPServer(('0.0.0.0', port), ProxyHandler) logger.info(f'Proxy server started on 0.0.0.0:{port}') logger.info(f'Target: {CONFIG["target_url"]}') logger.info(f'Rate limit: {CONFIG["rate_limit_per_minute"]} req/min per IP') try: server.serve_forever() except KeyboardInterrupt: logger.info('Shutting down...') server.shutdown() if __name__ == '__main__': main()