wechat_msg_clicker/wechat_clicker/config.py
2026-05-06 14:31:56 +08:00

206 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""配置加载与管理"""
import os
import random
import yaml
from datetime import datetime
# 默认配置
DEFAULTS = {
"wechat": {
"bundle_id": "com.tencent.xinWeChat",
"process_name": "WeChat",
},
"scan": {
"interval_seconds": 15,
"max_chats_per_scan": 0,
"scroll_chat_list": False,
},
"delays": {
"before_click_chat": [0.5, 2],
"after_open_chat": [0.5, 1.5],
"before_click_media": [0.3, 1.5],
"after_click_media": [1.5, 4],
"before_close_preview": [0.3, 1],
"before_close_chat": [0.3, 1],
"between_messages": [0.2, 0.8],
},
"schedule": {
"enabled": True,
"start_hour": 7,
"end_hour": 1,
"pause_on_weekends": False,
},
"filter": {
"mode": "all",
"whitelist": [],
"blacklist": ["腾讯新闻", "微信支付", "微信团队", "服务号", "订阅号", "文件传输助手"],
},
"media": {
"click_images": True,
"click_files": False,
"click_videos": True,
"max_media_per_chat": 20,
},
"logging": {
"level": "DEBUG",
"file": "wechat_clicker.log",
"max_bytes": 10485760,
"backup_count": 5,
"console": True,
},
}
def _deep_merge(base: dict, override: dict) -> dict:
"""深度合并两个字典override 中的值覆盖 base 中的值。"""
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = _deep_merge(result[key], value)
else:
result[key] = value
return result
class Config:
"""配置管理器"""
def __init__(self, config_path: str = "config.yaml"):
self._config_path = config_path
self._data = self._load()
def _load(self) -> dict:
"""加载配置文件,与默认值合并。"""
if os.path.exists(self._config_path):
with open(self._config_path, "r", encoding="utf-8") as f:
user_config = yaml.safe_load(f) or {}
return _deep_merge(DEFAULTS, user_config)
return DEFAULTS.copy()
def _get(self, dotted_key: str, default=None):
"""通过点号分隔的 key 获取配置值。例如 'scan.interval_seconds'"""
keys = dotted_key.split(".")
value = self._data
for key in keys:
if isinstance(value, dict) and key in value:
value = value[key]
else:
return default
return value
def override(self, dotted_key: str, value):
"""运行时覆盖某个配置值。"""
keys = dotted_key.split(".")
target = self._data
for key in keys[:-1]:
target = target.setdefault(key, {})
target[keys[-1]] = value
# --- 便捷属性 ---
@property
def bundle_id(self) -> str:
return self._get("wechat.bundle_id")
@property
def process_name(self) -> str:
return self._get("wechat.process_name")
@property
def scan_interval(self) -> int:
return self._get("scan.interval_seconds")
@property
def max_chats_per_scan(self) -> int:
return self._get("scan.max_chats_per_scan")
@property
def scroll_chat_list(self) -> bool:
return self._get("scan.scroll_chat_list")
@property
def click_images(self) -> bool:
return self._get("media.click_images")
@property
def click_files(self) -> bool:
return self._get("media.click_files")
@property
def click_videos(self) -> bool:
return self._get("media.click_videos")
@property
def max_media_per_chat(self) -> int:
return self._get("media.max_media_per_chat")
@property
def log_level(self) -> str:
return self._get("logging.level")
@property
def log_file(self) -> str:
return self._get("logging.file")
@property
def log_max_bytes(self) -> int:
return self._get("logging.max_bytes")
@property
def log_backup_count(self) -> int:
return self._get("logging.backup_count")
@property
def log_console(self) -> bool:
return self._get("logging.console")
# --- 延迟 ---
def get_delay(self, delay_name: str) -> float:
"""获取指定名称的随机延迟时间(秒),使用截断正态分布。"""
delay_range = self._get(f"delays.{delay_name}", [1, 3])
min_val, max_val = delay_range[0], delay_range[1]
mid = (min_val + max_val) / 2
std = (max_val - min_val) / 4 # 大部分值落在范围内
value = random.gauss(mid, std)
return max(min_val, min(max_val, value))
# --- 工作时间 ---
def is_within_working_hours(self) -> bool:
"""检查当前时间是否在配置的工作时间内。"""
if not self._get("schedule.enabled"):
return True
now = datetime.now()
# 周末检查
if self._get("schedule.pause_on_weekends") and now.weekday() >= 5:
return False
start_hour = self._get("schedule.start_hour")
end_hour = self._get("schedule.end_hour")
hour = now.hour
if start_hour <= end_hour:
return start_hour <= hour < end_hour
# 跨午夜:如 7:00-1:00 表示工作到凌晨1点
return hour >= start_hour or hour < end_hour
# --- 聊天过滤 ---
def should_process_chat(self, chat_name: str) -> bool:
"""根据白名单/黑名单判断是否处理此聊天。"""
mode = self._get("filter.mode", "all")
if mode == "all":
# 即使在 all 模式下也检查黑名单
blacklist = self._get("filter.blacklist", [])
return chat_name not in blacklist
elif mode == "whitelist":
whitelist = self._get("filter.whitelist", [])
return chat_name in whitelist
elif mode == "blacklist":
blacklist = self._get("filter.blacklist", [])
return chat_name not in blacklist
return True