ai_member_xiaoyan/skills/interactive-component-json/scripts/kp_matcher.py

412 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
知识点匹配模块
连接 MySQL vala_test.vala_kp 表,实现两阶段匹配:粗召回 + LLM精筛。
"""
import os
import re
import json
import logging
import threading
import pymysql
CURRENT_PATH = os.path.dirname(os.path.abspath(__file__))
SKILL_ROOT = os.path.dirname(CURRENT_PATH)
WORKSPACE_ROOT = os.path.dirname(os.path.dirname(SKILL_ROOT))
logger = logging.getLogger("kp_matcher")
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(
"%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s"
))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
# ============ Secrets 加载 ============
def _load_mysql_password_from_secrets():
"""从 workspace secrets.md 读取测试 MySQL 密码"""
secrets_path = os.path.join(WORKSPACE_ROOT, "secrets.md")
if not os.path.exists(secrets_path):
return None
try:
with open(secrets_path, "r", encoding="utf-8") as f:
content = f.read()
# 匹配 "### 测试 MySQL" 后面的 "read_only: xxx" 行
match = re.search(r'###\s*测试\s*MySQL\s*\n\s*read_only:\s*(.+)', content)
if match:
return match.group(1).strip()
except Exception:
pass
return None
def _get_mysql_password():
"""获取 MySQL 密码:环境变量 > secrets.md"""
env_pw = os.getenv("MYSQL_PASSWORD")
if env_pw:
return env_pw
pw = _load_mysql_password_from_secrets()
if pw:
return pw
logger.warning("未找到 MySQL 密码(环境变量 MYSQL_PASSWORD 和 secrets.md 均不可用)")
return ""
# ============ MySQL 配置 ============
MYSQL_CONFIG = {
"host": os.getenv("MYSQL_HOST", "bj-cdb-8frbdwju.sql.tencentcdb.com"),
"port": int(os.getenv("MYSQL_PORT", "25413")),
"user": os.getenv("MYSQL_USER", "read_only"),
"password": _get_mysql_password(),
"database": os.getenv("MYSQL_DB", "vala_test"),
"charset": "utf8mb4",
"connect_timeout": 10,
"read_timeout": 10,
}
# ============ kpSkill 默认值映射 ============
KP_SKILL_DEFAULTS = {
"vocab": ("vocab_meaning", "词义"),
"sentence": ("sentence_meaning", "语义"),
"grammar": ("sentence_meaning", "语义"),
"pron": ("vocab_pron", "发音"),
}
# ============ LLM 精筛 Prompt ============
KP_MATCH_SYSTEM_PROMPT = """你是一个知识点匹配专家。根据给定的知识点文本和候选列表,选出最匹配的知识点条目。
## 规则
1. 从候选列表中选出与"待匹配知识点"含义最接近的一条
2. 如果有 level 信息,优先选择同 level 的条目
3. 输出格式:仅输出被选中条目的 kp_id不要包含其他内容
4. 如果没有任何候选匹配,输出 null"""
def _build_match_user_prompt(kp_text, candidates, level=None, context=""):
"""构建 LLM 精筛的 user prompt"""
lines = [f"待匹配知识点: {kp_text}"]
if level:
lines.append(f"当前级别: {level}")
if context:
lines.append(f"组件上下文: {context}")
lines.append("\n候选列表:")
for c in candidates:
lines.append(f"- kp_id={c['kp_id']}, title={c['title']}, type={c['type']}, "
f"level={c['vala_level']}, desc={c['en_desc'] or ''}")
return "\n".join(lines)
# ============ MySQL 连接 ============
_thread_local = threading.local()
def _get_connection():
"""获取 MySQL 连接(线程本地存储,线程安全)"""
conn = getattr(_thread_local, 'connection', None)
if conn and conn.open:
try:
conn.ping(reconnect=True)
return conn
except Exception:
pass
conn = pymysql.connect(**MYSQL_CONFIG, cursorclass=pymysql.cursors.DictCursor)
_thread_local.connection = conn
return conn
def _close_connection():
"""关闭当前线程的连接"""
conn = getattr(_thread_local, 'connection', None)
if conn:
try:
conn.close()
except Exception:
pass
_thread_local.connection = None
# ============ 粗召回 ============
def _rough_recall(title_text, level=None):
"""
基于 title 字段的粗召回。
策略:
1. 精确匹配 title = text
2. 如果无精确匹配,尝试 LIKE 模糊匹配
3. 可选用 vala_level 过滤
4. 最多返回 10 条候选
"""
conn = _get_connection()
results = []
with conn.cursor() as cursor:
# 精确匹配
sql = "SELECT kp_id, type, title, vala_level, en_desc FROM vala_kp WHERE deleted_at IS NULL AND title = %s"
params = [title_text]
if level:
sql += " AND vala_level = %s"
params.append(level)
sql += " LIMIT 10"
cursor.execute(sql, params)
results = cursor.fetchall()
if results:
return results
# 精确匹配无 level 过滤
if level:
sql = "SELECT kp_id, type, title, vala_level, en_desc FROM vala_kp WHERE deleted_at IS NULL AND title = %s LIMIT 10"
cursor.execute(sql, [title_text])
results = cursor.fetchall()
if results:
return results
# LIKE 模糊匹配
like_pattern = f"%{title_text}%"
sql = "SELECT kp_id, type, title, vala_level, en_desc FROM vala_kp WHERE deleted_at IS NULL AND title LIKE %s"
params = [like_pattern]
if level:
sql += " AND vala_level = %s"
params.append(level)
sql += " LIMIT 10"
cursor.execute(sql, params)
results = cursor.fetchall()
if not results and level:
sql = "SELECT kp_id, type, title, vala_level, en_desc FROM vala_kp WHERE deleted_at IS NULL AND title LIKE %s LIMIT 10"
cursor.execute(sql, [like_pattern])
results = cursor.fetchall()
return results
# ============ 精筛选 ============
def _precise_match(kp_text, candidates, level=None, context="", llm_client=None):
"""
使用 LLM 从多个候选中精确选择最匹配的。
Returns:
dict or None: 选中的候选条目
"""
if not candidates:
return None
if len(candidates) == 1:
return candidates[0]
# 优先用 level 过滤
if level:
level_matched = [c for c in candidates if c.get("vala_level") == level]
if len(level_matched) == 1:
return level_matched[0]
if level_matched:
candidates = level_matched
if len(candidates) == 1:
return candidates[0]
# 多候选需要 LLM 精筛
if llm_client is None:
# 无 LLM 时取第一个(按 kp_id 排序)
candidates.sort(key=lambda x: x.get("kp_id", ""))
return candidates[0]
user_prompt = _build_match_user_prompt(kp_text, candidates, level, context)
try:
response, usage = llm_client.call(KP_MATCH_SYSTEM_PROMPT, user_prompt, max_tokens=64, temperature=0.0)
response = response.strip()
if response == "null":
return None
# 从候选中查找返回的 kp_id
for c in candidates:
if c["kp_id"] == response:
return c
# LLM 可能返回带引号的 kp_id
cleaned = response.strip('"').strip("'")
for c in candidates:
if c["kp_id"] == cleaned:
return c
# 兜底:取第一个
logger.warning(f"LLM 返回的 kp_id '{response}' 不在候选列表中,使用第一个候选")
return candidates[0]
except Exception as e:
logger.warning(f"LLM 精筛失败: {e},使用第一个候选")
return candidates[0]
# ============ 单条知识点匹配 ============
def _match_single_kp(kp_text, level=None, context="", llm_client=None):
"""
匹配单个知识点文本到数据库记录。
Returns:
dict: {"kpId": str|None, "kpType": str, "kpTitle": str, "kpSkill": str, "kpSkillName": str, "candidates": list}
"""
# 默认结果(未匹配)
is_sentence = "..." in kp_text or len(kp_text.split()) > 3
default_type = "sentence" if is_sentence else "vocab"
skill, skill_name = KP_SKILL_DEFAULTS.get(default_type, ("vocab_meaning", "词义"))
default_result = {
"kpId": None,
"kpType": default_type,
"kpTitle": kp_text,
"kpSkill": skill,
"kpSkillName": skill_name,
"candidates": [],
}
try:
candidates = _rough_recall(kp_text, level)
except Exception as e:
logger.error(f"MySQL 粗召回失败 '{kp_text}': {e}", exc_info=True)
return default_result
if not candidates:
logger.info(f"未找到匹配: '{kp_text}' (level={level})")
return default_result
# 保存候选列表供 HTML 下拉展示
candidates_for_ui = [
{"kp_id": c["kp_id"], "type": c.get("type", ""), "title": c.get("title", ""),
"vala_level": c.get("vala_level", ""), "en_desc": c.get("en_desc") or ""}
for c in candidates[:10]
]
matched = _precise_match(kp_text, candidates, level, context, llm_client)
if not matched:
default_result["candidates"] = candidates_for_ui
return default_result
# 从匹配结果构建 kpInfo 条目
kp_type = matched.get("type") or default_type
skill, skill_name = KP_SKILL_DEFAULTS.get(kp_type, ("vocab_meaning", "词义"))
return {
"kpId": matched["kp_id"],
"kpType": kp_type,
"kpTitle": kp_text,
"kpSkill": skill,
"kpSkillName": skill_name,
"candidates": candidates_for_ui,
}
# ============ 主接口 ============
def match_knowledge_points(knowledge_text, cType, cId, level=None, llm_client=None):
"""
知识点匹配主接口。解析知识点文本并逐条匹配到数据库。
Args:
knowledge_text: 知识点原始文本(可能多行,来自剧本 sheet
cType: 组件类型标识
cId: 组件ID
level: 剧本级别(如 "L1", "L2"
llm_client: LLMClient 实例(用于多候选精筛)
Returns:
dict or None: kpInfo 对象,无知识点时返回 None
"""
if not knowledge_text or not knowledge_text.strip():
return None
# 清洗文本(复用现有逻辑)
clean_text = re.sub(r'<text[^>]*>', '', knowledge_text)
clean_text = re.sub(r'</text>', '', clean_text)
clean_text = clean_text.strip()
if not clean_text:
return None
lines = [l.strip() for l in clean_text.split("\n") if l.strip()]
kp_list = []
context = f"cType={cType}, cId={cId}"
for line in lines:
# 去除数字后缀(如 "school 1" → "school"
stripped = re.sub(r'\s+\d+$', '', line).strip()
if not stripped:
continue
kp_entry = _match_single_kp(stripped, level, context, llm_client)
kp_list.append(kp_entry)
if not kp_list:
return None
logger.info(f"知识点匹配完成: cType={cType}, cId={cId}, "
f"总数={len(kp_list)}, "
f"已匹配={sum(1 for k in kp_list if k['kpId'] is not None)}")
return {
"pushType": "relationKp",
"cType": cType,
"cId": str(cId),
"kpInfo": kp_list,
}
# ============ CLI 测试 ============
if __name__ == "__main__":
import sys
sys.path.insert(0, CURRENT_PATH)
print("=== 知识点匹配模块测试 ===\n")
# 测试 MySQL 连接
print("1. 测试 MySQL 连接...")
try:
conn = _get_connection()
with conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) as cnt FROM vala_kp WHERE deleted_at IS NULL")
row = cursor.fetchone()
print(f" 连接成功!有效记录数: {row['cnt']}")
except Exception as e:
print(f" 连接失败: {e}")
sys.exit(1)
# 测试单词匹配
print("\n2. 测试单词匹配 'clean' (L2)...")
result = _match_single_kp("clean", level="L2")
print(f" 结果: kpId={result['kpId']}, kpType={result['kpType']}")
# 测试句型匹配
print("\n3. 测试句型匹配 \"Let's/Let me...\"...")
result = _match_single_kp("Let's/Let me...", level="L2")
print(f" 结果: kpId={result['kpId']}, kpType={result['kpType']}")
# 测试完整接口
print("\n4. 测试完整接口(多行知识点)...")
test_text = "clean\nschool\nLet's go to school"
result = match_knowledge_points(test_text, "mid_dialog_repeat", "0112001", level="L2")
if result:
print(f" kpInfo 条数: {len(result['kpInfo'])}")
for kp in result["kpInfo"]:
print(f" - {kp['kpTitle']}: kpId={kp['kpId']}, kpType={kp['kpType']}")
else:
print(" 无结果")
# 测试带 LLM 的精筛
print("\n5. 测试带 LLM 精筛(多候选消歧)...")
from llm_client import get_client
llm = get_client()
result = _match_single_kp("clean", level=None, context="cType=mid_dialog_repeat, 教学场景:打扫教室", llm_client=llm)
print(f" 结果: kpId={result['kpId']}, kpType={result['kpType']}")
_close_connection()
print("\n=== 测试完成 ===")