#!/usr/bin/env python3 """ 知识点匹配模块 连接 MySQL vala_test.vala_kp 表,实现两阶段匹配:粗召回 + LLM精筛。 """ import os import re import json import logging import threading import pymysql CURRENT_PATH = os.path.dirname(os.path.abspath(__file__)) SKILL_ROOT = os.path.dirname(CURRENT_PATH) WORKSPACE_ROOT = os.path.dirname(os.path.dirname(SKILL_ROOT)) logger = logging.getLogger("kp_matcher") if not logger.handlers: handler = logging.StreamHandler() handler.setFormatter(logging.Formatter( "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s" )) logger.addHandler(handler) logger.setLevel(logging.INFO) # ============ Secrets 加载 ============ def _load_mysql_password_from_secrets(): """从 workspace secrets.md 读取测试 MySQL 密码""" secrets_path = os.path.join(WORKSPACE_ROOT, "secrets.md") if not os.path.exists(secrets_path): return None try: with open(secrets_path, "r", encoding="utf-8") as f: content = f.read() # 匹配 "### 测试 MySQL" 后面的 "read_only: xxx" 行 match = re.search(r'###\s*测试\s*MySQL\s*\n\s*read_only:\s*(.+)', content) if match: return match.group(1).strip() except Exception: pass return None def _get_mysql_password(): """获取 MySQL 密码:环境变量 > secrets.md""" env_pw = os.getenv("MYSQL_PASSWORD") if env_pw: return env_pw pw = _load_mysql_password_from_secrets() if pw: return pw logger.warning("未找到 MySQL 密码(环境变量 MYSQL_PASSWORD 和 secrets.md 均不可用)") return "" # ============ MySQL 配置 ============ MYSQL_CONFIG = { "host": os.getenv("MYSQL_HOST", "bj-cdb-8frbdwju.sql.tencentcdb.com"), "port": int(os.getenv("MYSQL_PORT", "25413")), "user": os.getenv("MYSQL_USER", "read_only"), "password": _get_mysql_password(), "database": os.getenv("MYSQL_DB", "vala_test"), "charset": "utf8mb4", "connect_timeout": 10, "read_timeout": 10, } # ============ kpSkill 默认值映射 ============ KP_SKILL_DEFAULTS = { "vocab": ("vocab_meaning", "词义"), "sentence": ("sentence_meaning", "语义"), "grammar": ("sentence_meaning", "语义"), "pron": ("vocab_pron", "发音"), } # ============ LLM 精筛 Prompt ============ KP_MATCH_SYSTEM_PROMPT = """你是一个知识点匹配专家。根据给定的知识点文本和候选列表,选出最匹配的知识点条目。 ## 规则 1. 从候选列表中选出与"待匹配知识点"含义最接近的一条 2. 如果有 level 信息,优先选择同 level 的条目 3. 输出格式:仅输出被选中条目的 kp_id,不要包含其他内容 4. 如果没有任何候选匹配,输出 null""" def _build_match_user_prompt(kp_text, candidates, level=None, context=""): """构建 LLM 精筛的 user prompt""" lines = [f"待匹配知识点: {kp_text}"] if level: lines.append(f"当前级别: {level}") if context: lines.append(f"组件上下文: {context}") lines.append("\n候选列表:") for c in candidates: lines.append(f"- kp_id={c['kp_id']}, title={c['title']}, type={c['type']}, " f"level={c['vala_level']}, desc={c['en_desc'] or ''}") return "\n".join(lines) # ============ MySQL 连接 ============ _thread_local = threading.local() def _get_connection(): """获取 MySQL 连接(线程本地存储,线程安全)""" conn = getattr(_thread_local, 'connection', None) if conn and conn.open: try: conn.ping(reconnect=True) return conn except Exception: pass conn = pymysql.connect(**MYSQL_CONFIG, cursorclass=pymysql.cursors.DictCursor) _thread_local.connection = conn return conn def _close_connection(): """关闭当前线程的连接""" conn = getattr(_thread_local, 'connection', None) if conn: try: conn.close() except Exception: pass _thread_local.connection = None # ============ 粗召回 ============ def _rough_recall(title_text, level=None): """ 基于 title 字段的粗召回。 策略: 1. 精确匹配 title = text 2. 如果无精确匹配,尝试 LIKE 模糊匹配 3. 可选用 vala_level 过滤 4. 最多返回 10 条候选 """ conn = _get_connection() results = [] with conn.cursor() as cursor: # 精确匹配 sql = "SELECT kp_id, type, title, vala_level, en_desc FROM vala_kp WHERE deleted_at IS NULL AND title = %s" params = [title_text] if level: sql += " AND vala_level = %s" params.append(level) sql += " LIMIT 10" cursor.execute(sql, params) results = cursor.fetchall() if results: return results # 精确匹配无 level 过滤 if level: sql = "SELECT kp_id, type, title, vala_level, en_desc FROM vala_kp WHERE deleted_at IS NULL AND title = %s LIMIT 10" cursor.execute(sql, [title_text]) results = cursor.fetchall() if results: return results # LIKE 模糊匹配 like_pattern = f"%{title_text}%" sql = "SELECT kp_id, type, title, vala_level, en_desc FROM vala_kp WHERE deleted_at IS NULL AND title LIKE %s" params = [like_pattern] if level: sql += " AND vala_level = %s" params.append(level) sql += " LIMIT 10" cursor.execute(sql, params) results = cursor.fetchall() if not results and level: sql = "SELECT kp_id, type, title, vala_level, en_desc FROM vala_kp WHERE deleted_at IS NULL AND title LIKE %s LIMIT 10" cursor.execute(sql, [like_pattern]) results = cursor.fetchall() return results # ============ 精筛选 ============ def _precise_match(kp_text, candidates, level=None, context="", llm_client=None): """ 使用 LLM 从多个候选中精确选择最匹配的。 Returns: dict or None: 选中的候选条目 """ if not candidates: return None if len(candidates) == 1: return candidates[0] # 优先用 level 过滤 if level: level_matched = [c for c in candidates if c.get("vala_level") == level] if len(level_matched) == 1: return level_matched[0] if level_matched: candidates = level_matched if len(candidates) == 1: return candidates[0] # 多候选需要 LLM 精筛 if llm_client is None: # 无 LLM 时取第一个(按 kp_id 排序) candidates.sort(key=lambda x: x.get("kp_id", "")) return candidates[0] user_prompt = _build_match_user_prompt(kp_text, candidates, level, context) try: response, usage = llm_client.call(KP_MATCH_SYSTEM_PROMPT, user_prompt, max_tokens=64, temperature=0.0) response = response.strip() if response == "null": return None # 从候选中查找返回的 kp_id for c in candidates: if c["kp_id"] == response: return c # LLM 可能返回带引号的 kp_id cleaned = response.strip('"').strip("'") for c in candidates: if c["kp_id"] == cleaned: return c # 兜底:取第一个 logger.warning(f"LLM 返回的 kp_id '{response}' 不在候选列表中,使用第一个候选") return candidates[0] except Exception as e: logger.warning(f"LLM 精筛失败: {e},使用第一个候选") return candidates[0] # ============ 单条知识点匹配 ============ def _match_single_kp(kp_text, level=None, context="", llm_client=None): """ 匹配单个知识点文本到数据库记录。 Returns: dict: {"kpId": str|None, "kpType": str, "kpTitle": str, "kpSkill": str, "kpSkillName": str, "candidates": list} """ # 默认结果(未匹配) is_sentence = "..." in kp_text or len(kp_text.split()) > 3 default_type = "sentence" if is_sentence else "vocab" skill, skill_name = KP_SKILL_DEFAULTS.get(default_type, ("vocab_meaning", "词义")) default_result = { "kpId": None, "kpType": default_type, "kpTitle": kp_text, "kpSkill": skill, "kpSkillName": skill_name, "candidates": [], } try: candidates = _rough_recall(kp_text, level) except Exception as e: logger.error(f"MySQL 粗召回失败 '{kp_text}': {e}", exc_info=True) return default_result if not candidates: logger.info(f"未找到匹配: '{kp_text}' (level={level})") return default_result # 保存候选列表供 HTML 下拉展示 candidates_for_ui = [ {"kp_id": c["kp_id"], "type": c.get("type", ""), "title": c.get("title", ""), "vala_level": c.get("vala_level", ""), "en_desc": c.get("en_desc") or ""} for c in candidates[:10] ] matched = _precise_match(kp_text, candidates, level, context, llm_client) if not matched: default_result["candidates"] = candidates_for_ui return default_result # 从匹配结果构建 kpInfo 条目 kp_type = matched.get("type") or default_type skill, skill_name = KP_SKILL_DEFAULTS.get(kp_type, ("vocab_meaning", "词义")) return { "kpId": matched["kp_id"], "kpType": kp_type, "kpTitle": kp_text, "kpSkill": skill, "kpSkillName": skill_name, "candidates": candidates_for_ui, } # ============ 主接口 ============ def match_knowledge_points(knowledge_text, cType, cId, level=None, llm_client=None): """ 知识点匹配主接口。解析知识点文本并逐条匹配到数据库。 Args: knowledge_text: 知识点原始文本(可能多行,来自剧本 sheet) cType: 组件类型标识 cId: 组件ID level: 剧本级别(如 "L1", "L2") llm_client: LLMClient 实例(用于多候选精筛) Returns: dict or None: kpInfo 对象,无知识点时返回 None """ if not knowledge_text or not knowledge_text.strip(): return None # 清洗文本(复用现有逻辑) clean_text = re.sub(r']*>', '', knowledge_text) clean_text = re.sub(r'', '', clean_text) clean_text = clean_text.strip() if not clean_text: return None lines = [l.strip() for l in clean_text.split("\n") if l.strip()] kp_list = [] context = f"cType={cType}, cId={cId}" for line in lines: # 去除数字后缀(如 "school 1" → "school") stripped = re.sub(r'\s+\d+$', '', line).strip() if not stripped: continue kp_entry = _match_single_kp(stripped, level, context, llm_client) kp_list.append(kp_entry) if not kp_list: return None logger.info(f"知识点匹配完成: cType={cType}, cId={cId}, " f"总数={len(kp_list)}, " f"已匹配={sum(1 for k in kp_list if k['kpId'] is not None)}") return { "pushType": "relationKp", "cType": cType, "cId": str(cId), "kpInfo": kp_list, } # ============ CLI 测试 ============ if __name__ == "__main__": import sys sys.path.insert(0, CURRENT_PATH) print("=== 知识点匹配模块测试 ===\n") # 测试 MySQL 连接 print("1. 测试 MySQL 连接...") try: conn = _get_connection() with conn.cursor() as cursor: cursor.execute("SELECT COUNT(*) as cnt FROM vala_kp WHERE deleted_at IS NULL") row = cursor.fetchone() print(f" 连接成功!有效记录数: {row['cnt']}") except Exception as e: print(f" 连接失败: {e}") sys.exit(1) # 测试单词匹配 print("\n2. 测试单词匹配 'clean' (L2)...") result = _match_single_kp("clean", level="L2") print(f" 结果: kpId={result['kpId']}, kpType={result['kpType']}") # 测试句型匹配 print("\n3. 测试句型匹配 \"Let's/Let me...\"...") result = _match_single_kp("Let's/Let me...", level="L2") print(f" 结果: kpId={result['kpId']}, kpType={result['kpType']}") # 测试完整接口 print("\n4. 测试完整接口(多行知识点)...") test_text = "clean\nschool\nLet's go to school" result = match_knowledge_points(test_text, "mid_dialog_repeat", "0112001", level="L2") if result: print(f" kpInfo 条数: {len(result['kpInfo'])}") for kp in result["kpInfo"]: print(f" - {kp['kpTitle']}: kpId={kp['kpId']}, kpType={kp['kpType']}") else: print(" 无结果") # 测试带 LLM 的精筛 print("\n5. 测试带 LLM 精筛(多候选消歧)...") from llm_client import get_client llm = get_client() result = _match_single_kp("clean", level=None, context="cType=mid_dialog_repeat, 教学场景:打扫教室", llm_client=llm) print(f" 结果: kpId={result['kpId']}, kpType={result['kpType']}") _close_connection() print("\n=== 测试完成 ===")