ai_member_xiaoban/scripts/mid_component_aggregation.py
2026-06-25 08:00:01 +08:00

377 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
中互动组件按类型聚合统计 + 元数据导出(逐表聚合版)
策略:逐表查询+聚合Python 端合并,避免大 UNION ALL 超时。
"""
import argparse
import json
import os
import sys
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
import psycopg2
import pymysql
import pandas as pd
C_TYPE_NAME_MAPPING = {
"mid_dialog_choose": "对话选择",
"mid_dialog_express": "对话表达",
"mid_dialog_fillin": "对话填空",
"mid_dialog_repeat": "对话跟读",
"mid_dialog_select": "对话选择",
"mid_dialog_sentence": "对话组句",
"mid_image_choose": "图片选择",
"mid_image_drag": "图片拖拽",
"mid_image_multiple": "图片多选",
"mid_image_sequence": "图片排序",
"mid_message_combine": "消息组合",
"mid_message_fillin": "消息填空",
"mid_message_sentence": "消息组句",
"mid_message_spell": "消息拼写",
"mid_message_trace": "消息描红",
"mid_message_word": "消息选词",
"mid_grammar_cloze": "语法挖空",
"mid_grammar_sentence": "语法组句",
"mid_pron_pron": "发音互动",
"mid_sentence_dialogue": "句子对话",
"mid_sentence_makeSentence": "句子造句",
"mid_sentence_material": "句子材料",
"mid_sentence_voice": "句子语音",
"mid_vocab_fillBlank": "词汇填空",
"mid_vocab_image": "词汇图片",
"mid_vocab_instruction": "词汇指令",
"mid_vocab_item": "词汇物品",
}
def get_pg_conn():
return psycopg2.connect(
host=os.getenv("PG_DB_HOST"),
port=os.getenv("PG_DB_PORT"),
user=os.getenv("PG_DB_USER"),
password=os.getenv("PG_DB_PASSWORD"),
database=os.getenv("PG_DB_DATABASE"),
)
def get_mysql_conn():
return pymysql.connect(
host=os.getenv("MYSQL_HOST"),
user=os.getenv("MYSQL_USERNAME"),
password=os.getenv("MYSQL_PASSWORD"),
database="vala_test",
port=int(os.getenv("MYSQL_PORT", 3306)),
charset="utf8mb4",
)
def count_help_actions(user_behavior_info: Any) -> int:
if not user_behavior_info:
return 0
try:
if isinstance(user_behavior_info, str):
data = json.loads(user_behavior_info)
elif isinstance(user_behavior_info, list):
data = user_behavior_info
else:
return 0
if not isinstance(data, list):
return 0
count = 0
for item in data:
if isinstance(item, dict):
opt = item.get("submitOpt", "")
if opt and "help" in str(opt).lower():
count += 1
return count
except (json.JSONDecodeError, TypeError):
return 0
def query_table_aggregation(cursor, table: str) -> Dict[str, Dict]:
"""查询单表的聚合数据"""
query = f"""
SELECT
c_type,
COUNT(*) AS total_records,
COUNT(DISTINCT user_id) AS user_count,
SUM(CASE WHEN play_result = 'Perfect' THEN 1 ELSE 0 END) AS perfect_count,
SUM(CASE WHEN play_result = 'Pass' THEN 1 ELSE 0 END) AS pass_count,
SUM(CASE WHEN play_result IN ('Oops', 'Opps') THEN 1 ELSE 0 END) AS oops_count,
SUM(CASE WHEN play_result = 'None' THEN 1 ELSE 0 END) AS none_count,
AVG(interval_time) AS avg_interval_time,
SUM(interval_time) AS sum_interval_time
FROM {table}
WHERE c_type LIKE 'mid_%%' AND c_type != ''
AND play_status = 1
GROUP BY c_type
"""
cursor.execute(query)
result = {}
for row in cursor.fetchall():
ct = row[0]
result[ct] = {
"total_records": int(row[1]),
"user_count": int(row[2]),
"perfect_count": int(row[3]),
"pass_count": int(row[4]),
"oops_count": int(row[5]),
"none_count": int(row[6]),
"avg_interval_time": float(row[7]) if row[7] else 0,
"sum_interval_time": float(row[8]) if row[8] else 0,
}
return result
def query_table_exit(cursor, table: str) -> Dict[str, int]:
"""查询单表的退出数"""
query = f"""
SELECT c_type, COUNT(*) as cnt
FROM {table}
WHERE c_type LIKE 'mid_%%' AND c_type != ''
AND play_status = 2
GROUP BY c_type
"""
cursor.execute(query)
return {row[0]: int(row[1]) for row in cursor.fetchall()}
def query_table_help_samples(cursor, table: str, limit: int = 5000) -> Dict[str, List[int]]:
"""从单表抽样查询帮助次数"""
query = f"""
SELECT c_type, user_behavior_info
FROM {table}
WHERE c_type LIKE 'mid_%%' AND c_type != ''
AND play_status = 1
AND user_behavior_info IS NOT NULL
AND user_behavior_info != '[]'
LIMIT {limit}
"""
cursor.execute(query)
result: Dict[str, List[int]] = defaultdict(list)
for row in cursor.fetchall():
ct, ubi = row[0], row[1]
result[ct].append(count_help_actions(ubi))
return dict(result)
def query_table_used_ids(cursor, table: str) -> Set[Tuple[str, str]]:
"""查询单表已使用的组件ID"""
query = f"""
SELECT DISTINCT c_type, c_id
FROM {table}
WHERE c_type LIKE 'mid_%%' AND c_type != ''
AND play_status = 1
"""
cursor.execute(query)
return {(row[0], str(row[1])) for row in cursor.fetchall()}
def merge_aggregations(per_table: List[Dict[str, Dict]]) -> Dict[str, Dict]:
"""合并多表聚合结果"""
merged: Dict[str, Dict] = {}
for table_data in per_table:
for ct, d in table_data.items():
if ct not in merged:
merged[ct] = {
"total_records": 0,
"user_count": 0,
"perfect_count": 0,
"pass_count": 0,
"oops_count": 0,
"none_count": 0,
"sum_interval_time": 0.0,
"interval_count": 0,
}
m = merged[ct]
m["total_records"] += d["total_records"]
m["perfect_count"] += d["perfect_count"]
m["pass_count"] += d["pass_count"]
m["oops_count"] += d["oops_count"]
m["none_count"] += d["none_count"]
m["sum_interval_time"] += d["sum_interval_time"]
m["interval_count"] += d["total_records"]
return merged
def main():
parser = argparse.ArgumentParser(description="中互动组件按类型聚合统计")
parser.add_argument("--output-dir", default="output", help="输出目录")
args = parser.parse_args()
print("=" * 60)
print("中互动组件按类型聚合统计(逐表聚合版)")
print("=" * 60)
pg_conn = get_pg_conn()
cursor = pg_conn.cursor()
# ===== 1. 逐表查询聚合 + 退出 + 帮助 + 已用ID =====
print("\n[1/4] 逐表查询...")
all_aggs = []
all_exits: Dict[str, int] = {}
all_helps: Dict[str, List[int]] = defaultdict(list)
all_used_ids: Set[Tuple[str, str]] = set()
all_used_ids_by_type: Dict[str, Set[str]] = defaultdict(set)
for i in range(8):
table = f"user_component_play_record_{i}"
print(f" {table}...", end=" ", flush=True)
agg = query_table_aggregation(cursor, table)
exits = query_table_exit(cursor, table)
helps = query_table_help_samples(cursor, table)
ids = query_table_used_ids(cursor, table)
all_aggs.append(agg)
for ct, cnt in exits.items():
all_exits[ct] = all_exits.get(ct, 0) + cnt
for ct, vals in helps.items():
all_helps[ct].extend(vals)
all_used_ids.update(ids)
for ct, cid in ids:
all_used_ids_by_type[ct].add(cid)
total = sum(d["total_records"] for d in agg.values())
print(f"完成({total}条, {len(agg)}种类型)")
cursor.close()
pg_conn.close()
# ===== 2. 合并聚合结果 =====
print("\n[2/4] 合并聚合结果...")
merged = merge_aggregations(all_aggs)
# 计算 user_count需要去重这里用近似取最大值
# 实际 user_count 跨表可能重复,这里用 SQL 再查一次精确值
print(" 计算精确用户数...")
pg_conn2 = get_pg_conn()
cursor2 = pg_conn2.cursor()
for ct in merged:
# 精确查询每种类型的去重用户数
parts = []
for i in range(8):
parts.append(f"""
SELECT DISTINCT user_id FROM user_component_play_record_{i}
WHERE c_type = %s AND play_status = 1
""")
union = " UNION ".join(parts)
query = f"SELECT COUNT(*) FROM ({union}) AS u"
cursor2.execute(query, tuple([ct] * 8))
merged[ct]["user_count"] = cursor2.fetchone()[0]
cursor2.close()
pg_conn2.close()
# ===== 3. 构建输出 DataFrame =====
print("\n[3/4] 构建输出...")
rows = []
for ct, m in sorted(merged.items(), key=lambda x: x[1]["total_records"], reverse=True):
total = m["total_records"]
perfect = m["perfect_count"]
passed = m["pass_count"]
oops = m["oops_count"]
exit_cnt = all_exits.get(ct, 0)
perfect_pct = perfect / total * 100 if total > 0 else 0
pass_pct = passed / total * 100 if total > 0 else 0
oops_pct = oops / total * 100 if total > 0 else 0
exit_pct = exit_cnt / (total + exit_cnt) * 100 if (total + exit_cnt) > 0 else 0
helps = all_helps.get(ct, [])
avg_help = sum(helps) / len(helps) if helps else 0
avg_time = m["sum_interval_time"] / m["interval_count"] if m["interval_count"] > 0 else 0
avg_time_sec = avg_time / 1000
display_name = C_TYPE_NAME_MAPPING.get(ct, ct)
rows.append({
"组件类型(c_type)": ct,
"组件名称": display_name,
"组件个数": len(all_used_ids_by_type.get(ct, set())),
"总完成人数": m["user_count"],
"总记录数": total,
"Perfect数": perfect,
"Perfect率": f"{perfect_pct:.1f}%",
"Pass数": passed,
"Pass率": f"{pass_pct:.1f}%",
"Oops数": oops,
"Oops率": f"{oops_pct:.1f}%",
"退出数": exit_cnt,
"退出率": f"{exit_pct:.1f}%",
"平均帮助次数": round(avg_help, 2),
"平均耗时(秒)": round(avg_time_sec, 1),
})
result_df = pd.DataFrame(rows)
# ===== 4. 获取 MySQL 元数据 =====
print(" 获取 MySQL 组件元数据...")
mysql_conn = get_mysql_conn()
try:
query = """
SELECT c_type, c_id, title, component_config,
audio_list, text_analysis, related_path,
created_at, updated_at
FROM middle_interaction_component
WHERE c_type IS NOT NULL AND c_type != ''
ORDER BY c_type, c_id
"""
meta_df = pd.read_sql_query(query, mysql_conn)
finally:
mysql_conn.close()
meta_df["_key"] = meta_df.apply(lambda r: (r["c_type"], str(r["c_id"])), axis=1)
meta_df = meta_df[meta_df["_key"].isin(all_used_ids)].drop(columns=["_key"])
meta_df["组件名称"] = meta_df["c_type"].map(C_TYPE_NAME_MAPPING).fillna(meta_df["c_type"])
meta_cols = ["c_type", "组件名称", "c_id", "title", "component_config",
"audio_list", "text_analysis", "related_path", "created_at", "updated_at"]
meta_df = meta_df[meta_cols]
# ===== 5. 输出 Excel =====
print("\n[4/4] 输出 Excel...")
os.makedirs(args.output_dir, exist_ok=True)
date_str = datetime.now().strftime("%Y%m%d")
filename = f"中互动组件聚合统计_{date_str}.xlsx"
output_path = os.path.join(args.output_dir, filename)
with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
result_df.to_excel(writer, sheet_name="按类型聚合统计", index=False)
meta_df.to_excel(writer, sheet_name="组件元数据明细", index=False)
from openpyxl.utils import get_column_letter
for sheet_name in ["按类型聚合统计", "组件元数据明细"]:
ws = writer.sheets[sheet_name]
for col_idx, col_cells in enumerate(ws.columns, 1):
max_len = 0
for cell in col_cells:
if cell.value:
val = str(cell.value)
char_len = sum(2 if ord(c) > 127 else 1 for c in val)
max_len = max(max_len, char_len)
ws.column_dimensions[get_column_letter(col_idx)].width = min(max_len + 4, 60)
print(f"\n{'=' * 60}")
print(f"导出完成!")
print(f"文件: {output_path}")
print(f"Sheet1「按类型聚合统计」: {len(result_df)}")
print(f"Sheet2「组件元数据明细」: {len(meta_df)}")
print(f"{'=' * 60}")
print("\n📊 聚合统计预览:")
print(result_df[["组件名称", "组件个数", "总完成人数", "总记录数",
"Perfect率", "Pass率", "Oops率", "退出率",
"平均帮助次数", "平均耗时(秒)"]].to_string(index=False))
if __name__ == "__main__":
main()