Heng2004's picture
Create build_cache.py
2d6429a verified
# build_cache.py
import os
import torch
from sentence_transformers import SentenceTransformer
import qa_store
from loader import load_curriculum, load_glossary
# 1. Configuration
EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, "data")
CACHE_FILE = os.path.join(DATA_DIR, "cached_embeddings.pt")
def build_and_save():
print("⏳ Loading data...")
load_curriculum()
load_glossary()
print(f"⏳ Loading model: {EMBED_MODEL_NAME}...")
# Use CPU for build script to ensure compatibility, or cuda if you have it
device = "cuda" if torch.cuda.is_available() else "cpu"
embed_model = SentenceTransformer(EMBED_MODEL_NAME, device=device)
# --- 2. Build Textbook Embeddings ---
print(f"🧮 Computing embeddings for {len(qa_store.ENTRIES)} textbook entries...")
textbook_texts = []
for e in qa_store.ENTRIES:
chapter = e.get("chapter_title", "") or e.get("chapter", "") or ""
section = e.get("section_title", "") or e.get("section", "") or ""
text = e.get("text", "") or ""
combined = f"{chapter}\n{section}\n{text}"
textbook_texts.append(combined)
if textbook_texts:
textbook_embeddings = embed_model.encode(
textbook_texts,
convert_to_tensor=True,
show_progress_bar=True
)
else:
textbook_embeddings = None
# --- 3. Build Glossary Embeddings ---
print(f"🧮 Computing embeddings for {len(qa_store.GLOSSARY)} glossary terms...")
glossary_texts = [
f"{item.get('term', '')} :: {item.get('definition', '')}"
for item in qa_store.GLOSSARY
]
if glossary_texts:
glossary_embeddings = embed_model.encode(
glossary_texts,
convert_to_numpy=True,
normalize_embeddings=True,
show_progress_bar=True
)
else:
glossary_embeddings = None
# --- 4. Save to Disk ---
print(f"💾 Saving to {CACHE_FILE}...")
torch.save({
"textbook": textbook_embeddings,
"glossary": glossary_embeddings
}, CACHE_FILE)
print("✅ Done! You can now upload 'data/cached_embeddings.pt' to Hugging Face.")
if __name__ == "__main__":
build_and_save()