ai_member_xiaoyan/skills/interactive-component-json/scripts/proxy_server.py

169 lines
6.1 KiB
Python

#!/usr/bin/env python3
"""轻量 CORS 代理服务 - 转发推送请求到目标 API"""
import json
import time
import logging
import threading
from http.server import HTTPServer, BaseHTTPRequestHandler
from logging.handlers import TimedRotatingFileHandler
from pathlib import Path
import requests
PROJECT_ROOT = Path(__file__).resolve().parent.parent
LOG_DIR = PROJECT_ROOT / 'logs'
LOG_DIR.mkdir(exist_ok=True)
file_handler = TimedRotatingFileHandler(LOG_DIR / 'proxy.log', when='midnight', backupCount=30, encoding='utf-8')
file_handler.suffix = '%Y-%m-%d'
file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
logging.basicConfig(level=logging.INFO, handlers=[file_handler, stream_handler])
logger = logging.getLogger(__name__)
CONFIG_PATH = PROJECT_ROOT / 'config.json'
def load_config():
with open(CONFIG_PATH, 'r') as f:
return json.load(f)['proxy']
CONFIG = load_config()
VALID_BODY_KEYS = {'componentData', 'pushRelationKp', 'pushType'}
# --- Rate Limiter ---
class RateLimiter:
def __init__(self, max_per_minute):
self.max_per_minute = max_per_minute
self.records = {}
self.lock = threading.Lock()
def is_allowed(self, ip):
now = time.time()
with self.lock:
timestamps = self.records.get(ip, [])
timestamps = [t for t in timestamps if now - t < 60]
if len(timestamps) >= self.max_per_minute:
self.records[ip] = timestamps
return False
timestamps.append(now)
self.records[ip] = timestamps
return True
rate_limiter = RateLimiter(CONFIG['rate_limit_per_minute'])
class ProxyHandler(BaseHTTPRequestHandler):
def _send_json(self, code, data, extra_headers=None):
body = json.dumps(data, ensure_ascii=False).encode('utf-8')
self.send_response(code)
self.send_header('Content-Type', 'application/json; charset=utf-8')
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS, GET')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
if extra_headers:
for k, v in extra_headers.items():
self.send_header(k, v)
self.send_header('Content-Length', str(len(body)))
self.end_headers()
self.wfile.write(body)
def _get_client_ip(self):
forwarded = self.headers.get('X-Forwarded-For')
if forwarded:
return forwarded.split(',')[0].strip()
return self.client_address[0]
def do_OPTIONS(self):
if self.path == '/api/push':
self.send_response(204)
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
self.send_header('Access-Control-Max-Age', '86400')
self.end_headers()
else:
self.send_error(404)
def do_GET(self):
if self.path == '/health':
self._send_json(200, {'status': 'ok'})
else:
self.send_error(404)
def do_POST(self):
if self.path != '/api/push':
self.send_error(404)
return
client_ip = self._get_client_ip()
if not rate_limiter.is_allowed(client_ip):
self._send_json(429, {'error': 'rate limit exceeded'})
logger.warning(f'Rate limit hit: {client_ip}')
return
content_length = int(self.headers.get('Content-Length', 0))
if content_length == 0:
self._send_json(400, {'error': 'empty body'})
return
if content_length > 1024 * 1024:
self._send_json(413, {'error': 'body too large'})
return
raw_body = self.rfile.read(content_length)
try:
body_data = json.loads(raw_body)
except json.JSONDecodeError:
self._send_json(400, {'error': 'invalid JSON'})
logger.warning(f'[{client_ip}] Invalid JSON body: {raw_body[:200]}')
return
if not isinstance(body_data, dict) or not VALID_BODY_KEYS.intersection(body_data.keys()):
self._send_json(400, {'error': f'body must contain one of: {", ".join(VALID_BODY_KEYS)}'})
logger.warning(f'[{client_ip}] Invalid body keys: {list(body_data.keys()) if isinstance(body_data, dict) else type(body_data)}')
return
logger.info(f'[{client_ip}] POST /api/push | keys={list(body_data.keys())} | size={content_length}B')
logger.info(f'[{client_ip}] >>> body: {raw_body.decode("utf-8", errors="replace")[:2000]}')
try:
resp = requests.post(
CONFIG['target_url'],
json=body_data,
headers={
'Content-Type': 'application/json',
'Origin': CONFIG['target_origin'],
},
timeout=15,
)
resp_body = resp.text[:1000]
self._send_json(resp.status_code, resp.json() if resp.headers.get('content-type', '').startswith('application/json') else {'raw': resp_body})
logger.info(f'[{client_ip}] -> upstream {resp.status_code} | resp={resp_body[:300]}')
except requests.RequestException as e:
self._send_json(502, {'error': f'upstream error: {str(e)}'})
logger.error(f'[{client_ip}] Forward failed: {e}')
def log_message(self, format, *args):
logger.debug(f'{self.client_address[0]} - {format % args}')
def main():
port = CONFIG['port']
server = HTTPServer(('0.0.0.0', port), ProxyHandler)
logger.info(f'Proxy server started on 0.0.0.0:{port}')
logger.info(f'Target: {CONFIG["target_url"]}')
logger.info(f'Rate limit: {CONFIG["rate_limit_per_minute"]} req/min per IP')
try:
server.serve_forever()
except KeyboardInterrupt:
logger.info('Shutting down...')
server.shutdown()
if __name__ == '__main__':
main()