# model_utils.py from typing import List, Optional import re import os import numpy as np import torch from transformers import AutoTokenizer, AutoModelForCausalLM from sentence_transformers import SentenceTransformer from sentence_transformers.util import cos_sim import qa_store from loader import ( load_curriculum, load_manual_qa, rebuild_combined_qa, load_glossary, sync_download_manual_qa, # <--- Import it sync_download_cache, # <--- Add this import sync_upload_cache, # <--- Add this import CACHE_PATH # <--- Add this import ) # ----------------------------- # Base chat model # ----------------------------- MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat" EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) CACHE_FILE = os.path.join(BASE_DIR, "data", "cached_embeddings.pt") device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Use float16 on GPU to save memory, float32 on CPU dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype) model.to(device) model.eval() embed_model = SentenceTransformer(EMBED_MODEL_NAME) embed_model = embed_model.to(device) # Number of textbook entries to include in the RAG context MAX_CONTEXT_ENTRIES = 4 # ----------------------------- # Embedding builders # ----------------------------- # ๐Ÿ‘‡๐Ÿ‘‡๐Ÿ‘‡ ADD THIS NEW FUNCTION ๐Ÿ‘‡๐Ÿ‘‡๐Ÿ‘‡ def admin_force_rebuild_cache() -> str: """ Forcedly re-calculate all embeddings and upload to cloud. Triggered by Teacher Panel button. """ status_msg = [] # 1. Compute Textbook print("[ADMIN] Rebuilding Textbook Embeddings...") texts = [] for e in qa_store.ENTRIES: chapter = e.get("chapter_title", "") or "" section = e.get("section_title", "") or "" text = e.get("text", "") or "" texts.append(f"{chapter}\n{section}\n{text}") if texts: qa_store.TEXT_EMBEDDINGS = embed_model.encode(texts, convert_to_tensor=True) status_msg.append(f"โœ… Textbook ({len(texts)})") # 2. Compute Glossary print("[ADMIN] Rebuilding Glossary Embeddings...") gloss_texts = [f"{i.get('term')} :: {i.get('definition')}" for i in qa_store.GLOSSARY] if gloss_texts: qa_store.GLOSSARY_EMBEDDINGS = embed_model.encode( gloss_texts, convert_to_numpy=True, normalize_embeddings=True ) status_msg.append(f"โœ… Glossary ({len(gloss_texts)})") # 3. Save to Disk print("[ADMIN] Saving to disk...") torch.save({ "textbook": qa_store.TEXT_EMBEDDINGS, "glossary": qa_store.GLOSSARY_EMBEDDINGS }, CACHE_PATH) # 4. Upload to Cloud upload_status = sync_upload_cache() return f"Rebuild Complete: {', '.join(status_msg)} | {upload_status}" def _build_entry_embeddings() -> None: """ Load pre-computed embeddings if available, otherwise build them. """ if not getattr(qa_store, "ENTRIES", None): qa_store.TEXT_EMBEDDINGS = None return # 1. Try Loading from Cache if os.path.exists(CACHE_FILE): try: print(f"[INFO] Loading cached embeddings from {CACHE_FILE}...") cache = torch.load(CACHE_FILE, map_location=device) if "textbook" in cache and cache["textbook"] is not None: # Validate size matches if len(cache["textbook"]) == len(qa_store.ENTRIES): qa_store.TEXT_EMBEDDINGS = cache["textbook"].to(device) print("[INFO] Textbook embeddings loaded successfully.") return else: print("[WARN] Cache size mismatch (Data changed?). Re-computing...") except Exception as e: print(f"[WARN] Failed to load cache: {e}") # 2. Fallback: Compute from scratch (The old slow way) print("[INFO] Computing textbook embeddings from scratch...") texts: List[str] = [] for e in qa_store.ENTRIES: chapter = e.get("chapter_title", "") or e.get("chapter", "") or "" section = e.get("section_title", "") or e.get("section", "") or "" text = e.get("text", "") or "" combined = f"{chapter}\n{section}\n{text}" texts.append(combined) qa_store.TEXT_EMBEDDINGS = embed_model.encode( texts, convert_to_tensor=True, show_progress_bar=False, ) def _build_glossary_embeddings() -> None: """Create embeddings for glossary terms + definitions.""" if not getattr(qa_store, "GLOSSARY", None): qa_store.GLOSSARY_EMBEDDINGS = None print("[INFO] No glossary terms to embed.") return # Embed term + definition together texts = [ f"{item.get('term', '')} :: {item.get('definition', '')}" for item in qa_store.GLOSSARY ] embeddings = embed_model.encode( texts, convert_to_numpy=True, normalize_embeddings=True, ) qa_store.GLOSSARY_EMBEDDINGS = embeddings print(f"[INFO] Built glossary embeddings for {len(texts)} terms.") # ----------------------------- # Load data once at import time # ----------------------------- sync_download_manual_qa() sync_download_cache() # <--- Add this line! load_curriculum() load_manual_qa() load_glossary() rebuild_combined_qa() _build_entry_embeddings() _build_glossary_embeddings() # ----------------------------- # System prompt (Natural Science) # ----------------------------- SYSTEM_PROMPT = ( "เบ—เปˆเบฒเบ™เปเบกเปˆเบ™เบœเบนเป‰เบŠเปˆเบงเบเป€เบซเบผเบทเบญเบ”เป‰เบฒเบ™เบงเบดเบ—เบฐเบเบฒเบชเบฒเบ”เบ—เปเบฒเบกเบฐเบŠเบฒเบ” " "เบชเปเบฒเบฅเบฑเบšเบ™เบฑเบเบฎเบฝเบ™เบŠเบฑเป‰เบ™ เบก.1-เบก.4. " "เบ•เบญเบšเปเบ•เปˆเบžเบฒเบชเบฒเบฅเบฒเบง เปƒเบซเป‰เบ•เบญเบšเบชเบฑเป‰เบ™เป† 2โ€“3 เบ›เบฐเป‚เบซเบเบ เปเบฅเบฐเป€เบ‚เบปเป‰เบฒเปƒเบˆเบ‡เปˆเบฒเบ. " "เปƒเบซเป‰เบญเบตเบ‡เบˆเบฒเบเบ‚เปเป‰เบกเบนเบ™เบญเป‰เบฒเบ‡เบญเบตเบ‡เบ‚เป‰เบฒเบ‡เบฅเบธเปˆเบกเบ™เบตเป‰เป€เบ—เบปเปˆเบฒเบ™เบฑเป‰เบ™. " "เบ–เป‰เบฒเบ‚เปเป‰เบกเบนเบ™เบšเปเปˆเบžเบฝเบ‡เบžเป เบซเบผเบทเบšเปเปˆเบŠเบฑเบ”เป€เบˆเบ™ เปƒเบซเป‰เบšเบญเบเบงเปˆเบฒเบšเปเปˆเปเบ™เปˆเปƒเบˆ." ) # ----------------------------- # Helper: history formatting # ----------------------------- def _format_history(history: Optional[List]) -> str: """ Convert last few chat turns into a Lao conversation snippet to give the model context for follow-up questions. Gradio history format: [[user_msg, bot_msg], [user_msg, bot_msg], ...] """ if not history: return "" # keep only the last 3 turns to avoid very long prompts recent = history[-3:] lines: List[str] = [] for turn in recent: if not isinstance(turn, (list, tuple)) or len(turn) != 2: continue user_msg, bot_msg = turn lines.append(f"เบ™เบฑเบเบฎเบฝเบ™: {user_msg}") lines.append(f"เบญเบฒเบˆเบฒเบ™ AI: {bot_msg}") if not lines: return "" joined = "\n".join(lines) + "\n\n" return joined # ----------------------------- # RAG: retrieve textbook context # ----------------------------- def retrieve_context(question: str, max_entries: int = MAX_CONTEXT_ENTRIES) -> str: """ Embedding-based retrieval over textbook entries. Falls back to concatenated raw knowledge if embeddings are missing. """ if not getattr(qa_store, "ENTRIES", None): # Fallback: raw knowledge (if available) or empty string return getattr(qa_store, "RAW_KNOWLEDGE", "") if qa_store.TEXT_EMBEDDINGS is None: top_entries = qa_store.ENTRIES[:max_entries] else: # 1) Encode the question q_vec = embed_model.encode( question, convert_to_tensor=True, show_progress_bar=False, ) # 2) Cosine similarity with all entry embeddings sims = cos_sim(q_vec, qa_store.TEXT_EMBEDDINGS)[0] # shape [N] # 3) Take top-k top_indices = torch.topk(sims, k=min(max_entries, sims.shape[0])).indices top_entries = [qa_store.ENTRIES[i] for i in top_indices.tolist()] # Build context string for the prompt context_blocks: List[str] = [] for e in top_entries: header = ( f"[เบŠเบฑเป‰เบ™ {e.get('grade','')}, " f"เปœเปˆเบงเบ {e.get('unit','')}, " f"เบšเบปเบ” {e.get('chapter_title','')}, " f"เบซเบปเบงเบ‚เปเป‰ {e.get('section_title','')}]" ) context_blocks.append(f"{header}\n{e.get('text','')}") return "\n\n".join(context_blocks) # ----------------------------- # Glossary-based answering # ----------------------------- def normalize_lao_text(text: str) -> str: """ Clean Lao text for accurate matching. Removes punctuation and extra spaces. """ if not text: return "" # 1. Lowercase text = text.lower().strip() # 2. Remove punctuation (Using the safe single-quote format) text = re.sub(r'[?.!,;ึ‰:\'\""โ€œโ€โ€˜โ€™]', "", text) # 3. Collapse multiple spaces into one (THIS WAS MISSING) text = re.sub(r"\s+", " ", text) return text.strip() def answer_from_glossary(message: str) -> Optional[str]: """ Try to answer using the glossary index. Tier 1: Exact/Substring match (Sorted by Length to fix overlap bugs). Tier 2: Vector embedding match (Fallback). """ if not getattr(qa_store, "GLOSSARY", None): return None norm_msg = normalize_lao_text(message) # --- FIX START: Sort by Length + Exact Match --- # 1. Sort glossary terms by length (Longest first) # This ensures we match "เบ™เบฑเบเบงเบดเบ—เบฐเบเบฒเบชเบฒเบ”" (14 chars) BEFORE "เบงเบดเบ—เบฐเบเบฒเบชเบฒเบ”" (11 chars) sorted_glossary = sorted( qa_store.GLOSSARY, key=lambda x: len(normalize_lao_text(x.get("term", ""))), reverse=True ) for item in sorted_glossary: term_raw = item.get("term", "") norm_term = normalize_lao_text(term_raw) if not norm_term: continue # Condition A: EXACT Match (Perfect precision) # Example: User types "เบ™เบฑเบเบงเบดเบ—เบฐเบเบฒเบชเบฒเบ”" is_exact = (norm_msg == norm_term) # Condition B: Substring Match (High precision for questions) # Example: User types "เบ™เบฑเบเบงเบดเบ—เบฐเบเบฒเบชเบฒเบ” เปเบกเปˆเบ™เบซเบเบฑเบ‡" # We enforce a length check so "Science" doesn't match a huge paragraph about Pollution. is_substring = (norm_term in norm_msg) and (len(norm_msg) < len(norm_term) + 20) if is_exact or is_substring: definition = item.get("definition", "").strip() example = item.get("example", "").strip() # Return the result immediately once the longest match is found if example: return f"{definition} เบ•เบปเบงเบขเปˆเบฒเบ‡: {example}" return definition # --- FIX END --- # If no text match, proceed to Vector Similarity (Tier 2) if qa_store.GLOSSARY_EMBEDDINGS is None: return None q_emb = embed_model.encode( [message], convert_to_numpy=True, normalize_embeddings=True, )[0] sims = np.dot(qa_store.GLOSSARY_EMBEDDINGS, q_emb) best_idx = int(np.argmax(sims)) best_sim = float(sims[best_idx]) # Threshold 0.65 to prevent weak matches if best_sim < 0.65: return None item = qa_store.GLOSSARY[best_idx] definition = item.get("definition", "").strip() example = item.get("example", "").strip() if example: return f"{definition} เบ•เบปเบงเบขเปˆเบฒเบ‡: {example}" return definition # ----------------------------- # Prompt + LLM generation # ----------------------------- def build_prompt(question: str, history: Optional[List] = None) -> str: context = retrieve_context(question, max_entries=MAX_CONTEXT_ENTRIES) history_block = _format_history(history) return f"""{SYSTEM_PROMPT} {history_block}เบ‚เปเป‰เบกเบนเบ™เบญเป‰เบฒเบ‡เบญเบตเบ‡: {context} เบ„เบณเบ–เบฒเบก: {question} เบ„เบณเบ•เบญเบšเบ”เป‰เบงเบเบžเบฒเบชเบฒเบฅเบฒเบง:""" def generate_answer(question: str, history: Optional[List] = None) -> str: prompt = build_prompt(question, history) inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=160, do_sample=False, ) generated_ids = outputs[0][inputs["input_ids"].shape[1]:] answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() # Enforce 2โ€“3 sentence answers for students sentences = re.split(r"(?<=[\.?!โ€ฆ])\s+", answer) short_answer = " ".join(sentences[:3]).strip() return short_answer if short_answer else answer # ----------------------------- # QA lookup (exact + fuzzy) # ----------------------------- def answer_from_qa(question: str) -> Optional[str]: """ 1) Exact match in QA_INDEX 2) Fuzzy match via word overlap with ALL_QA_KNOWLEDGE """ norm_q = qa_store.normalize_question(question) if not norm_q: return None # Exact match if norm_q in qa_store.QA_INDEX: return qa_store.QA_INDEX[norm_q] # Fuzzy match q_terms = [t for t in norm_q.split(" ") if len(t) > 1] if not q_terms: return None best_score = 0 best_answer: Optional[str] = None for item in qa_store.ALL_QA_KNOWLEDGE: stored_terms = [t for t in item["norm_q"].split(" ") if len(t) > 1] overlap = sum(1 for t in q_terms if t in stored_terms) if overlap > best_score: best_score = overlap best_answer = item["a"] # require at least 2 overlapping words to accept fuzzy match if best_score >= 2 and best_answer is not None: # optional: log when fuzzy match is used print(f"[FUZZY MATCH] score={best_score} -> {best_answer[:50]!r}") return best_answer return None # ----------------------------- # Main chatbot entry # ----------------------------- def laos_science_bot(message: str, history: List) -> str: """ Main chatbot function for Student tab (Gradio ChatInterface). """ if not message.strip(): return "เบเบฐเบฅเบธเบ™เบฒเบžเบดเบกเบ„เปเบฒเบ–เบฒเบกเบเปˆเบญเบ™." # 0) Try glossary first for key concepts gloss = answer_from_glossary(message) if gloss: return gloss # 1) Try exact / fuzzy Q&A first direct = answer_from_qa(message) if direct: return direct # 2) Fall back to LLM + retrieved context try: answer = generate_answer(message, history) except Exception as e: # noqa: BLE001 return f"เบฅเบฐเบšเบปเบšเบกเบตเบšเบฑเบ™เบซเบฒ: {e}" return answer