EphAsad commited on
Commit
f2213be
·
verified ·
1 Parent(s): 9fc007c

Upload 21 files

Browse files
models/genus_xgb.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6530346e18bdb61e778f887fef7ca33e8e0b56e040ce95287c373073db726cfb
3
+ size 34613851
models/genus_xgb_meta.json ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "genus_to_idx": {
3
+ "Staphylococcus": 0,
4
+ "Salmonella": 1,
5
+ "Listeria": 2,
6
+ "Enterobacter": 3,
7
+ "Pseudomonas": 4,
8
+ "Streptococcus": 5,
9
+ "Enterococcus": 6,
10
+ "Bacillus": 7,
11
+ "Shigella": 8,
12
+ "Escherichia": 9,
13
+ "Klebsiella": 10,
14
+ "Proteus": 11,
15
+ "Vibrio": 12,
16
+ "Neisseria": 13,
17
+ "Campylobacter": 14,
18
+ "Clostridium": 15,
19
+ "Corynebacterium": 16,
20
+ "Legionella": 17,
21
+ "Mycobacterium": 18,
22
+ "Bacteroides": 19,
23
+ "Micrococcus": 20,
24
+ "Erysipelothrix": 21,
25
+ "Haemophilus": 22,
26
+ "Aeromonas": 23,
27
+ "Yersinia": 24,
28
+ "Acinetobacter": 25,
29
+ "Serratia": 26,
30
+ "Morganella": 27,
31
+ "Providencia": 28,
32
+ "Burkholderia": 29,
33
+ "Helicobacter": 30,
34
+ "Actinomyces": 31,
35
+ "Nocardia": 32,
36
+ "Pasteurella": 33,
37
+ "Citrobacter": 34,
38
+ "Leptospira": 35,
39
+ "Alcaligenes": 36,
40
+ "Shewanella": 37,
41
+ "Edwardsiella": 38,
42
+ "Chromobacterium": 39,
43
+ "Lactobacillus": 40,
44
+ "Propionibacterium": 41,
45
+ "Peptostreptococcus": 42,
46
+ "Veillonella": 43,
47
+ "Fusobacterium": 44,
48
+ "Eubacterium": 45,
49
+ "Halomonas": 46,
50
+ "Psychrobacter": 47,
51
+ "Rhodococcus": 48,
52
+ "Mycoplasma": 49,
53
+ "Bordetella": 50,
54
+ "Stenotrophomonas": 51,
55
+ "Ralstonia": 52,
56
+ "Achromobacter": 53,
57
+ "Brucella": 54,
58
+ "Arthrobacter": 55,
59
+ "Flavobacterium": 56,
60
+ "Oerskovia": 57,
61
+ "Sphingomonas": 58,
62
+ "Comamonas": 59,
63
+ "Thermococcus": 60,
64
+ "Elizabethkingia": 61,
65
+ "Hafnia": 62,
66
+ "Raoultella": 63,
67
+ "Ochrobactrum": 64,
68
+ "Roseomonas": 65,
69
+ "Actinobacillus": 66,
70
+ "Gemella": 67,
71
+ "Rothia": 68,
72
+ "Carnobacterium": 69,
73
+ "Plesiomonas": 70,
74
+ "Janthinobacterium": 71,
75
+ "Paenibacillus": 72,
76
+ "Moraxella": 73,
77
+ "Aerococcus": 74,
78
+ "Kocuria": 75,
79
+ "Leuconostoc": 76,
80
+ "Arcanobacterium": 77,
81
+ "Gardnerella": 78,
82
+ "Porphyromonas": 79,
83
+ "Prevotella": 80,
84
+ "Pediococcus": 81,
85
+ "Weissella": 82,
86
+ "Lactococcus": 83,
87
+ "Microbacterium": 84,
88
+ "Clostridioides": 85,
89
+ "Cronobacter": 86,
90
+ "Rhizobium": 87,
91
+ "Azotobacter": 88,
92
+ "Spirillum": 89,
93
+ "Candida": 90,
94
+ "Cryptococcus": 91,
95
+ "Saccharomyces": 92,
96
+ "Rickettsia": 93,
97
+ "Borrelia": 94,
98
+ "Chlamydia": 95,
99
+ "Acidaminococcus": 96,
100
+ "Bartonella": 97,
101
+ "Coxiella": 98,
102
+ "Kingella": 99,
103
+ "Eikenella": 100,
104
+ "Bilophila": 101,
105
+ "Anaerococcus": 102,
106
+ "Finegoldia": 103,
107
+ "Parvimonas": 104,
108
+ "Ruminococcus": 105,
109
+ "Cutibacterium": 106,
110
+ "Exiguobacterium": 107,
111
+ "Kluyvera": 108,
112
+ "Pluralibacter": 109,
113
+ "Massilia": 110,
114
+ "Methylobacterium": 111,
115
+ "Cupriavidus": 112,
116
+ "Acidovorax": 113,
117
+ "Geobacillus": 114,
118
+ "Trueperella": 115,
119
+ "Streptomyces": 116,
120
+ "Thermoactinomyces": 117,
121
+ "Capnocytophaga": 118,
122
+ "Cardiobacterium": 119,
123
+ "Yokenella": 120,
124
+ "Brevibacterium": 121,
125
+ "Peptoniphilus": 122,
126
+ "Weisella": 123,
127
+ "Saccharopolyspora": 124,
128
+ "Frankia": 125,
129
+ "Spiroplasma": 126,
130
+ "Cedecea": 127,
131
+ "Photorhabdus": 128,
132
+ "Abiotrophia": 129,
133
+ "Cellulomonas": 130,
134
+ "Leifsonia": 131,
135
+ "Alicyclobacillus": 132,
136
+ "Sporolactobacillus": 133,
137
+ "Leclercia": 134,
138
+ "Kosakonia": 135,
139
+ "Bergeyella": 136,
140
+ "Myroides": 137,
141
+ "Aggregatibacter": 138,
142
+ ":": 139
143
+ },
144
+ "idx_to_genus": {
145
+ "0": "Staphylococcus",
146
+ "1": "Salmonella",
147
+ "2": "Listeria",
148
+ "3": "Enterobacter",
149
+ "4": "Pseudomonas",
150
+ "5": "Streptococcus",
151
+ "6": "Enterococcus",
152
+ "7": "Bacillus",
153
+ "8": "Shigella",
154
+ "9": "Escherichia",
155
+ "10": "Klebsiella",
156
+ "11": "Proteus",
157
+ "12": "Vibrio",
158
+ "13": "Neisseria",
159
+ "14": "Campylobacter",
160
+ "15": "Clostridium",
161
+ "16": "Corynebacterium",
162
+ "17": "Legionella",
163
+ "18": "Mycobacterium",
164
+ "19": "Bacteroides",
165
+ "20": "Micrococcus",
166
+ "21": "Erysipelothrix",
167
+ "22": "Haemophilus",
168
+ "23": "Aeromonas",
169
+ "24": "Yersinia",
170
+ "25": "Acinetobacter",
171
+ "26": "Serratia",
172
+ "27": "Morganella",
173
+ "28": "Providencia",
174
+ "29": "Burkholderia",
175
+ "30": "Helicobacter",
176
+ "31": "Actinomyces",
177
+ "32": "Nocardia",
178
+ "33": "Pasteurella",
179
+ "34": "Citrobacter",
180
+ "35": "Leptospira",
181
+ "36": "Alcaligenes",
182
+ "37": "Shewanella",
183
+ "38": "Edwardsiella",
184
+ "39": "Chromobacterium",
185
+ "40": "Lactobacillus",
186
+ "41": "Propionibacterium",
187
+ "42": "Peptostreptococcus",
188
+ "43": "Veillonella",
189
+ "44": "Fusobacterium",
190
+ "45": "Eubacterium",
191
+ "46": "Halomonas",
192
+ "47": "Psychrobacter",
193
+ "48": "Rhodococcus",
194
+ "49": "Mycoplasma",
195
+ "50": "Bordetella",
196
+ "51": "Stenotrophomonas",
197
+ "52": "Ralstonia",
198
+ "53": "Achromobacter",
199
+ "54": "Brucella",
200
+ "55": "Arthrobacter",
201
+ "56": "Flavobacterium",
202
+ "57": "Oerskovia",
203
+ "58": "Sphingomonas",
204
+ "59": "Comamonas",
205
+ "60": "Thermococcus",
206
+ "61": "Elizabethkingia",
207
+ "62": "Hafnia",
208
+ "63": "Raoultella",
209
+ "64": "Ochrobactrum",
210
+ "65": "Roseomonas",
211
+ "66": "Actinobacillus",
212
+ "67": "Gemella",
213
+ "68": "Rothia",
214
+ "69": "Carnobacterium",
215
+ "70": "Plesiomonas",
216
+ "71": "Janthinobacterium",
217
+ "72": "Paenibacillus",
218
+ "73": "Moraxella",
219
+ "74": "Aerococcus",
220
+ "75": "Kocuria",
221
+ "76": "Leuconostoc",
222
+ "77": "Arcanobacterium",
223
+ "78": "Gardnerella",
224
+ "79": "Porphyromonas",
225
+ "80": "Prevotella",
226
+ "81": "Pediococcus",
227
+ "82": "Weissella",
228
+ "83": "Lactococcus",
229
+ "84": "Microbacterium",
230
+ "85": "Clostridioides",
231
+ "86": "Cronobacter",
232
+ "87": "Rhizobium",
233
+ "88": "Azotobacter",
234
+ "89": "Spirillum",
235
+ "90": "Candida",
236
+ "91": "Cryptococcus",
237
+ "92": "Saccharomyces",
238
+ "93": "Rickettsia",
239
+ "94": "Borrelia",
240
+ "95": "Chlamydia",
241
+ "96": "Acidaminococcus",
242
+ "97": "Bartonella",
243
+ "98": "Coxiella",
244
+ "99": "Kingella",
245
+ "100": "Eikenella",
246
+ "101": "Bilophila",
247
+ "102": "Anaerococcus",
248
+ "103": "Finegoldia",
249
+ "104": "Parvimonas",
250
+ "105": "Ruminococcus",
251
+ "106": "Cutibacterium",
252
+ "107": "Exiguobacterium",
253
+ "108": "Kluyvera",
254
+ "109": "Pluralibacter",
255
+ "110": "Massilia",
256
+ "111": "Methylobacterium",
257
+ "112": "Cupriavidus",
258
+ "113": "Acidovorax",
259
+ "114": "Geobacillus",
260
+ "115": "Trueperella",
261
+ "116": "Streptomyces",
262
+ "117": "Thermoactinomyces",
263
+ "118": "Capnocytophaga",
264
+ "119": "Cardiobacterium",
265
+ "120": "Yokenella",
266
+ "121": "Brevibacterium",
267
+ "122": "Peptoniphilus",
268
+ "123": "Weisella",
269
+ "124": "Saccharopolyspora",
270
+ "125": "Frankia",
271
+ "126": "Spiroplasma",
272
+ "127": "Cedecea",
273
+ "128": "Photorhabdus",
274
+ "129": "Abiotrophia",
275
+ "130": "Cellulomonas",
276
+ "131": "Leifsonia",
277
+ "132": "Alicyclobacillus",
278
+ "133": "Sporolactobacillus",
279
+ "134": "Leclercia",
280
+ "135": "Kosakonia",
281
+ "136": "Bergeyella",
282
+ "137": "Myroides",
283
+ "138": "Aggregatibacter",
284
+ "139": ":"
285
+ },
286
+ "n_features": 73,
287
+ "num_classes": 140,
288
+ "metrics": {
289
+ "train_accuracy": 0.9869916267942583,
290
+ "valid_accuracy": 0.9509569377990431,
291
+ "best_iteration": 270
292
+ },
293
+ "feature_schema_path": "data/feature_schema.json",
294
+ "feature_names": [
295
+ "Gram Stain",
296
+ "Shape",
297
+ "Haemolysis",
298
+ "Haemolysis Type",
299
+ "Catalase",
300
+ "Oxidase",
301
+ "Indole",
302
+ "Urease",
303
+ "Citrate",
304
+ "H2S",
305
+ "DNase",
306
+ "Lysine Decarboxylase",
307
+ "Ornithine Decarboxylase",
308
+ "Arginine dihydrolase",
309
+ "ONPG",
310
+ "Nitrate Reduction",
311
+ "Methyl Red",
312
+ "VP",
313
+ "Coagulase",
314
+ "Lipase Test",
315
+ "Motility",
316
+ "Motility Type",
317
+ "Capsule",
318
+ "Spore Formation",
319
+ "Pigment",
320
+ "Odor",
321
+ "Colony Pattern",
322
+ "TSI Pattern",
323
+ "Temperature_4C",
324
+ "Temperature_25C",
325
+ "Temperature_30C",
326
+ "Temperature_37C",
327
+ "Temperature_42C",
328
+ "Lactose Fermentation",
329
+ "Glucose Fermentation",
330
+ "Sucrose Fermentation",
331
+ "Mannitol Fermentation",
332
+ "Maltose Fermentation",
333
+ "Sorbitol Fermentation",
334
+ "Xylose Fermentation",
335
+ "Rhamnose Fermentation",
336
+ "Arabinose Fermentation",
337
+ "Raffinose Fermentation",
338
+ "Trehalose Fermentation",
339
+ "Inositol Fermentation",
340
+ "Oxygen Requirement",
341
+ "Gas Production",
342
+ "MacConkey Growth",
343
+ "Blood Growth",
344
+ "XLD Growth",
345
+ "Nutrient Growth",
346
+ "Cetrimide Growth",
347
+ "BCYE Growth",
348
+ "Hektoen Enteric Growth",
349
+ "Mannitol Salt Growth",
350
+ "Bordet-Gengou Growth",
351
+ "Thayer Martin Growth",
352
+ "Cycloserine Cefoxitin Fructose Growth",
353
+ "Sabouraud Growth",
354
+ "Lowenstein-Jensen Growth",
355
+ "Yeast Extract Mannitol Growth",
356
+ "BSK Growth",
357
+ "Brucella Growth",
358
+ "Charcoal Growth",
359
+ "BHI Growth",
360
+ "Ashby Growth",
361
+ "MRS Growth",
362
+ "Anaerobic Blood Growth",
363
+ "BP Growth",
364
+ "ALOA Growth",
365
+ "Anaerobic Growth",
366
+ "Chocolate Growth",
367
+ "TCBS Growth"
368
+ ]
369
+ }
rag/context_shaper.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/context_shaper.py
2
+ # ============================================================
3
+ # Context shaper for RAG
4
+ #
5
+ # Goal:
6
+ # - Convert "flattened schema dumps" (Field: Value lines) into
7
+ # readable evidence blocks the LLM can reason over.
8
+ # - Deterministic, no LLM usage.
9
+ #
10
+ # Works with:
11
+ # - llm_context from rag_retriever (biology-only text)
12
+ # ============================================================
13
+
14
+ from __future__ import annotations
15
+
16
+ import re
17
+ from typing import Dict, List, Tuple, Optional
18
+
19
+
20
+ _FIELD_LINE_RE = re.compile(r"^\s*([^:\n]{1,80})\s*:\s*(.+?)\s*$")
21
+
22
+ # Some fields are usually lists separated by ; or , or |
23
+ _LIST_LIKE_FIELDS = {
24
+ "Media Grown On",
25
+ "Colony Morphology",
26
+ "Colony Pattern",
27
+ "Growth Temperature",
28
+ }
29
+
30
+ # A light grouping map to turn fields into readable sections.
31
+ # (You can expand this over time.)
32
+ _GROUPS: List[Tuple[str, List[str]]] = [
33
+ ("Morphology & staining", [
34
+ "Gram Stain", "Shape", "Cellular Arrangement", "Capsule", "Spore Forming",
35
+ ]),
36
+ ("Culture & colony", [
37
+ "Media Grown On", "Colony Morphology", "Colony Pattern", "Pigment", "Odour",
38
+ "Haemolysis", "Haemolysis Type",
39
+ ]),
40
+ ("Core biochemistry", [
41
+ "Oxidase", "Catalase", "Indole", "Urease", "Citrate", "Methyl Red", "VP",
42
+ "Nitrate Reduction", "ONPG", "TSI Pattern", "H2S", "Gas Production",
43
+ "Glucose Fermentation", "Lactose Fermentation", "Sucrose Fermentation",
44
+ "Inositol Fermentation", "Mannitol Fermentation",
45
+ ]),
46
+ ("Motility & growth conditions", [
47
+ "Motility", "Motility Type", "Growth Temperature", "NaCl", "NaCl Tolerance",
48
+ "Oxygen Requirement",
49
+ ]),
50
+ ("Other tests", [
51
+ "DNase", "Esculin Hydrolysis", "Gelatin Hydrolysis",
52
+ "Lysine Decarboxylase", "Ornithine Decarboxylase", "Arginine Dihydrolase",
53
+ ]),
54
+ ]
55
+
56
+
57
+ def _is_schema_dump(text: str) -> bool:
58
+ """
59
+ Detect if rag context looks like flattened Field: Value lines.
60
+ """
61
+ if not text:
62
+ return False
63
+ lines = [l for l in text.splitlines() if l.strip()]
64
+ if len(lines) < 6:
65
+ return False
66
+ hits = 0
67
+ for l in lines[:40]:
68
+ if _FIELD_LINE_RE.match(l):
69
+ hits += 1
70
+ return hits >= max(4, int(0.5 * min(len(lines), 40)))
71
+
72
+
73
+ def _split_listish(field: str, value: str) -> str:
74
+ """
75
+ Normalize list-like values into comma-separated readable text.
76
+ """
77
+ v = (value or "").strip()
78
+ if not v:
79
+ return v
80
+ if field in _LIST_LIKE_FIELDS or (";" in v) or ("," in v):
81
+ parts = [p.strip() for p in re.split(r"[;,\|]+", v) if p.strip()]
82
+ if parts:
83
+ return ", ".join(parts)
84
+ return v
85
+
86
+
87
+ def _parse_field_lines(text: str) -> Dict[str, str]:
88
+ """
89
+ Parse Field: Value lines into a dict. Keeps last occurrence.
90
+ """
91
+ out: Dict[str, str] = {}
92
+ for raw in (text or "").splitlines():
93
+ line = raw.strip()
94
+ if not line:
95
+ continue
96
+ m = _FIELD_LINE_RE.match(line)
97
+ if not m:
98
+ continue
99
+ field = m.group(1).strip()
100
+ value = m.group(2).strip()
101
+ if not field:
102
+ continue
103
+ out[field] = _split_listish(field, value)
104
+ return out
105
+
106
+
107
+ def _format_grouped_blocks(fields: Dict[str, str]) -> str:
108
+ """
109
+ Turn fields into grouped, readable evidence blocks.
110
+ """
111
+ used = set()
112
+ blocks: List[str] = []
113
+
114
+ for title, keys in _GROUPS:
115
+ lines: List[str] = []
116
+ for k in keys:
117
+ if k in fields:
118
+ val = fields[k]
119
+ if val and val.lower() != "unknown":
120
+ lines.append(f"- {k}: {val}")
121
+ used.add(k)
122
+ if lines:
123
+ blocks.append(f"{title}:\n" + "\n".join(lines))
124
+
125
+ # Any leftovers not in group map
126
+ leftovers: List[str] = []
127
+ for k, v in fields.items():
128
+ if k in used:
129
+ continue
130
+ if not v or v.lower() == "unknown":
131
+ continue
132
+ leftovers.append(f"- {k}: {v}")
133
+ if leftovers:
134
+ blocks.append("Additional traits:\n" + "\n".join(leftovers))
135
+
136
+ return "\n\n".join(blocks).strip()
137
+
138
+
139
+ def shape_llm_context(
140
+ llm_context: str,
141
+ target_genus: str = "",
142
+ max_chars: int = 1800,
143
+ ) -> str:
144
+ """
145
+ Main entrypoint.
146
+ - If context is already narrative, keep it (trim to max_chars).
147
+ - If it is a schema dump, convert to grouped evidence blocks.
148
+ """
149
+ ctx = (llm_context or "").strip()
150
+ if not ctx:
151
+ return ""
152
+
153
+ if _is_schema_dump(ctx):
154
+ fields = _parse_field_lines(ctx)
155
+ shaped = _format_grouped_blocks(fields)
156
+
157
+ # Add a tiny header to cue the LLM that this is reference evidence
158
+ if target_genus:
159
+ shaped = f"Reference evidence for {target_genus} (compiled traits):\n\n{shaped}"
160
+ else:
161
+ shaped = f"Reference evidence (compiled traits):\n\n{shaped}"
162
+
163
+ return shaped[:max_chars].strip()
164
+
165
+ # Narrative context: just trim
166
+ if target_genus:
167
+ ctx = f"Reference context for {target_genus}:\n\n{ctx}"
168
+ return ctx[:max_chars].strip()
rag/rag_embedder.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/rag_embedder.py
2
+ # ============================================================
3
+ # Embedding utilities for RAG (knowledge base + queries)
4
+ # Uses a SentenceTransformer model for dense embeddings.
5
+ # ============================================================
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ import json
11
+ from typing import List, Dict, Any
12
+
13
+ import numpy as np
14
+ from sentence_transformers import SentenceTransformer
15
+
16
+
17
+ # ------------------------------------------------------------
18
+ # CONFIG
19
+ # ------------------------------------------------------------
20
+
21
+ EMBEDDING_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
22
+
23
+ _model: SentenceTransformer | None = None
24
+
25
+
26
+ # ------------------------------------------------------------
27
+ # MODEL LOADING
28
+ # ------------------------------------------------------------
29
+
30
+ def get_embedder() -> SentenceTransformer:
31
+ global _model
32
+ if _model is None:
33
+ _model = SentenceTransformer(EMBEDDING_MODEL_NAME)
34
+ return _model
35
+
36
+
37
+ # ------------------------------------------------------------
38
+ # EMBEDDING
39
+ # ------------------------------------------------------------
40
+
41
+ def embed_text(text: str, normalize: bool = True) -> np.ndarray:
42
+ """
43
+ Embed a single piece of text.
44
+ Returns a 1D numpy array (MPNet: 768-dim).
45
+ """
46
+ model = get_embedder()
47
+ emb = model.encode(
48
+ [text],
49
+ show_progress_bar=False,
50
+ normalize_embeddings=normalize,
51
+ )
52
+ return emb[0]
53
+
54
+
55
+ def embed_texts(texts: List[str], normalize: bool = True) -> np.ndarray:
56
+ """
57
+ Embed a list of strings -> (N, D) numpy array.
58
+ """
59
+ model = get_embedder()
60
+ return model.encode(
61
+ texts,
62
+ show_progress_bar=False,
63
+ normalize_embeddings=normalize,
64
+ )
65
+
66
+
67
+ # ------------------------------------------------------------
68
+ # INDEX LOADING
69
+ # ------------------------------------------------------------
70
+
71
+ def load_kb_index(path: str = "data/rag/index/kb_index.json") -> Dict[str, Any]:
72
+ """
73
+ Load the RAG knowledge base index JSON.
74
+
75
+ Expected format:
76
+ {
77
+ "version": int,
78
+ "model_name": str,
79
+ "records": [
80
+ {
81
+ "id": str,
82
+ "genus": str,
83
+ "species": str | null,
84
+ "level": "genus" | "species",
85
+ "chunk_id": int,
86
+ "source_file": str,
87
+ "text": str,
88
+ "embedding": [float, ...]
89
+ }
90
+ ]
91
+ }
92
+ """
93
+ if not os.path.exists(path):
94
+ raise FileNotFoundError(f"KB index not found at {path}")
95
+
96
+ with open(path, "r", encoding="utf-8") as f:
97
+ data = json.load(f)
98
+
99
+ index_model = data.get("model_name")
100
+ if index_model != EMBEDDING_MODEL_NAME:
101
+ raise ValueError(
102
+ f"KB index built with '{index_model}', "
103
+ f"but current embedder is '{EMBEDDING_MODEL_NAME}'. "
104
+ "Rebuild the index."
105
+ )
106
+
107
+ # Convert embeddings to numpy arrays
108
+ for rec in data.get("records", []):
109
+ if isinstance(rec.get("embedding"), list):
110
+ rec["embedding"] = np.array(rec["embedding"], dtype="float32")
111
+
112
+ return data
rag/rag_generator.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/rag_generator.py
2
+ # ============================================================
3
+ # RAG generator using google/flan-t5-large (CPU-friendly)
4
+ #
5
+ # Goal (user-visible, structured, deterministic-first):
6
+ # - Show the user:
7
+ # KEY TRAITS:
8
+ # CONFLICTS:
9
+ # CONCLUSION:
10
+ # - KEY TRAITS and CONFLICTS are extracted deterministically from the
11
+ # shaped retriever context (preferred).
12
+ # - The LLM only writes the CONCLUSION (2–5 sentences) based on those
13
+ # extracted sections.
14
+ #
15
+ # Reliability:
16
+ # - flan-t5 sometimes echoes prompt instructions.
17
+ # - We keep the prompt extremely short and avoid imperative bullet rules.
18
+ # - We keep deterministic fallback logic if the LLM output is garbage/echo.
19
+ #
20
+ # Expected usage:
21
+ # ctx = retrieve_rag_context(..., parsed_fields=...)
22
+ # explanation = generate_genus_rag_explanation(
23
+ # phenotype_text=text,
24
+ # rag_context=ctx.get("llm_context_shaped") or ctx.get("llm_context"),
25
+ # genus=genus
26
+ # )
27
+ #
28
+ # Optional HF Space logs:
29
+ # export BACTAI_RAG_GEN_LOG_INPUT=1
30
+ # export BACTAI_RAG_GEN_LOG_OUTPUT=1
31
+ # ============================================================
32
+
33
+ from __future__ import annotations
34
+
35
+ import os
36
+ import re
37
+ import torch
38
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
39
+
40
+
41
+ # ------------------------------------------------------------
42
+ # MODEL CONFIG
43
+ # ------------------------------------------------------------
44
+
45
+ MODEL_NAME = "google/flan-t5-large"
46
+
47
+ _tokenizer: T5Tokenizer | None = None
48
+ _model: T5ForConditionalGeneration | None = None
49
+
50
+ # Keep small for CPU + to reduce prompt truncation weirdness
51
+ _MAX_INPUT_TOKENS = 768
52
+ _DEFAULT_MAX_NEW_TOKENS = 160
53
+
54
+ # Hard cap the context chars we feed to T5 (prevents the model focusing on junk)
55
+ _CONTEXT_CHAR_CAP = 2400
56
+
57
+
58
+ def _get_model() -> tuple[T5Tokenizer, T5ForConditionalGeneration]:
59
+ global _tokenizer, _model
60
+ if _tokenizer is None or _model is None:
61
+ _tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
62
+ _model = T5ForConditionalGeneration.from_pretrained(
63
+ MODEL_NAME,
64
+ device_map="auto",
65
+ torch_dtype=torch.float32,
66
+ )
67
+ return _tokenizer, _model
68
+
69
+
70
+ # ------------------------------------------------------------
71
+ # DEBUG LOGGING (HF Space logs)
72
+ # ------------------------------------------------------------
73
+
74
+ RAG_GEN_LOG_INPUT = os.getenv("BACTAI_RAG_GEN_LOG_INPUT", "0").strip() == "1"
75
+ RAG_GEN_LOG_OUTPUT = os.getenv("BACTAI_RAG_GEN_LOG_OUTPUT", "0").strip() == "1"
76
+
77
+
78
+ def _log_block(title: str, body: str) -> None:
79
+ print("=" * 80)
80
+ print(f"RAG GENERATOR DEBUG — {title}")
81
+ print("=" * 80)
82
+ print(body.strip() if body else "")
83
+ print()
84
+
85
+
86
+ # ------------------------------------------------------------
87
+ # PROMPT (LLM WRITES ONLY THE CONCLUSION)
88
+ # ------------------------------------------------------------
89
+
90
+ # Intentionally minimal. No "rules list", no bullets specification.
91
+ # The LLM sees ONLY extracted matches/conflicts and writes a short conclusion.
92
+ RAG_PROMPT = """summarize: Evaluate whether the phenotype fits the target genus using the provided matches and conflicts.
93
+
94
+ Target genus: {genus}
95
+
96
+ Key traits that match:
97
+ {matches}
98
+
99
+ Conflicts:
100
+ {conflicts}
101
+
102
+ Write a short conclusion (2–5 sentences) stating whether this is a strong, moderate, or tentative genus match, and briefly mention the most important matches and conflicts.
103
+ """
104
+
105
+
106
+ # ------------------------------------------------------------
107
+ # OUTPUT CLEANUP + ECHO DETECTION
108
+ # ------------------------------------------------------------
109
+
110
+ _BAD_SUBSTRINGS = (
111
+ "summarize:",
112
+ "target genus",
113
+ "key traits that match",
114
+ "write a short conclusion",
115
+ "conflicts:",
116
+ )
117
+
118
+ def _clean_generation(text: str) -> str:
119
+ s = (text or "").strip()
120
+ if not s:
121
+ return ""
122
+
123
+ # collapse excessive whitespace/newlines
124
+ s = re.sub(r"\s*\n+\s*", " ", s).strip()
125
+ s = re.sub(r"\s{2,}", " ", s).strip()
126
+
127
+ # guard runaway length
128
+ if len(s) > 900:
129
+ s = s[:900].rstrip() + "..."
130
+
131
+ return s
132
+
133
+
134
+ def _looks_like_echo_or_garbage(text: str) -> bool:
135
+ s = (text or "").strip()
136
+ if not s:
137
+ return True
138
+
139
+ # extremely short / non-sentence
140
+ if len(s) < 25:
141
+ return True
142
+
143
+ low = s.lower()
144
+ if any(bad in low for bad in _BAD_SUBSTRINGS):
145
+ return True
146
+
147
+ # Must look like actual prose
148
+ if "." not in s and "because" not in low and "match" not in low and "fits" not in low:
149
+ return True
150
+
151
+ return False
152
+
153
+
154
+ # ------------------------------------------------------------
155
+ # EXTRACT KEY TRAITS + CONFLICTS FROM SHAPED CONTEXT
156
+ # ------------------------------------------------------------
157
+
158
+ # Shaped context format (example):
159
+ # KEY MATCHES:
160
+ # - Trait: Value (matches reference: ...)
161
+ #
162
+ # CONFLICTS (observed vs CORE traits):
163
+ # - Trait: Value (conflicts reference: ...)
164
+ # or:
165
+ # CONFLICTS: Not specified.
166
+
167
+ _KEY_MATCHES_HEADER_RE = re.compile(r"^\s*KEY MATCHES\s*:\s*$", re.IGNORECASE)
168
+ _CONFLICTS_HEADER_RE = re.compile(r"^\s*CONFLICTS\b.*:\s*$", re.IGNORECASE)
169
+ _CONFLICTS_INLINE_NONE_RE = re.compile(r"^\s*CONFLICTS\s*:\s*not specified\.?\s*$", re.IGNORECASE)
170
+
171
+ _MATCH_LINE_RE = re.compile(
172
+ r"^\s*-\s*([^:]+)\s*:\s*(.+?)\s*\(matches reference:\s*(.+?)\)\s*$",
173
+ re.IGNORECASE,
174
+ )
175
+ _CONFLICT_LINE_RE = re.compile(
176
+ r"^\s*-\s*([^:]+)\s*:\s*(.+?)\s*\(conflicts reference:\s*(.+?)\)\s*$",
177
+ re.IGNORECASE,
178
+ )
179
+
180
+ # More permissive bullet capture (if shaper changes slightly)
181
+ _GENERIC_BULLET_RE = re.compile(r"^\s*-\s*(.+?)\s*$")
182
+
183
+
184
+ def _extract_key_traits_and_conflicts(shaped_ctx: str) -> tuple[list[str], list[str], bool]:
185
+ """
186
+ Extracts KEY MATCHES and CONFLICTS bullets from shaped retriever context.
187
+
188
+ Returns:
189
+ (key_traits, conflicts, found_structured_headers)
190
+
191
+ - key_traits items are short: "Trait: ObservedValue"
192
+ - conflicts items are short: "Trait: ObservedValue"
193
+ """
194
+ key_traits: list[str] = []
195
+ conflicts: list[str] = []
196
+
197
+ lines = (shaped_ctx or "").splitlines()
198
+ if not lines:
199
+ return key_traits, conflicts, False
200
+
201
+ in_matches = False
202
+ in_conflicts = False
203
+ saw_headers = False
204
+
205
+ for raw in lines:
206
+ line = raw.rstrip("\n")
207
+
208
+ # detect headers
209
+ if _KEY_MATCHES_HEADER_RE.match(line.strip()):
210
+ in_matches = True
211
+ in_conflicts = False
212
+ saw_headers = True
213
+ continue
214
+
215
+ if _CONFLICTS_INLINE_NONE_RE.match(line.strip()):
216
+ in_matches = False
217
+ in_conflicts = False
218
+ saw_headers = True
219
+ # explicit "no conflicts"
220
+ continue
221
+
222
+ if _CONFLICTS_HEADER_RE.match(line.strip()):
223
+ in_matches = False
224
+ in_conflicts = True
225
+ saw_headers = True
226
+ continue
227
+
228
+ # stop capture if another section begins (common shaper headings)
229
+ if saw_headers and (line.strip().endswith(":") and not line.strip().startswith("-")):
230
+ # If it's a new heading (and not one of our two), stop both
231
+ if not _KEY_MATCHES_HEADER_RE.match(line.strip()) and not _CONFLICTS_HEADER_RE.match(line.strip()):
232
+ in_matches = False
233
+ in_conflicts = False
234
+
235
+ # capture bullets under each section
236
+ if in_matches and line.strip().startswith("-"):
237
+ m = _MATCH_LINE_RE.match(line.strip())
238
+ if m:
239
+ trait = m.group(1).strip()
240
+ obs = m.group(2).strip()
241
+ key_traits.append(f"{trait}: {obs}")
242
+ else:
243
+ g = _GENERIC_BULLET_RE.match(line.strip())
244
+ if g:
245
+ key_traits.append(g.group(1).strip())
246
+ continue
247
+
248
+ if in_conflicts and line.strip().startswith("-"):
249
+ c = _CONFLICT_LINE_RE.match(line.strip())
250
+ if c:
251
+ trait = c.group(1).strip()
252
+ obs = c.group(2).strip()
253
+ conflicts.append(f"{trait}: {obs}")
254
+ else:
255
+ g = _GENERIC_BULLET_RE.match(line.strip())
256
+ if g:
257
+ conflicts.append(g.group(1).strip())
258
+ continue
259
+
260
+ return key_traits, conflicts, saw_headers
261
+
262
+
263
+ def _extract_matches_conflicts_legacy(shaped_ctx: str) -> tuple[list[str], list[str]]:
264
+ """
265
+ Legacy extraction based purely on (matches reference: ...) / (conflicts reference: ...)
266
+ anywhere in the text. Useful if headers are missing.
267
+ """
268
+ matches: list[str] = []
269
+ conflicts: list[str] = []
270
+
271
+ for raw in (shaped_ctx or "").splitlines():
272
+ line = raw.strip()
273
+ if not line.startswith("-"):
274
+ continue
275
+
276
+ m = _MATCH_LINE_RE.match(line)
277
+ if m:
278
+ trait = m.group(1).strip()
279
+ obs = m.group(2).strip()
280
+ matches.append(f"{trait}: {obs}")
281
+ continue
282
+
283
+ c = _CONFLICT_LINE_RE.match(line)
284
+ if c:
285
+ trait = c.group(1).strip()
286
+ obs = c.group(2).strip()
287
+ conflicts.append(f"{trait}: {obs}")
288
+ continue
289
+
290
+ return matches, conflicts
291
+
292
+
293
+ def _format_bullets(items: list[str], *, none_text: str) -> str:
294
+ if not items:
295
+ return none_text
296
+ return "\n".join(f"- {x}" for x in items)
297
+
298
+
299
+ # ------------------------------------------------------------
300
+ # DETERMINISTIC CONCLUSION FALLBACK
301
+ # ------------------------------------------------------------
302
+
303
+ def _deterministic_conclusion(genus: str, key_traits: list[str], conflicts: list[str]) -> str:
304
+ g = (genus or "").strip() or "Unknown"
305
+
306
+ m = key_traits[:4]
307
+ c = conflicts[:2]
308
+
309
+ if m and c:
310
+ return (
311
+ f"This is a probable match to {g} because it aligns with key traits such as "
312
+ f"{', '.join(m)}. However, there are conflicts ({', '.join(c)}), so treat this "
313
+ f"as a moderate/tentative genus-level fit and consider re-checking the conflicting tests."
314
+ )
315
+ if m and not c:
316
+ return (
317
+ f"This phenotype is consistent with {g} based on key matching traits such as "
318
+ f"{', '.join(m)}. No major conflicts were detected against the retrieved core genus traits, "
319
+ f"supporting a strong genus-level match."
320
+ )
321
+ if (not m) and c:
322
+ return (
323
+ f"This phenotype does not cleanly fit {g} because it conflicts with core traits "
324
+ f"({', '.join(c)}). Consider re-checking those tests or comparing against the next-ranked genera."
325
+ )
326
+
327
+ return (
328
+ f"Reference evidence was available for {g}, but no clear matches or conflicts could be extracted "
329
+ f"from the shaped context. Try increasing top_k genus chunks or ensuring parsed_fields are being "
330
+ f"passed into retrieve_rag_context so the shaper can compute KEY MATCHES and CONFLICTS."
331
+ )
332
+
333
+
334
+ def _trim_context(ctx: str) -> str:
335
+ s = (ctx or "").strip()
336
+ if not s:
337
+ return ""
338
+ if len(s) <= _CONTEXT_CHAR_CAP:
339
+ return s
340
+ return s[:_CONTEXT_CHAR_CAP].rstrip() + "\n... (truncated)"
341
+
342
+
343
+ # ------------------------------------------------------------
344
+ # PUBLIC API
345
+ # ------------------------------------------------------------
346
+
347
+ def generate_genus_rag_explanation(
348
+ phenotype_text: str,
349
+ rag_context: str,
350
+ genus: str,
351
+ max_new_tokens: int = _DEFAULT_MAX_NEW_TOKENS,
352
+ ) -> str:
353
+ """
354
+ Generates a structured RAG output intended for direct display:
355
+
356
+ KEY TRAITS:
357
+ - ...
358
+ CONFLICTS:
359
+ - ...
360
+ CONCLUSION:
361
+ ...
362
+
363
+ Notes:
364
+ - KEY TRAITS + CONFLICTS are extracted deterministically from the (shaped) context.
365
+ - The LLM writes only the CONCLUSION.
366
+ - If the LLM output is garbage/echo, we use a deterministic conclusion fallback.
367
+ """
368
+ tokenizer, model = _get_model()
369
+
370
+ genus_clean = (genus or "").strip() or "Unknown"
371
+ context = _trim_context(rag_context or "")
372
+
373
+ if not context:
374
+ return (
375
+ "KEY TRAITS:\n"
376
+ "- Not specified.\n\n"
377
+ "CONFLICTS:\n"
378
+ "- Not specified.\n\n"
379
+ "CONCLUSION:\n"
380
+ "No reference evidence was available to evaluate this genus against the observed phenotype."
381
+ )
382
+
383
+ # Prefer structured extraction (KEY MATCHES / CONFLICTS sections)
384
+ key_traits, conflicts, saw_headers = _extract_key_traits_and_conflicts(context)
385
+
386
+ # If the headers weren't found or extraction is empty, try legacy extraction
387
+ if (not saw_headers) or (not key_traits and not conflicts):
388
+ legacy_matches, legacy_conflicts = _extract_matches_conflicts_legacy(context)
389
+ if legacy_matches or legacy_conflicts:
390
+ key_traits = key_traits or legacy_matches
391
+ conflicts = conflicts or legacy_conflicts
392
+
393
+ key_traits_text = _format_bullets(key_traits, none_text="- Not specified.")
394
+ conflicts_text = _format_bullets(conflicts, none_text="- Not specified.")
395
+
396
+ # LLM: conclusion only
397
+ prompt = RAG_PROMPT.format(
398
+ genus=genus_clean,
399
+ matches=key_traits_text,
400
+ conflicts=conflicts_text,
401
+ )
402
+
403
+ if RAG_GEN_LOG_INPUT:
404
+ _log_block("PROMPT (CONCLUSION-ONLY)", prompt[:3000] + ("\n...(truncated)" if len(prompt) > 3000 else ""))
405
+
406
+ inputs = tokenizer(
407
+ prompt,
408
+ return_tensors="pt",
409
+ truncation=True,
410
+ max_length=_MAX_INPUT_TOKENS,
411
+ ).to(model.device)
412
+
413
+ output = model.generate(
414
+ **inputs,
415
+ max_new_tokens=max_new_tokens,
416
+ temperature=0.0,
417
+ num_beams=1,
418
+ do_sample=False,
419
+ repetition_penalty=1.2,
420
+ no_repeat_ngram_size=3,
421
+ )
422
+
423
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True).strip()
424
+ cleaned = _clean_generation(decoded)
425
+
426
+ if RAG_GEN_LOG_OUTPUT:
427
+ _log_block("RAW OUTPUT (CONCLUSION)", decoded)
428
+ _log_block("CLEANED OUTPUT (CONCLUSION)", cleaned)
429
+
430
+ # If LLM output is junk, use deterministic conclusion
431
+ if _looks_like_echo_or_garbage(cleaned):
432
+ cleaned = _deterministic_conclusion(genus_clean, key_traits, conflicts)
433
+ if RAG_GEN_LOG_OUTPUT:
434
+ _log_block("FALLBACK CONCLUSION (DETERMINISTIC)", cleaned)
435
+
436
+ # Final user-visible structured output
437
+ final = (
438
+ "KEY TRAITS:\n"
439
+ f"{key_traits_text}\n\n"
440
+ "CONFLICTS:\n"
441
+ f"{conflicts_text}\n\n"
442
+ "CONCLUSION:\n"
443
+ f"{cleaned}"
444
+ )
445
+
446
+ return final
rag/rag_retriever.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/rag_retriever.py
2
+ # ============================================================
3
+ # RAG retriever (Stage 2 – microbiology-aware)
4
+ #
5
+ # Key change (GENUS-FIRST):
6
+ # - The generator must NOT see multiple species dumps.
7
+ # - We retrieve GENUS-level records only for llm_context/llm_context_shaped.
8
+ # - Species is handled separately (deterministic species_scorer), not via LLM context.
9
+ #
10
+ # Improvements retained:
11
+ # - Source-type weighting (but genus-only for generator)
12
+ # - Genus-aware query expansion
13
+ # - Diversity enforcement (avoid duplicate sources)
14
+ # - Explicit ranking & score annotations for generator (DEBUG ONLY)
15
+ # - OPTIONAL: species evidence scoring (deterministic)
16
+ # - NEW: Context shaper (deterministic) -> resolves conflicts + emits genus-ready summary
17
+ #
18
+ # IMPORTANT:
19
+ # - We return THREE contexts:
20
+ # 1) llm_context -> GENUS-only raw text (SAFE but unshaped)
21
+ # 2) llm_context_shaped -> shaped, conflict-aware, generator-friendly
22
+ # 3) debug_context -> includes RANK/SCORE/WEIGHTS (UI/logging only)
23
+ # ============================================================
24
+
25
+ from __future__ import annotations
26
+
27
+ from typing import List, Dict, Any, Optional, Tuple
28
+ import re
29
+ import numpy as np
30
+
31
+ from rag.rag_embedder import embed_text, load_kb_index
32
+
33
+ # deterministic species evidence scorer (separate from generator context)
34
+ try:
35
+ from rag.species_scorer import score_species_for_genus
36
+ HAS_SPECIES_SCORER = True
37
+ except Exception:
38
+ score_species_for_genus = None # type: ignore
39
+ HAS_SPECIES_SCORER = False
40
+
41
+
42
+ # ------------------------------------------------------------
43
+ # Configuration
44
+ # ------------------------------------------------------------
45
+
46
+ # NOTE: We keep these for debug display + potential fallback modes.
47
+ SOURCE_TYPE_WEIGHTS = {
48
+ "species": 1.15,
49
+ "genus": 1.00,
50
+ "table": 1.10,
51
+ "note": 0.85,
52
+ }
53
+
54
+ MAX_CHUNKS_PER_SOURCE = 1
55
+
56
+ # Context shaping caps (keeps prompt within LLM limits)
57
+ SHAPER_MAX_CORE = 14
58
+ SHAPER_MAX_VARIABLE = 12
59
+ SHAPER_MAX_MATCHES = 14
60
+ SHAPER_MAX_CONFLICTS = 12
61
+ SHAPER_MAX_TOTAL_CHARS = 9000 # final guardrail
62
+
63
+
64
+ # ------------------------------------------------------------
65
+ # Similarity helper
66
+ # ------------------------------------------------------------
67
+
68
+ def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
69
+ """
70
+ Cosine similarity for normalized embeddings.
71
+ Assumes both vectors are already L2-normalized.
72
+ """
73
+ return float(np.dot(a, b))
74
+
75
+
76
+ # ------------------------------------------------------------
77
+ # Context Shaper (deterministic)
78
+ # ------------------------------------------------------------
79
+
80
+ _TRAIT_LINE_RE = re.compile(
81
+ r"^\s*([A-Za-z0-9][A-Za-z0-9 \/\-\(\)\[\]%>=<\+\.]*?)\s*:\s*(.+?)\s*$"
82
+ )
83
+
84
+ # Headers / junk lines we don't want treated as traits
85
+ _SHAPER_SKIP_PREFIXES = (
86
+ "expected fields for species",
87
+ "expected fields for genus",
88
+ "reference context",
89
+ "genus evidence primer",
90
+ )
91
+
92
+ def _norm_val(v: str) -> str:
93
+ s = (v or "").strip()
94
+ if not s:
95
+ return ""
96
+ s = re.sub(r"\s+", " ", s)
97
+ return s
98
+
99
+ def _canon_bool(v: str) -> str:
100
+ """
101
+ Canonicalize common boolean-ish microbiology values.
102
+ Conservative: no inference.
103
+ """
104
+ s = _norm_val(v).lower()
105
+ if s in {"pos", "positive", "+", "reactive"}:
106
+ return "Positive"
107
+ if s in {"neg", "negative", "-", "nonreactive", "non-reactive"}:
108
+ return "Negative"
109
+ if s in {"none"}:
110
+ return "None"
111
+ if s in {"unknown", "not specified", "n/a", "na"}:
112
+ return "Unknown"
113
+ if s in {"variable"}:
114
+ return "Variable"
115
+ return _norm_val(v)
116
+
117
+ def _canon_trait_name(name: str) -> str:
118
+ s = _norm_val(name)
119
+ s_low = s.lower()
120
+ if s_low == "ornitihine decarboxylase":
121
+ return "Ornithine Decarboxylase"
122
+ return s
123
+
124
+ def _extract_traits_from_text_block(text: str) -> List[Tuple[str, str]]:
125
+ """
126
+ Extract (trait, value) pairs from lines like:
127
+ Trait Name: Value
128
+ """
129
+ pairs: List[Tuple[str, str]] = []
130
+ for raw_line in (text or "").splitlines():
131
+ line = raw_line.strip()
132
+ if not line:
133
+ continue
134
+ low = line.lower()
135
+ if any(low.startswith(p) for p in _SHAPER_SKIP_PREFIXES):
136
+ continue
137
+ m = _TRAIT_LINE_RE.match(line)
138
+ if not m:
139
+ continue
140
+ k = _canon_trait_name(m.group(1))
141
+ v = _canon_bool(m.group(2))
142
+ if not k or not v:
143
+ continue
144
+ pairs.append((k, v))
145
+ return pairs
146
+
147
+ def _compare_vals(observed: str, reference: str) -> Optional[bool]:
148
+ """
149
+ Returns:
150
+ True -> match
151
+ False -> conflict
152
+ None -> cannot compare (unknown/variable/empty)
153
+ """
154
+ o = _canon_bool(observed)
155
+ r = _canon_bool(reference)
156
+
157
+ if not o or o == "Unknown":
158
+ return None
159
+ if not r or r in {"Unknown", "Variable"}:
160
+ return None
161
+
162
+ if o == r:
163
+ return True
164
+
165
+ # Safe equivalences (very conservative)
166
+ eq = {
167
+ ("None", "Negative"),
168
+ ("Negative", "None"),
169
+ }
170
+ if (o, r) in eq:
171
+ return True
172
+
173
+ return False
174
+
175
+ def shape_genus_context(
176
+ *,
177
+ target_genus: str,
178
+ selected_chunks: List[Dict[str, Any]],
179
+ parsed_fields: Optional[Dict[str, str]] = None,
180
+ ) -> str:
181
+ """
182
+ Deterministic, GENUS-focused context shaper.
183
+
184
+ It:
185
+ - aggregates trait lines across retrieved GENUS chunks
186
+ - identifies CORE traits (single consistent value across chunks)
187
+ - identifies VARIABLE traits (multiple values across chunks)
188
+ - if parsed_fields provided, derives:
189
+ - phenotype-supported matches vs CORE traits
190
+ - phenotype conflicts vs CORE traits
191
+ - outputs a compact, reasoning-friendly block for the generator
192
+ """
193
+ genus = (target_genus or "").strip() or "Unknown"
194
+
195
+ trait_values: Dict[str, List[str]] = {}
196
+
197
+ for rec in selected_chunks or []:
198
+ txt = (rec.get("text") or "").strip()
199
+ if not txt:
200
+ continue
201
+ for k, v in _extract_traits_from_text_block(txt):
202
+ trait_values.setdefault(k, []).append(v)
203
+
204
+ # Reduce to unique canonical values
205
+ trait_uniques: Dict[str, List[str]] = {}
206
+ for k, vals in trait_values.items():
207
+ uniq: List[str] = []
208
+ for v in vals:
209
+ vv = _canon_bool(v)
210
+ if not vv:
211
+ continue
212
+ if vv not in uniq:
213
+ uniq.append(vv)
214
+ if uniq:
215
+ trait_uniques[k] = uniq
216
+
217
+ core_traits: List[Tuple[str, str]] = []
218
+ variable_traits: List[Tuple[str, str]] = []
219
+
220
+ for k, uniq in trait_uniques.items():
221
+ if len(uniq) == 1:
222
+ core_traits.append((k, uniq[0]))
223
+ else:
224
+ variable_traits.append((k, " / ".join(uniq)))
225
+
226
+ PRIORITY = {
227
+ "Gram Stain": 1,
228
+ "Shape": 2,
229
+ "Motility": 3,
230
+ "Motility Type": 4,
231
+ "Oxidase": 5,
232
+ "Catalase": 6,
233
+ "Oxygen Requirement": 7,
234
+ "Lactose Fermentation": 8,
235
+ "Glucose Fermentation": 9,
236
+ "H2S": 10,
237
+ "Indole": 11,
238
+ "Urease": 12,
239
+ "Citrate": 13,
240
+ "ONPG": 14,
241
+ "NaCl Tolerant (>=6%)": 15,
242
+ "Media Grown On": 16,
243
+ "Colony Morphology": 17,
244
+ }
245
+
246
+ def _sort_key(item: Tuple[str, str]) -> Tuple[int, str]:
247
+ return (PRIORITY.get(item[0], 999), item[0].lower())
248
+
249
+ core_traits.sort(key=_sort_key)
250
+ variable_traits.sort(key=_sort_key)
251
+
252
+ core_traits = core_traits[:SHAPER_MAX_CORE]
253
+ variable_traits = variable_traits[:SHAPER_MAX_VARIABLE]
254
+
255
+ matches: List[str] = []
256
+ conflicts: List[str] = []
257
+
258
+ if parsed_fields:
259
+ for k, ref_v in core_traits:
260
+ obs_v = parsed_fields.get(k)
261
+ if obs_v is None:
262
+ continue
263
+ cmp = _compare_vals(obs_v, ref_v)
264
+ if cmp is True:
265
+ matches.append(f"- {k}: {_canon_bool(obs_v)} (matches reference: {ref_v})")
266
+ elif cmp is False:
267
+ conflicts.append(f"- {k}: {_canon_bool(obs_v)} (conflicts reference: {ref_v})")
268
+
269
+ matches = matches[:SHAPER_MAX_MATCHES]
270
+ conflicts = conflicts[:SHAPER_MAX_CONFLICTS]
271
+
272
+ lines: List[str] = []
273
+ lines.append(f"GENUS SUMMARY (reference-driven): {genus}")
274
+
275
+ if core_traits:
276
+ lines.append("\nCORE GENUS TRAITS (consistent across retrieved genus references):")
277
+ for k, v in core_traits:
278
+ lines.append(f"- {k}: {v}")
279
+ else:
280
+ lines.append("\nCORE GENUS TRAITS: Not available from retrieved context.")
281
+
282
+ if variable_traits:
283
+ lines.append("\nTRAITS VARIABLE ACROSS RETRIEVED GENUS REFERENCES (do not treat as contradictions):")
284
+ for k, v in variable_traits:
285
+ lines.append(f"- {k}: Variable ({v})")
286
+
287
+ if parsed_fields:
288
+ lines.append("\nPHENOTYPE SUPPORT (observed vs CORE traits):")
289
+ if matches:
290
+ lines.append("KEY MATCHES:")
291
+ lines.extend(matches)
292
+ else:
293
+ lines.append("KEY MATCHES: Not specified.")
294
+
295
+ if conflicts:
296
+ lines.append("\nCONFLICTS (observed vs CORE traits):")
297
+ lines.extend(conflicts)
298
+ else:
299
+ lines.append("\nCONFLICTS: Not specified.")
300
+
301
+ shaped = "\n".join(lines).strip()
302
+
303
+ if len(shaped) > SHAPER_MAX_TOTAL_CHARS:
304
+ shaped = shaped[:SHAPER_MAX_TOTAL_CHARS].rstrip() + "\n... (truncated)"
305
+
306
+ return shaped
307
+
308
+
309
+ # ------------------------------------------------------------
310
+ # Public API
311
+ # ------------------------------------------------------------
312
+
313
+ def retrieve_rag_context(
314
+ phenotype_text: str,
315
+ target_genus: str,
316
+ top_k: int = 5,
317
+ kb_path: str = "data/rag/index/kb_index.json",
318
+ parsed_fields: Optional[Dict[str, str]] = None,
319
+ species_top_n: int = 5,
320
+ allow_species_fallback: bool = False,
321
+ ) -> Dict[str, Any]:
322
+ """
323
+ Retrieve the most relevant RAG chunks for a phenotype + genus.
324
+
325
+ GENUS-FIRST behavior:
326
+ - For LLM generator contexts, we retrieve ONLY genus-level records (level == "genus").
327
+ - Species is handled separately via deterministic species_scorer.
328
+
329
+ Optional:
330
+ parsed_fields -> enables species evidence scoring + context shaping matches/conflicts.
331
+
332
+ Returns:
333
+ {
334
+ "genus": target_genus,
335
+ "chunks": [...], # ranked chunk metadata (GENUS chunks unless fallback enabled)
336
+ "llm_context": "....", # GENUS raw text (no scores)
337
+ "llm_context_shaped": "....", # deterministic genus-friendly summary
338
+ "debug_context": "....", # annotated with rank/score/weights
339
+ "species_evidence": { ... } # optional deterministic species scoring
340
+ }
341
+ """
342
+
343
+ kb = load_kb_index(kb_path)
344
+ records = kb.get("records", [])
345
+
346
+ if not records:
347
+ return {
348
+ "genus": target_genus,
349
+ "chunks": [],
350
+ "llm_context": "",
351
+ "llm_context_shaped": "",
352
+ "debug_context": "",
353
+ "species_evidence": {"genus": target_genus, "ranked": []},
354
+ }
355
+
356
+ query_text = (phenotype_text or "").strip()
357
+ if target_genus:
358
+ query_text = f"{query_text}\nTarget genus: {target_genus}"
359
+
360
+ q_emb = embed_text(query_text, normalize=True)
361
+ target_genus_lc = (target_genus or "").strip().lower()
362
+
363
+ scored_records: List[Dict[str, Any]] = []
364
+
365
+ # --------------------------------------------------------
366
+ # Primary pass: STRICT genus-filtered + GENUS-LEVEL only
367
+ # --------------------------------------------------------
368
+ for rec in records:
369
+ rec_genus = (rec.get("genus") or "").strip().lower()
370
+ if target_genus_lc and rec_genus != target_genus_lc:
371
+ continue
372
+
373
+ level = (rec.get("level") or "").strip().lower()
374
+ if level != "genus":
375
+ continue # GENUS-ONLY for generator context
376
+
377
+ emb = rec.get("embedding")
378
+ if emb is None:
379
+ continue
380
+
381
+ base_score = _cosine_similarity(q_emb, emb)
382
+ weight = SOURCE_TYPE_WEIGHTS.get(level, 1.0)
383
+ score = base_score * weight
384
+
385
+ scored_records.append(
386
+ {
387
+ "id": rec.get("id"),
388
+ "genus": rec.get("genus"),
389
+ "species": rec.get("species"),
390
+ "source_type": level,
391
+ "path": rec.get("source_file"),
392
+ "text": rec.get("text"),
393
+ "score": float(score),
394
+ "base_score": float(base_score),
395
+ "type_weight": float(weight),
396
+ "section": rec.get("section"),
397
+ "role": rec.get("role"),
398
+ "chunk_id": rec.get("chunk_id"),
399
+ }
400
+ )
401
+
402
+ # --------------------------------------------------------
403
+ # Fallback modes
404
+ # --------------------------------------------------------
405
+ if not scored_records and allow_species_fallback:
406
+ # Emergency fallback: allow any level if no genus chunks exist.
407
+ # This keeps your app functioning, but can reintroduce noise.
408
+ for rec in records:
409
+ rec_genus = (rec.get("genus") or "").strip().lower()
410
+ if target_genus_lc and rec_genus != target_genus_lc:
411
+ continue
412
+
413
+ emb = rec.get("embedding")
414
+ if emb is None:
415
+ continue
416
+
417
+ level = (rec.get("level") or "").strip().lower()
418
+ base_score = _cosine_similarity(q_emb, emb)
419
+ weight = SOURCE_TYPE_WEIGHTS.get(level, 1.0)
420
+ score = base_score * weight
421
+
422
+ scored_records.append(
423
+ {
424
+ "id": rec.get("id"),
425
+ "genus": rec.get("genus"),
426
+ "species": rec.get("species"),
427
+ "source_type": level,
428
+ "path": rec.get("source_file"),
429
+ "text": rec.get("text"),
430
+ "score": float(score),
431
+ "base_score": float(base_score),
432
+ "type_weight": float(weight),
433
+ "section": rec.get("section"),
434
+ "role": rec.get("role"),
435
+ "chunk_id": rec.get("chunk_id"),
436
+ }
437
+ )
438
+
439
+ # Sort by score
440
+ scored_records.sort(key=lambda r: r["score"], reverse=True)
441
+
442
+ # Diversity enforcement
443
+ selected: List[Dict[str, Any]] = []
444
+ source_counts: Dict[str, int] = {}
445
+
446
+ for rec in scored_records:
447
+ src = rec.get("path") or ""
448
+ count = source_counts.get(src, 0)
449
+ if count >= MAX_CHUNKS_PER_SOURCE:
450
+ continue
451
+ selected.append(rec)
452
+ source_counts[src] = count + 1
453
+ if len(selected) >= top_k:
454
+ break
455
+
456
+ # Build contexts
457
+ llm_ctx_parts: List[str] = []
458
+ debug_ctx_parts: List[str] = []
459
+
460
+ for idx, rec in enumerate(selected, start=1):
461
+ txt = (rec.get("text") or "").strip()
462
+ if txt:
463
+ llm_ctx_parts.append(txt)
464
+
465
+ label = rec.get("genus") or "Unknown genus"
466
+ if rec.get("species"):
467
+ label = f"{label} {rec['species']}"
468
+
469
+ debug_ctx_parts.append(
470
+ f"[RANK {idx} | SCORE {rec['score']:.3f} | BASE {rec['base_score']:.3f} | "
471
+ f"W {rec['type_weight']:.2f} | {label} — {rec.get('source_type')}]"
472
+ + (
473
+ f" [section={rec.get('section')} role={rec.get('role')}]"
474
+ if rec.get("section") or rec.get("role")
475
+ else ""
476
+ )
477
+ + "\n"
478
+ + (txt or "")
479
+ )
480
+
481
+ llm_context = "\n\n".join(llm_ctx_parts).strip()
482
+ debug_context = "\n\n".join(debug_ctx_parts).strip()
483
+
484
+ llm_context_shaped = shape_genus_context(
485
+ target_genus=target_genus,
486
+ selected_chunks=selected,
487
+ parsed_fields=parsed_fields,
488
+ )
489
+
490
+ # OPTIONAL: deterministic species evidence scoring
491
+ species_evidence = {"genus": target_genus, "ranked": []}
492
+ if parsed_fields and HAS_SPECIES_SCORER and score_species_for_genus is not None:
493
+ try:
494
+ species_evidence = score_species_for_genus(
495
+ target_genus=target_genus,
496
+ parsed_fields=parsed_fields,
497
+ top_n=species_top_n,
498
+ )
499
+ except Exception:
500
+ species_evidence = {"genus": target_genus, "ranked": []}
501
+
502
+ return {
503
+ "genus": target_genus,
504
+ "chunks": selected,
505
+ "llm_context": llm_context,
506
+ "llm_context_shaped": llm_context_shaped,
507
+ "debug_context": debug_context,
508
+ "species_evidence": species_evidence,
509
+ }
rag/species_scorer.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/species_scorer.py
2
+ # ============================================================
3
+ # Species evidence scorer (deterministic, explainable)
4
+ #
5
+ # Given:
6
+ # - target_genus
7
+ # - parsed_fields (from fusion)
8
+ # It loads species JSON files under:
9
+ # data/rag/knowledge_base/<Genus>/*.json (excluding genus.json)
10
+ #
11
+ # And returns:
12
+ # - ranked species list with scores
13
+ # - explicit matches / conflicts
14
+ # - marker hits (importance-weighted)
15
+ #
16
+ # Notes:
17
+ # - This is NOT an LLM. No speculation.
18
+ # - Handles list-like fields (Media / Colony Morphology) as overlap scores.
19
+ # - Handles P/N/V/Unknown fields.
20
+ # ============================================================
21
+
22
+ from __future__ import annotations
23
+
24
+ import json
25
+ import os
26
+ import re
27
+ from typing import Any, Dict, List, Tuple, Optional
28
+
29
+
30
+ KB_ROOT = os.path.join("data", "rag", "knowledge_base")
31
+
32
+ UNKNOWN = "Unknown"
33
+
34
+ LIST_FIELDS = {
35
+ "Media Grown On",
36
+ "Colony Morphology",
37
+ }
38
+
39
+ # Importance → weight
40
+ MARKER_WEIGHT = {
41
+ "high": 3.0,
42
+ "medium": 2.0,
43
+ "low": 1.5,
44
+ }
45
+
46
+ # Base scoring weights
47
+ FIELD_MATCH_WEIGHT = 1.0
48
+ FIELD_CONFLICT_PENALTY = 1.2 # conflicts hurt slightly more than matches help
49
+ VARIABLE_MATCH_BONUS = 0.2 # weak support if expected is Variable
50
+
51
+
52
+ def _norm_str(v: Any) -> str:
53
+ if v is None:
54
+ return ""
55
+ return str(v).strip()
56
+
57
+
58
+ def _norm_val(v: Any) -> str:
59
+ s = _norm_str(v)
60
+ return s if s else UNKNOWN
61
+
62
+
63
+ def _split_semicolon(s: str) -> List[str]:
64
+ parts = [p.strip() for p in re.split(r"[;,\n]+", s or "") if p.strip()]
65
+ # normalize case lightly for matching
66
+ return [p.lower() for p in parts]
67
+
68
+
69
+ def _as_list_lower(v: Any) -> List[str]:
70
+ if v is None:
71
+ return []
72
+ if isinstance(v, list):
73
+ return [str(x).strip().lower() for x in v if str(x).strip()]
74
+ # string fallback
75
+ return _split_semicolon(str(v))
76
+
77
+
78
+ def _overlap_score(expected_list: List[str], observed_list: List[str]) -> float:
79
+ """
80
+ Jaccard-like overlap, but anchored to expected:
81
+ score = (# of expected items found) / (# expected)
82
+ """
83
+ if not expected_list:
84
+ return 0.0
85
+ if not observed_list:
86
+ return 0.0
87
+ exp = set(expected_list)
88
+ obs = set(observed_list)
89
+ hit = len(exp.intersection(obs))
90
+ return hit / max(1, len(exp))
91
+
92
+
93
+ def _load_species_docs_for_genus(target_genus: str) -> List[Dict[str, Any]]:
94
+ genus = (target_genus or "").strip()
95
+ if not genus:
96
+ return []
97
+
98
+ genus_dir = os.path.join(KB_ROOT, genus)
99
+ if not os.path.isdir(genus_dir):
100
+ return []
101
+
102
+ docs: List[Dict[str, Any]] = []
103
+ for fname in sorted(os.listdir(genus_dir)):
104
+ if not fname.lower().endswith(".json"):
105
+ continue
106
+ if fname == "genus.json":
107
+ continue
108
+
109
+ path = os.path.join(genus_dir, fname)
110
+ try:
111
+ with open(path, "r", encoding="utf-8") as f:
112
+ doc = json.load(f)
113
+ if isinstance(doc, dict) and doc.get("level") == "species":
114
+ doc["_source_path"] = os.path.relpath(path)
115
+ docs.append(doc)
116
+ except Exception:
117
+ continue
118
+
119
+ return docs
120
+
121
+
122
+ def _score_expected_fields(
123
+ expected_fields: Dict[str, Any],
124
+ parsed_fields: Dict[str, str],
125
+ ) -> Tuple[float, float, List[str], List[str]]:
126
+ """
127
+ Returns:
128
+ (score, possible, matches, conflicts)
129
+ """
130
+ score = 0.0
131
+ possible = 0.0
132
+ matches: List[str] = []
133
+ conflicts: List[str] = []
134
+
135
+ for field, expected in (expected_fields or {}).items():
136
+ exp_norm = expected
137
+ obs_norm = parsed_fields.get(field, UNKNOWN)
138
+
139
+ # Skip unknown observed
140
+ if obs_norm == UNKNOWN:
141
+ continue
142
+
143
+ # List fields: overlap
144
+ if field in LIST_FIELDS:
145
+ exp_list = _as_list_lower(exp_norm)
146
+ obs_list = _as_list_lower(obs_norm)
147
+ if not exp_list:
148
+ continue
149
+
150
+ possible += FIELD_MATCH_WEIGHT
151
+ ov = _overlap_score(exp_list, obs_list)
152
+
153
+ # thresholding: any overlap = support; none = conflict
154
+ if ov > 0:
155
+ score += FIELD_MATCH_WEIGHT * ov
156
+ matches.append(f"{field}: overlap {ov:.2f}")
157
+ else:
158
+ score -= FIELD_CONFLICT_PENALTY
159
+ conflicts.append(f"{field}: expected {expected}, got {obs_norm}")
160
+ continue
161
+
162
+ exp_val = _norm_val(exp_norm)
163
+ obs_val = _norm_val(obs_norm)
164
+
165
+ # If expected is Unknown, skip
166
+ if exp_val == UNKNOWN:
167
+ continue
168
+
169
+ # If expected is Variable, weakly supportive if observed is known
170
+ if exp_val == "Variable":
171
+ possible += VARIABLE_MATCH_BONUS
172
+ score += VARIABLE_MATCH_BONUS
173
+ matches.append(f"{field}: expected Variable (observed {obs_val})")
174
+ continue
175
+
176
+ # Normal exact match
177
+ possible += FIELD_MATCH_WEIGHT
178
+ if obs_val == exp_val:
179
+ score += FIELD_MATCH_WEIGHT
180
+ matches.append(f"{field}: {obs_val}")
181
+ else:
182
+ score -= FIELD_CONFLICT_PENALTY
183
+ conflicts.append(f"{field}: expected {exp_val}, got {obs_val}")
184
+
185
+ return score, possible, matches, conflicts
186
+
187
+
188
+ def _score_species_markers(
189
+ markers: List[Dict[str, Any]],
190
+ parsed_fields: Dict[str, str],
191
+ ) -> Tuple[float, float, List[str], List[str]]:
192
+ """
193
+ Weighted marker hits. Markers are higher-signal than generic expected fields.
194
+
195
+ Returns:
196
+ (score, possible, marker_hits, marker_misses)
197
+ """
198
+ score = 0.0
199
+ possible = 0.0
200
+ hits: List[str] = []
201
+ misses: List[str] = []
202
+
203
+ for m in markers or []:
204
+ field = _norm_str(m.get("field"))
205
+ val = _norm_val(m.get("value"))
206
+ importance = _norm_str(m.get("importance")).lower() or "medium"
207
+ w = MARKER_WEIGHT.get(importance, 2.0)
208
+
209
+ if not field or val == UNKNOWN:
210
+ continue
211
+
212
+ obs = _norm_val(parsed_fields.get(field, UNKNOWN))
213
+ if obs == UNKNOWN:
214
+ continue
215
+
216
+ possible += w
217
+ if obs == val:
218
+ score += w
219
+ hits.append(f"{field}: {obs} ({importance})")
220
+ else:
221
+ score -= w * 1.1 # marker conflicts hurt more
222
+ misses.append(f"{field}: expected {val}, got {obs} ({importance})")
223
+
224
+ return score, possible, hits, misses
225
+
226
+
227
+ def _to_confidence(raw_score: float, possible: float) -> float:
228
+ """
229
+ Convert raw score into 0..1 confidence.
230
+
231
+ We use a bounded transform:
232
+ - normalize by possible
233
+ - clamp into [0,1]
234
+ """
235
+ if possible <= 0:
236
+ return 0.0
237
+
238
+ # raw_score can be negative; convert to a 0..1 scale
239
+ # normalized_score around 0 means mixed evidence
240
+ normalized = raw_score / possible # roughly -something .. +1
241
+ conf = (normalized + 1.0) / 2.0 # map [-1, +1] -> [0,1] (approx)
242
+ if conf < 0:
243
+ conf = 0.0
244
+ if conf > 1:
245
+ conf = 1.0
246
+ return float(conf)
247
+
248
+
249
+ def score_species_for_genus(
250
+ target_genus: str,
251
+ parsed_fields: Dict[str, str],
252
+ top_n: int = 5,
253
+ ) -> Dict[str, Any]:
254
+ """
255
+ Main entrypoint.
256
+
257
+ Returns:
258
+ {
259
+ "genus": "...",
260
+ "ranked": [
261
+ {
262
+ "species": "cloacae",
263
+ "full_name": "Enterobacter cloacae",
264
+ "score": 0.87,
265
+ "raw_score": ...,
266
+ "possible": ...,
267
+ "matches": [...],
268
+ "conflicts": [...],
269
+ "marker_hits": [...],
270
+ "marker_conflicts": [...],
271
+ "source_file": "data/rag/knowledge_base/Enterobacter/cloacae.json"
272
+ }, ...
273
+ ]
274
+ }
275
+ """
276
+ docs = _load_species_docs_for_genus(target_genus)
277
+ if not docs:
278
+ return {"genus": target_genus, "ranked": []}
279
+
280
+ ranked: List[Dict[str, Any]] = []
281
+
282
+ for doc in docs:
283
+ genus = _norm_str(doc.get("genus") or target_genus)
284
+ species = _norm_str(doc.get("species"))
285
+ full_name = f"{genus} {species}".strip()
286
+
287
+ expected_fields = doc.get("expected_fields") or {}
288
+ markers = doc.get("species_markers") or []
289
+
290
+ s1, p1, matches, conflicts = _score_expected_fields(expected_fields, parsed_fields)
291
+ s2, p2, marker_hits, marker_conflicts = _score_species_markers(markers, parsed_fields)
292
+
293
+ raw_score = s1 + s2
294
+ possible = p1 + p2
295
+
296
+ conf = _to_confidence(raw_score, possible)
297
+
298
+ ranked.append(
299
+ {
300
+ "species": species or os.path.splitext(os.path.basename(doc.get("_source_path", "")))[0],
301
+ "full_name": full_name,
302
+ "score": conf,
303
+ "raw_score": raw_score,
304
+ "possible": possible,
305
+ "matches": matches,
306
+ "conflicts": conflicts,
307
+ "marker_hits": marker_hits,
308
+ "marker_conflicts": marker_conflicts,
309
+ "source_file": doc.get("_source_path", ""),
310
+ }
311
+ )
312
+
313
+ ranked.sort(key=lambda x: x["score"], reverse=True)
314
+ return {"genus": target_genus, "ranked": ranked[: max(1, int(top_n))]}
scoring/diagnostic_anchors.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # scoring/diagnostic_anchors.py
2
+ # ============================================================
3
+ # Diagnostic anchor overrides:
4
+ # - If the free-text description clearly contains certain
5
+ # pathognomonic phrases, boost the corresponding genus
6
+ # in the unified ranking.
7
+ # ============================================================
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import List, Dict, Any
12
+
13
+ # Simple v1 — can expand over time
14
+ DIAGNOSTIC_ANCHORS = {
15
+ "Yersinia": [
16
+ "bull’s-eye",
17
+ "bull's eye",
18
+ "cin agar",
19
+ "pseudoappendicitis",
20
+ "pseudo-appendicitis",
21
+ ],
22
+ "Campylobacter": [
23
+ "hippurate",
24
+ "darting motility",
25
+ ],
26
+ "Vibrio": [
27
+ "tcbs agar",
28
+ "thiosulfate citrate bile salts sucrose",
29
+ "yellow colonies on tcbs",
30
+ "rice-water stool",
31
+ "rice water stool",
32
+ ],
33
+ "Proteus": [
34
+ "swarming motility",
35
+ "swarm across the plate",
36
+ "burnt chocolate odor",
37
+ "burned chocolate odour",
38
+ ],
39
+ "Listeria": [
40
+ "tumbling motility",
41
+ "cold enrichment",
42
+ "grows at 4°c",
43
+ "4°c enrichment",
44
+ ],
45
+ "Clostridioides": [
46
+ "ccfa agar",
47
+ "cycloserine cefoxitin fructose agar",
48
+ "barnyard odor",
49
+ "ground glass colonies",
50
+ ],
51
+ }
52
+
53
+
54
+ def apply_diagnostic_overrides(
55
+ description_text: str,
56
+ unified_ranking: List[Dict[str, Any]],
57
+ ) -> List[Dict[str, Any]]:
58
+ """
59
+ If the input description strongly suggests a particular genus
60
+ (anchor phrases), boost that genus in the unified ranking.
61
+
62
+ Strategy:
63
+ - If any anchor phrase for a genus is present in the text,
64
+ ensure that genus has at least 0.70 combined_score
65
+ (70% overall) *if it already appears*.
66
+ - Then re-sort by combined_score.
67
+
68
+ This is conservative: it won't hallucinate genera that aren't
69
+ already in the top list, but strengthens strong clinical signals.
70
+ """
71
+ if not description_text or not unified_ranking:
72
+ return unified_ranking
73
+
74
+ text_lc = description_text.lower()
75
+
76
+ # Which genera have anchors present?
77
+ boosted_genera = set()
78
+ for genus, phrases in DIAGNOSTIC_ANCHORS.items():
79
+ for p in phrases:
80
+ if p.lower() in text_lc:
81
+ boosted_genera.add(genus)
82
+ break
83
+
84
+ if not boosted_genera:
85
+ return unified_ranking
86
+
87
+ # Apply boost only if genus already present
88
+ for item in unified_ranking:
89
+ g = item.get("genus")
90
+ if g in boosted_genera:
91
+ score = float(item.get("combined_score", 0.0))
92
+ if score < 0.70:
93
+ item["combined_score"] = 0.70
94
+ item["combined_percent"] = 70.0
95
+
96
+ unified_ranking.sort(key=lambda d: d.get("combined_score", 0.0), reverse=True)
97
+ return unified_ranking
scoring/overall_ranker.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # scoring/overall_ranker.py
2
+ # ============================================================
3
+ # Overall Ranker — Probability Normalisation Layer
4
+ #
5
+ # PURPOSE:
6
+ # - Take already-computed combined scores (Tri-Fusion + ML)
7
+ # - Normalize top-K into human-interpretable probabilities
8
+ # - Provide odds per 1000 for UI display
9
+ #
10
+ # IMPORTANT:
11
+ # - This module DOES NOT assign confidence labels
12
+ # - Confidence logic lives in app.py (decision-band contract)
13
+ #
14
+ # OUTPUT CONTRACT (STRICT):
15
+ # {
16
+ # "overall": [
17
+ # {
18
+ # "rank": int,
19
+ # "genus": str,
20
+ # "combined_score": float,
21
+ # "normalized_share": float, # 0–1, sums to 1.0
22
+ # },
23
+ # ...
24
+ # ],
25
+ # "probabilities_1000": [
26
+ # {
27
+ # "genus": str,
28
+ # "odds_1000": int
29
+ # },
30
+ # ...
31
+ # ]
32
+ # }
33
+ # ============================================================
34
+
35
+ from typing import Dict, List, Any
36
+
37
+
38
+ def compute_overall_scores(
39
+ ml_scores: List[Dict[str, Any]],
40
+ tri_scores: Dict[str, float],
41
+ top_k: int = 5,
42
+ ) -> Dict[str, Any]:
43
+ """
44
+ Normalize already-computed combined scores into
45
+ probability shares and odds for the Top-5 decision table.
46
+
47
+ Parameters
48
+ ----------
49
+ ml_scores : list of dict
50
+ Each dict contains at least:
51
+ { "genus": str, "probability": float }
52
+ (Used ONLY to determine candidate genera)
53
+
54
+ tri_scores : dict
55
+ Dict mapping genus -> combined_score (0–1)
56
+ NOTE: This is already unified (Tri-Fusion + ML).
57
+
58
+ top_k : int
59
+ Number of top genera to return.
60
+
61
+ Returns
62
+ -------
63
+ dict
64
+ {
65
+ "overall": [
66
+ {
67
+ "rank": int,
68
+ "genus": str,
69
+ "combined_score": float,
70
+ "normalized_share": float
71
+ }
72
+ ],
73
+ "probabilities_1000": [
74
+ { "genus": str, "odds_1000": int }
75
+ ]
76
+ }
77
+ """
78
+
79
+ # --------------------------------------------------------
80
+ # 1. Build candidate list
81
+ # --------------------------------------------------------
82
+ combined_rows: List[Dict[str, Any]] = []
83
+
84
+ for genus, score in tri_scores.items():
85
+ try:
86
+ cs = float(score)
87
+ except Exception:
88
+ cs = 0.0
89
+
90
+ if cs > 0:
91
+ combined_rows.append({
92
+ "genus": genus,
93
+ "combined_score": cs
94
+ })
95
+
96
+ if not combined_rows:
97
+ return {
98
+ "overall": [],
99
+ "probabilities_1000": [],
100
+ }
101
+
102
+ # --------------------------------------------------------
103
+ # 2. Sort and trim to top_k
104
+ # --------------------------------------------------------
105
+ combined_rows.sort(
106
+ key=lambda x: x["combined_score"],
107
+ reverse=True
108
+ )
109
+
110
+ top = combined_rows[:top_k]
111
+
112
+ # --------------------------------------------------------
113
+ # 3. Normalize to probability shares (sum = 1.0)
114
+ # --------------------------------------------------------
115
+ total_score = sum(x["combined_score"] for x in top)
116
+
117
+ if total_score <= 0:
118
+ total_score = 1.0 # safety fallback
119
+
120
+ overall: List[Dict[str, Any]] = []
121
+ probabilities_1000: List[Dict[str, Any]] = []
122
+
123
+ for idx, row in enumerate(top, start=1):
124
+ share = row["combined_score"] / total_score
125
+
126
+ # Clamp defensively
127
+ share = max(0.0, min(1.0, share))
128
+
129
+ odds_1000 = int(round(share * 1000))
130
+
131
+ overall.append({
132
+ "rank": idx,
133
+ "genus": row["genus"],
134
+ "combined_score": round(row["combined_score"], 6),
135
+ "normalized_share": round(share, 6),
136
+ })
137
+
138
+ probabilities_1000.append({
139
+ "genus": row["genus"],
140
+ "odds_1000": odds_1000,
141
+ })
142
+
143
+ return {
144
+ "overall": overall,
145
+ "probabilities_1000": probabilities_1000,
146
+ }
static/eph.jpeg ADDED

Git LFS Details

  • SHA256: 0852f987a45e317f52bfacd47f93df1fb9d2cbcb626be12def47672f400c45f3
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
training/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Marks the 'training' directory as a Python package
2
+
training/alias_trainer.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training/alias_trainer.py
2
+ # ------------------------------------------------------------
3
+ # Stage 10B - Alias Trainer
4
+ #
5
+ # Learns field/value synonyms from gold tests by comparing:
6
+ # - expected values (gold standard)
7
+ # - parsed values (rules + extended)
8
+ #
9
+ # Outputs:
10
+ # - Updated alias_maps.json
11
+ #
12
+ # This is the core intelligence that allows BactAI-D
13
+ # to understand variations in microbiology language.
14
+ # ------------------------------------------------------------
15
+
16
+ import json
17
+ import os
18
+ from collections import defaultdict
19
+
20
+ from engine.parser_rules import parse_text_rules
21
+ from engine.parser_ext import parse_text_extended
22
+
23
+
24
+ GOLD_PATH = "training/gold_tests.json"
25
+ ALIAS_PATH = "data/alias_maps.json"
26
+
27
+
28
+ def normalise(s):
29
+ if s is None:
30
+ return ""
31
+ return str(s).strip().lower()
32
+
33
+
34
+ def learn_aliases():
35
+ """
36
+ Learns synonym mappings from gold tests.
37
+ """
38
+ if not os.path.exists(GOLD_PATH):
39
+ return {"error": f"Gold tests missing: {GOLD_PATH}"}
40
+
41
+ with open(GOLD_PATH, "r", encoding="utf-8") as f:
42
+ gold = json.load(f)
43
+
44
+ # Load or create alias map
45
+ if os.path.exists(ALIAS_PATH):
46
+ with open(ALIAS_PATH, "r", encoding="utf-8") as f:
47
+ alias_maps = json.load(f)
48
+ else:
49
+ alias_maps = {}
50
+
51
+ # Track suggestions
52
+ suggestions = defaultdict(lambda: defaultdict(int))
53
+
54
+ # ------------------------------------------------------------
55
+ # Compare expected vs parsed for all tests
56
+ # ------------------------------------------------------------
57
+ for test in gold:
58
+ text = test.get("input", "")
59
+ expected = test.get("expected", {})
60
+
61
+ rules = parse_text_rules(text).get("parsed_fields", {})
62
+ ext = parse_text_extended(text).get("parsed_fields", {})
63
+
64
+ # merge deterministic parsers
65
+ merged = dict(rules)
66
+ for k, v in ext.items():
67
+ if v != "Unknown":
68
+ merged[k] = v
69
+
70
+ # now compare with expected
71
+ for field, exp_val in expected.items():
72
+ exp_norm = normalise(exp_val)
73
+ got_norm = normalise(merged.get(field, "Unknown"))
74
+
75
+ # Skip correct matches
76
+ if exp_norm == got_norm:
77
+ continue
78
+
79
+ # Skip unknown expected
80
+ if exp_norm in ["", "unknown"]:
81
+ continue
82
+
83
+ # Mismatched → candidate alias
84
+ if got_norm not in ["", "unknown"]:
85
+ suggestions[field][got_norm] += 1
86
+
87
+ # ------------------------------------------------------------
88
+ # Convert suggestions into alias mappings
89
+ # ------------------------------------------------------------
90
+ alias_updates = {}
91
+
92
+ for field, values in suggestions.items():
93
+ # ignore fields with tiny evidence
94
+ for wrong_value, count in values.items():
95
+ if count < 2:
96
+ continue # avoid noise
97
+
98
+ # add/update alias
99
+ if field not in alias_maps:
100
+ alias_maps[field] = {}
101
+
102
+ # map wrong_value → expected canonical version
103
+ # canonical version is the most common value in gold_tests for that field
104
+ canonical = None
105
+ # determine canonical
106
+ field_values = [normalise(t["expected"][field]) for t in gold if field in t["expected"]]
107
+ if field_values:
108
+ # most common expected value
109
+ canonical = max(set(field_values), key=field_values.count)
110
+
111
+ if canonical:
112
+ alias_maps[field][wrong_value] = canonical
113
+ alias_updates[f"{field}:{wrong_value}"] = canonical
114
+
115
+ # ------------------------------------------------------------
116
+ # Save alias maps
117
+ # ------------------------------------------------------------
118
+ with open(ALIAS_PATH, "w", encoding="utf-8") as f:
119
+ json.dump(alias_maps, f, indent=2)
120
+
121
+ return {
122
+ "ok": True,
123
+ "updated_aliases": alias_updates,
124
+ "total_updates": len(alias_updates),
125
+ "alias_map_path": ALIAS_PATH,
126
+ }
training/field_weight_trainer.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training/field_weight_trainer.py
2
+ # ------------------------------------------------------------
3
+ # Stage 12A — Train Per-Field Parser Weights from Gold Tests
4
+ #
5
+ # Produces:
6
+ # data/field_weights.json
7
+ #
8
+ # This script computes reliability scores for:
9
+ # - parser_rules
10
+ # - parser_ext
11
+ # - parser_llm
12
+ #
13
+ # and outputs:
14
+ # {
15
+ # "global": { ... },
16
+ # "fields": { field -> weights },
17
+ # "meta": { ... }
18
+ # }
19
+ #
20
+ # These weights are used by parser_fusion (Stage 12B).
21
+ # ------------------------------------------------------------
22
+
23
+ from __future__ import annotations
24
+
25
+ import argparse
26
+ import json
27
+ import os
28
+ from collections import defaultdict
29
+ from dataclasses import dataclass
30
+ from typing import Any, Dict, List, Optional, Tuple
31
+
32
+ # Core parsers
33
+ from engine.parser_rules import parse_text_rules
34
+ from engine.parser_ext import parse_text_extended
35
+
36
+ # LLM parser (optional)
37
+ try:
38
+ from engine.parser_llm import parse_llm as parse_text_llm_local
39
+ except Exception:
40
+ parse_text_llm_local = None # gracefully degrade if LLM unavailable
41
+
42
+
43
+ # ------------------------------------------------------------
44
+ # Constants
45
+ # ------------------------------------------------------------
46
+
47
+ DEFAULT_GOLD_PATH = os.path.join("data", "gold_tests.json")
48
+ DEFAULT_OUT_PATH = os.path.join("data", "field_weights.json")
49
+
50
+ MISSING_PENALTY = 0.5
51
+ SMOOTHING = 1e-3
52
+
53
+
54
+ # ------------------------------------------------------------
55
+ # Data Structures
56
+ # ------------------------------------------------------------
57
+
58
+ @dataclass
59
+ class ParserOutcome:
60
+ prediction: Optional[str]
61
+ correct: bool
62
+ wrong: bool
63
+ missing: bool
64
+
65
+
66
+ @dataclass
67
+ class FieldStats:
68
+ correct: int = 0
69
+ wrong: int = 0
70
+ missing: int = 0
71
+
72
+ def total(self) -> int:
73
+ return self.correct + self.wrong + self.missing
74
+
75
+ def score(self, missing_penalty: float = MISSING_PENALTY) -> float:
76
+ if self.total() == 0:
77
+ return 0.0
78
+ denom = self.correct + self.wrong + missing_penalty * self.missing
79
+ if denom == 0:
80
+ return 0.0
81
+ return self.correct / denom
82
+
83
+
84
+ # ------------------------------------------------------------
85
+ # Gold Loading
86
+ # ------------------------------------------------------------
87
+
88
+ def _load_gold_tests(path: str) -> List[Dict[str, Any]]:
89
+ if not os.path.exists(path):
90
+ raise FileNotFoundError(f"Gold tests not found: {path}")
91
+ with open(path, "r", encoding="utf-8") as f:
92
+ data = json.load(f)
93
+ if not isinstance(data, list):
94
+ raise ValueError("gold_tests.json must be a list")
95
+ return data
96
+
97
+
98
+ def _extract_text_and_expected(test_obj: Dict[str, Any]) -> Tuple[str, Dict[str, str]]:
99
+ text = (
100
+ test_obj.get("text")
101
+ or test_obj.get("description")
102
+ or test_obj.get("input")
103
+ or test_obj.get("raw")
104
+ or ""
105
+ )
106
+ if not isinstance(text, str):
107
+ text = str(text)
108
+
109
+ expected: Dict[str, str] = {}
110
+
111
+ if isinstance(test_obj.get("expected"), dict):
112
+ for k, v in test_obj["expected"].items():
113
+ expected[str(k)] = str(v)
114
+ return text, expected
115
+
116
+ if isinstance(test_obj.get("expected_core"), dict):
117
+ for k, v in test_obj["expected_core"].items():
118
+ expected[str(k)] = str(v)
119
+
120
+ if isinstance(test_obj.get("expected_extended"), dict):
121
+ for k, v in test_obj["expected_extended"].items():
122
+ expected[str(k)] = str(v)
123
+
124
+ return text, expected
125
+
126
+
127
+ # ------------------------------------------------------------
128
+ # Parser Execution
129
+ # ------------------------------------------------------------
130
+
131
+ def _get_parser_predictions(text: str, include_llm: bool = True) -> Dict[str, Dict[str, str]]:
132
+ results: Dict[str, Dict[str, str]] = {}
133
+
134
+ r = parse_text_rules(text)
135
+ results["rules"] = dict(r.get("parsed_fields", {}))
136
+
137
+ e = parse_text_extended(text)
138
+ results["extended"] = dict(e.get("parsed_fields", {}))
139
+
140
+ llm_values: Dict[str, str] = {}
141
+ if include_llm and parse_text_llm_local is not None:
142
+ try:
143
+ llm_out = parse_text_llm_local(text)
144
+ llm_values = dict(llm_out.get("parsed_fields", {}))
145
+ except Exception:
146
+ llm_values = {}
147
+ results["llm"] = llm_values
148
+
149
+ return results
150
+
151
+
152
+ def _outcome_for_field(expected_val: str, predicted_val: Optional[str]) -> ParserOutcome:
153
+ if predicted_val is None:
154
+ return ParserOutcome(prediction=None, correct=False, wrong=False, missing=True)
155
+ if predicted_val == expected_val:
156
+ return ParserOutcome(prediction=predicted_val, correct=True, wrong=False, missing=False)
157
+ return ParserOutcome(prediction=predicted_val, correct=False, wrong=True, missing=False)
158
+
159
+
160
+ # ------------------------------------------------------------
161
+ # Stats Computation
162
+ # ------------------------------------------------------------
163
+
164
+ def _compute_stats_from_gold(
165
+ gold_tests: List[Dict[str, Any]],
166
+ include_llm: bool = True,
167
+ ):
168
+ field_stats = defaultdict(lambda: defaultdict(FieldStats))
169
+ global_stats = defaultdict(FieldStats)
170
+
171
+ total_samples = 0
172
+
173
+ for sample in gold_tests:
174
+ text, expected = _extract_text_and_expected(sample)
175
+ if not expected:
176
+ continue
177
+
178
+ total_samples += 1
179
+ preds = _get_parser_predictions(text, include_llm=include_llm)
180
+
181
+ for field, expected_val in expected.items():
182
+ expected_val = str(expected_val)
183
+ for parser_name in ["rules", "extended", "llm"]:
184
+ if parser_name == "llm" and not include_llm:
185
+ continue
186
+
187
+ pred_val = preds.get(parser_name, {}).get(field)
188
+
189
+ outcome = _outcome_for_field(expected_val, pred_val)
190
+
191
+ fs = field_stats[field][parser_name]
192
+ if outcome.correct:
193
+ fs.correct += 1
194
+ if outcome.wrong:
195
+ fs.wrong += 1
196
+ if outcome.missing:
197
+ fs.missing += 1
198
+
199
+ gs = global_stats[parser_name]
200
+ if outcome.correct:
201
+ gs.correct += 1
202
+ if outcome.wrong:
203
+ gs.wrong += 1
204
+ if outcome.missing:
205
+ gs.missing += 1
206
+
207
+ return field_stats, global_stats, total_samples
208
+
209
+
210
+ def _normalise(weights: Dict[str, float], smoothing: float = SMOOTHING) -> Dict[str, float]:
211
+ adjusted = {k: max(smoothing, v) for k, v in weights.items()}
212
+ total = sum(adjusted.values())
213
+ if total <= 0:
214
+ n = len(adjusted)
215
+ return {k: 1.0 / n for k in adjusted}
216
+ return {k: v / total for k, v in adjusted.items()}
217
+
218
+
219
+ def _build_weights_json(
220
+ field_stats,
221
+ global_stats,
222
+ total_samples,
223
+ include_llm=True,
224
+ ):
225
+ # Global scores
226
+ raw_global = {}
227
+ for parser_name, stats in global_stats.items():
228
+ if parser_name == "llm" and not include_llm:
229
+ continue
230
+ raw_global[parser_name] = stats.score(MISSING_PENALTY)
231
+
232
+ global_weights = _normalise(raw_global)
233
+
234
+ # Per-field
235
+ fields_block = {}
236
+
237
+ for field_name, stats_dict in field_stats.items():
238
+ raw_scores = {}
239
+ total_support = 0
240
+
241
+ for parser_name, stats in stats_dict.items():
242
+ if parser_name == "llm" and not include_llm:
243
+ continue
244
+ raw_scores[parser_name] = stats.score(MISSING_PENALTY)
245
+ total_support += stats.total()
246
+
247
+ if total_support < 5:
248
+ # low support → blend global + local
249
+ local_norm = _normalise(raw_scores)
250
+ mixed = {}
251
+ for p in global_weights:
252
+ mixed[p] = 0.7 * global_weights[p] + 0.3 * local_norm.get(p, global_weights[p])
253
+ field_w = _normalise(mixed)
254
+ else:
255
+ field_w = _normalise(raw_scores)
256
+
257
+ fields_block[field_name] = {
258
+ **field_w,
259
+ "support": total_support,
260
+ }
261
+
262
+ return {
263
+ "global": global_weights,
264
+ "fields": fields_block,
265
+ "meta": {
266
+ "total_samples": total_samples,
267
+ "missing_penalty": MISSING_PENALTY,
268
+ "smoothing": SMOOTHING,
269
+ "include_llm": include_llm,
270
+ },
271
+ }
272
+
273
+
274
+ # ------------------------------------------------------------
275
+ # Public API
276
+ # ------------------------------------------------------------
277
+
278
+ def train_field_weights(
279
+ gold_path: str = DEFAULT_GOLD_PATH,
280
+ out_path: str = DEFAULT_OUT_PATH,
281
+ include_llm: bool = False,
282
+ ):
283
+ print(f"[12A] Loading gold tests: {gold_path}")
284
+ gold = _load_gold_tests(gold_path)
285
+ print(f"[12A] {len(gold)} gold samples loaded")
286
+
287
+ field_stats, global_stats, total_samples = _compute_stats_from_gold(
288
+ gold, include_llm=include_llm
289
+ )
290
+
291
+ print("[12A] Computing weights...")
292
+ weights = _build_weights_json(
293
+ field_stats, global_stats, total_samples, include_llm=include_llm
294
+ )
295
+
296
+ out_dir = os.path.dirname(out_path)
297
+ if out_dir and not os.path.exists(out_dir):
298
+ os.makedirs(out_dir, exist_ok=True)
299
+
300
+ print(f"[12A] Writing: {out_path}")
301
+ with open(out_path, "w", encoding="utf-8") as f:
302
+ json.dump(weights, f, indent=2, ensure_ascii=False)
303
+
304
+ print("[12A] Done.")
305
+ return weights
306
+
307
+
308
+ # ------------------------------------------------------------
309
+ # CLI
310
+ # ------------------------------------------------------------
311
+
312
+ def _parse_args(argv=None):
313
+ p = argparse.ArgumentParser(description="Stage 12A — Train parser weights")
314
+ p.add_argument("--gold", type=str, default=DEFAULT_GOLD_PATH)
315
+ p.add_argument("--out", type=str, default=DEFAULT_OUT_PATH)
316
+ p.add_argument("--include-llm", action="store_true")
317
+ return p.parse_args(argv)
318
+
319
+
320
+ def main(argv=None):
321
+ args = _parse_args(argv)
322
+ train_field_weights(
323
+ gold_path=args.gold,
324
+ out_path=args.out,
325
+ include_llm=args.include_llm,
326
+ )
327
+
328
+
329
+ if __name__ == "__main__":
330
+ main()
training/gold_tester.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training/gold_tester.py
2
+ # ------------------------------------------------------------
3
+ # Stage 10A: Evaluate parsers on gold tests.
4
+ # This MUST NOT crash during import.
5
+ # ------------------------------------------------------------
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import os
11
+ from typing import Dict, Any, List
12
+
13
+ from engine.parser_rules import parse_text_rules
14
+ from engine.parser_ext import parse_text_extended
15
+
16
+
17
+ GOLD_PATH = "training/gold_tests.json"
18
+ REPORT_DIR = "reports"
19
+
20
+
21
+ def _load_gold_tests() -> List[Dict[str, Any]]:
22
+ if not os.path.exists(GOLD_PATH):
23
+ return []
24
+ with open(GOLD_PATH, "r", encoding="utf-8") as f:
25
+ try:
26
+ data = json.load(f)
27
+ return data if isinstance(data, list) else []
28
+ except Exception:
29
+ return []
30
+
31
+
32
+ def run_gold_tests(mode: str = "rules") -> Dict[str, Any]:
33
+ gold_tests = _load_gold_tests()
34
+ if not gold_tests:
35
+ return {
36
+ "summary": {
37
+ "mode": mode,
38
+ "tests": 0,
39
+ "total_correct": 0,
40
+ "total_fields": 0,
41
+ "overall_accuracy": 0.0,
42
+ "proposals_path": "data/extended_proposals.jsonl",
43
+ }
44
+ }
45
+
46
+ os.makedirs(REPORT_DIR, exist_ok=True)
47
+
48
+ wrong_cases = []
49
+ total_correct = 0
50
+ total_fields = 0
51
+
52
+ for idx, test in enumerate(gold_tests):
53
+ text = test.get("input", "")
54
+ expected = test.get("expected", {})
55
+
56
+ if mode == "rules":
57
+ parsed = parse_text_rules(text).get("parsed_fields", {})
58
+ elif mode == "rules+extended":
59
+ rule_fields = parse_text_rules(text).get("parsed_fields", {})
60
+ ext_fields = parse_text_extended(text).get("parsed_fields", {})
61
+ parsed = {**rule_fields, **ext_fields}
62
+ else:
63
+ parsed = {}
64
+
65
+ # Compare field-by-field
66
+ correct_count = 0
67
+ for key, val in expected.items():
68
+ total_fields += 1
69
+ if key in parsed and str(parsed[key]).strip().lower() == str(val).strip().lower():
70
+ correct_count += 1
71
+
72
+ total_correct += correct_count
73
+
74
+ if correct_count < len(expected):
75
+ wrong_cases.append(idx)
76
+
77
+ accuracy = total_correct / total_fields if total_fields else 0.0
78
+
79
+ summary = {
80
+ "mode": mode,
81
+ "tests": len(gold_tests),
82
+ "total_correct": total_correct,
83
+ "total_fields": total_fields,
84
+ "overall_accuracy": accuracy,
85
+ "wrong_cases": wrong_cases,
86
+ "proposals_path": "data/extended_proposals.jsonl",
87
+ }
88
+
89
+ return {"summary": summary}
training/gold_tests.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb94485222e9733a0d530d4df8ac0f35c8f95770cc9ef44bcb1289807e0b108e
3
+ size 18563634
training/gold_trainer.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training/gold_trainer.py
2
+ # ------------------------------------------------------------
3
+ # Stage 10C — Orchestrates gold-test-driven training:
4
+ # 1) Alias trainer (DISABLED for safety)
5
+ # 2) Schema expander (safe v10C)
6
+ # 3) Signals trainer (placeholder)
7
+ #
8
+ # This file MUST successfully import and expose train_from_gold().
9
+ # ------------------------------------------------------------
10
+
11
+ from __future__ import annotations
12
+ from typing import Dict, Any
13
+
14
+ # Safe schema expander
15
+ from training.schema_expander import expand_schema
16
+
17
+ # Placeholder signals trainer
18
+ from training.signal_trainer import train_signals
19
+
20
+
21
+ def train_from_gold() -> Dict[str, Any]:
22
+ """
23
+ Runs all gold-test–driven training components (Stage 10C).
24
+
25
+ Returns a dict:
26
+ {
27
+ "alias_trainer": {...},
28
+ "schema_expander": {...},
29
+ "signals_trainer": {...}
30
+ }
31
+ """
32
+
33
+ # --------------------------------------------------------
34
+ # 1) Alias Trainer — DISABLED to avoid destructive mappings
35
+ # --------------------------------------------------------
36
+ alias_result = {
37
+ "ok": False,
38
+ "message": (
39
+ "Alias trainer is disabled in Stage 10C to prevent unsafe "
40
+ "auto-mappings. Edit data/alias_maps.json manually if needed."
41
+ ),
42
+ "alias_map_path": "data/alias_maps.json",
43
+ }
44
+
45
+ # --------------------------------------------------------
46
+ # 2) Schema Expander — Safe version
47
+ # --------------------------------------------------------
48
+ try:
49
+ schema_result = expand_schema()
50
+ except Exception as e:
51
+ schema_result = {
52
+ "ok": False,
53
+ "message": f"Schema expander crashed: {e}",
54
+ "auto_added_fields": {},
55
+ "proposed_fields": [],
56
+ "schema_path": "data/extended_schema.json",
57
+ "proposals_path": "data/extended_proposals.jsonl",
58
+ }
59
+
60
+ # --------------------------------------------------------
61
+ # 3) Signals Trainer (placeholder)
62
+ # --------------------------------------------------------
63
+ try:
64
+ signals_result = train_signals()
65
+ except Exception as e:
66
+ signals_result = {
67
+ "ok": False,
68
+ "message": f"Signal trainer crashed: {e}",
69
+ "signals_catalog_path": "data/signals_catalog.json",
70
+ }
71
+
72
+ # --------------------------------------------------------
73
+ # Combined report
74
+ # --------------------------------------------------------
75
+ return {
76
+ "alias_trainer": alias_result,
77
+ "schema_expander": schema_result,
78
+ "signals_trainer": signals_result,
79
+ }
training/hf_sync.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training/hf_sync.py
2
+ # ------------------------------------------------------------
3
+ # Sync updated data files back to the SAME Hugging Face Space.
4
+ # ------------------------------------------------------------
5
+
6
+ import os
7
+ from typing import List, Dict, Any
8
+
9
+ from huggingface_hub import HfApi, CommitOperationAdd
10
+
11
+
12
+ def push_to_hf(
13
+ paths: List[str],
14
+ commit_message: str = "train: update extended schema, aliases, signals from gold tests",
15
+ ) -> Dict[str, Any]:
16
+
17
+ repo_id = os.getenv("HF_SPACE_REPO_ID")
18
+ token = os.getenv("HF_TOKEN")
19
+
20
+ if not repo_id:
21
+ return {
22
+ "ok": False,
23
+ "error": "Missing HF_SPACE_REPO_ID environment variable.",
24
+ "uploaded": [],
25
+ }
26
+
27
+ if not token:
28
+ return {
29
+ "ok": False,
30
+ "error": "Missing HF_TOKEN environment variable.",
31
+ "uploaded": [],
32
+ }
33
+
34
+ api = HfApi()
35
+ operations = []
36
+ uploaded = []
37
+
38
+ for p in paths:
39
+ if not os.path.exists(p):
40
+ continue
41
+
42
+ operations.append(
43
+ CommitOperationAdd(path_in_repo=p, path_or_fileobj=p)
44
+ )
45
+ uploaded.append(p)
46
+
47
+ if not operations:
48
+ return {
49
+ "ok": False,
50
+ "error": "No existing files to upload.",
51
+ "uploaded": [],
52
+ }
53
+
54
+ commit_info = api.create_commit(
55
+ repo_id=repo_id,
56
+ repo_type="space",
57
+ operations=operations,
58
+ commit_message=commit_message,
59
+ token=token,
60
+ )
61
+
62
+ return {
63
+ "ok": True,
64
+ "uploaded": uploaded,
65
+ "repo_id": repo_id,
66
+ "commit_message": commit_message,
67
+ "commit_url": commit_info.commit_url,
68
+ }
training/parser_eval.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training/parser_eval.py
2
+ # ------------------------------------------------------------
3
+ # Parser Evaluation (Stage 10A)
4
+ #
5
+ # This version ONLY evaluates:
6
+ # - Rule parser
7
+ # - Extended parser
8
+ #
9
+ # The LLM parser is intentionally disabled at this stage
10
+ # because alias maps and schema are not trained yet.
11
+ #
12
+ # This makes Stage 10A FAST and stable (< 3 seconds).
13
+ # ------------------------------------------------------------
14
+
15
+ import json
16
+ import os
17
+ from typing import Dict, Any
18
+
19
+ from engine.parser_rules import parse_text_rules
20
+ from engine.parser_ext import parse_text_extended
21
+
22
+
23
+ # Path to the gold tests
24
+ GOLD_PATH = "training/gold_tests.json"
25
+
26
+
27
+ def evaluate_single_test(test: Dict[str, Any]) -> Dict[str, Any]:
28
+ """
29
+ Evaluate one gold test with rules + extended parsers.
30
+ """
31
+ text = test.get("input", "")
32
+ expected = test.get("expected", {})
33
+
34
+ # Run deterministic parsers
35
+ rule_out = parse_text_rules(text).get("parsed_fields", {})
36
+ ext_out = parse_text_extended(text).get("parsed_fields", {})
37
+
38
+ # Merge rule + extended (extended overwrites rules)
39
+ merged = dict(rule_out)
40
+ for k, v in ext_out.items():
41
+ if v != "Unknown":
42
+ merged[k] = v
43
+
44
+ total = len(expected)
45
+ correct = 0
46
+ wrong = {}
47
+
48
+ for field, exp_val in expected.items():
49
+ got = merged.get(field, "Unknown")
50
+ if got.lower() == exp_val.lower():
51
+ correct += 0 if exp_val == "Unknown" else 1 # Unknown is neutral
52
+ else:
53
+ wrong[field] = {"expected": exp_val, "got": got}
54
+
55
+ return {
56
+ "correct": correct,
57
+ "total": total,
58
+ "accuracy": correct / total if total else 0,
59
+ "wrong": wrong,
60
+ "merged": merged,
61
+ }
62
+
63
+
64
+ def run_parser_eval(mode: str = "rules_extended") -> Dict[str, Any]:
65
+ """
66
+ Evaluate ALL gold tests using rules + extended parsing only.
67
+ """
68
+ if not os.path.exists(GOLD_PATH):
69
+ return {"error": f"Gold test file not found at {GOLD_PATH}"}
70
+
71
+ with open(GOLD_PATH, "r", encoding="utf-8") as f:
72
+ gold = json.load(f)
73
+
74
+ results = []
75
+ wrong_cases = []
76
+
77
+ total_correct = 0
78
+ total_fields = 0
79
+
80
+ for test in gold:
81
+ out = evaluate_single_test(test)
82
+ results.append(out)
83
+
84
+ total_correct += out["correct"]
85
+ total_fields += out["total"]
86
+
87
+ if out["wrong"]:
88
+ wrong_cases.append({
89
+ "name": test.get("name", "Unnamed"),
90
+ "wrong": out["wrong"],
91
+ "parsed": out["merged"],
92
+ "expected": test.get("expected", {})
93
+ })
94
+
95
+ summary = {
96
+ "mode": "rules+extended",
97
+ "tests": len(gold),
98
+ "total_correct": total_correct,
99
+ "total_fields": total_fields,
100
+ "overall_accuracy": total_correct / total_fields if total_fields else 0,
101
+ "wrong_cases": wrong_cases,
102
+ }
103
+
104
+ return summary
training/rag_index_builder.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training/rag_index_builder.py
2
+ # ============================================================
3
+ # Build RAG index from JSON knowledge base (SECTION-AWARE)
4
+ #
5
+ # - Walks data/rag/knowledge_base/<Genus>/
6
+ # - Reads genus.json + species JSONs
7
+ # - Converts JSON → structured SECTION records
8
+ # - Computes embeddings via rag.rag_embedder.embed_texts
9
+ # - Writes index to data/rag/index/kb_index.json
10
+ #
11
+ # Output record schema (LOCKED):
12
+ # {
13
+ # "id": "Enterobacter|cloacae|species_markers|0",
14
+ # "level": "genus" | "species",
15
+ # "genus": "Enterobacter",
16
+ # "species": "cloacae" | null,
17
+ # "section": "...",
18
+ # "role": "...",
19
+ # "text": "...",
20
+ # "source_file": "...",
21
+ # "chunk_id": 0,
22
+ # "embedding": [...]
23
+ # }
24
+ #
25
+ # NOTE:
26
+ # We keep the locked keys above. We MAY add extra keys (non-breaking),
27
+ # e.g. "field_key" to support future scoring/weighting.
28
+ # ============================================================
29
+
30
+ from __future__ import annotations
31
+
32
+ import json
33
+ import os
34
+ import re
35
+ from typing import Dict, Any, List, Tuple, Optional
36
+
37
+ from rag.rag_embedder import embed_texts, EMBEDDING_MODEL_NAME
38
+
39
+ KB_ROOT = os.path.join("data", "rag", "knowledge_base")
40
+ INDEX_DIR = os.path.join("data", "rag", "index")
41
+ INDEX_PATH = os.path.join(INDEX_DIR, "kb_index.json")
42
+
43
+ # Chunk size is per-section. This should generally be smaller than the generator
44
+ # prompt chunk budget so retriever can pick "tight" context blocks.
45
+ DEFAULT_MAX_CHARS = int(os.getenv("BACTAI_RAG_CHUNK_MAX_CHARS", "1100"))
46
+
47
+ # ------------------------------------------------------------
48
+ # TEXT HELPERS
49
+ # ------------------------------------------------------------
50
+
51
+ def _norm_str(x: Any) -> str:
52
+ return str(x).strip() if x is not None else ""
53
+
54
+ def _safe_join(items: List[str], sep: str = " ") -> str:
55
+ return sep.join([s for s in items if s])
56
+
57
+ def _bullet_lines(items: List[str], prefix: str = "- ") -> str:
58
+ clean = [i.strip() for i in items if isinstance(i, str) and i.strip()]
59
+ if not clean:
60
+ return ""
61
+ return "\n".join(prefix + c for c in clean)
62
+
63
+ def _title_case_field(field_name: str) -> str:
64
+ # Keep parser field names stable (don’t “prettify” them incorrectly)
65
+ return field_name.strip()
66
+
67
+ def _format_expected_fields(expected_fields: Dict[str, Any]) -> str:
68
+ """
69
+ Turn your expected_fields into a compact, self-contained key:value block.
70
+ Handles strings, lists, and simple scalars.
71
+ """
72
+ if not isinstance(expected_fields, dict) or not expected_fields:
73
+ return ""
74
+
75
+ lines: List[str] = []
76
+ for k in sorted(expected_fields.keys(), key=lambda s: str(s).lower()):
77
+ key = _title_case_field(str(k))
78
+ v = expected_fields.get(k)
79
+
80
+ if isinstance(v, list):
81
+ vals = [str(x).strip() for x in v if str(x).strip()]
82
+ if vals:
83
+ lines.append(f"{key}: " + "; ".join(vals))
84
+ else:
85
+ lines.append(f"{key}: Unknown")
86
+ else:
87
+ val = _norm_str(v) or "Unknown"
88
+ lines.append(f"{key}: {val}")
89
+
90
+ return "\n".join(lines)
91
+
92
+ def _as_list(v: Any) -> List[str]:
93
+ if isinstance(v, list):
94
+ return [str(x).strip() for x in v if str(x).strip()]
95
+ if isinstance(v, str) and v.strip():
96
+ return [v.strip()]
97
+ if v is None:
98
+ return []
99
+ s = str(v).strip()
100
+ return [s] if s else []
101
+
102
+ def _is_unknown(v: str) -> bool:
103
+ return (v or "").strip().lower() in {"unknown", "not specified", "n/a", "na", ""}
104
+
105
+ def _expected_fields_to_sentences(
106
+ expected_fields: Dict[str, Any],
107
+ *,
108
+ subject: str,
109
+ ) -> str:
110
+ """
111
+ Convert expected_fields into DECLARATIVE microbiology statements.
112
+ This is the key fix for "Not specified" RAG outputs:
113
+ LLMs treat these as evidence-like assertions rather than schema metadata.
114
+ """
115
+ if not isinstance(expected_fields, dict) or not expected_fields:
116
+ return ""
117
+
118
+ # Prefer these first (front-load the most diagnostic traits)
119
+ priority = [
120
+ "Gram Stain",
121
+ "Shape",
122
+ "Oxygen Requirement",
123
+ "Motility",
124
+ "Motility Type",
125
+ "Capsule",
126
+ "Spore Formation",
127
+ "Haemolysis",
128
+ "Haemolysis Type",
129
+ "Oxidase",
130
+ "Catalase",
131
+ "Indole",
132
+ "Urease",
133
+ "Citrate",
134
+ "Methyl Red",
135
+ "VP",
136
+ "H2S",
137
+ "ONPG",
138
+ "Nitrate Reduction",
139
+ "NaCl Tolerant (>=6%)",
140
+ "Growth Temperature",
141
+ "Media Grown On",
142
+ "Colony Morphology",
143
+ "Colony Pattern",
144
+ "Pigment",
145
+ "TSI Pattern",
146
+ "Gas Production",
147
+ ]
148
+
149
+ # Then everything else, stable order
150
+ all_keys = list(expected_fields.keys())
151
+ ordered = []
152
+ seen = set()
153
+ for k in priority:
154
+ if k in expected_fields:
155
+ ordered.append(k)
156
+ seen.add(k)
157
+ for k in sorted(all_keys, key=lambda s: str(s).lower()):
158
+ if k not in seen:
159
+ ordered.append(k)
160
+ seen.add(k)
161
+
162
+ lines: List[str] = []
163
+ subj = subject.strip() or "This organism"
164
+
165
+ for k in ordered:
166
+ key = _title_case_field(str(k))
167
+ raw = expected_fields.get(k)
168
+
169
+ if isinstance(raw, list):
170
+ vals = [x for x in _as_list(raw) if not _is_unknown(x)]
171
+ if not vals:
172
+ continue
173
+
174
+ # Special handling for list-like fields
175
+ if key == "Media Grown On":
176
+ lines.append(f"{subj} can grow on: " + ", ".join(vals) + ".")
177
+ elif key == "Colony Morphology":
178
+ lines.append(f"{subj} colonies are described as: " + ", ".join(vals) + ".")
179
+ else:
180
+ lines.append(f"{subj} {key} includes: " + ", ".join(vals) + ".")
181
+ continue
182
+
183
+ val = _norm_str(raw)
184
+ if _is_unknown(val):
185
+ continue
186
+
187
+ # Field-specific phrasing for better “evidence-like” feel
188
+ if key == "Gram Stain":
189
+ lines.append(f"{subj} is typically Gram {val}.")
190
+ elif key == "Shape":
191
+ lines.append(f"{subj} typically has shape: {val}.")
192
+ elif key == "Oxygen Requirement":
193
+ lines.append(f"{subj} is typically {val}.")
194
+ elif key == "Growth Temperature":
195
+ lines.append(f"{subj} typically grows within: {val} °C.")
196
+ elif key == "Haemolysis Type":
197
+ lines.append(f"{subj} haemolysis type is typically: {val}.")
198
+ elif key == "Haemolysis":
199
+ lines.append(f"{subj} haemolysis is typically: {val}.")
200
+ elif key == "Pigment":
201
+ if val.lower() in {"none", "no", "negative"}:
202
+ lines.append(f"{subj} typically produces no pigment.")
203
+ else:
204
+ lines.append(f"{subj} may produce pigment: {val}.")
205
+ elif key == "Colony Pattern":
206
+ lines.append(f"{subj} colony/cellular pattern may be described as: {val}.")
207
+ else:
208
+ # Default: simple assertive sentence
209
+ lines.append(f"{subj} {key} is typically: {val}.")
210
+
211
+ # If we emitted nothing, return empty so we don’t add noise
212
+ return "\n".join(lines).strip()
213
+
214
+ def _format_key_differentiators(items: List[Dict[str, Any]]) -> str:
215
+ """
216
+ For genus-level key_differentiators.
217
+ """
218
+ if not isinstance(items, list) or not items:
219
+ return ""
220
+ out: List[str] = []
221
+ for obj in items:
222
+ if not isinstance(obj, dict):
223
+ continue
224
+ field = _norm_str(obj.get("field"))
225
+ expected = _norm_str(obj.get("expected"))
226
+ notes = _norm_str(obj.get("notes"))
227
+ distinguishes_from = obj.get("distinguishes_from") or []
228
+ if not field:
229
+ continue
230
+
231
+ line = f"{field}: expected {expected or 'Unknown'}."
232
+ if isinstance(distinguishes_from, list) and distinguishes_from:
233
+ line += " Distinguishes from: " + ", ".join([_norm_str(x) for x in distinguishes_from if _norm_str(x)])
234
+ if not line.endswith("."):
235
+ line += "."
236
+ if notes:
237
+ line += f" Notes: {notes}"
238
+ if not line.endswith("."):
239
+ line += "."
240
+ out.append(line)
241
+
242
+ return "\n".join(out)
243
+
244
+ def _format_common_confusions(items: List[Dict[str, Any]], level: str) -> str:
245
+ """
246
+ For genus/species common_confusions.
247
+ """
248
+ if not isinstance(items, list) or not items:
249
+ return ""
250
+ out: List[str] = []
251
+ for obj in items:
252
+ if not isinstance(obj, dict):
253
+ continue
254
+ reason = _norm_str(obj.get("reason"))
255
+ if level == "genus":
256
+ who = _norm_str(obj.get("genus"))
257
+ if who:
258
+ out.append(f"{who}: {reason or 'Reason not specified.'}")
259
+ else:
260
+ who = _norm_str(obj.get("species")) or _norm_str(obj.get("genus"))
261
+ if who:
262
+ out.append(f"{who}: {reason or 'Reason not specified.'}")
263
+ return "\n".join(out)
264
+
265
+ def _format_recommended_next_tests(items: List[Dict[str, Any]]) -> str:
266
+ """
267
+ For recommended_next_tests with optional API kit note.
268
+ """
269
+ if not isinstance(items, list) or not items:
270
+ return ""
271
+ out: List[str] = []
272
+ for obj in items:
273
+ if not isinstance(obj, dict):
274
+ continue
275
+ test = _norm_str(obj.get("test"))
276
+ reason = _norm_str(obj.get("reason"))
277
+ api_kit = _norm_str(obj.get("api_kit"))
278
+
279
+ if not test:
280
+ continue
281
+
282
+ line = f"{test}"
283
+ if api_kit:
284
+ line += f" (API kit: {api_kit})"
285
+ if reason:
286
+ line += f": {reason}"
287
+ out.append(line)
288
+ return "\n".join(out)
289
+
290
+ # ------------------------------------------------------------
291
+ # CHUNKING (SECTION-LOCAL)
292
+ # ------------------------------------------------------------
293
+
294
+ def chunk_text_by_paragraph(text: str, max_chars: int = DEFAULT_MAX_CHARS) -> List[str]:
295
+ """
296
+ Chunk within a single section. We never merge different sections together.
297
+ """
298
+ text = (text or "").strip()
299
+ if not text:
300
+ return []
301
+
302
+ if len(text) <= max_chars:
303
+ return [text]
304
+
305
+ paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
306
+ if not paras:
307
+ paras = [l.strip() for l in text.splitlines() if l.strip()]
308
+
309
+ chunks: List[str] = []
310
+ current = ""
311
+
312
+ for p in paras:
313
+ candidate = (current + "\n\n" + p).strip() if current else p
314
+ if len(candidate) <= max_chars:
315
+ current = candidate
316
+ else:
317
+ if current:
318
+ chunks.append(current)
319
+ if len(p) <= max_chars:
320
+ current = p
321
+ else:
322
+ for i in range(0, len(p), max_chars):
323
+ chunks.append(p[i:i + max_chars].strip())
324
+ current = ""
325
+
326
+ if current:
327
+ chunks.append(current)
328
+
329
+ return [c for c in chunks if c.strip()]
330
+
331
+ # ------------------------------------------------------------
332
+ # SECTION EMITTERS
333
+ # ------------------------------------------------------------
334
+
335
+ def emit_genus_sections(doc: Dict[str, Any], genus: str) -> List[Dict[str, Any]]:
336
+ """
337
+ Convert genus.json to a list of {section, role, text} entries.
338
+ """
339
+ out: List[Dict[str, Any]] = []
340
+
341
+ overview = doc.get("overview") or {}
342
+ if isinstance(overview, dict):
343
+ short = _norm_str(overview.get("short"))
344
+ clinical = _norm_str(overview.get("clinical_context"))
345
+ if short:
346
+ out.append({"section": "overview", "role": "description", "text": f"Genus {genus}: {short}"})
347
+ if clinical:
348
+ out.append({"section": "overview", "role": "description", "text": f"Clinical context: {clinical}"})
349
+
350
+ expected_fields = doc.get("expected_fields")
351
+ if isinstance(expected_fields, dict) and expected_fields:
352
+ # 1) Declarative evidence-like sentences (NEW)
353
+ sent = _expected_fields_to_sentences(expected_fields, subject=f"Genus {genus}")
354
+ if sent:
355
+ out.append({
356
+ "section": "expected_profile_sentences",
357
+ "role": "expected_profile",
358
+ "text": sent,
359
+ })
360
+
361
+ # 2) Keep original key:value block (still useful)
362
+ text = _format_expected_fields(expected_fields)
363
+ if text:
364
+ out.append({
365
+ "section": "expected_fields",
366
+ "role": "expected_profile",
367
+ "text": f"Expected fields for genus {genus}:\n{text}",
368
+ })
369
+
370
+ field_notes = doc.get("field_notes")
371
+ if isinstance(field_notes, dict) and field_notes:
372
+ lines: List[str] = []
373
+ for k in sorted(field_notes.keys(), key=lambda s: str(s).lower()):
374
+ v = _norm_str(field_notes.get(k))
375
+ if v:
376
+ lines.append(f"{_title_case_field(str(k))}: {v}")
377
+ if lines:
378
+ out.append({"section": "field_notes", "role": "clarification", "text": "Field notes:\n" + "\n".join(lines)})
379
+
380
+ kd = doc.get("key_differentiators")
381
+ if isinstance(kd, list) and kd:
382
+ text = _format_key_differentiators(kd)
383
+ if text:
384
+ out.append({"section": "key_differentiators", "role": "differentiation", "text": "Key differentiators:\n" + text})
385
+
386
+ conf = doc.get("common_confusions")
387
+ if isinstance(conf, list) and conf:
388
+ text = _format_common_confusions(conf, level="genus")
389
+ if text:
390
+ out.append({"section": "common_confusions", "role": "warning", "text": "Common confusions:\n" + text})
391
+
392
+ wq = doc.get("when_to_question_identification")
393
+ if isinstance(wq, list) and wq:
394
+ lines = [str(x).strip() for x in wq if str(x).strip()]
395
+ if lines:
396
+ out.append({"section": "when_to_question_identification", "role": "warning", "text": "When to question identification:\n" + _bullet_lines(lines)})
397
+
398
+ rnt = doc.get("recommended_next_tests")
399
+ if isinstance(rnt, list) and rnt:
400
+ text = _format_recommended_next_tests(rnt)
401
+ if text:
402
+ out.append({"section": "recommended_next_tests", "role": "recommendation", "text": "Recommended next tests:\n" + text})
403
+
404
+ ss = doc.get("supported_species")
405
+ if isinstance(ss, list) and ss:
406
+ species_list = [str(x).strip() for x in ss if str(x).strip()]
407
+ if species_list:
408
+ out.append({"section": "supported_species", "role": "metadata", "text": f"Supported species for genus {genus}: " + ", ".join(species_list)})
409
+
410
+ return out
411
+
412
+
413
+ def emit_species_sections(doc: Dict[str, Any], genus: str, species: str) -> List[Dict[str, Any]]:
414
+ """
415
+ Convert a species JSON to a list of {section, role, text} entries.
416
+ """
417
+ out: List[Dict[str, Any]] = []
418
+ overview = doc.get("overview") or {}
419
+ if isinstance(overview, dict):
420
+ short = _norm_str(overview.get("short"))
421
+ clinical = _norm_str(overview.get("clinical_context"))
422
+ if short:
423
+ out.append({"section": "overview", "role": "description", "text": f"Species {genus} {species}: {short}"})
424
+ if clinical:
425
+ out.append({"section": "overview", "role": "description", "text": f"Clinical context: {clinical}"})
426
+
427
+ expected_fields = doc.get("expected_fields")
428
+ if isinstance(expected_fields, dict) and expected_fields:
429
+ # 1) Declarative evidence-like sentences (NEW)
430
+ sent = _expected_fields_to_sentences(expected_fields, subject=f"Species {genus} {species}")
431
+ if sent:
432
+ out.append({
433
+ "section": "expected_profile_sentences",
434
+ "role": "expected_profile",
435
+ "text": sent,
436
+ })
437
+
438
+ # 2) Keep original key:value block
439
+ text = _format_expected_fields(expected_fields)
440
+ if text:
441
+ out.append({"section": "expected_fields", "role": "expected_profile", "text": f"Expected fields for species {genus} {species}:\n{text}"})
442
+
443
+ markers = doc.get("species_markers")
444
+ if isinstance(markers, list) and markers:
445
+ lines: List[str] = []
446
+ for m in markers:
447
+ if not isinstance(m, dict):
448
+ continue
449
+ field = _norm_str(m.get("field"))
450
+ val = _norm_str(m.get("value"))
451
+ importance = _norm_str(m.get("importance"))
452
+ notes = _norm_str(m.get("notes"))
453
+ if not field:
454
+ continue
455
+ line = f"{field}: {val or 'Unknown'}"
456
+ if importance:
457
+ line += f" (importance: {importance})"
458
+ if notes:
459
+ line += f" — {notes}"
460
+ lines.append(line)
461
+ if lines:
462
+ out.append({"section": "species_markers", "role": "species_marker", "text": "Species markers:\n" + "\n".join(lines)})
463
+
464
+ conf = doc.get("common_confusions")
465
+ if isinstance(conf, list) and conf:
466
+ text = _format_common_confusions(conf, level="species")
467
+ if text:
468
+ out.append({"section": "common_confusions", "role": "warning", "text": "Common confusions:\n" + text})
469
+
470
+ wq = doc.get("when_to_question_identification")
471
+ if isinstance(wq, list) and wq:
472
+ lines = [str(x).strip() for x in wq if str(x).strip()]
473
+ if lines:
474
+ out.append({"section": "when_to_question_identification", "role": "warning", "text": "When to question identification:\n" + _bullet_lines(lines)})
475
+
476
+ rnt = doc.get("recommended_next_tests")
477
+ if isinstance(rnt, list) and rnt:
478
+ text = _format_recommended_next_tests(rnt)
479
+ if text:
480
+ out.append({"section": "recommended_next_tests", "role": "recommendation", "text": "Recommended next tests:\n" + text})
481
+
482
+ return out
483
+
484
+
485
+ # ------------------------------------------------------------
486
+ # INDEX BUILD
487
+ # ------------------------------------------------------------
488
+
489
+ def _iter_kb_files() -> List[Tuple[str, str]]:
490
+ entries: List[Tuple[str, str]] = []
491
+ if not os.path.isdir(KB_ROOT):
492
+ return entries
493
+
494
+ for genus in sorted(os.listdir(KB_ROOT)):
495
+ genus_dir = os.path.join(KB_ROOT, genus)
496
+ if not os.path.isdir(genus_dir):
497
+ continue
498
+ for fname in sorted(os.listdir(genus_dir)):
499
+ if fname.lower().endswith(".json"):
500
+ entries.append((genus, os.path.join(genus_dir, fname)))
501
+ return entries
502
+
503
+
504
+ def build_rag_index(max_chars: int = DEFAULT_MAX_CHARS) -> Dict[str, Any]:
505
+ os.makedirs(INDEX_DIR, exist_ok=True)
506
+
507
+ kb_entries = _iter_kb_files()
508
+ if not kb_entries:
509
+ return {"ok": False, "message": "No KB JSON files found."}
510
+
511
+ docs_for_embedding: List[str] = []
512
+ meta: List[Dict[str, Any]] = []
513
+
514
+ num_json_errors = 0
515
+
516
+ for genus_dir_name, path in kb_entries:
517
+ with open(path, "r", encoding="utf-8") as f:
518
+ try:
519
+ doc = json.load(f)
520
+ except json.JSONDecodeError as e:
521
+ print(f"[rag_index_builder] JSON error in {path}: {e}")
522
+ num_json_errors += 1
523
+ continue
524
+
525
+ fname = os.path.basename(path)
526
+ is_genus = fname == "genus.json"
527
+
528
+ genus = _norm_str(doc.get("genus")) or genus_dir_name
529
+ level = "genus" if is_genus else "species"
530
+
531
+ species: Optional[str]
532
+ if is_genus:
533
+ species = None
534
+ sections = emit_genus_sections(doc, genus=genus)
535
+ else:
536
+ species = _norm_str(doc.get("species")) or os.path.splitext(fname)[0]
537
+ sections = emit_species_sections(doc, genus=genus, species=species)
538
+
539
+ for sec in sections:
540
+ section = _norm_str(sec.get("section"))
541
+ role = _norm_str(sec.get("role"))
542
+ text = _norm_str(sec.get("text"))
543
+
544
+ if not section or not role or not text:
545
+ continue
546
+
547
+ chunks = chunk_text_by_paragraph(text, max_chars=max_chars)
548
+ for idx, chunk in enumerate(chunks):
549
+ if not chunk.strip():
550
+ continue
551
+
552
+ rec_id = f"{genus}|{species or 'GENUS'}|{section}|{idx}"
553
+
554
+ docs_for_embedding.append(chunk)
555
+ meta.append(
556
+ {
557
+ "id": rec_id,
558
+ "level": level,
559
+ "genus": genus,
560
+ "species": species,
561
+ "section": section,
562
+ "role": role,
563
+ "text": chunk,
564
+ "source_file": os.path.relpath(path),
565
+ "chunk_id": idx,
566
+ # Optional: helps later for field-level weighting
567
+ "field_key": None,
568
+ }
569
+ )
570
+
571
+ if not docs_for_embedding:
572
+ return {
573
+ "ok": False,
574
+ "message": "No valid sections emitted from KB JSON files. Check schema/contents.",
575
+ "num_files": len(kb_entries),
576
+ "num_json_errors": num_json_errors,
577
+ }
578
+
579
+ embeddings = embed_texts(docs_for_embedding, normalize=True)
580
+
581
+ index_records: List[Dict[str, Any]] = []
582
+ for m, emb in zip(meta, embeddings):
583
+ rec = dict(m)
584
+ rec["embedding"] = emb.tolist()
585
+ index_records.append(rec)
586
+
587
+ with open(INDEX_PATH, "w", encoding="utf-8") as f:
588
+ json.dump(
589
+ {
590
+ "version": 2,
591
+ "model_name": EMBEDDING_MODEL_NAME,
592
+ "record_schema": {
593
+ "id": "str",
594
+ "level": "genus|species",
595
+ "genus": "str",
596
+ "species": "str|null",
597
+ "section": "str",
598
+ "role": "str",
599
+ "text": "str",
600
+ "source_file": "str",
601
+ "chunk_id": "int",
602
+ "embedding": "list[float]",
603
+ },
604
+ "stats": {
605
+ "num_files": len(kb_entries),
606
+ "num_records": len(index_records),
607
+ "num_json_errors": num_json_errors,
608
+ "chunk_max_chars": max_chars,
609
+ },
610
+ "records": index_records,
611
+ },
612
+ f,
613
+ ensure_ascii=False,
614
+ )
615
+
616
+ return {
617
+ "ok": True,
618
+ "message": "RAG index built successfully (section-aware, declarative expected profiles).",
619
+ "index_path": INDEX_PATH,
620
+ "num_records": len(index_records),
621
+ "num_files": len(kb_entries),
622
+ "num_json_errors": num_json_errors,
623
+ "chunk_max_chars": max_chars,
624
+ }
625
+
626
+
627
+ if __name__ == "__main__":
628
+ summary = build_rag_index()
629
+ print(json.dumps(summary, indent=2))
training/schema_expander.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training/schema_expander.py
2
+ # ------------------------------------------------------------
3
+ # Stage 10C — SAFE schema expansion
4
+ #
5
+ # Core fields = EXACT columns in bacteria_db.xlsx.
6
+ # Extended fields = ONLY the ones NOT in DB and NOT in existing schema.
7
+ #
8
+ # This version:
9
+ # - NEVER adds core fields to extended schema.
10
+ # - Only adds true extended fields found in gold tests.
11
+ # - Logs ambiguous or rare fields to proposals file.
12
+ # - Reports field frequencies & values seen for debugging.
13
+ # ------------------------------------------------------------
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ import json
19
+ from typing import Dict, Any, List
20
+ from collections import Counter
21
+ from datetime import datetime
22
+
23
+ import pandas as pd
24
+
25
+ from engine.schema import (
26
+ load_extended_schema,
27
+ save_extended_schema,
28
+ )
29
+
30
+ # ------------------------------------------------------------
31
+ # Paths
32
+ # ------------------------------------------------------------
33
+
34
+ GOLD_PATH = "training/gold_tests.json"
35
+ EXTENDED_SCHEMA_PATH = "data/extended_schema.json"
36
+ PROPOSALS_PATH = "data/extended_proposals.jsonl"
37
+
38
+ # Minimum frequency before auto-adding a new extended field
39
+ MIN_FIELD_FREQ = 5
40
+
41
+
42
+ # ------------------------------------------------------------
43
+ # Helper: load gold tests
44
+ # ------------------------------------------------------------
45
+
46
+ def _load_gold_tests() -> List[Dict[str, Any]]:
47
+ if not os.path.exists(GOLD_PATH):
48
+ return []
49
+ with open(GOLD_PATH, "r", encoding="utf-8") as f:
50
+ try:
51
+ data = json.load(f)
52
+ return data if isinstance(data, list) else []
53
+ except Exception:
54
+ return []
55
+
56
+
57
+ # ------------------------------------------------------------
58
+ # Helper: load DB columns (TRUE core schema)
59
+ # ------------------------------------------------------------
60
+
61
+ def _load_db_columns() -> List[str]:
62
+ candidates = [
63
+ os.path.join("data", "bacteria_db.xlsx"),
64
+ "bacteria_db.xlsx",
65
+ ]
66
+ for p in candidates:
67
+ if os.path.exists(p):
68
+ try:
69
+ df = pd.read_excel(p)
70
+ return [c.strip() for c in df.columns]
71
+ except Exception:
72
+ continue
73
+ return []
74
+
75
+
76
+ # ------------------------------------------------------------
77
+ # Decide if field name is safe for auto-adding
78
+ # ------------------------------------------------------------
79
+
80
+ def _is_safe_field_name(name: str) -> bool:
81
+ n = name.strip()
82
+ if not n:
83
+ return False
84
+
85
+ low = n.lower()
86
+
87
+ # Ignore extremely short or generic names
88
+ if len(n) < 4:
89
+ return False
90
+ if low in {"test", "growth", "acid", "base", "value", "result"}:
91
+ return False
92
+
93
+ # Clear biochemical patterns
94
+ patterns = [
95
+ "hydrolysis",
96
+ "fermentation",
97
+ "decarboxylase",
98
+ "dihydrolase",
99
+ "reduction",
100
+ "utilization",
101
+ "tolerance",
102
+ "solubility",
103
+ "oxidation",
104
+ "lysis",
105
+ "susceptibility",
106
+ "resistance",
107
+ "pyruvate",
108
+ "lecithinase",
109
+ "lipase",
110
+ "casein",
111
+ "hippurate",
112
+ "tyrosine",
113
+ ]
114
+ if any(pat in low for pat in patterns):
115
+ return True
116
+
117
+ # Known short disc tests
118
+ known_short = {"CAMP", "PYR", "Optochin", "Bacitracin", "Novobiocin"}
119
+ if n in known_short:
120
+ return True
121
+
122
+ # If contains "test" and more than one word → likely legitimate
123
+ if "test" in low and " " in low:
124
+ return True
125
+
126
+ return False
127
+
128
+
129
+ # ------------------------------------------------------------
130
+ # Log proposal (rare/ambiguous fields)
131
+ # ------------------------------------------------------------
132
+
133
+ def _append_proposal(record: Dict[str, Any]) -> None:
134
+ os.makedirs(os.path.dirname(PROPOSALS_PATH), exist_ok=True)
135
+ with open(PROPOSALS_PATH, "a", encoding="utf-8") as f:
136
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
137
+
138
+
139
+ # ------------------------------------------------------------
140
+ # MAIN ENTRY — SAFE SCHEMA EXPANSION
141
+ # ------------------------------------------------------------
142
+
143
+ def expand_schema() -> Dict[str, Any]:
144
+ gold = _load_gold_tests()
145
+ if not gold:
146
+ return {
147
+ "ok": False,
148
+ "message": f"No gold tests found at {GOLD_PATH}",
149
+ "auto_added_fields": {},
150
+ "proposed_fields": [],
151
+ "schema_path": EXTENDED_SCHEMA_PATH,
152
+ "proposals_path": PROPOSALS_PATH,
153
+ "unknown_fields_raw": {},
154
+ "field_frequencies": {},
155
+ }
156
+
157
+ db_columns = set(_load_db_columns()) # TRUE core schema
158
+ extended_schema = load_extended_schema(EXTENDED_SCHEMA_PATH)
159
+ extended_fields = set(extended_schema.keys())
160
+
161
+ # Counter for unknown fields
162
+ field_counts: Counter[str] = Counter()
163
+ field_values: Dict[str, Counter[str]] = {}
164
+
165
+ for test in gold:
166
+ expected = test.get("expected", {})
167
+ if not isinstance(expected, dict):
168
+ continue
169
+
170
+ for field, value in expected.items():
171
+ fname = str(field).strip()
172
+ if not fname:
173
+ continue
174
+
175
+ # Skip core DB fields
176
+ if fname in db_columns:
177
+ continue
178
+
179
+ # Skip already-known extended fields
180
+ if fname in extended_fields:
181
+ continue
182
+
183
+ # Count unknowns
184
+ field_counts[fname] += 1
185
+ if fname not in field_values:
186
+ field_values[fname] = Counter()
187
+ field_values[fname][str(value).strip()] += 1
188
+
189
+ auto_added: Dict[str, Any] = {}
190
+ proposed: List[Dict[str, Any]] = []
191
+
192
+ # Decide which unknown fields to auto-add
193
+ for fname, freq in field_counts.items():
194
+ values_seen = dict(field_values.get(fname, {}))
195
+
196
+ if freq >= MIN_FIELD_FREQ and _is_safe_field_name(fname):
197
+ # Auto-add as extended test
198
+ extended_schema[fname] = {
199
+ "value_type": "enum_PNV",
200
+ "description": "Auto-added from gold tests (Stage 10C)",
201
+ "values": list(values_seen.keys()),
202
+ }
203
+ auto_added[fname] = {
204
+ "freq": freq,
205
+ "values_seen": list(values_seen.keys()),
206
+ }
207
+ else:
208
+ # Log proposal for later review
209
+ proposed.append(
210
+ {
211
+ "field_name": fname,
212
+ "freq": freq,
213
+ "values_seen": values_seen,
214
+ }
215
+ )
216
+ _append_proposal(
217
+ {
218
+ "timestamp": datetime.utcnow().isoformat() + "Z",
219
+ "field_name": fname,
220
+ "freq": freq,
221
+ "values_seen": values_seen,
222
+ }
223
+ )
224
+
225
+ # Save updated schema
226
+ if auto_added:
227
+ save_extended_schema(extended_schema, EXTENDED_SCHEMA_PATH)
228
+
229
+ return {
230
+ "ok": True,
231
+ "auto_added_fields": auto_added,
232
+ "proposed_fields": proposed,
233
+ "schema_path": EXTENDED_SCHEMA_PATH,
234
+ "proposals_path": PROPOSALS_PATH,
235
+ "unknown_fields_raw": {f: dict(cnt) for f, cnt in field_values.items()},
236
+ "field_frequencies": dict(field_counts),
237
+ }
training/signal_trainer.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training/signal_trainer.py
2
+ # ------------------------------------------------------------
3
+ # Stage 10C placeholder:
4
+ # Safely returns a no-op result for signal training.
5
+ # This MUST NOT crash during import.
6
+ # ------------------------------------------------------------
7
+
8
+ from __future__ import annotations
9
+ from typing import Dict, Any
10
+ import json
11
+ import os
12
+
13
+
14
+ SIGNALS_PATH = "data/signals_catalog.json"
15
+
16
+
17
+ def train_signals() -> Dict[str, Any]:
18
+ """
19
+ Placeholder trainer. Does nothing except ensure signals_catalog.json exists.
20
+ Must NEVER crash.
21
+ """
22
+
23
+ # Ensure signals catalog exists
24
+ if not os.path.exists(SIGNALS_PATH):
25
+ try:
26
+ with open(SIGNALS_PATH, "w", encoding="utf-8") as f:
27
+ json.dump({}, f, indent=2, ensure_ascii=False)
28
+ except Exception:
29
+ pass
30
+
31
+ return {
32
+ "ok": True,
33
+ "message": "Signal trainer not implemented yet (Stage 10C placeholder).",
34
+ "signals_catalog_path": SIGNALS_PATH,
35
+ }