313 lines
11 KiB
Python
313 lines
11 KiB
Python
"""扫描调度器 — 群聊发现 + 优先级队列 + 自适应频率 + COS 上传"""
|
||
|
||
import heapq
|
||
import logging
|
||
import os
|
||
import re
|
||
import sys
|
||
import time
|
||
import random
|
||
from datetime import datetime
|
||
|
||
from .config import CollectorConfig
|
||
from .storage import MessageStorage
|
||
from .wechat_adapter import WeChatAdapter, MEDIA_TYPES
|
||
|
||
log = logging.getLogger(__name__)
|
||
|
||
|
||
def _init_cos_uploader(config: CollectorConfig):
|
||
"""延迟初始化 COS 上传器(避免未安装 SDK 时导入失败)"""
|
||
cos_script = os.path.join(
|
||
os.path.dirname(os.path.dirname(__file__)),
|
||
"skills", "tencent-cos-upload.xiaokui", "scripts",
|
||
)
|
||
if cos_script not in sys.path:
|
||
sys.path.insert(0, cos_script)
|
||
from cos_upload import CosUploader
|
||
return CosUploader(
|
||
secret_id=config.cos_secret_id,
|
||
secret_key=config.cos_secret_key,
|
||
region=config.cos_region,
|
||
bucket=config.cos_bucket,
|
||
domain=config.cos_download_domain,
|
||
)
|
||
|
||
|
||
def _build_cos_key(config: CollectorConfig, msg_type: str, create_time: int, filename: str) -> str:
|
||
"""构建 COS 存储路径: {base}/类型/YYYY-MM/文件名"""
|
||
dt = datetime.fromtimestamp(create_time)
|
||
date_prefix = dt.strftime("%Y-%m")
|
||
# 文件名清理:去除中文和特殊字符,保留 ASCII
|
||
safe_name = re.sub(r'[^\w.\-]', '_', filename)
|
||
return f"{config.cos_base_path}/{msg_type}/{date_prefix}/{safe_name}"
|
||
|
||
|
||
class GroupState:
|
||
"""单个群聊的扫描状态"""
|
||
__slots__ = (
|
||
"username", "display_name", "next_scan_time",
|
||
"scan_interval", "activity_level", "last_message_at",
|
||
"consecutive_empty",
|
||
)
|
||
|
||
def __init__(self, username, display_name, next_scan_time=0,
|
||
scan_interval=300.0, activity_level="cold",
|
||
last_message_at=0, consecutive_empty=0):
|
||
self.username = username
|
||
self.display_name = display_name
|
||
self.next_scan_time = next_scan_time
|
||
self.scan_interval = scan_interval
|
||
self.activity_level = activity_level
|
||
self.last_message_at = last_message_at
|
||
self.consecutive_empty = consecutive_empty
|
||
|
||
def __lt__(self, other):
|
||
return self.next_scan_time < other.next_scan_time
|
||
|
||
|
||
class GroupScanner:
|
||
def __init__(self, adapter: WeChatAdapter, storage: MessageStorage, config: CollectorConfig):
|
||
self._adapter = adapter
|
||
self._storage = storage
|
||
self._cfg = config
|
||
self._groups: dict[str, GroupState] = {} # username -> GroupState
|
||
self._queue: list[GroupState] = [] # heapq
|
||
self._cos = None
|
||
self._last_discovery = 0
|
||
|
||
def _get_cos(self):
|
||
if self._cos is None:
|
||
try:
|
||
self._cos = _init_cos_uploader(self._cfg)
|
||
log.info("COS 上传器初始化成功")
|
||
except Exception as e:
|
||
log.warning("COS 上传器初始化失败,媒体文件将不会上传: %s", e)
|
||
return self._cos
|
||
|
||
def _matches_filter(self, display_name: str, username: str, patterns: list) -> bool:
|
||
"""检查群名或 username 是否匹配过滤模式列表"""
|
||
for pat in patterns:
|
||
if pat == display_name or pat == username:
|
||
return True
|
||
try:
|
||
if re.search(pat, display_name) or re.search(pat, username):
|
||
return True
|
||
except re.error:
|
||
pass
|
||
return False
|
||
|
||
def discover_groups(self) -> list[str]:
|
||
"""发现所有群聊,应用过滤规则,返回新增群列表"""
|
||
sessions = self._adapter.list_group_sessions(limit=500)
|
||
new_groups = []
|
||
|
||
for s in sessions:
|
||
username = s["username"]
|
||
display = s["display_name"]
|
||
|
||
# 白名单过滤
|
||
if self._cfg.whitelist:
|
||
if not self._matches_filter(display, username, self._cfg.whitelist):
|
||
continue
|
||
|
||
# 黑名单过滤
|
||
if self._cfg.blacklist:
|
||
if self._matches_filter(display, username, self._cfg.blacklist):
|
||
continue
|
||
|
||
if username not in self._groups:
|
||
# 新发现的群
|
||
last_ts = self._storage.get_last_msg_timestamp(username)
|
||
state = GroupState(
|
||
username=username,
|
||
display_name=display,
|
||
next_scan_time=time.time(), # 立即扫描
|
||
scan_interval=self._cfg.base_interval,
|
||
last_message_at=last_ts,
|
||
)
|
||
self._groups[username] = state
|
||
heapq.heappush(self._queue, state)
|
||
new_groups.append(username)
|
||
log.info("发现群聊: %s (%s)", display, username)
|
||
else:
|
||
# 已知群,更新名称
|
||
self._groups[username].display_name = display
|
||
|
||
self._last_discovery = time.time()
|
||
log.info("群聊发现完成: 共 %d 个群, 新增 %d 个", len(self._groups), len(new_groups))
|
||
return new_groups
|
||
|
||
def scan_next_batch(self) -> int:
|
||
"""扫描到期的群聊批次,返回本批新增消息总数"""
|
||
now = time.time()
|
||
total_new = 0
|
||
scanned = 0
|
||
|
||
while scanned < self._cfg.batch_size and self._queue:
|
||
# peek
|
||
if self._queue[0].next_scan_time > now:
|
||
break
|
||
|
||
state = heapq.heappop(self._queue)
|
||
|
||
# 可能已经被移除
|
||
if state.username not in self._groups:
|
||
continue
|
||
|
||
start_t = time.time()
|
||
try:
|
||
new_count = self._scan_single_group(state)
|
||
duration_ms = int((time.time() - start_t) * 1000)
|
||
total_new += new_count
|
||
self._update_activity(state, new_count)
|
||
log.info(
|
||
"[%s] %s: +%d 条, 耗时 %dms, 下次 %ds 后 (%s)",
|
||
state.activity_level.upper(), state.display_name,
|
||
new_count, duration_ms, int(state.scan_interval),
|
||
state.username,
|
||
)
|
||
except Exception as e:
|
||
duration_ms = int((time.time() - start_t) * 1000)
|
||
log.error("[ERROR] %s 扫描失败: %s", state.display_name, e)
|
||
# 出错后延长间隔
|
||
state.scan_interval = min(
|
||
state.scan_interval * 2, self._cfg.max_interval
|
||
)
|
||
|
||
# 推回队列
|
||
state.next_scan_time = time.time() + state.scan_interval
|
||
heapq.heappush(self._queue, state)
|
||
scanned += 1
|
||
|
||
# 群间 jitter
|
||
if scanned < self._cfg.batch_size and self._queue:
|
||
jitter = random.uniform(0, self._cfg.jitter_max)
|
||
time.sleep(jitter)
|
||
|
||
return total_new
|
||
|
||
def _scan_single_group(self, state: GroupState) -> int:
|
||
"""扫描单个群聊,返回新增消息数"""
|
||
messages = self._adapter.query_new_messages(
|
||
state.username,
|
||
after_ts=state.last_message_at,
|
||
limit=self._cfg.messages_per_scan,
|
||
)
|
||
if not messages:
|
||
return 0
|
||
|
||
# 处理媒体上传 + 详细日志
|
||
media_stats = {"uploaded": 0, "no_file": 0, "failed": 0}
|
||
for msg in messages:
|
||
log.info(
|
||
" 新消息: [%s] %s: %s (local_id=%d)",
|
||
msg["msg_type"], msg.get("sender_name", "?"),
|
||
(msg.get("content") or "")[:80],
|
||
msg["local_id"],
|
||
)
|
||
if msg["msg_type"] in MEDIA_TYPES:
|
||
if msg.get("media_path"):
|
||
url = self._upload_media(msg)
|
||
msg["media_url"] = url
|
||
if url:
|
||
media_stats["uploaded"] += 1
|
||
else:
|
||
media_stats["failed"] += 1
|
||
else:
|
||
msg["media_url"] = None
|
||
media_stats["no_file"] += 1
|
||
log.info(" → 媒体文件未在本地找到, 跳过上传")
|
||
else:
|
||
msg["media_url"] = None
|
||
|
||
# 批量入库
|
||
inserted = self._storage.insert_messages(messages)
|
||
|
||
# 媒体统计日志
|
||
if any(v > 0 for v in media_stats.values()):
|
||
log.info(
|
||
" 媒体统计: 上传成功=%d, 本地无文件=%d, 上传失败=%d",
|
||
media_stats["uploaded"], media_stats["no_file"], media_stats["failed"],
|
||
)
|
||
|
||
# 更新状态
|
||
max_ts = max(m["create_time"] for m in messages)
|
||
state.last_message_at = max_ts
|
||
|
||
return inserted
|
||
|
||
def _upload_media(self, msg: dict) -> str | None:
|
||
"""上传媒体文件到 COS,返回 URL"""
|
||
cos = self._get_cos()
|
||
if not cos:
|
||
return None
|
||
|
||
local_path = msg["media_path"]
|
||
if not local_path or not os.path.isfile(local_path):
|
||
return None
|
||
|
||
try:
|
||
filename = os.path.basename(local_path)
|
||
cos_key = _build_cos_key(
|
||
self._cfg, msg["msg_type"], msg["create_time"], filename
|
||
)
|
||
url = cos.upload(local_path, cos_key)
|
||
log.info(" → COS 上传成功: %s → %s", local_path, url)
|
||
return url
|
||
except Exception as e:
|
||
log.warning("上传失败 %s: %s", local_path, e)
|
||
return None
|
||
|
||
def _update_activity(self, state: GroupState, new_count: int):
|
||
"""根据扫描结果更新活跃度和下次扫描间隔"""
|
||
now = time.time()
|
||
|
||
if new_count > 0:
|
||
state.consecutive_empty = 0
|
||
age = now - state.last_message_at if state.last_message_at else float("inf")
|
||
|
||
if age < self._cfg.hot_threshold:
|
||
state.activity_level = "hot"
|
||
state.scan_interval = self._cfg.min_interval
|
||
elif age < self._cfg.warm_threshold:
|
||
state.activity_level = "warm"
|
||
state.scan_interval = self._cfg.base_interval
|
||
else:
|
||
state.activity_level = "warm"
|
||
state.scan_interval = self._cfg.base_interval
|
||
else:
|
||
state.consecutive_empty += 1
|
||
state.activity_level = "cold"
|
||
state.scan_interval = min(
|
||
state.scan_interval * self._cfg.backoff_factor,
|
||
self._cfg.max_interval,
|
||
)
|
||
|
||
def should_discover(self) -> bool:
|
||
return time.time() - self._last_discovery >= self._cfg.discovery_interval
|
||
|
||
def get_status(self) -> dict:
|
||
"""获取扫描器当前状态"""
|
||
total_msgs = self._storage.get_total_count()
|
||
group_stats = {}
|
||
for username, state in self._groups.items():
|
||
group_stats[state.display_name] = {
|
||
"username": username,
|
||
"activity": state.activity_level,
|
||
"interval": f"{int(state.scan_interval)}s",
|
||
"last_msg_at": state.last_message_at,
|
||
}
|
||
return {
|
||
"total_groups": len(self._groups),
|
||
"total_messages": total_msgs,
|
||
"groups": group_stats,
|
||
}
|
||
|
||
def time_to_next(self) -> float:
|
||
"""返回距离下一个群到期的秒数"""
|
||
if not self._queue:
|
||
return self._cfg.cycle_sleep
|
||
wait = self._queue[0].next_scan_time - time.time()
|
||
return max(0, min(wait, self._cfg.cycle_sleep))
|