Heng2004 commited on
Commit
11d64bd
·
verified ·
1 Parent(s): 2d6429a

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +23 -4
model_utils.py CHANGED
@@ -2,6 +2,7 @@
2
  from typing import List, Optional
3
  import re
4
 
 
5
  import numpy as np
6
  import torch
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -23,6 +24,9 @@ from loader import (
23
  MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
24
  EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
25
 
 
 
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
 
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
@@ -44,15 +48,30 @@ MAX_CONTEXT_ENTRIES = 4
44
  # -----------------------------
45
  def _build_entry_embeddings() -> None:
46
  """
47
- Build embeddings for each textbook entry using chapter + section + text
48
- and store them in qa_store.TEXT_EMBEDDINGS.
49
-
50
- Call this after loading / reloading curriculum.
51
  """
52
  if not getattr(qa_store, "ENTRIES", None):
53
  qa_store.TEXT_EMBEDDINGS = None
54
  return
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  texts: List[str] = []
57
  for e in qa_store.ENTRIES:
58
  chapter = e.get("chapter_title", "") or e.get("chapter", "") or ""
 
2
  from typing import List, Optional
3
  import re
4
 
5
+ import os
6
  import numpy as np
7
  import torch
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
24
  MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
25
  EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
26
 
27
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
28
+ CACHE_FILE = os.path.join(BASE_DIR, "data", "cached_embeddings.pt")
29
+
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
48
  # -----------------------------
49
  def _build_entry_embeddings() -> None:
50
  """
51
+ Load pre-computed embeddings if available, otherwise build them.
 
 
 
52
  """
53
  if not getattr(qa_store, "ENTRIES", None):
54
  qa_store.TEXT_EMBEDDINGS = None
55
  return
56
 
57
+ # 1. Try Loading from Cache
58
+ if os.path.exists(CACHE_FILE):
59
+ try:
60
+ print(f"[INFO] Loading cached embeddings from {CACHE_FILE}...")
61
+ cache = torch.load(CACHE_FILE, map_location=device)
62
+ if "textbook" in cache and cache["textbook"] is not None:
63
+ # Validate size matches
64
+ if len(cache["textbook"]) == len(qa_store.ENTRIES):
65
+ qa_store.TEXT_EMBEDDINGS = cache["textbook"].to(device)
66
+ print("[INFO] Textbook embeddings loaded successfully.")
67
+ return
68
+ else:
69
+ print("[WARN] Cache size mismatch (Data changed?). Re-computing...")
70
+ except Exception as e:
71
+ print(f"[WARN] Failed to load cache: {e}")
72
+
73
+ # 2. Fallback: Compute from scratch (The old slow way)
74
+ print("[INFO] Computing textbook embeddings from scratch...")
75
  texts: List[str] = []
76
  for e in qa_store.ENTRIES:
77
  chapter = e.get("chapter_title", "") or e.get("chapter", "") or ""