Update model_utils.py
Browse files- model_utils.py +23 -4
model_utils.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
from typing import List, Optional
|
| 3 |
import re
|
| 4 |
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
@@ -23,6 +24,9 @@ from loader import (
|
|
| 23 |
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
|
| 24 |
EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
|
| 28 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
@@ -44,15 +48,30 @@ MAX_CONTEXT_ENTRIES = 4
|
|
| 44 |
# -----------------------------
|
| 45 |
def _build_entry_embeddings() -> None:
|
| 46 |
"""
|
| 47 |
-
|
| 48 |
-
and store them in qa_store.TEXT_EMBEDDINGS.
|
| 49 |
-
|
| 50 |
-
Call this after loading / reloading curriculum.
|
| 51 |
"""
|
| 52 |
if not getattr(qa_store, "ENTRIES", None):
|
| 53 |
qa_store.TEXT_EMBEDDINGS = None
|
| 54 |
return
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
texts: List[str] = []
|
| 57 |
for e in qa_store.ENTRIES:
|
| 58 |
chapter = e.get("chapter_title", "") or e.get("chapter", "") or ""
|
|
|
|
| 2 |
from typing import List, Optional
|
| 3 |
import re
|
| 4 |
|
| 5 |
+
import os
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
| 24 |
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
|
| 25 |
EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
| 26 |
|
| 27 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 28 |
+
CACHE_FILE = os.path.join(BASE_DIR, "data", "cached_embeddings.pt")
|
| 29 |
+
|
| 30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
|
| 32 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
| 48 |
# -----------------------------
|
| 49 |
def _build_entry_embeddings() -> None:
|
| 50 |
"""
|
| 51 |
+
Load pre-computed embeddings if available, otherwise build them.
|
|
|
|
|
|
|
|
|
|
| 52 |
"""
|
| 53 |
if not getattr(qa_store, "ENTRIES", None):
|
| 54 |
qa_store.TEXT_EMBEDDINGS = None
|
| 55 |
return
|
| 56 |
|
| 57 |
+
# 1. Try Loading from Cache
|
| 58 |
+
if os.path.exists(CACHE_FILE):
|
| 59 |
+
try:
|
| 60 |
+
print(f"[INFO] Loading cached embeddings from {CACHE_FILE}...")
|
| 61 |
+
cache = torch.load(CACHE_FILE, map_location=device)
|
| 62 |
+
if "textbook" in cache and cache["textbook"] is not None:
|
| 63 |
+
# Validate size matches
|
| 64 |
+
if len(cache["textbook"]) == len(qa_store.ENTRIES):
|
| 65 |
+
qa_store.TEXT_EMBEDDINGS = cache["textbook"].to(device)
|
| 66 |
+
print("[INFO] Textbook embeddings loaded successfully.")
|
| 67 |
+
return
|
| 68 |
+
else:
|
| 69 |
+
print("[WARN] Cache size mismatch (Data changed?). Re-computing...")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"[WARN] Failed to load cache: {e}")
|
| 72 |
+
|
| 73 |
+
# 2. Fallback: Compute from scratch (The old slow way)
|
| 74 |
+
print("[INFO] Computing textbook embeddings from scratch...")
|
| 75 |
texts: List[str] = []
|
| 76 |
for e in qa_store.ENTRIES:
|
| 77 |
chapter = e.get("chapter_title", "") or e.get("chapter", "") or ""
|