195 lines
8.0 KiB
Python
195 lines
8.0 KiB
Python
"""MySQL 存储层 — 只操作 wechat_group_message 表"""
|
|
|
|
import logging
|
|
import pymysql
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
CREATE_TABLE_SQL = """
|
|
CREATE TABLE IF NOT EXISTS `{table}` (
|
|
`id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
|
|
`group_username` VARCHAR(128) NOT NULL COMMENT '群聊 username, 如 xxx@chatroom',
|
|
`group_name` VARCHAR(255) NOT NULL DEFAULT '' COMMENT '群聊名称',
|
|
`sender_username` VARCHAR(128) NOT NULL DEFAULT '' COMMENT '发送者 wxid',
|
|
`sender_name` VARCHAR(255) NOT NULL DEFAULT '' COMMENT '发送者显示名称',
|
|
`msg_type` VARCHAR(32) NOT NULL DEFAULT 'text' COMMENT '消息类型: text/voice/video/file/sticker/link/location/call/system',
|
|
`content` TEXT COMMENT '消息文本内容或描述',
|
|
`media_url` VARCHAR(1024) DEFAULT NULL COMMENT '媒体文件 COS URL',
|
|
`local_id` BIGINT NOT NULL DEFAULT 0 COMMENT '微信消息 local_id',
|
|
`local_type` BIGINT NOT NULL DEFAULT 0 COMMENT '微信消息原始类型',
|
|
`source_db` VARCHAR(255) NOT NULL DEFAULT '' COMMENT '来源 message_N.db 路径',
|
|
`msg_time` DATETIME NOT NULL COMMENT '消息发送时间(微信 create_time)',
|
|
`msg_timestamp` BIGINT NOT NULL DEFAULT 0 COMMENT '消息时间戳(unix)',
|
|
`svr_msg_id` BIGINT UNSIGNED DEFAULT NULL COMMENT '微信服务端消息ID(server_id)',
|
|
`refer_msg_svrid` BIGINT UNSIGNED DEFAULT NULL COMMENT '引用消息的服务端ID(refermsg/svrid)',
|
|
`collected_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '采集入库时间',
|
|
PRIMARY KEY (`id`),
|
|
UNIQUE KEY `uk_group_local` (`group_username`, `local_id`, `source_db`, `msg_timestamp`),
|
|
KEY `idx_group_time` (`group_username`, `msg_timestamp`),
|
|
KEY `idx_msg_time` (`msg_timestamp`),
|
|
KEY `idx_sender` (`sender_username`),
|
|
KEY `idx_svr_msg_id` (`svr_msg_id`)
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
|
|
COMMENT='微信群聊消息采集表';
|
|
"""
|
|
|
|
INSERT_SQL = """
|
|
INSERT IGNORE INTO `{table}`
|
|
(group_username, group_name, sender_username, sender_name,
|
|
msg_type, content, media_url, local_id, local_type, source_db,
|
|
msg_time, msg_timestamp, svr_msg_id, refer_msg_svrid)
|
|
VALUES
|
|
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, FROM_UNIXTIME(%s), %s, %s, %s)
|
|
"""
|
|
|
|
|
|
class MessageStorage:
|
|
def __init__(self, config):
|
|
self._cfg = config
|
|
self._table = config.mysql_table
|
|
self._conn = None
|
|
|
|
def _get_conn(self):
|
|
if self._conn is not None:
|
|
try:
|
|
self._conn.ping(reconnect=True)
|
|
return self._conn
|
|
except Exception:
|
|
self._conn = None
|
|
self._conn = pymysql.connect(
|
|
host=self._cfg.mysql_host,
|
|
port=self._cfg.mysql_port,
|
|
user=self._cfg.mysql_user,
|
|
password=self._cfg.mysql_password,
|
|
database=self._cfg.mysql_database,
|
|
charset="utf8mb4",
|
|
connect_timeout=10,
|
|
read_timeout=30,
|
|
write_timeout=30,
|
|
autocommit=False,
|
|
)
|
|
return self._conn
|
|
|
|
def ensure_table(self):
|
|
conn = self._get_conn()
|
|
with conn.cursor() as cur:
|
|
cur.execute(CREATE_TABLE_SQL.format(table=self._table))
|
|
self._ensure_columns(cur)
|
|
conn.commit()
|
|
log.info("表 %s 已就绪", self._table)
|
|
|
|
def _ensure_columns(self, cur):
|
|
"""为已有表补充新增列(兼容旧表结构)"""
|
|
cur.execute(f"SHOW COLUMNS FROM `{self._table}`")
|
|
existing = {row[0] for row in cur.fetchall()}
|
|
migrations = [
|
|
("svr_msg_id", "BIGINT UNSIGNED DEFAULT NULL COMMENT '微信服务端消息ID(server_id)' AFTER `msg_timestamp`"),
|
|
("refer_msg_svrid", "BIGINT UNSIGNED DEFAULT NULL COMMENT '引用消息的服务端ID(refermsg/svrid)' AFTER `svr_msg_id`"),
|
|
]
|
|
for col, definition in migrations:
|
|
if col not in existing:
|
|
cur.execute(f"ALTER TABLE `{self._table}` ADD COLUMN `{col}` {definition}")
|
|
log.info("已为表 %s 添加列 %s", self._table, col)
|
|
if "svr_msg_id" not in existing:
|
|
cur.execute(f"ALTER TABLE `{self._table}` ADD INDEX `idx_svr_msg_id` (`svr_msg_id`)")
|
|
log.info("已为表 %s 添加索引 idx_svr_msg_id", self._table)
|
|
|
|
def insert_messages(self, messages: list[dict]) -> int:
|
|
"""批量插入消息,返回实际插入条数"""
|
|
if not messages:
|
|
return 0
|
|
conn = self._get_conn()
|
|
sql = INSERT_SQL.format(table=self._table)
|
|
rows = []
|
|
for m in messages:
|
|
rows.append((
|
|
m["group_username"],
|
|
m.get("group_name", ""),
|
|
m.get("sender_username", ""),
|
|
m.get("sender_name", ""),
|
|
m.get("msg_type", "text"),
|
|
m.get("content", ""),
|
|
m.get("media_url"),
|
|
m["local_id"],
|
|
m.get("local_type", 0),
|
|
m.get("source_db", ""),
|
|
m["create_time"],
|
|
m["create_time"],
|
|
m.get("svr_msg_id"),
|
|
m.get("refer_msg_svrid"),
|
|
))
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.executemany(sql, rows)
|
|
conn.commit()
|
|
return cur.rowcount
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
|
|
def get_last_msg_timestamp(self, group_username: str) -> int:
|
|
"""获取某群最后已入库消息的时间戳"""
|
|
conn = self._get_conn()
|
|
sql = f"SELECT MAX(msg_timestamp) FROM `{self._table}` WHERE group_username = %s"
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, (group_username,))
|
|
row = cur.fetchone()
|
|
return row[0] or 0
|
|
|
|
def get_group_stats(self) -> list[dict]:
|
|
"""获取各群采集统计"""
|
|
conn = self._get_conn()
|
|
sql = f"""
|
|
SELECT group_username, group_name,
|
|
COUNT(*) AS total,
|
|
MAX(msg_timestamp) AS last_ts,
|
|
MIN(msg_timestamp) AS first_ts
|
|
FROM `{self._table}`
|
|
GROUP BY group_username, group_name
|
|
ORDER BY last_ts DESC
|
|
"""
|
|
with conn.cursor(pymysql.cursors.DictCursor) as cur:
|
|
cur.execute(sql)
|
|
return cur.fetchall()
|
|
|
|
def get_total_count(self) -> int:
|
|
conn = self._get_conn()
|
|
with conn.cursor() as cur:
|
|
cur.execute(f"SELECT COUNT(*) FROM `{self._table}`")
|
|
return cur.fetchone()[0]
|
|
|
|
def get_pending_media_messages(self, lookback_days: int = 7, limit: int = 500) -> list[dict]:
|
|
"""查询 media_url 为空的媒体消息(用于回溯补录)"""
|
|
conn = self._get_conn()
|
|
sql = f"""
|
|
SELECT id, group_username, msg_type, content, local_id,
|
|
local_type, source_db, msg_timestamp
|
|
FROM `{self._table}`
|
|
WHERE media_url IS NULL
|
|
AND msg_type IN ('image', 'voice', 'video', 'file')
|
|
AND msg_time > DATE_SUB(NOW(), INTERVAL %s DAY)
|
|
ORDER BY msg_timestamp DESC
|
|
LIMIT %s
|
|
"""
|
|
with conn.cursor(pymysql.cursors.DictCursor) as cur:
|
|
cur.execute(sql, (lookback_days, limit))
|
|
return cur.fetchall()
|
|
|
|
def update_media_url(self, record_id: int, media_url: str) -> bool:
|
|
"""更新单条记录的 media_url"""
|
|
conn = self._get_conn()
|
|
sql = f"UPDATE `{self._table}` SET media_url = %s WHERE id = %s AND media_url IS NULL"
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, (media_url, record_id))
|
|
conn.commit()
|
|
return cur.rowcount > 0
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
|
|
def close(self):
|
|
if self._conn and self._conn.open:
|
|
self._conn.close()
|
|
self._conn = None
|