412 lines
13 KiB
Python
412 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
知识点匹配模块
|
||
连接 MySQL vala_test.vala_kp 表,实现两阶段匹配:粗召回 + LLM精筛。
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import json
|
||
import logging
|
||
import threading
|
||
|
||
import pymysql
|
||
|
||
CURRENT_PATH = os.path.dirname(os.path.abspath(__file__))
|
||
SKILL_ROOT = os.path.dirname(CURRENT_PATH)
|
||
WORKSPACE_ROOT = os.path.dirname(os.path.dirname(SKILL_ROOT))
|
||
|
||
logger = logging.getLogger("kp_matcher")
|
||
if not logger.handlers:
|
||
handler = logging.StreamHandler()
|
||
handler.setFormatter(logging.Formatter(
|
||
"%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s"
|
||
))
|
||
logger.addHandler(handler)
|
||
logger.setLevel(logging.INFO)
|
||
|
||
# ============ Secrets 加载 ============
|
||
|
||
def _load_mysql_password_from_secrets():
|
||
"""从 workspace secrets.md 读取测试 MySQL 密码"""
|
||
secrets_path = os.path.join(WORKSPACE_ROOT, "secrets.md")
|
||
if not os.path.exists(secrets_path):
|
||
return None
|
||
try:
|
||
with open(secrets_path, "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
# 匹配 "### 测试 MySQL" 后面的 "read_only: xxx" 行
|
||
match = re.search(r'###\s*测试\s*MySQL\s*\n\s*read_only:\s*(.+)', content)
|
||
if match:
|
||
return match.group(1).strip()
|
||
except Exception:
|
||
pass
|
||
return None
|
||
|
||
|
||
def _get_mysql_password():
|
||
"""获取 MySQL 密码:环境变量 > secrets.md"""
|
||
env_pw = os.getenv("MYSQL_PASSWORD")
|
||
if env_pw:
|
||
return env_pw
|
||
pw = _load_mysql_password_from_secrets()
|
||
if pw:
|
||
return pw
|
||
logger.warning("未找到 MySQL 密码(环境变量 MYSQL_PASSWORD 和 secrets.md 均不可用)")
|
||
return ""
|
||
|
||
|
||
# ============ MySQL 配置 ============
|
||
|
||
MYSQL_CONFIG = {
|
||
"host": os.getenv("MYSQL_HOST", "bj-cdb-8frbdwju.sql.tencentcdb.com"),
|
||
"port": int(os.getenv("MYSQL_PORT", "25413")),
|
||
"user": os.getenv("MYSQL_USER", "read_only"),
|
||
"password": _get_mysql_password(),
|
||
"database": os.getenv("MYSQL_DB", "vala_test"),
|
||
"charset": "utf8mb4",
|
||
"connect_timeout": 10,
|
||
"read_timeout": 10,
|
||
}
|
||
|
||
# ============ kpSkill 默认值映射 ============
|
||
|
||
KP_SKILL_DEFAULTS = {
|
||
"vocab": ("vocab_meaning", "词义"),
|
||
"sentence": ("sentence_meaning", "语义"),
|
||
"grammar": ("sentence_meaning", "语义"),
|
||
"pron": ("vocab_pron", "发音"),
|
||
}
|
||
|
||
# ============ LLM 精筛 Prompt ============
|
||
|
||
KP_MATCH_SYSTEM_PROMPT = """你是一个知识点匹配专家。根据给定的知识点文本和候选列表,选出最匹配的知识点条目。
|
||
|
||
## 规则
|
||
1. 从候选列表中选出与"待匹配知识点"含义最接近的一条
|
||
2. 如果有 level 信息,优先选择同 level 的条目
|
||
3. 输出格式:仅输出被选中条目的 kp_id,不要包含其他内容
|
||
4. 如果没有任何候选匹配,输出 null"""
|
||
|
||
|
||
def _build_match_user_prompt(kp_text, candidates, level=None, context=""):
|
||
"""构建 LLM 精筛的 user prompt"""
|
||
lines = [f"待匹配知识点: {kp_text}"]
|
||
if level:
|
||
lines.append(f"当前级别: {level}")
|
||
if context:
|
||
lines.append(f"组件上下文: {context}")
|
||
lines.append("\n候选列表:")
|
||
for c in candidates:
|
||
lines.append(f"- kp_id={c['kp_id']}, title={c['title']}, type={c['type']}, "
|
||
f"level={c['vala_level']}, desc={c['en_desc'] or ''}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
# ============ MySQL 连接 ============
|
||
|
||
_thread_local = threading.local()
|
||
|
||
|
||
def _get_connection():
|
||
"""获取 MySQL 连接(线程本地存储,线程安全)"""
|
||
conn = getattr(_thread_local, 'connection', None)
|
||
if conn and conn.open:
|
||
try:
|
||
conn.ping(reconnect=True)
|
||
return conn
|
||
except Exception:
|
||
pass
|
||
|
||
conn = pymysql.connect(**MYSQL_CONFIG, cursorclass=pymysql.cursors.DictCursor)
|
||
_thread_local.connection = conn
|
||
return conn
|
||
|
||
|
||
def _close_connection():
|
||
"""关闭当前线程的连接"""
|
||
conn = getattr(_thread_local, 'connection', None)
|
||
if conn:
|
||
try:
|
||
conn.close()
|
||
except Exception:
|
||
pass
|
||
_thread_local.connection = None
|
||
|
||
|
||
# ============ 粗召回 ============
|
||
|
||
def _rough_recall(title_text, level=None):
|
||
"""
|
||
基于 title 字段的粗召回。
|
||
|
||
策略:
|
||
1. 精确匹配 title = text
|
||
2. 如果无精确匹配,尝试 LIKE 模糊匹配
|
||
3. 可选用 vala_level 过滤
|
||
4. 最多返回 10 条候选
|
||
"""
|
||
conn = _get_connection()
|
||
results = []
|
||
|
||
with conn.cursor() as cursor:
|
||
# 精确匹配
|
||
sql = "SELECT kp_id, type, title, vala_level, en_desc FROM vala_kp WHERE deleted_at IS NULL AND title = %s"
|
||
params = [title_text]
|
||
if level:
|
||
sql += " AND vala_level = %s"
|
||
params.append(level)
|
||
sql += " LIMIT 10"
|
||
cursor.execute(sql, params)
|
||
results = cursor.fetchall()
|
||
|
||
if results:
|
||
return results
|
||
|
||
# 精确匹配无 level 过滤
|
||
if level:
|
||
sql = "SELECT kp_id, type, title, vala_level, en_desc FROM vala_kp WHERE deleted_at IS NULL AND title = %s LIMIT 10"
|
||
cursor.execute(sql, [title_text])
|
||
results = cursor.fetchall()
|
||
if results:
|
||
return results
|
||
|
||
# LIKE 模糊匹配
|
||
like_pattern = f"%{title_text}%"
|
||
sql = "SELECT kp_id, type, title, vala_level, en_desc FROM vala_kp WHERE deleted_at IS NULL AND title LIKE %s"
|
||
params = [like_pattern]
|
||
if level:
|
||
sql += " AND vala_level = %s"
|
||
params.append(level)
|
||
sql += " LIMIT 10"
|
||
cursor.execute(sql, params)
|
||
results = cursor.fetchall()
|
||
|
||
if not results and level:
|
||
sql = "SELECT kp_id, type, title, vala_level, en_desc FROM vala_kp WHERE deleted_at IS NULL AND title LIKE %s LIMIT 10"
|
||
cursor.execute(sql, [like_pattern])
|
||
results = cursor.fetchall()
|
||
|
||
return results
|
||
|
||
|
||
# ============ 精筛选 ============
|
||
|
||
def _precise_match(kp_text, candidates, level=None, context="", llm_client=None):
|
||
"""
|
||
使用 LLM 从多个候选中精确选择最匹配的。
|
||
|
||
Returns:
|
||
dict or None: 选中的候选条目
|
||
"""
|
||
if not candidates:
|
||
return None
|
||
if len(candidates) == 1:
|
||
return candidates[0]
|
||
|
||
# 优先用 level 过滤
|
||
if level:
|
||
level_matched = [c for c in candidates if c.get("vala_level") == level]
|
||
if len(level_matched) == 1:
|
||
return level_matched[0]
|
||
if level_matched:
|
||
candidates = level_matched
|
||
|
||
if len(candidates) == 1:
|
||
return candidates[0]
|
||
|
||
# 多候选需要 LLM 精筛
|
||
if llm_client is None:
|
||
# 无 LLM 时取第一个(按 kp_id 排序)
|
||
candidates.sort(key=lambda x: x.get("kp_id", ""))
|
||
return candidates[0]
|
||
|
||
user_prompt = _build_match_user_prompt(kp_text, candidates, level, context)
|
||
try:
|
||
response, usage = llm_client.call(KP_MATCH_SYSTEM_PROMPT, user_prompt, max_tokens=64, temperature=0.0)
|
||
response = response.strip()
|
||
if response == "null":
|
||
return None
|
||
# 从候选中查找返回的 kp_id
|
||
for c in candidates:
|
||
if c["kp_id"] == response:
|
||
return c
|
||
# LLM 可能返回带引号的 kp_id
|
||
cleaned = response.strip('"').strip("'")
|
||
for c in candidates:
|
||
if c["kp_id"] == cleaned:
|
||
return c
|
||
# 兜底:取第一个
|
||
logger.warning(f"LLM 返回的 kp_id '{response}' 不在候选列表中,使用第一个候选")
|
||
return candidates[0]
|
||
except Exception as e:
|
||
logger.warning(f"LLM 精筛失败: {e},使用第一个候选")
|
||
return candidates[0]
|
||
|
||
|
||
# ============ 单条知识点匹配 ============
|
||
|
||
def _match_single_kp(kp_text, level=None, context="", llm_client=None):
|
||
"""
|
||
匹配单个知识点文本到数据库记录。
|
||
|
||
Returns:
|
||
dict: {"kpId": str|None, "kpType": str, "kpTitle": str, "kpSkill": str, "kpSkillName": str, "candidates": list}
|
||
"""
|
||
# 默认结果(未匹配)
|
||
is_sentence = "..." in kp_text or len(kp_text.split()) > 3
|
||
default_type = "sentence" if is_sentence else "vocab"
|
||
skill, skill_name = KP_SKILL_DEFAULTS.get(default_type, ("vocab_meaning", "词义"))
|
||
|
||
default_result = {
|
||
"kpId": None,
|
||
"kpType": default_type,
|
||
"kpTitle": kp_text,
|
||
"kpSkill": skill,
|
||
"kpSkillName": skill_name,
|
||
"candidates": [],
|
||
}
|
||
|
||
try:
|
||
candidates = _rough_recall(kp_text, level)
|
||
except Exception as e:
|
||
logger.error(f"MySQL 粗召回失败 '{kp_text}': {e}", exc_info=True)
|
||
return default_result
|
||
|
||
if not candidates:
|
||
logger.info(f"未找到匹配: '{kp_text}' (level={level})")
|
||
return default_result
|
||
|
||
# 保存候选列表供 HTML 下拉展示
|
||
candidates_for_ui = [
|
||
{"kp_id": c["kp_id"], "type": c.get("type", ""), "title": c.get("title", ""),
|
||
"vala_level": c.get("vala_level", ""), "en_desc": c.get("en_desc") or ""}
|
||
for c in candidates[:10]
|
||
]
|
||
|
||
matched = _precise_match(kp_text, candidates, level, context, llm_client)
|
||
if not matched:
|
||
default_result["candidates"] = candidates_for_ui
|
||
return default_result
|
||
|
||
# 从匹配结果构建 kpInfo 条目
|
||
kp_type = matched.get("type") or default_type
|
||
skill, skill_name = KP_SKILL_DEFAULTS.get(kp_type, ("vocab_meaning", "词义"))
|
||
|
||
return {
|
||
"kpId": matched["kp_id"],
|
||
"kpType": kp_type,
|
||
"kpTitle": kp_text,
|
||
"kpSkill": skill,
|
||
"kpSkillName": skill_name,
|
||
"candidates": candidates_for_ui,
|
||
}
|
||
|
||
|
||
# ============ 主接口 ============
|
||
|
||
def match_knowledge_points(knowledge_text, cType, cId, level=None, llm_client=None):
|
||
"""
|
||
知识点匹配主接口。解析知识点文本并逐条匹配到数据库。
|
||
|
||
Args:
|
||
knowledge_text: 知识点原始文本(可能多行,来自剧本 sheet)
|
||
cType: 组件类型标识
|
||
cId: 组件ID
|
||
level: 剧本级别(如 "L1", "L2")
|
||
llm_client: LLMClient 实例(用于多候选精筛)
|
||
|
||
Returns:
|
||
dict or None: kpInfo 对象,无知识点时返回 None
|
||
"""
|
||
if not knowledge_text or not knowledge_text.strip():
|
||
return None
|
||
|
||
# 清洗文本(复用现有逻辑)
|
||
clean_text = re.sub(r'<text[^>]*>', '', knowledge_text)
|
||
clean_text = re.sub(r'</text>', '', clean_text)
|
||
clean_text = clean_text.strip()
|
||
|
||
if not clean_text:
|
||
return None
|
||
|
||
lines = [l.strip() for l in clean_text.split("\n") if l.strip()]
|
||
|
||
kp_list = []
|
||
context = f"cType={cType}, cId={cId}"
|
||
|
||
for line in lines:
|
||
# 去除数字后缀(如 "school 1" → "school")
|
||
stripped = re.sub(r'\s+\d+$', '', line).strip()
|
||
if not stripped:
|
||
continue
|
||
|
||
kp_entry = _match_single_kp(stripped, level, context, llm_client)
|
||
kp_list.append(kp_entry)
|
||
|
||
if not kp_list:
|
||
return None
|
||
|
||
logger.info(f"知识点匹配完成: cType={cType}, cId={cId}, "
|
||
f"总数={len(kp_list)}, "
|
||
f"已匹配={sum(1 for k in kp_list if k['kpId'] is not None)}")
|
||
|
||
return {
|
||
"pushType": "relationKp",
|
||
"cType": cType,
|
||
"cId": str(cId),
|
||
"kpInfo": kp_list,
|
||
}
|
||
|
||
|
||
# ============ CLI 测试 ============
|
||
|
||
if __name__ == "__main__":
|
||
import sys
|
||
sys.path.insert(0, CURRENT_PATH)
|
||
|
||
print("=== 知识点匹配模块测试 ===\n")
|
||
|
||
# 测试 MySQL 连接
|
||
print("1. 测试 MySQL 连接...")
|
||
try:
|
||
conn = _get_connection()
|
||
with conn.cursor() as cursor:
|
||
cursor.execute("SELECT COUNT(*) as cnt FROM vala_kp WHERE deleted_at IS NULL")
|
||
row = cursor.fetchone()
|
||
print(f" 连接成功!有效记录数: {row['cnt']}")
|
||
except Exception as e:
|
||
print(f" 连接失败: {e}")
|
||
sys.exit(1)
|
||
|
||
# 测试单词匹配
|
||
print("\n2. 测试单词匹配 'clean' (L2)...")
|
||
result = _match_single_kp("clean", level="L2")
|
||
print(f" 结果: kpId={result['kpId']}, kpType={result['kpType']}")
|
||
|
||
# 测试句型匹配
|
||
print("\n3. 测试句型匹配 \"Let's/Let me...\"...")
|
||
result = _match_single_kp("Let's/Let me...", level="L2")
|
||
print(f" 结果: kpId={result['kpId']}, kpType={result['kpType']}")
|
||
|
||
# 测试完整接口
|
||
print("\n4. 测试完整接口(多行知识点)...")
|
||
test_text = "clean\nschool\nLet's go to school"
|
||
result = match_knowledge_points(test_text, "mid_dialog_repeat", "0112001", level="L2")
|
||
if result:
|
||
print(f" kpInfo 条数: {len(result['kpInfo'])}")
|
||
for kp in result["kpInfo"]:
|
||
print(f" - {kp['kpTitle']}: kpId={kp['kpId']}, kpType={kp['kpType']}")
|
||
else:
|
||
print(" 无结果")
|
||
|
||
# 测试带 LLM 的精筛
|
||
print("\n5. 测试带 LLM 精筛(多候选消歧)...")
|
||
from llm_client import get_client
|
||
llm = get_client()
|
||
result = _match_single_kp("clean", level=None, context="cType=mid_dialog_repeat, 教学场景:打扫教室", llm_client=llm)
|
||
print(f" 结果: kpId={result['kpId']}, kpType={result['kpType']}")
|
||
|
||
_close_connection()
|
||
print("\n=== 测试完成 ===")
|