243 lines
8.8 KiB
Python
243 lines
8.8 KiB
Python
#!/usr/bin/env python3
|
|
"""轻量 CORS 代理服务 - 转发推送请求到目标 API + 单组件重新生成"""
|
|
|
|
import json
|
|
import sys
|
|
import time
|
|
import logging
|
|
import threading
|
|
import traceback
|
|
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
|
|
from logging.handlers import TimedRotatingFileHandler
|
|
from pathlib import Path
|
|
|
|
import requests
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
|
SCRIPTS_DIR = Path(__file__).resolve().parent
|
|
LOG_DIR = PROJECT_ROOT / 'logs'
|
|
LOG_DIR.mkdir(exist_ok=True)
|
|
|
|
file_handler = TimedRotatingFileHandler(LOG_DIR / 'proxy.log', when='midnight', backupCount=30, encoding='utf-8')
|
|
file_handler.suffix = '%Y-%m-%d'
|
|
file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
|
|
stream_handler = logging.StreamHandler()
|
|
stream_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
|
|
|
|
logging.basicConfig(level=logging.INFO, handlers=[file_handler, stream_handler])
|
|
logger = logging.getLogger(__name__)
|
|
|
|
CONFIG_PATH = PROJECT_ROOT / 'config.json'
|
|
|
|
def load_config():
|
|
with open(CONFIG_PATH, 'r') as f:
|
|
return json.load(f)['proxy']
|
|
|
|
CONFIG = load_config()
|
|
|
|
VALID_BODY_KEYS = {'componentData', 'pushRelationKp', 'pushType'}
|
|
|
|
# --- Rate Limiter ---
|
|
class RateLimiter:
|
|
def __init__(self, max_per_minute):
|
|
self.max_per_minute = max_per_minute
|
|
self.records = {}
|
|
self.lock = threading.Lock()
|
|
|
|
def is_allowed(self, ip):
|
|
now = time.time()
|
|
with self.lock:
|
|
timestamps = self.records.get(ip, [])
|
|
timestamps = [t for t in timestamps if now - t < 60]
|
|
if len(timestamps) >= self.max_per_minute:
|
|
self.records[ip] = timestamps
|
|
return False
|
|
timestamps.append(now)
|
|
self.records[ip] = timestamps
|
|
return True
|
|
|
|
rate_limiter = RateLimiter(CONFIG['rate_limit_per_minute'])
|
|
|
|
|
|
# --- Regenerate Handler ---
|
|
def _do_regenerate(body_data):
|
|
"""调用 generate_component 重新生成单个组件"""
|
|
if str(SCRIPTS_DIR) not in sys.path:
|
|
sys.path.insert(0, str(SCRIPTS_DIR))
|
|
|
|
from generate_json import generate_component
|
|
from llm_client import LLMClient
|
|
|
|
teaching_config = body_data.get("teaching_config", "")
|
|
type_name = body_data.get("type_name", "")
|
|
cId = body_data.get("cId", "")
|
|
character_map = body_data.get("character_map") or {}
|
|
level = body_data.get("level") or None
|
|
|
|
if not teaching_config or not type_name:
|
|
return 400, {"error": "缺少 teaching_config 或 type_name"}
|
|
|
|
component = {
|
|
"type_name": type_name,
|
|
"cId": cId,
|
|
"teaching_config": teaching_config,
|
|
"has_image": body_data.get("has_image", False),
|
|
"knowledge_text": body_data.get("knowledge_text", ""),
|
|
"config_info": body_data.get("config_info", ""),
|
|
}
|
|
|
|
try:
|
|
llm_client = LLMClient()
|
|
result = generate_component(component, character_map=character_map, llm_client=llm_client, level=level)
|
|
return 200, result
|
|
except Exception as e:
|
|
logger.error(f"Regenerate failed: {traceback.format_exc()}")
|
|
return 500, {"error": str(e)}
|
|
|
|
|
|
class ProxyHandler(BaseHTTPRequestHandler):
|
|
|
|
def _send_json(self, code, data, extra_headers=None):
|
|
body = json.dumps(data, ensure_ascii=False).encode('utf-8')
|
|
self.send_response(code)
|
|
self.send_header('Content-Type', 'application/json; charset=utf-8')
|
|
self.send_header('Access-Control-Allow-Origin', '*')
|
|
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS, GET')
|
|
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
|
|
if extra_headers:
|
|
for k, v in extra_headers.items():
|
|
self.send_header(k, v)
|
|
self.send_header('Content-Length', str(len(body)))
|
|
self.end_headers()
|
|
self.wfile.write(body)
|
|
|
|
def _get_client_ip(self):
|
|
forwarded = self.headers.get('X-Forwarded-For')
|
|
if forwarded:
|
|
return forwarded.split(',')[0].strip()
|
|
return self.client_address[0]
|
|
|
|
def do_OPTIONS(self):
|
|
if self.path in ('/api/push', '/api/regenerate'):
|
|
self.send_response(204)
|
|
self.send_header('Access-Control-Allow-Origin', '*')
|
|
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
|
|
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
|
|
self.send_header('Access-Control-Max-Age', '86400')
|
|
self.end_headers()
|
|
else:
|
|
self.send_error(404)
|
|
|
|
def do_GET(self):
|
|
if self.path == '/health':
|
|
self._send_json(200, {'status': 'ok'})
|
|
else:
|
|
self.send_error(404)
|
|
|
|
def do_POST(self):
|
|
if self.path == '/api/push':
|
|
self._handle_push()
|
|
elif self.path == '/api/regenerate':
|
|
self._handle_regenerate()
|
|
else:
|
|
self.send_error(404)
|
|
|
|
def _read_body(self, max_size=2 * 1024 * 1024):
|
|
content_length = int(self.headers.get('Content-Length', 0))
|
|
if content_length == 0:
|
|
return None, (400, {'error': 'empty body'})
|
|
if content_length > max_size:
|
|
return None, (413, {'error': 'body too large'})
|
|
raw_body = self.rfile.read(content_length)
|
|
try:
|
|
return json.loads(raw_body), None
|
|
except json.JSONDecodeError:
|
|
return None, (400, {'error': 'invalid JSON'})
|
|
|
|
def _handle_push(self):
|
|
client_ip = self._get_client_ip()
|
|
if not rate_limiter.is_allowed(client_ip):
|
|
self._send_json(429, {'error': 'rate limit exceeded'})
|
|
logger.warning(f'Rate limit hit: {client_ip}')
|
|
return
|
|
|
|
content_length = int(self.headers.get('Content-Length', 0))
|
|
if content_length == 0:
|
|
self._send_json(400, {'error': 'empty body'})
|
|
return
|
|
if content_length > 1024 * 1024:
|
|
self._send_json(413, {'error': 'body too large'})
|
|
return
|
|
|
|
raw_body = self.rfile.read(content_length)
|
|
try:
|
|
body_data = json.loads(raw_body)
|
|
except json.JSONDecodeError:
|
|
self._send_json(400, {'error': 'invalid JSON'})
|
|
logger.warning(f'[{client_ip}] Invalid JSON body: {raw_body[:200]}')
|
|
return
|
|
|
|
if not isinstance(body_data, dict) or not VALID_BODY_KEYS.intersection(body_data.keys()):
|
|
self._send_json(400, {'error': f'body must contain one of: {", ".join(VALID_BODY_KEYS)}'})
|
|
logger.warning(f'[{client_ip}] Invalid body keys: {list(body_data.keys()) if isinstance(body_data, dict) else type(body_data)}')
|
|
return
|
|
|
|
logger.info(f'[{client_ip}] POST /api/push | keys={list(body_data.keys())} | size={content_length}B')
|
|
logger.info(f'[{client_ip}] >>> body: {raw_body.decode("utf-8", errors="replace")[:2000]}')
|
|
|
|
try:
|
|
resp = requests.post(
|
|
CONFIG['target_url'],
|
|
json=body_data,
|
|
headers={
|
|
'Content-Type': 'application/json',
|
|
'Origin': CONFIG['target_origin'],
|
|
},
|
|
timeout=15,
|
|
)
|
|
resp_body = resp.text[:1000]
|
|
self._send_json(resp.status_code, resp.json() if resp.headers.get('content-type', '').startswith('application/json') else {'raw': resp_body})
|
|
logger.info(f'[{client_ip}] -> upstream {resp.status_code} | resp={resp_body[:300]}')
|
|
except requests.RequestException as e:
|
|
self._send_json(502, {'error': f'upstream error: {str(e)}'})
|
|
logger.error(f'[{client_ip}] Forward failed: {e}')
|
|
|
|
def _handle_regenerate(self):
|
|
client_ip = self._get_client_ip()
|
|
logger.info(f'[{client_ip}] POST /api/regenerate')
|
|
|
|
body_data, err = self._read_body()
|
|
if err:
|
|
self._send_json(err[0], err[1])
|
|
return
|
|
|
|
if not isinstance(body_data, dict) or not body_data.get('teaching_config'):
|
|
self._send_json(400, {'error': 'body must contain teaching_config'})
|
|
return
|
|
|
|
logger.info(f'[{client_ip}] Regenerating: type_name={body_data.get("type_name")}, cId={body_data.get("cId")}')
|
|
|
|
code, result = _do_regenerate(body_data)
|
|
self._send_json(code, result)
|
|
logger.info(f'[{client_ip}] Regenerate done: status={code}')
|
|
|
|
def log_message(self, format, *args):
|
|
logger.debug(f'{self.client_address[0]} - {format % args}')
|
|
|
|
|
|
def main():
|
|
port = CONFIG['port']
|
|
server = ThreadingHTTPServer(('0.0.0.0', port), ProxyHandler)
|
|
logger.info(f'Proxy server started on 0.0.0.0:{port}')
|
|
logger.info(f'Target: {CONFIG["target_url"]}')
|
|
logger.info(f'Rate limit: {CONFIG["rate_limit_per_minute"]} req/min per IP')
|
|
try:
|
|
server.serve_forever()
|
|
except KeyboardInterrupt:
|
|
logger.info('Shutting down...')
|
|
server.shutdown()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|