Heng2004's picture
Update loader.py
31e421c verified
raw
history blame
9.31 kB
# loader.py
import os
import json
from typing import List, Dict, Any
import qa_store
# ---------------------------------------------------------
# CONFIGURATION
# ---------------------------------------------------------
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, "data")
# Keep Manual QA global so Teacher Panel can write to it easily
MANUAL_QA_PATH = os.path.join(DATA_DIR, "manual_qa.jsonl")
# Cache file (Generated locally)
CACHE_FILENAME = "cached_embeddings.pt"
CACHE_PATH = os.path.join(DATA_DIR, CACHE_FILENAME)
DATASET_REPO_ID = "Heng2004/lao-science-qa-store"
DATASET_FILENAME = "manual_qa.jsonl"
# ---------------------------------------------------------
# CLOUD SYNC (Unchanged)
# ---------------------------------------------------------
def sync_upload_cache() -> str:
"""Upload the cached_embeddings.pt to Hugging Face Dataset."""
if not DATASET_REPO_ID or "YOUR_USERNAME" in DATASET_REPO_ID:
return "⚠️ Upload Skipped (Repo ID not set)"
try:
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj=CACHE_PATH,
path_in_repo=CACHE_FILENAME,
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message="System: Updated embedding cache"
)
return "☁️ Cache Upload Success"
except Exception as e:
print(f"[ERROR] Upload cache failed: {e}")
return f"⚠️ Cache Upload Failed: {e}"
def sync_download_cache() -> None:
"""Download cached_embeddings.pt at startup."""
if not DATASET_REPO_ID: return
try:
from huggingface_hub import hf_hub_download
import shutil
downloaded_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
filename=CACHE_FILENAME,
repo_type="dataset",
token=os.environ.get("HF_TOKEN")
)
shutil.copy(downloaded_path, CACHE_PATH)
print("[INFO] Cache download success!")
except Exception as e:
print(f"[WARN] Could not download cache (First run?): {e}")
def sync_upload_manual_qa() -> str:
if not DATASET_REPO_ID or "YOUR_USERNAME" in DATASET_REPO_ID:
return "⚠️ Upload Skipped"
try:
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj=MANUAL_QA_PATH,
path_in_repo=DATASET_FILENAME,
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message="Teacher Panel: Updated Q&A data"
)
return "☁️ Cloud Upload Success"
except Exception as e:
return f"⚠️ Cloud Upload Failed: {e}"
def sync_download_manual_qa() -> None:
if not DATASET_REPO_ID: return
try:
from huggingface_hub import hf_hub_download
import shutil
downloaded_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
filename=DATASET_FILENAME,
repo_type="dataset",
token=os.environ.get("HF_TOKEN")
)
shutil.copy(downloaded_path, MANUAL_QA_PATH)
print("[INFO] Manual QA download success!")
except Exception as e:
print(f"[WARN] Could not download manual_qa.jsonl: {e}")
# ---------------------------------------------------------
# RECURSIVE LOADERS (The New Upgrade)
# ---------------------------------------------------------
def load_curriculum() -> None:
"""
Recursively find and load all textbook JSONL files in data/
Looks for files named 'textbook.jsonl' OR starting with 'M'.
"""
qa_store.ENTRIES.clear()
qa_store.AUTO_QA_KNOWLEDGE.clear()
print(f"[INFO] Scanning {DATA_DIR} for textbook content...")
file_count = 0
# os.walk goes deep into M_1/U_1/...
for root, dirs, files in os.walk(DATA_DIR):
for file in files:
# Logic: Match specific filenames
is_textbook = file == "textbook.jsonl" or (file.startswith("M") and file.endswith(".jsonl"))
if is_textbook:
full_path = os.path.join(root, file)
_parse_curriculum_file(full_path)
file_count += 1
if qa_store.ENTRIES:
qa_store.RAW_KNOWLEDGE = "\n\n".join(e["text"] for e in qa_store.ENTRIES)
print(f"[INFO] Loaded {len(qa_store.ENTRIES)} entries from {file_count} files.")
else:
qa_store.RAW_KNOWLEDGE = "ຍັງບໍ່ມີຂໍ້ມູນ."
print("[WARN] No curriculum files found.")
def _parse_curriculum_file(path: str):
"""Helper to read a single textbook file"""
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line: continue
try:
obj = json.loads(line)
if "text" not in obj: continue
qa_store.ENTRIES.append(obj)
# Extract Auto-QA
for pair in obj.get("qa", []):
q = (pair.get("q") or "").strip()
a = (pair.get("a") or "").strip()
if q and a:
norm_q = qa_store.normalize_question(q)
qa_store.AUTO_QA_KNOWLEDGE.append({
"norm_q": norm_q,
"q": q,
"a": a,
"source": "auto",
"id": obj.get("id", "")
})
except json.JSONDecodeError:
continue
def load_glossary() -> None:
"""
Recursively find and load all glossary JSONL files.
Looks for files named 'glossary.jsonl' OR starting with 'glossary'.
"""
qa_store.GLOSSARY.clear()
print(f"[INFO] Scanning {DATA_DIR} for glossary files...")
for root, dirs, files in os.walk(DATA_DIR):
for file in files:
is_glossary = "glossary" in file and file.endswith(".jsonl")
if is_glossary:
full_path = os.path.join(root, file)
with open(full_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line: continue
try:
obj = json.loads(line)
qa_store.GLOSSARY.append(obj)
except json.JSONDecodeError:
continue
print(f"[INFO] Loaded {len(qa_store.GLOSSARY)} glossary terms.")
# ---------------------------------------------------------
# MANUAL QA & UTILS (Same as before)
# ---------------------------------------------------------
def load_manual_qa() -> None:
qa_store.MANUAL_QA_LIST.clear()
qa_store.MANUAL_QA_INDEX.clear()
max_num = 0
if not os.path.exists(MANUAL_QA_PATH):
print(f"[WARN] Manual QA file not found: {MANUAL_QA_PATH}")
qa_store.NEXT_MANUAL_ID = 1
return
with open(MANUAL_QA_PATH, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line: continue
try:
obj = json.loads(line)
entry_id = str(obj.get("id") or "")
# ID tracking logic
import re
m = re.search(r"(\d+)$", entry_id)
if m: max_num = max(max_num, int(m.group(1)))
q = (obj.get("q") or "").strip()
a = (obj.get("a") or "").strip()
if q and a:
norm_q = qa_store.normalize_question(q)
entry = {"id": entry_id, "q": q, "a": a, "norm_q": norm_q}
qa_store.MANUAL_QA_LIST.append(entry)
qa_store.MANUAL_QA_INDEX[norm_q] = entry
except json.JSONDecodeError:
continue
qa_store.NEXT_MANUAL_ID = max_num + 1 if max_num > 0 else 1
def generate_new_manual_id() -> str:
import re
used_nums = set()
for e in qa_store.MANUAL_QA_LIST:
raw_id = str(e.get("id") or "")
m = re.search(r"(\d+)$", raw_id)
if m: used_nums.add(int(m.group(1)))
i = 1
while i in used_nums: i += 1
return f"manual_{i:04d}"
def save_manual_qa_file() -> None:
os.makedirs(os.path.dirname(MANUAL_QA_PATH), exist_ok=True)
with open(MANUAL_QA_PATH, "w", encoding="utf-8") as f:
for e in qa_store.MANUAL_QA_LIST:
obj = {"id": e["id"], "q": e["q"], "a": e["a"]}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
def rebuild_combined_qa() -> None:
qa_store.QA_INDEX.clear()
qa_store.ALL_QA_KNOWLEDGE.clear()
for item in qa_store.AUTO_QA_KNOWLEDGE:
norm_q = item["norm_q"]
qa_store.QA_INDEX[norm_q] = item["a"]
qa_store.ALL_QA_KNOWLEDGE.append(item)
for e in qa_store.MANUAL_QA_LIST:
item = {"norm_q": e["norm_q"], "q": e["q"], "a": e["a"], "source": "manual", "id": e["id"]}
qa_store.QA_INDEX[item["norm_q"]] = item["a"]
qa_store.ALL_QA_KNOWLEDGE.append(item)
def manual_qa_table_data() -> List[List[str]]:
return [[e["id"], e["q"], e["a"]] for e in qa_store.MANUAL_QA_LIST]