Spaces:
Running
Running
Upload 21 files
Browse files- models/genus_xgb.json +3 -0
- models/genus_xgb_meta.json +369 -0
- rag/context_shaper.py +168 -0
- rag/rag_embedder.py +112 -0
- rag/rag_generator.py +446 -0
- rag/rag_retriever.py +509 -0
- rag/species_scorer.py +314 -0
- scoring/diagnostic_anchors.py +97 -0
- scoring/overall_ranker.py +146 -0
- static/eph.jpeg +3 -0
- training/__init__.py +2 -0
- training/alias_trainer.py +126 -0
- training/field_weight_trainer.py +330 -0
- training/gold_tester.py +89 -0
- training/gold_tests.json +3 -0
- training/gold_trainer.py +79 -0
- training/hf_sync.py +68 -0
- training/parser_eval.py +104 -0
- training/rag_index_builder.py +629 -0
- training/schema_expander.py +237 -0
- training/signal_trainer.py +35 -0
models/genus_xgb.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6530346e18bdb61e778f887fef7ca33e8e0b56e040ce95287c373073db726cfb
|
| 3 |
+
size 34613851
|
models/genus_xgb_meta.json
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"genus_to_idx": {
|
| 3 |
+
"Staphylococcus": 0,
|
| 4 |
+
"Salmonella": 1,
|
| 5 |
+
"Listeria": 2,
|
| 6 |
+
"Enterobacter": 3,
|
| 7 |
+
"Pseudomonas": 4,
|
| 8 |
+
"Streptococcus": 5,
|
| 9 |
+
"Enterococcus": 6,
|
| 10 |
+
"Bacillus": 7,
|
| 11 |
+
"Shigella": 8,
|
| 12 |
+
"Escherichia": 9,
|
| 13 |
+
"Klebsiella": 10,
|
| 14 |
+
"Proteus": 11,
|
| 15 |
+
"Vibrio": 12,
|
| 16 |
+
"Neisseria": 13,
|
| 17 |
+
"Campylobacter": 14,
|
| 18 |
+
"Clostridium": 15,
|
| 19 |
+
"Corynebacterium": 16,
|
| 20 |
+
"Legionella": 17,
|
| 21 |
+
"Mycobacterium": 18,
|
| 22 |
+
"Bacteroides": 19,
|
| 23 |
+
"Micrococcus": 20,
|
| 24 |
+
"Erysipelothrix": 21,
|
| 25 |
+
"Haemophilus": 22,
|
| 26 |
+
"Aeromonas": 23,
|
| 27 |
+
"Yersinia": 24,
|
| 28 |
+
"Acinetobacter": 25,
|
| 29 |
+
"Serratia": 26,
|
| 30 |
+
"Morganella": 27,
|
| 31 |
+
"Providencia": 28,
|
| 32 |
+
"Burkholderia": 29,
|
| 33 |
+
"Helicobacter": 30,
|
| 34 |
+
"Actinomyces": 31,
|
| 35 |
+
"Nocardia": 32,
|
| 36 |
+
"Pasteurella": 33,
|
| 37 |
+
"Citrobacter": 34,
|
| 38 |
+
"Leptospira": 35,
|
| 39 |
+
"Alcaligenes": 36,
|
| 40 |
+
"Shewanella": 37,
|
| 41 |
+
"Edwardsiella": 38,
|
| 42 |
+
"Chromobacterium": 39,
|
| 43 |
+
"Lactobacillus": 40,
|
| 44 |
+
"Propionibacterium": 41,
|
| 45 |
+
"Peptostreptococcus": 42,
|
| 46 |
+
"Veillonella": 43,
|
| 47 |
+
"Fusobacterium": 44,
|
| 48 |
+
"Eubacterium": 45,
|
| 49 |
+
"Halomonas": 46,
|
| 50 |
+
"Psychrobacter": 47,
|
| 51 |
+
"Rhodococcus": 48,
|
| 52 |
+
"Mycoplasma": 49,
|
| 53 |
+
"Bordetella": 50,
|
| 54 |
+
"Stenotrophomonas": 51,
|
| 55 |
+
"Ralstonia": 52,
|
| 56 |
+
"Achromobacter": 53,
|
| 57 |
+
"Brucella": 54,
|
| 58 |
+
"Arthrobacter": 55,
|
| 59 |
+
"Flavobacterium": 56,
|
| 60 |
+
"Oerskovia": 57,
|
| 61 |
+
"Sphingomonas": 58,
|
| 62 |
+
"Comamonas": 59,
|
| 63 |
+
"Thermococcus": 60,
|
| 64 |
+
"Elizabethkingia": 61,
|
| 65 |
+
"Hafnia": 62,
|
| 66 |
+
"Raoultella": 63,
|
| 67 |
+
"Ochrobactrum": 64,
|
| 68 |
+
"Roseomonas": 65,
|
| 69 |
+
"Actinobacillus": 66,
|
| 70 |
+
"Gemella": 67,
|
| 71 |
+
"Rothia": 68,
|
| 72 |
+
"Carnobacterium": 69,
|
| 73 |
+
"Plesiomonas": 70,
|
| 74 |
+
"Janthinobacterium": 71,
|
| 75 |
+
"Paenibacillus": 72,
|
| 76 |
+
"Moraxella": 73,
|
| 77 |
+
"Aerococcus": 74,
|
| 78 |
+
"Kocuria": 75,
|
| 79 |
+
"Leuconostoc": 76,
|
| 80 |
+
"Arcanobacterium": 77,
|
| 81 |
+
"Gardnerella": 78,
|
| 82 |
+
"Porphyromonas": 79,
|
| 83 |
+
"Prevotella": 80,
|
| 84 |
+
"Pediococcus": 81,
|
| 85 |
+
"Weissella": 82,
|
| 86 |
+
"Lactococcus": 83,
|
| 87 |
+
"Microbacterium": 84,
|
| 88 |
+
"Clostridioides": 85,
|
| 89 |
+
"Cronobacter": 86,
|
| 90 |
+
"Rhizobium": 87,
|
| 91 |
+
"Azotobacter": 88,
|
| 92 |
+
"Spirillum": 89,
|
| 93 |
+
"Candida": 90,
|
| 94 |
+
"Cryptococcus": 91,
|
| 95 |
+
"Saccharomyces": 92,
|
| 96 |
+
"Rickettsia": 93,
|
| 97 |
+
"Borrelia": 94,
|
| 98 |
+
"Chlamydia": 95,
|
| 99 |
+
"Acidaminococcus": 96,
|
| 100 |
+
"Bartonella": 97,
|
| 101 |
+
"Coxiella": 98,
|
| 102 |
+
"Kingella": 99,
|
| 103 |
+
"Eikenella": 100,
|
| 104 |
+
"Bilophila": 101,
|
| 105 |
+
"Anaerococcus": 102,
|
| 106 |
+
"Finegoldia": 103,
|
| 107 |
+
"Parvimonas": 104,
|
| 108 |
+
"Ruminococcus": 105,
|
| 109 |
+
"Cutibacterium": 106,
|
| 110 |
+
"Exiguobacterium": 107,
|
| 111 |
+
"Kluyvera": 108,
|
| 112 |
+
"Pluralibacter": 109,
|
| 113 |
+
"Massilia": 110,
|
| 114 |
+
"Methylobacterium": 111,
|
| 115 |
+
"Cupriavidus": 112,
|
| 116 |
+
"Acidovorax": 113,
|
| 117 |
+
"Geobacillus": 114,
|
| 118 |
+
"Trueperella": 115,
|
| 119 |
+
"Streptomyces": 116,
|
| 120 |
+
"Thermoactinomyces": 117,
|
| 121 |
+
"Capnocytophaga": 118,
|
| 122 |
+
"Cardiobacterium": 119,
|
| 123 |
+
"Yokenella": 120,
|
| 124 |
+
"Brevibacterium": 121,
|
| 125 |
+
"Peptoniphilus": 122,
|
| 126 |
+
"Weisella": 123,
|
| 127 |
+
"Saccharopolyspora": 124,
|
| 128 |
+
"Frankia": 125,
|
| 129 |
+
"Spiroplasma": 126,
|
| 130 |
+
"Cedecea": 127,
|
| 131 |
+
"Photorhabdus": 128,
|
| 132 |
+
"Abiotrophia": 129,
|
| 133 |
+
"Cellulomonas": 130,
|
| 134 |
+
"Leifsonia": 131,
|
| 135 |
+
"Alicyclobacillus": 132,
|
| 136 |
+
"Sporolactobacillus": 133,
|
| 137 |
+
"Leclercia": 134,
|
| 138 |
+
"Kosakonia": 135,
|
| 139 |
+
"Bergeyella": 136,
|
| 140 |
+
"Myroides": 137,
|
| 141 |
+
"Aggregatibacter": 138,
|
| 142 |
+
":": 139
|
| 143 |
+
},
|
| 144 |
+
"idx_to_genus": {
|
| 145 |
+
"0": "Staphylococcus",
|
| 146 |
+
"1": "Salmonella",
|
| 147 |
+
"2": "Listeria",
|
| 148 |
+
"3": "Enterobacter",
|
| 149 |
+
"4": "Pseudomonas",
|
| 150 |
+
"5": "Streptococcus",
|
| 151 |
+
"6": "Enterococcus",
|
| 152 |
+
"7": "Bacillus",
|
| 153 |
+
"8": "Shigella",
|
| 154 |
+
"9": "Escherichia",
|
| 155 |
+
"10": "Klebsiella",
|
| 156 |
+
"11": "Proteus",
|
| 157 |
+
"12": "Vibrio",
|
| 158 |
+
"13": "Neisseria",
|
| 159 |
+
"14": "Campylobacter",
|
| 160 |
+
"15": "Clostridium",
|
| 161 |
+
"16": "Corynebacterium",
|
| 162 |
+
"17": "Legionella",
|
| 163 |
+
"18": "Mycobacterium",
|
| 164 |
+
"19": "Bacteroides",
|
| 165 |
+
"20": "Micrococcus",
|
| 166 |
+
"21": "Erysipelothrix",
|
| 167 |
+
"22": "Haemophilus",
|
| 168 |
+
"23": "Aeromonas",
|
| 169 |
+
"24": "Yersinia",
|
| 170 |
+
"25": "Acinetobacter",
|
| 171 |
+
"26": "Serratia",
|
| 172 |
+
"27": "Morganella",
|
| 173 |
+
"28": "Providencia",
|
| 174 |
+
"29": "Burkholderia",
|
| 175 |
+
"30": "Helicobacter",
|
| 176 |
+
"31": "Actinomyces",
|
| 177 |
+
"32": "Nocardia",
|
| 178 |
+
"33": "Pasteurella",
|
| 179 |
+
"34": "Citrobacter",
|
| 180 |
+
"35": "Leptospira",
|
| 181 |
+
"36": "Alcaligenes",
|
| 182 |
+
"37": "Shewanella",
|
| 183 |
+
"38": "Edwardsiella",
|
| 184 |
+
"39": "Chromobacterium",
|
| 185 |
+
"40": "Lactobacillus",
|
| 186 |
+
"41": "Propionibacterium",
|
| 187 |
+
"42": "Peptostreptococcus",
|
| 188 |
+
"43": "Veillonella",
|
| 189 |
+
"44": "Fusobacterium",
|
| 190 |
+
"45": "Eubacterium",
|
| 191 |
+
"46": "Halomonas",
|
| 192 |
+
"47": "Psychrobacter",
|
| 193 |
+
"48": "Rhodococcus",
|
| 194 |
+
"49": "Mycoplasma",
|
| 195 |
+
"50": "Bordetella",
|
| 196 |
+
"51": "Stenotrophomonas",
|
| 197 |
+
"52": "Ralstonia",
|
| 198 |
+
"53": "Achromobacter",
|
| 199 |
+
"54": "Brucella",
|
| 200 |
+
"55": "Arthrobacter",
|
| 201 |
+
"56": "Flavobacterium",
|
| 202 |
+
"57": "Oerskovia",
|
| 203 |
+
"58": "Sphingomonas",
|
| 204 |
+
"59": "Comamonas",
|
| 205 |
+
"60": "Thermococcus",
|
| 206 |
+
"61": "Elizabethkingia",
|
| 207 |
+
"62": "Hafnia",
|
| 208 |
+
"63": "Raoultella",
|
| 209 |
+
"64": "Ochrobactrum",
|
| 210 |
+
"65": "Roseomonas",
|
| 211 |
+
"66": "Actinobacillus",
|
| 212 |
+
"67": "Gemella",
|
| 213 |
+
"68": "Rothia",
|
| 214 |
+
"69": "Carnobacterium",
|
| 215 |
+
"70": "Plesiomonas",
|
| 216 |
+
"71": "Janthinobacterium",
|
| 217 |
+
"72": "Paenibacillus",
|
| 218 |
+
"73": "Moraxella",
|
| 219 |
+
"74": "Aerococcus",
|
| 220 |
+
"75": "Kocuria",
|
| 221 |
+
"76": "Leuconostoc",
|
| 222 |
+
"77": "Arcanobacterium",
|
| 223 |
+
"78": "Gardnerella",
|
| 224 |
+
"79": "Porphyromonas",
|
| 225 |
+
"80": "Prevotella",
|
| 226 |
+
"81": "Pediococcus",
|
| 227 |
+
"82": "Weissella",
|
| 228 |
+
"83": "Lactococcus",
|
| 229 |
+
"84": "Microbacterium",
|
| 230 |
+
"85": "Clostridioides",
|
| 231 |
+
"86": "Cronobacter",
|
| 232 |
+
"87": "Rhizobium",
|
| 233 |
+
"88": "Azotobacter",
|
| 234 |
+
"89": "Spirillum",
|
| 235 |
+
"90": "Candida",
|
| 236 |
+
"91": "Cryptococcus",
|
| 237 |
+
"92": "Saccharomyces",
|
| 238 |
+
"93": "Rickettsia",
|
| 239 |
+
"94": "Borrelia",
|
| 240 |
+
"95": "Chlamydia",
|
| 241 |
+
"96": "Acidaminococcus",
|
| 242 |
+
"97": "Bartonella",
|
| 243 |
+
"98": "Coxiella",
|
| 244 |
+
"99": "Kingella",
|
| 245 |
+
"100": "Eikenella",
|
| 246 |
+
"101": "Bilophila",
|
| 247 |
+
"102": "Anaerococcus",
|
| 248 |
+
"103": "Finegoldia",
|
| 249 |
+
"104": "Parvimonas",
|
| 250 |
+
"105": "Ruminococcus",
|
| 251 |
+
"106": "Cutibacterium",
|
| 252 |
+
"107": "Exiguobacterium",
|
| 253 |
+
"108": "Kluyvera",
|
| 254 |
+
"109": "Pluralibacter",
|
| 255 |
+
"110": "Massilia",
|
| 256 |
+
"111": "Methylobacterium",
|
| 257 |
+
"112": "Cupriavidus",
|
| 258 |
+
"113": "Acidovorax",
|
| 259 |
+
"114": "Geobacillus",
|
| 260 |
+
"115": "Trueperella",
|
| 261 |
+
"116": "Streptomyces",
|
| 262 |
+
"117": "Thermoactinomyces",
|
| 263 |
+
"118": "Capnocytophaga",
|
| 264 |
+
"119": "Cardiobacterium",
|
| 265 |
+
"120": "Yokenella",
|
| 266 |
+
"121": "Brevibacterium",
|
| 267 |
+
"122": "Peptoniphilus",
|
| 268 |
+
"123": "Weisella",
|
| 269 |
+
"124": "Saccharopolyspora",
|
| 270 |
+
"125": "Frankia",
|
| 271 |
+
"126": "Spiroplasma",
|
| 272 |
+
"127": "Cedecea",
|
| 273 |
+
"128": "Photorhabdus",
|
| 274 |
+
"129": "Abiotrophia",
|
| 275 |
+
"130": "Cellulomonas",
|
| 276 |
+
"131": "Leifsonia",
|
| 277 |
+
"132": "Alicyclobacillus",
|
| 278 |
+
"133": "Sporolactobacillus",
|
| 279 |
+
"134": "Leclercia",
|
| 280 |
+
"135": "Kosakonia",
|
| 281 |
+
"136": "Bergeyella",
|
| 282 |
+
"137": "Myroides",
|
| 283 |
+
"138": "Aggregatibacter",
|
| 284 |
+
"139": ":"
|
| 285 |
+
},
|
| 286 |
+
"n_features": 73,
|
| 287 |
+
"num_classes": 140,
|
| 288 |
+
"metrics": {
|
| 289 |
+
"train_accuracy": 0.9869916267942583,
|
| 290 |
+
"valid_accuracy": 0.9509569377990431,
|
| 291 |
+
"best_iteration": 270
|
| 292 |
+
},
|
| 293 |
+
"feature_schema_path": "data/feature_schema.json",
|
| 294 |
+
"feature_names": [
|
| 295 |
+
"Gram Stain",
|
| 296 |
+
"Shape",
|
| 297 |
+
"Haemolysis",
|
| 298 |
+
"Haemolysis Type",
|
| 299 |
+
"Catalase",
|
| 300 |
+
"Oxidase",
|
| 301 |
+
"Indole",
|
| 302 |
+
"Urease",
|
| 303 |
+
"Citrate",
|
| 304 |
+
"H2S",
|
| 305 |
+
"DNase",
|
| 306 |
+
"Lysine Decarboxylase",
|
| 307 |
+
"Ornithine Decarboxylase",
|
| 308 |
+
"Arginine dihydrolase",
|
| 309 |
+
"ONPG",
|
| 310 |
+
"Nitrate Reduction",
|
| 311 |
+
"Methyl Red",
|
| 312 |
+
"VP",
|
| 313 |
+
"Coagulase",
|
| 314 |
+
"Lipase Test",
|
| 315 |
+
"Motility",
|
| 316 |
+
"Motility Type",
|
| 317 |
+
"Capsule",
|
| 318 |
+
"Spore Formation",
|
| 319 |
+
"Pigment",
|
| 320 |
+
"Odor",
|
| 321 |
+
"Colony Pattern",
|
| 322 |
+
"TSI Pattern",
|
| 323 |
+
"Temperature_4C",
|
| 324 |
+
"Temperature_25C",
|
| 325 |
+
"Temperature_30C",
|
| 326 |
+
"Temperature_37C",
|
| 327 |
+
"Temperature_42C",
|
| 328 |
+
"Lactose Fermentation",
|
| 329 |
+
"Glucose Fermentation",
|
| 330 |
+
"Sucrose Fermentation",
|
| 331 |
+
"Mannitol Fermentation",
|
| 332 |
+
"Maltose Fermentation",
|
| 333 |
+
"Sorbitol Fermentation",
|
| 334 |
+
"Xylose Fermentation",
|
| 335 |
+
"Rhamnose Fermentation",
|
| 336 |
+
"Arabinose Fermentation",
|
| 337 |
+
"Raffinose Fermentation",
|
| 338 |
+
"Trehalose Fermentation",
|
| 339 |
+
"Inositol Fermentation",
|
| 340 |
+
"Oxygen Requirement",
|
| 341 |
+
"Gas Production",
|
| 342 |
+
"MacConkey Growth",
|
| 343 |
+
"Blood Growth",
|
| 344 |
+
"XLD Growth",
|
| 345 |
+
"Nutrient Growth",
|
| 346 |
+
"Cetrimide Growth",
|
| 347 |
+
"BCYE Growth",
|
| 348 |
+
"Hektoen Enteric Growth",
|
| 349 |
+
"Mannitol Salt Growth",
|
| 350 |
+
"Bordet-Gengou Growth",
|
| 351 |
+
"Thayer Martin Growth",
|
| 352 |
+
"Cycloserine Cefoxitin Fructose Growth",
|
| 353 |
+
"Sabouraud Growth",
|
| 354 |
+
"Lowenstein-Jensen Growth",
|
| 355 |
+
"Yeast Extract Mannitol Growth",
|
| 356 |
+
"BSK Growth",
|
| 357 |
+
"Brucella Growth",
|
| 358 |
+
"Charcoal Growth",
|
| 359 |
+
"BHI Growth",
|
| 360 |
+
"Ashby Growth",
|
| 361 |
+
"MRS Growth",
|
| 362 |
+
"Anaerobic Blood Growth",
|
| 363 |
+
"BP Growth",
|
| 364 |
+
"ALOA Growth",
|
| 365 |
+
"Anaerobic Growth",
|
| 366 |
+
"Chocolate Growth",
|
| 367 |
+
"TCBS Growth"
|
| 368 |
+
]
|
| 369 |
+
}
|
rag/context_shaper.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/context_shaper.py
|
| 2 |
+
# ============================================================
|
| 3 |
+
# Context shaper for RAG
|
| 4 |
+
#
|
| 5 |
+
# Goal:
|
| 6 |
+
# - Convert "flattened schema dumps" (Field: Value lines) into
|
| 7 |
+
# readable evidence blocks the LLM can reason over.
|
| 8 |
+
# - Deterministic, no LLM usage.
|
| 9 |
+
#
|
| 10 |
+
# Works with:
|
| 11 |
+
# - llm_context from rag_retriever (biology-only text)
|
| 12 |
+
# ============================================================
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import re
|
| 17 |
+
from typing import Dict, List, Tuple, Optional
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
_FIELD_LINE_RE = re.compile(r"^\s*([^:\n]{1,80})\s*:\s*(.+?)\s*$")
|
| 21 |
+
|
| 22 |
+
# Some fields are usually lists separated by ; or , or |
|
| 23 |
+
_LIST_LIKE_FIELDS = {
|
| 24 |
+
"Media Grown On",
|
| 25 |
+
"Colony Morphology",
|
| 26 |
+
"Colony Pattern",
|
| 27 |
+
"Growth Temperature",
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
# A light grouping map to turn fields into readable sections.
|
| 31 |
+
# (You can expand this over time.)
|
| 32 |
+
_GROUPS: List[Tuple[str, List[str]]] = [
|
| 33 |
+
("Morphology & staining", [
|
| 34 |
+
"Gram Stain", "Shape", "Cellular Arrangement", "Capsule", "Spore Forming",
|
| 35 |
+
]),
|
| 36 |
+
("Culture & colony", [
|
| 37 |
+
"Media Grown On", "Colony Morphology", "Colony Pattern", "Pigment", "Odour",
|
| 38 |
+
"Haemolysis", "Haemolysis Type",
|
| 39 |
+
]),
|
| 40 |
+
("Core biochemistry", [
|
| 41 |
+
"Oxidase", "Catalase", "Indole", "Urease", "Citrate", "Methyl Red", "VP",
|
| 42 |
+
"Nitrate Reduction", "ONPG", "TSI Pattern", "H2S", "Gas Production",
|
| 43 |
+
"Glucose Fermentation", "Lactose Fermentation", "Sucrose Fermentation",
|
| 44 |
+
"Inositol Fermentation", "Mannitol Fermentation",
|
| 45 |
+
]),
|
| 46 |
+
("Motility & growth conditions", [
|
| 47 |
+
"Motility", "Motility Type", "Growth Temperature", "NaCl", "NaCl Tolerance",
|
| 48 |
+
"Oxygen Requirement",
|
| 49 |
+
]),
|
| 50 |
+
("Other tests", [
|
| 51 |
+
"DNase", "Esculin Hydrolysis", "Gelatin Hydrolysis",
|
| 52 |
+
"Lysine Decarboxylase", "Ornithine Decarboxylase", "Arginine Dihydrolase",
|
| 53 |
+
]),
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _is_schema_dump(text: str) -> bool:
|
| 58 |
+
"""
|
| 59 |
+
Detect if rag context looks like flattened Field: Value lines.
|
| 60 |
+
"""
|
| 61 |
+
if not text:
|
| 62 |
+
return False
|
| 63 |
+
lines = [l for l in text.splitlines() if l.strip()]
|
| 64 |
+
if len(lines) < 6:
|
| 65 |
+
return False
|
| 66 |
+
hits = 0
|
| 67 |
+
for l in lines[:40]:
|
| 68 |
+
if _FIELD_LINE_RE.match(l):
|
| 69 |
+
hits += 1
|
| 70 |
+
return hits >= max(4, int(0.5 * min(len(lines), 40)))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _split_listish(field: str, value: str) -> str:
|
| 74 |
+
"""
|
| 75 |
+
Normalize list-like values into comma-separated readable text.
|
| 76 |
+
"""
|
| 77 |
+
v = (value or "").strip()
|
| 78 |
+
if not v:
|
| 79 |
+
return v
|
| 80 |
+
if field in _LIST_LIKE_FIELDS or (";" in v) or ("," in v):
|
| 81 |
+
parts = [p.strip() for p in re.split(r"[;,\|]+", v) if p.strip()]
|
| 82 |
+
if parts:
|
| 83 |
+
return ", ".join(parts)
|
| 84 |
+
return v
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _parse_field_lines(text: str) -> Dict[str, str]:
|
| 88 |
+
"""
|
| 89 |
+
Parse Field: Value lines into a dict. Keeps last occurrence.
|
| 90 |
+
"""
|
| 91 |
+
out: Dict[str, str] = {}
|
| 92 |
+
for raw in (text or "").splitlines():
|
| 93 |
+
line = raw.strip()
|
| 94 |
+
if not line:
|
| 95 |
+
continue
|
| 96 |
+
m = _FIELD_LINE_RE.match(line)
|
| 97 |
+
if not m:
|
| 98 |
+
continue
|
| 99 |
+
field = m.group(1).strip()
|
| 100 |
+
value = m.group(2).strip()
|
| 101 |
+
if not field:
|
| 102 |
+
continue
|
| 103 |
+
out[field] = _split_listish(field, value)
|
| 104 |
+
return out
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _format_grouped_blocks(fields: Dict[str, str]) -> str:
|
| 108 |
+
"""
|
| 109 |
+
Turn fields into grouped, readable evidence blocks.
|
| 110 |
+
"""
|
| 111 |
+
used = set()
|
| 112 |
+
blocks: List[str] = []
|
| 113 |
+
|
| 114 |
+
for title, keys in _GROUPS:
|
| 115 |
+
lines: List[str] = []
|
| 116 |
+
for k in keys:
|
| 117 |
+
if k in fields:
|
| 118 |
+
val = fields[k]
|
| 119 |
+
if val and val.lower() != "unknown":
|
| 120 |
+
lines.append(f"- {k}: {val}")
|
| 121 |
+
used.add(k)
|
| 122 |
+
if lines:
|
| 123 |
+
blocks.append(f"{title}:\n" + "\n".join(lines))
|
| 124 |
+
|
| 125 |
+
# Any leftovers not in group map
|
| 126 |
+
leftovers: List[str] = []
|
| 127 |
+
for k, v in fields.items():
|
| 128 |
+
if k in used:
|
| 129 |
+
continue
|
| 130 |
+
if not v or v.lower() == "unknown":
|
| 131 |
+
continue
|
| 132 |
+
leftovers.append(f"- {k}: {v}")
|
| 133 |
+
if leftovers:
|
| 134 |
+
blocks.append("Additional traits:\n" + "\n".join(leftovers))
|
| 135 |
+
|
| 136 |
+
return "\n\n".join(blocks).strip()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def shape_llm_context(
|
| 140 |
+
llm_context: str,
|
| 141 |
+
target_genus: str = "",
|
| 142 |
+
max_chars: int = 1800,
|
| 143 |
+
) -> str:
|
| 144 |
+
"""
|
| 145 |
+
Main entrypoint.
|
| 146 |
+
- If context is already narrative, keep it (trim to max_chars).
|
| 147 |
+
- If it is a schema dump, convert to grouped evidence blocks.
|
| 148 |
+
"""
|
| 149 |
+
ctx = (llm_context or "").strip()
|
| 150 |
+
if not ctx:
|
| 151 |
+
return ""
|
| 152 |
+
|
| 153 |
+
if _is_schema_dump(ctx):
|
| 154 |
+
fields = _parse_field_lines(ctx)
|
| 155 |
+
shaped = _format_grouped_blocks(fields)
|
| 156 |
+
|
| 157 |
+
# Add a tiny header to cue the LLM that this is reference evidence
|
| 158 |
+
if target_genus:
|
| 159 |
+
shaped = f"Reference evidence for {target_genus} (compiled traits):\n\n{shaped}"
|
| 160 |
+
else:
|
| 161 |
+
shaped = f"Reference evidence (compiled traits):\n\n{shaped}"
|
| 162 |
+
|
| 163 |
+
return shaped[:max_chars].strip()
|
| 164 |
+
|
| 165 |
+
# Narrative context: just trim
|
| 166 |
+
if target_genus:
|
| 167 |
+
ctx = f"Reference context for {target_genus}:\n\n{ctx}"
|
| 168 |
+
return ctx[:max_chars].strip()
|
rag/rag_embedder.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/rag_embedder.py
|
| 2 |
+
# ============================================================
|
| 3 |
+
# Embedding utilities for RAG (knowledge base + queries)
|
| 4 |
+
# Uses a SentenceTransformer model for dense embeddings.
|
| 5 |
+
# ============================================================
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
from typing import List, Dict, Any
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from sentence_transformers import SentenceTransformer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ------------------------------------------------------------
|
| 18 |
+
# CONFIG
|
| 19 |
+
# ------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
EMBEDDING_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
| 22 |
+
|
| 23 |
+
_model: SentenceTransformer | None = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ------------------------------------------------------------
|
| 27 |
+
# MODEL LOADING
|
| 28 |
+
# ------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
def get_embedder() -> SentenceTransformer:
|
| 31 |
+
global _model
|
| 32 |
+
if _model is None:
|
| 33 |
+
_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
|
| 34 |
+
return _model
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ------------------------------------------------------------
|
| 38 |
+
# EMBEDDING
|
| 39 |
+
# ------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
def embed_text(text: str, normalize: bool = True) -> np.ndarray:
|
| 42 |
+
"""
|
| 43 |
+
Embed a single piece of text.
|
| 44 |
+
Returns a 1D numpy array (MPNet: 768-dim).
|
| 45 |
+
"""
|
| 46 |
+
model = get_embedder()
|
| 47 |
+
emb = model.encode(
|
| 48 |
+
[text],
|
| 49 |
+
show_progress_bar=False,
|
| 50 |
+
normalize_embeddings=normalize,
|
| 51 |
+
)
|
| 52 |
+
return emb[0]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def embed_texts(texts: List[str], normalize: bool = True) -> np.ndarray:
|
| 56 |
+
"""
|
| 57 |
+
Embed a list of strings -> (N, D) numpy array.
|
| 58 |
+
"""
|
| 59 |
+
model = get_embedder()
|
| 60 |
+
return model.encode(
|
| 61 |
+
texts,
|
| 62 |
+
show_progress_bar=False,
|
| 63 |
+
normalize_embeddings=normalize,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ------------------------------------------------------------
|
| 68 |
+
# INDEX LOADING
|
| 69 |
+
# ------------------------------------------------------------
|
| 70 |
+
|
| 71 |
+
def load_kb_index(path: str = "data/rag/index/kb_index.json") -> Dict[str, Any]:
|
| 72 |
+
"""
|
| 73 |
+
Load the RAG knowledge base index JSON.
|
| 74 |
+
|
| 75 |
+
Expected format:
|
| 76 |
+
{
|
| 77 |
+
"version": int,
|
| 78 |
+
"model_name": str,
|
| 79 |
+
"records": [
|
| 80 |
+
{
|
| 81 |
+
"id": str,
|
| 82 |
+
"genus": str,
|
| 83 |
+
"species": str | null,
|
| 84 |
+
"level": "genus" | "species",
|
| 85 |
+
"chunk_id": int,
|
| 86 |
+
"source_file": str,
|
| 87 |
+
"text": str,
|
| 88 |
+
"embedding": [float, ...]
|
| 89 |
+
}
|
| 90 |
+
]
|
| 91 |
+
}
|
| 92 |
+
"""
|
| 93 |
+
if not os.path.exists(path):
|
| 94 |
+
raise FileNotFoundError(f"KB index not found at {path}")
|
| 95 |
+
|
| 96 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 97 |
+
data = json.load(f)
|
| 98 |
+
|
| 99 |
+
index_model = data.get("model_name")
|
| 100 |
+
if index_model != EMBEDDING_MODEL_NAME:
|
| 101 |
+
raise ValueError(
|
| 102 |
+
f"KB index built with '{index_model}', "
|
| 103 |
+
f"but current embedder is '{EMBEDDING_MODEL_NAME}'. "
|
| 104 |
+
"Rebuild the index."
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Convert embeddings to numpy arrays
|
| 108 |
+
for rec in data.get("records", []):
|
| 109 |
+
if isinstance(rec.get("embedding"), list):
|
| 110 |
+
rec["embedding"] = np.array(rec["embedding"], dtype="float32")
|
| 111 |
+
|
| 112 |
+
return data
|
rag/rag_generator.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/rag_generator.py
|
| 2 |
+
# ============================================================
|
| 3 |
+
# RAG generator using google/flan-t5-large (CPU-friendly)
|
| 4 |
+
#
|
| 5 |
+
# Goal (user-visible, structured, deterministic-first):
|
| 6 |
+
# - Show the user:
|
| 7 |
+
# KEY TRAITS:
|
| 8 |
+
# CONFLICTS:
|
| 9 |
+
# CONCLUSION:
|
| 10 |
+
# - KEY TRAITS and CONFLICTS are extracted deterministically from the
|
| 11 |
+
# shaped retriever context (preferred).
|
| 12 |
+
# - The LLM only writes the CONCLUSION (2–5 sentences) based on those
|
| 13 |
+
# extracted sections.
|
| 14 |
+
#
|
| 15 |
+
# Reliability:
|
| 16 |
+
# - flan-t5 sometimes echoes prompt instructions.
|
| 17 |
+
# - We keep the prompt extremely short and avoid imperative bullet rules.
|
| 18 |
+
# - We keep deterministic fallback logic if the LLM output is garbage/echo.
|
| 19 |
+
#
|
| 20 |
+
# Expected usage:
|
| 21 |
+
# ctx = retrieve_rag_context(..., parsed_fields=...)
|
| 22 |
+
# explanation = generate_genus_rag_explanation(
|
| 23 |
+
# phenotype_text=text,
|
| 24 |
+
# rag_context=ctx.get("llm_context_shaped") or ctx.get("llm_context"),
|
| 25 |
+
# genus=genus
|
| 26 |
+
# )
|
| 27 |
+
#
|
| 28 |
+
# Optional HF Space logs:
|
| 29 |
+
# export BACTAI_RAG_GEN_LOG_INPUT=1
|
| 30 |
+
# export BACTAI_RAG_GEN_LOG_OUTPUT=1
|
| 31 |
+
# ============================================================
|
| 32 |
+
|
| 33 |
+
from __future__ import annotations
|
| 34 |
+
|
| 35 |
+
import os
|
| 36 |
+
import re
|
| 37 |
+
import torch
|
| 38 |
+
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ------------------------------------------------------------
|
| 42 |
+
# MODEL CONFIG
|
| 43 |
+
# ------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
MODEL_NAME = "google/flan-t5-large"
|
| 46 |
+
|
| 47 |
+
_tokenizer: T5Tokenizer | None = None
|
| 48 |
+
_model: T5ForConditionalGeneration | None = None
|
| 49 |
+
|
| 50 |
+
# Keep small for CPU + to reduce prompt truncation weirdness
|
| 51 |
+
_MAX_INPUT_TOKENS = 768
|
| 52 |
+
_DEFAULT_MAX_NEW_TOKENS = 160
|
| 53 |
+
|
| 54 |
+
# Hard cap the context chars we feed to T5 (prevents the model focusing on junk)
|
| 55 |
+
_CONTEXT_CHAR_CAP = 2400
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _get_model() -> tuple[T5Tokenizer, T5ForConditionalGeneration]:
|
| 59 |
+
global _tokenizer, _model
|
| 60 |
+
if _tokenizer is None or _model is None:
|
| 61 |
+
_tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
|
| 62 |
+
_model = T5ForConditionalGeneration.from_pretrained(
|
| 63 |
+
MODEL_NAME,
|
| 64 |
+
device_map="auto",
|
| 65 |
+
torch_dtype=torch.float32,
|
| 66 |
+
)
|
| 67 |
+
return _tokenizer, _model
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ------------------------------------------------------------
|
| 71 |
+
# DEBUG LOGGING (HF Space logs)
|
| 72 |
+
# ------------------------------------------------------------
|
| 73 |
+
|
| 74 |
+
RAG_GEN_LOG_INPUT = os.getenv("BACTAI_RAG_GEN_LOG_INPUT", "0").strip() == "1"
|
| 75 |
+
RAG_GEN_LOG_OUTPUT = os.getenv("BACTAI_RAG_GEN_LOG_OUTPUT", "0").strip() == "1"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _log_block(title: str, body: str) -> None:
|
| 79 |
+
print("=" * 80)
|
| 80 |
+
print(f"RAG GENERATOR DEBUG — {title}")
|
| 81 |
+
print("=" * 80)
|
| 82 |
+
print(body.strip() if body else "")
|
| 83 |
+
print()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ------------------------------------------------------------
|
| 87 |
+
# PROMPT (LLM WRITES ONLY THE CONCLUSION)
|
| 88 |
+
# ------------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
# Intentionally minimal. No "rules list", no bullets specification.
|
| 91 |
+
# The LLM sees ONLY extracted matches/conflicts and writes a short conclusion.
|
| 92 |
+
RAG_PROMPT = """summarize: Evaluate whether the phenotype fits the target genus using the provided matches and conflicts.
|
| 93 |
+
|
| 94 |
+
Target genus: {genus}
|
| 95 |
+
|
| 96 |
+
Key traits that match:
|
| 97 |
+
{matches}
|
| 98 |
+
|
| 99 |
+
Conflicts:
|
| 100 |
+
{conflicts}
|
| 101 |
+
|
| 102 |
+
Write a short conclusion (2–5 sentences) stating whether this is a strong, moderate, or tentative genus match, and briefly mention the most important matches and conflicts.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ------------------------------------------------------------
|
| 107 |
+
# OUTPUT CLEANUP + ECHO DETECTION
|
| 108 |
+
# ------------------------------------------------------------
|
| 109 |
+
|
| 110 |
+
_BAD_SUBSTRINGS = (
|
| 111 |
+
"summarize:",
|
| 112 |
+
"target genus",
|
| 113 |
+
"key traits that match",
|
| 114 |
+
"write a short conclusion",
|
| 115 |
+
"conflicts:",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def _clean_generation(text: str) -> str:
|
| 119 |
+
s = (text or "").strip()
|
| 120 |
+
if not s:
|
| 121 |
+
return ""
|
| 122 |
+
|
| 123 |
+
# collapse excessive whitespace/newlines
|
| 124 |
+
s = re.sub(r"\s*\n+\s*", " ", s).strip()
|
| 125 |
+
s = re.sub(r"\s{2,}", " ", s).strip()
|
| 126 |
+
|
| 127 |
+
# guard runaway length
|
| 128 |
+
if len(s) > 900:
|
| 129 |
+
s = s[:900].rstrip() + "..."
|
| 130 |
+
|
| 131 |
+
return s
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _looks_like_echo_or_garbage(text: str) -> bool:
|
| 135 |
+
s = (text or "").strip()
|
| 136 |
+
if not s:
|
| 137 |
+
return True
|
| 138 |
+
|
| 139 |
+
# extremely short / non-sentence
|
| 140 |
+
if len(s) < 25:
|
| 141 |
+
return True
|
| 142 |
+
|
| 143 |
+
low = s.lower()
|
| 144 |
+
if any(bad in low for bad in _BAD_SUBSTRINGS):
|
| 145 |
+
return True
|
| 146 |
+
|
| 147 |
+
# Must look like actual prose
|
| 148 |
+
if "." not in s and "because" not in low and "match" not in low and "fits" not in low:
|
| 149 |
+
return True
|
| 150 |
+
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ------------------------------------------------------------
|
| 155 |
+
# EXTRACT KEY TRAITS + CONFLICTS FROM SHAPED CONTEXT
|
| 156 |
+
# ------------------------------------------------------------
|
| 157 |
+
|
| 158 |
+
# Shaped context format (example):
|
| 159 |
+
# KEY MATCHES:
|
| 160 |
+
# - Trait: Value (matches reference: ...)
|
| 161 |
+
#
|
| 162 |
+
# CONFLICTS (observed vs CORE traits):
|
| 163 |
+
# - Trait: Value (conflicts reference: ...)
|
| 164 |
+
# or:
|
| 165 |
+
# CONFLICTS: Not specified.
|
| 166 |
+
|
| 167 |
+
_KEY_MATCHES_HEADER_RE = re.compile(r"^\s*KEY MATCHES\s*:\s*$", re.IGNORECASE)
|
| 168 |
+
_CONFLICTS_HEADER_RE = re.compile(r"^\s*CONFLICTS\b.*:\s*$", re.IGNORECASE)
|
| 169 |
+
_CONFLICTS_INLINE_NONE_RE = re.compile(r"^\s*CONFLICTS\s*:\s*not specified\.?\s*$", re.IGNORECASE)
|
| 170 |
+
|
| 171 |
+
_MATCH_LINE_RE = re.compile(
|
| 172 |
+
r"^\s*-\s*([^:]+)\s*:\s*(.+?)\s*\(matches reference:\s*(.+?)\)\s*$",
|
| 173 |
+
re.IGNORECASE,
|
| 174 |
+
)
|
| 175 |
+
_CONFLICT_LINE_RE = re.compile(
|
| 176 |
+
r"^\s*-\s*([^:]+)\s*:\s*(.+?)\s*\(conflicts reference:\s*(.+?)\)\s*$",
|
| 177 |
+
re.IGNORECASE,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# More permissive bullet capture (if shaper changes slightly)
|
| 181 |
+
_GENERIC_BULLET_RE = re.compile(r"^\s*-\s*(.+?)\s*$")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _extract_key_traits_and_conflicts(shaped_ctx: str) -> tuple[list[str], list[str], bool]:
|
| 185 |
+
"""
|
| 186 |
+
Extracts KEY MATCHES and CONFLICTS bullets from shaped retriever context.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
(key_traits, conflicts, found_structured_headers)
|
| 190 |
+
|
| 191 |
+
- key_traits items are short: "Trait: ObservedValue"
|
| 192 |
+
- conflicts items are short: "Trait: ObservedValue"
|
| 193 |
+
"""
|
| 194 |
+
key_traits: list[str] = []
|
| 195 |
+
conflicts: list[str] = []
|
| 196 |
+
|
| 197 |
+
lines = (shaped_ctx or "").splitlines()
|
| 198 |
+
if not lines:
|
| 199 |
+
return key_traits, conflicts, False
|
| 200 |
+
|
| 201 |
+
in_matches = False
|
| 202 |
+
in_conflicts = False
|
| 203 |
+
saw_headers = False
|
| 204 |
+
|
| 205 |
+
for raw in lines:
|
| 206 |
+
line = raw.rstrip("\n")
|
| 207 |
+
|
| 208 |
+
# detect headers
|
| 209 |
+
if _KEY_MATCHES_HEADER_RE.match(line.strip()):
|
| 210 |
+
in_matches = True
|
| 211 |
+
in_conflicts = False
|
| 212 |
+
saw_headers = True
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
if _CONFLICTS_INLINE_NONE_RE.match(line.strip()):
|
| 216 |
+
in_matches = False
|
| 217 |
+
in_conflicts = False
|
| 218 |
+
saw_headers = True
|
| 219 |
+
# explicit "no conflicts"
|
| 220 |
+
continue
|
| 221 |
+
|
| 222 |
+
if _CONFLICTS_HEADER_RE.match(line.strip()):
|
| 223 |
+
in_matches = False
|
| 224 |
+
in_conflicts = True
|
| 225 |
+
saw_headers = True
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
# stop capture if another section begins (common shaper headings)
|
| 229 |
+
if saw_headers and (line.strip().endswith(":") and not line.strip().startswith("-")):
|
| 230 |
+
# If it's a new heading (and not one of our two), stop both
|
| 231 |
+
if not _KEY_MATCHES_HEADER_RE.match(line.strip()) and not _CONFLICTS_HEADER_RE.match(line.strip()):
|
| 232 |
+
in_matches = False
|
| 233 |
+
in_conflicts = False
|
| 234 |
+
|
| 235 |
+
# capture bullets under each section
|
| 236 |
+
if in_matches and line.strip().startswith("-"):
|
| 237 |
+
m = _MATCH_LINE_RE.match(line.strip())
|
| 238 |
+
if m:
|
| 239 |
+
trait = m.group(1).strip()
|
| 240 |
+
obs = m.group(2).strip()
|
| 241 |
+
key_traits.append(f"{trait}: {obs}")
|
| 242 |
+
else:
|
| 243 |
+
g = _GENERIC_BULLET_RE.match(line.strip())
|
| 244 |
+
if g:
|
| 245 |
+
key_traits.append(g.group(1).strip())
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
if in_conflicts and line.strip().startswith("-"):
|
| 249 |
+
c = _CONFLICT_LINE_RE.match(line.strip())
|
| 250 |
+
if c:
|
| 251 |
+
trait = c.group(1).strip()
|
| 252 |
+
obs = c.group(2).strip()
|
| 253 |
+
conflicts.append(f"{trait}: {obs}")
|
| 254 |
+
else:
|
| 255 |
+
g = _GENERIC_BULLET_RE.match(line.strip())
|
| 256 |
+
if g:
|
| 257 |
+
conflicts.append(g.group(1).strip())
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
return key_traits, conflicts, saw_headers
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def _extract_matches_conflicts_legacy(shaped_ctx: str) -> tuple[list[str], list[str]]:
|
| 264 |
+
"""
|
| 265 |
+
Legacy extraction based purely on (matches reference: ...) / (conflicts reference: ...)
|
| 266 |
+
anywhere in the text. Useful if headers are missing.
|
| 267 |
+
"""
|
| 268 |
+
matches: list[str] = []
|
| 269 |
+
conflicts: list[str] = []
|
| 270 |
+
|
| 271 |
+
for raw in (shaped_ctx or "").splitlines():
|
| 272 |
+
line = raw.strip()
|
| 273 |
+
if not line.startswith("-"):
|
| 274 |
+
continue
|
| 275 |
+
|
| 276 |
+
m = _MATCH_LINE_RE.match(line)
|
| 277 |
+
if m:
|
| 278 |
+
trait = m.group(1).strip()
|
| 279 |
+
obs = m.group(2).strip()
|
| 280 |
+
matches.append(f"{trait}: {obs}")
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
c = _CONFLICT_LINE_RE.match(line)
|
| 284 |
+
if c:
|
| 285 |
+
trait = c.group(1).strip()
|
| 286 |
+
obs = c.group(2).strip()
|
| 287 |
+
conflicts.append(f"{trait}: {obs}")
|
| 288 |
+
continue
|
| 289 |
+
|
| 290 |
+
return matches, conflicts
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def _format_bullets(items: list[str], *, none_text: str) -> str:
|
| 294 |
+
if not items:
|
| 295 |
+
return none_text
|
| 296 |
+
return "\n".join(f"- {x}" for x in items)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# ------------------------------------------------------------
|
| 300 |
+
# DETERMINISTIC CONCLUSION FALLBACK
|
| 301 |
+
# ------------------------------------------------------------
|
| 302 |
+
|
| 303 |
+
def _deterministic_conclusion(genus: str, key_traits: list[str], conflicts: list[str]) -> str:
|
| 304 |
+
g = (genus or "").strip() or "Unknown"
|
| 305 |
+
|
| 306 |
+
m = key_traits[:4]
|
| 307 |
+
c = conflicts[:2]
|
| 308 |
+
|
| 309 |
+
if m and c:
|
| 310 |
+
return (
|
| 311 |
+
f"This is a probable match to {g} because it aligns with key traits such as "
|
| 312 |
+
f"{', '.join(m)}. However, there are conflicts ({', '.join(c)}), so treat this "
|
| 313 |
+
f"as a moderate/tentative genus-level fit and consider re-checking the conflicting tests."
|
| 314 |
+
)
|
| 315 |
+
if m and not c:
|
| 316 |
+
return (
|
| 317 |
+
f"This phenotype is consistent with {g} based on key matching traits such as "
|
| 318 |
+
f"{', '.join(m)}. No major conflicts were detected against the retrieved core genus traits, "
|
| 319 |
+
f"supporting a strong genus-level match."
|
| 320 |
+
)
|
| 321 |
+
if (not m) and c:
|
| 322 |
+
return (
|
| 323 |
+
f"This phenotype does not cleanly fit {g} because it conflicts with core traits "
|
| 324 |
+
f"({', '.join(c)}). Consider re-checking those tests or comparing against the next-ranked genera."
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
return (
|
| 328 |
+
f"Reference evidence was available for {g}, but no clear matches or conflicts could be extracted "
|
| 329 |
+
f"from the shaped context. Try increasing top_k genus chunks or ensuring parsed_fields are being "
|
| 330 |
+
f"passed into retrieve_rag_context so the shaper can compute KEY MATCHES and CONFLICTS."
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def _trim_context(ctx: str) -> str:
|
| 335 |
+
s = (ctx or "").strip()
|
| 336 |
+
if not s:
|
| 337 |
+
return ""
|
| 338 |
+
if len(s) <= _CONTEXT_CHAR_CAP:
|
| 339 |
+
return s
|
| 340 |
+
return s[:_CONTEXT_CHAR_CAP].rstrip() + "\n... (truncated)"
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# ------------------------------------------------------------
|
| 344 |
+
# PUBLIC API
|
| 345 |
+
# ------------------------------------------------------------
|
| 346 |
+
|
| 347 |
+
def generate_genus_rag_explanation(
|
| 348 |
+
phenotype_text: str,
|
| 349 |
+
rag_context: str,
|
| 350 |
+
genus: str,
|
| 351 |
+
max_new_tokens: int = _DEFAULT_MAX_NEW_TOKENS,
|
| 352 |
+
) -> str:
|
| 353 |
+
"""
|
| 354 |
+
Generates a structured RAG output intended for direct display:
|
| 355 |
+
|
| 356 |
+
KEY TRAITS:
|
| 357 |
+
- ...
|
| 358 |
+
CONFLICTS:
|
| 359 |
+
- ...
|
| 360 |
+
CONCLUSION:
|
| 361 |
+
...
|
| 362 |
+
|
| 363 |
+
Notes:
|
| 364 |
+
- KEY TRAITS + CONFLICTS are extracted deterministically from the (shaped) context.
|
| 365 |
+
- The LLM writes only the CONCLUSION.
|
| 366 |
+
- If the LLM output is garbage/echo, we use a deterministic conclusion fallback.
|
| 367 |
+
"""
|
| 368 |
+
tokenizer, model = _get_model()
|
| 369 |
+
|
| 370 |
+
genus_clean = (genus or "").strip() or "Unknown"
|
| 371 |
+
context = _trim_context(rag_context or "")
|
| 372 |
+
|
| 373 |
+
if not context:
|
| 374 |
+
return (
|
| 375 |
+
"KEY TRAITS:\n"
|
| 376 |
+
"- Not specified.\n\n"
|
| 377 |
+
"CONFLICTS:\n"
|
| 378 |
+
"- Not specified.\n\n"
|
| 379 |
+
"CONCLUSION:\n"
|
| 380 |
+
"No reference evidence was available to evaluate this genus against the observed phenotype."
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Prefer structured extraction (KEY MATCHES / CONFLICTS sections)
|
| 384 |
+
key_traits, conflicts, saw_headers = _extract_key_traits_and_conflicts(context)
|
| 385 |
+
|
| 386 |
+
# If the headers weren't found or extraction is empty, try legacy extraction
|
| 387 |
+
if (not saw_headers) or (not key_traits and not conflicts):
|
| 388 |
+
legacy_matches, legacy_conflicts = _extract_matches_conflicts_legacy(context)
|
| 389 |
+
if legacy_matches or legacy_conflicts:
|
| 390 |
+
key_traits = key_traits or legacy_matches
|
| 391 |
+
conflicts = conflicts or legacy_conflicts
|
| 392 |
+
|
| 393 |
+
key_traits_text = _format_bullets(key_traits, none_text="- Not specified.")
|
| 394 |
+
conflicts_text = _format_bullets(conflicts, none_text="- Not specified.")
|
| 395 |
+
|
| 396 |
+
# LLM: conclusion only
|
| 397 |
+
prompt = RAG_PROMPT.format(
|
| 398 |
+
genus=genus_clean,
|
| 399 |
+
matches=key_traits_text,
|
| 400 |
+
conflicts=conflicts_text,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
if RAG_GEN_LOG_INPUT:
|
| 404 |
+
_log_block("PROMPT (CONCLUSION-ONLY)", prompt[:3000] + ("\n...(truncated)" if len(prompt) > 3000 else ""))
|
| 405 |
+
|
| 406 |
+
inputs = tokenizer(
|
| 407 |
+
prompt,
|
| 408 |
+
return_tensors="pt",
|
| 409 |
+
truncation=True,
|
| 410 |
+
max_length=_MAX_INPUT_TOKENS,
|
| 411 |
+
).to(model.device)
|
| 412 |
+
|
| 413 |
+
output = model.generate(
|
| 414 |
+
**inputs,
|
| 415 |
+
max_new_tokens=max_new_tokens,
|
| 416 |
+
temperature=0.0,
|
| 417 |
+
num_beams=1,
|
| 418 |
+
do_sample=False,
|
| 419 |
+
repetition_penalty=1.2,
|
| 420 |
+
no_repeat_ngram_size=3,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
decoded = tokenizer.decode(output[0], skip_special_tokens=True).strip()
|
| 424 |
+
cleaned = _clean_generation(decoded)
|
| 425 |
+
|
| 426 |
+
if RAG_GEN_LOG_OUTPUT:
|
| 427 |
+
_log_block("RAW OUTPUT (CONCLUSION)", decoded)
|
| 428 |
+
_log_block("CLEANED OUTPUT (CONCLUSION)", cleaned)
|
| 429 |
+
|
| 430 |
+
# If LLM output is junk, use deterministic conclusion
|
| 431 |
+
if _looks_like_echo_or_garbage(cleaned):
|
| 432 |
+
cleaned = _deterministic_conclusion(genus_clean, key_traits, conflicts)
|
| 433 |
+
if RAG_GEN_LOG_OUTPUT:
|
| 434 |
+
_log_block("FALLBACK CONCLUSION (DETERMINISTIC)", cleaned)
|
| 435 |
+
|
| 436 |
+
# Final user-visible structured output
|
| 437 |
+
final = (
|
| 438 |
+
"KEY TRAITS:\n"
|
| 439 |
+
f"{key_traits_text}\n\n"
|
| 440 |
+
"CONFLICTS:\n"
|
| 441 |
+
f"{conflicts_text}\n\n"
|
| 442 |
+
"CONCLUSION:\n"
|
| 443 |
+
f"{cleaned}"
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
return final
|
rag/rag_retriever.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/rag_retriever.py
|
| 2 |
+
# ============================================================
|
| 3 |
+
# RAG retriever (Stage 2 – microbiology-aware)
|
| 4 |
+
#
|
| 5 |
+
# Key change (GENUS-FIRST):
|
| 6 |
+
# - The generator must NOT see multiple species dumps.
|
| 7 |
+
# - We retrieve GENUS-level records only for llm_context/llm_context_shaped.
|
| 8 |
+
# - Species is handled separately (deterministic species_scorer), not via LLM context.
|
| 9 |
+
#
|
| 10 |
+
# Improvements retained:
|
| 11 |
+
# - Source-type weighting (but genus-only for generator)
|
| 12 |
+
# - Genus-aware query expansion
|
| 13 |
+
# - Diversity enforcement (avoid duplicate sources)
|
| 14 |
+
# - Explicit ranking & score annotations for generator (DEBUG ONLY)
|
| 15 |
+
# - OPTIONAL: species evidence scoring (deterministic)
|
| 16 |
+
# - NEW: Context shaper (deterministic) -> resolves conflicts + emits genus-ready summary
|
| 17 |
+
#
|
| 18 |
+
# IMPORTANT:
|
| 19 |
+
# - We return THREE contexts:
|
| 20 |
+
# 1) llm_context -> GENUS-only raw text (SAFE but unshaped)
|
| 21 |
+
# 2) llm_context_shaped -> shaped, conflict-aware, generator-friendly
|
| 22 |
+
# 3) debug_context -> includes RANK/SCORE/WEIGHTS (UI/logging only)
|
| 23 |
+
# ============================================================
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 28 |
+
import re
|
| 29 |
+
import numpy as np
|
| 30 |
+
|
| 31 |
+
from rag.rag_embedder import embed_text, load_kb_index
|
| 32 |
+
|
| 33 |
+
# deterministic species evidence scorer (separate from generator context)
|
| 34 |
+
try:
|
| 35 |
+
from rag.species_scorer import score_species_for_genus
|
| 36 |
+
HAS_SPECIES_SCORER = True
|
| 37 |
+
except Exception:
|
| 38 |
+
score_species_for_genus = None # type: ignore
|
| 39 |
+
HAS_SPECIES_SCORER = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ------------------------------------------------------------
|
| 43 |
+
# Configuration
|
| 44 |
+
# ------------------------------------------------------------
|
| 45 |
+
|
| 46 |
+
# NOTE: We keep these for debug display + potential fallback modes.
|
| 47 |
+
SOURCE_TYPE_WEIGHTS = {
|
| 48 |
+
"species": 1.15,
|
| 49 |
+
"genus": 1.00,
|
| 50 |
+
"table": 1.10,
|
| 51 |
+
"note": 0.85,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
MAX_CHUNKS_PER_SOURCE = 1
|
| 55 |
+
|
| 56 |
+
# Context shaping caps (keeps prompt within LLM limits)
|
| 57 |
+
SHAPER_MAX_CORE = 14
|
| 58 |
+
SHAPER_MAX_VARIABLE = 12
|
| 59 |
+
SHAPER_MAX_MATCHES = 14
|
| 60 |
+
SHAPER_MAX_CONFLICTS = 12
|
| 61 |
+
SHAPER_MAX_TOTAL_CHARS = 9000 # final guardrail
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ------------------------------------------------------------
|
| 65 |
+
# Similarity helper
|
| 66 |
+
# ------------------------------------------------------------
|
| 67 |
+
|
| 68 |
+
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
| 69 |
+
"""
|
| 70 |
+
Cosine similarity for normalized embeddings.
|
| 71 |
+
Assumes both vectors are already L2-normalized.
|
| 72 |
+
"""
|
| 73 |
+
return float(np.dot(a, b))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ------------------------------------------------------------
|
| 77 |
+
# Context Shaper (deterministic)
|
| 78 |
+
# ------------------------------------------------------------
|
| 79 |
+
|
| 80 |
+
_TRAIT_LINE_RE = re.compile(
|
| 81 |
+
r"^\s*([A-Za-z0-9][A-Za-z0-9 \/\-\(\)\[\]%>=<\+\.]*?)\s*:\s*(.+?)\s*$"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Headers / junk lines we don't want treated as traits
|
| 85 |
+
_SHAPER_SKIP_PREFIXES = (
|
| 86 |
+
"expected fields for species",
|
| 87 |
+
"expected fields for genus",
|
| 88 |
+
"reference context",
|
| 89 |
+
"genus evidence primer",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def _norm_val(v: str) -> str:
|
| 93 |
+
s = (v or "").strip()
|
| 94 |
+
if not s:
|
| 95 |
+
return ""
|
| 96 |
+
s = re.sub(r"\s+", " ", s)
|
| 97 |
+
return s
|
| 98 |
+
|
| 99 |
+
def _canon_bool(v: str) -> str:
|
| 100 |
+
"""
|
| 101 |
+
Canonicalize common boolean-ish microbiology values.
|
| 102 |
+
Conservative: no inference.
|
| 103 |
+
"""
|
| 104 |
+
s = _norm_val(v).lower()
|
| 105 |
+
if s in {"pos", "positive", "+", "reactive"}:
|
| 106 |
+
return "Positive"
|
| 107 |
+
if s in {"neg", "negative", "-", "nonreactive", "non-reactive"}:
|
| 108 |
+
return "Negative"
|
| 109 |
+
if s in {"none"}:
|
| 110 |
+
return "None"
|
| 111 |
+
if s in {"unknown", "not specified", "n/a", "na"}:
|
| 112 |
+
return "Unknown"
|
| 113 |
+
if s in {"variable"}:
|
| 114 |
+
return "Variable"
|
| 115 |
+
return _norm_val(v)
|
| 116 |
+
|
| 117 |
+
def _canon_trait_name(name: str) -> str:
|
| 118 |
+
s = _norm_val(name)
|
| 119 |
+
s_low = s.lower()
|
| 120 |
+
if s_low == "ornitihine decarboxylase":
|
| 121 |
+
return "Ornithine Decarboxylase"
|
| 122 |
+
return s
|
| 123 |
+
|
| 124 |
+
def _extract_traits_from_text_block(text: str) -> List[Tuple[str, str]]:
|
| 125 |
+
"""
|
| 126 |
+
Extract (trait, value) pairs from lines like:
|
| 127 |
+
Trait Name: Value
|
| 128 |
+
"""
|
| 129 |
+
pairs: List[Tuple[str, str]] = []
|
| 130 |
+
for raw_line in (text or "").splitlines():
|
| 131 |
+
line = raw_line.strip()
|
| 132 |
+
if not line:
|
| 133 |
+
continue
|
| 134 |
+
low = line.lower()
|
| 135 |
+
if any(low.startswith(p) for p in _SHAPER_SKIP_PREFIXES):
|
| 136 |
+
continue
|
| 137 |
+
m = _TRAIT_LINE_RE.match(line)
|
| 138 |
+
if not m:
|
| 139 |
+
continue
|
| 140 |
+
k = _canon_trait_name(m.group(1))
|
| 141 |
+
v = _canon_bool(m.group(2))
|
| 142 |
+
if not k or not v:
|
| 143 |
+
continue
|
| 144 |
+
pairs.append((k, v))
|
| 145 |
+
return pairs
|
| 146 |
+
|
| 147 |
+
def _compare_vals(observed: str, reference: str) -> Optional[bool]:
|
| 148 |
+
"""
|
| 149 |
+
Returns:
|
| 150 |
+
True -> match
|
| 151 |
+
False -> conflict
|
| 152 |
+
None -> cannot compare (unknown/variable/empty)
|
| 153 |
+
"""
|
| 154 |
+
o = _canon_bool(observed)
|
| 155 |
+
r = _canon_bool(reference)
|
| 156 |
+
|
| 157 |
+
if not o or o == "Unknown":
|
| 158 |
+
return None
|
| 159 |
+
if not r or r in {"Unknown", "Variable"}:
|
| 160 |
+
return None
|
| 161 |
+
|
| 162 |
+
if o == r:
|
| 163 |
+
return True
|
| 164 |
+
|
| 165 |
+
# Safe equivalences (very conservative)
|
| 166 |
+
eq = {
|
| 167 |
+
("None", "Negative"),
|
| 168 |
+
("Negative", "None"),
|
| 169 |
+
}
|
| 170 |
+
if (o, r) in eq:
|
| 171 |
+
return True
|
| 172 |
+
|
| 173 |
+
return False
|
| 174 |
+
|
| 175 |
+
def shape_genus_context(
|
| 176 |
+
*,
|
| 177 |
+
target_genus: str,
|
| 178 |
+
selected_chunks: List[Dict[str, Any]],
|
| 179 |
+
parsed_fields: Optional[Dict[str, str]] = None,
|
| 180 |
+
) -> str:
|
| 181 |
+
"""
|
| 182 |
+
Deterministic, GENUS-focused context shaper.
|
| 183 |
+
|
| 184 |
+
It:
|
| 185 |
+
- aggregates trait lines across retrieved GENUS chunks
|
| 186 |
+
- identifies CORE traits (single consistent value across chunks)
|
| 187 |
+
- identifies VARIABLE traits (multiple values across chunks)
|
| 188 |
+
- if parsed_fields provided, derives:
|
| 189 |
+
- phenotype-supported matches vs CORE traits
|
| 190 |
+
- phenotype conflicts vs CORE traits
|
| 191 |
+
- outputs a compact, reasoning-friendly block for the generator
|
| 192 |
+
"""
|
| 193 |
+
genus = (target_genus or "").strip() or "Unknown"
|
| 194 |
+
|
| 195 |
+
trait_values: Dict[str, List[str]] = {}
|
| 196 |
+
|
| 197 |
+
for rec in selected_chunks or []:
|
| 198 |
+
txt = (rec.get("text") or "").strip()
|
| 199 |
+
if not txt:
|
| 200 |
+
continue
|
| 201 |
+
for k, v in _extract_traits_from_text_block(txt):
|
| 202 |
+
trait_values.setdefault(k, []).append(v)
|
| 203 |
+
|
| 204 |
+
# Reduce to unique canonical values
|
| 205 |
+
trait_uniques: Dict[str, List[str]] = {}
|
| 206 |
+
for k, vals in trait_values.items():
|
| 207 |
+
uniq: List[str] = []
|
| 208 |
+
for v in vals:
|
| 209 |
+
vv = _canon_bool(v)
|
| 210 |
+
if not vv:
|
| 211 |
+
continue
|
| 212 |
+
if vv not in uniq:
|
| 213 |
+
uniq.append(vv)
|
| 214 |
+
if uniq:
|
| 215 |
+
trait_uniques[k] = uniq
|
| 216 |
+
|
| 217 |
+
core_traits: List[Tuple[str, str]] = []
|
| 218 |
+
variable_traits: List[Tuple[str, str]] = []
|
| 219 |
+
|
| 220 |
+
for k, uniq in trait_uniques.items():
|
| 221 |
+
if len(uniq) == 1:
|
| 222 |
+
core_traits.append((k, uniq[0]))
|
| 223 |
+
else:
|
| 224 |
+
variable_traits.append((k, " / ".join(uniq)))
|
| 225 |
+
|
| 226 |
+
PRIORITY = {
|
| 227 |
+
"Gram Stain": 1,
|
| 228 |
+
"Shape": 2,
|
| 229 |
+
"Motility": 3,
|
| 230 |
+
"Motility Type": 4,
|
| 231 |
+
"Oxidase": 5,
|
| 232 |
+
"Catalase": 6,
|
| 233 |
+
"Oxygen Requirement": 7,
|
| 234 |
+
"Lactose Fermentation": 8,
|
| 235 |
+
"Glucose Fermentation": 9,
|
| 236 |
+
"H2S": 10,
|
| 237 |
+
"Indole": 11,
|
| 238 |
+
"Urease": 12,
|
| 239 |
+
"Citrate": 13,
|
| 240 |
+
"ONPG": 14,
|
| 241 |
+
"NaCl Tolerant (>=6%)": 15,
|
| 242 |
+
"Media Grown On": 16,
|
| 243 |
+
"Colony Morphology": 17,
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
def _sort_key(item: Tuple[str, str]) -> Tuple[int, str]:
|
| 247 |
+
return (PRIORITY.get(item[0], 999), item[0].lower())
|
| 248 |
+
|
| 249 |
+
core_traits.sort(key=_sort_key)
|
| 250 |
+
variable_traits.sort(key=_sort_key)
|
| 251 |
+
|
| 252 |
+
core_traits = core_traits[:SHAPER_MAX_CORE]
|
| 253 |
+
variable_traits = variable_traits[:SHAPER_MAX_VARIABLE]
|
| 254 |
+
|
| 255 |
+
matches: List[str] = []
|
| 256 |
+
conflicts: List[str] = []
|
| 257 |
+
|
| 258 |
+
if parsed_fields:
|
| 259 |
+
for k, ref_v in core_traits:
|
| 260 |
+
obs_v = parsed_fields.get(k)
|
| 261 |
+
if obs_v is None:
|
| 262 |
+
continue
|
| 263 |
+
cmp = _compare_vals(obs_v, ref_v)
|
| 264 |
+
if cmp is True:
|
| 265 |
+
matches.append(f"- {k}: {_canon_bool(obs_v)} (matches reference: {ref_v})")
|
| 266 |
+
elif cmp is False:
|
| 267 |
+
conflicts.append(f"- {k}: {_canon_bool(obs_v)} (conflicts reference: {ref_v})")
|
| 268 |
+
|
| 269 |
+
matches = matches[:SHAPER_MAX_MATCHES]
|
| 270 |
+
conflicts = conflicts[:SHAPER_MAX_CONFLICTS]
|
| 271 |
+
|
| 272 |
+
lines: List[str] = []
|
| 273 |
+
lines.append(f"GENUS SUMMARY (reference-driven): {genus}")
|
| 274 |
+
|
| 275 |
+
if core_traits:
|
| 276 |
+
lines.append("\nCORE GENUS TRAITS (consistent across retrieved genus references):")
|
| 277 |
+
for k, v in core_traits:
|
| 278 |
+
lines.append(f"- {k}: {v}")
|
| 279 |
+
else:
|
| 280 |
+
lines.append("\nCORE GENUS TRAITS: Not available from retrieved context.")
|
| 281 |
+
|
| 282 |
+
if variable_traits:
|
| 283 |
+
lines.append("\nTRAITS VARIABLE ACROSS RETRIEVED GENUS REFERENCES (do not treat as contradictions):")
|
| 284 |
+
for k, v in variable_traits:
|
| 285 |
+
lines.append(f"- {k}: Variable ({v})")
|
| 286 |
+
|
| 287 |
+
if parsed_fields:
|
| 288 |
+
lines.append("\nPHENOTYPE SUPPORT (observed vs CORE traits):")
|
| 289 |
+
if matches:
|
| 290 |
+
lines.append("KEY MATCHES:")
|
| 291 |
+
lines.extend(matches)
|
| 292 |
+
else:
|
| 293 |
+
lines.append("KEY MATCHES: Not specified.")
|
| 294 |
+
|
| 295 |
+
if conflicts:
|
| 296 |
+
lines.append("\nCONFLICTS (observed vs CORE traits):")
|
| 297 |
+
lines.extend(conflicts)
|
| 298 |
+
else:
|
| 299 |
+
lines.append("\nCONFLICTS: Not specified.")
|
| 300 |
+
|
| 301 |
+
shaped = "\n".join(lines).strip()
|
| 302 |
+
|
| 303 |
+
if len(shaped) > SHAPER_MAX_TOTAL_CHARS:
|
| 304 |
+
shaped = shaped[:SHAPER_MAX_TOTAL_CHARS].rstrip() + "\n... (truncated)"
|
| 305 |
+
|
| 306 |
+
return shaped
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# ------------------------------------------------------------
|
| 310 |
+
# Public API
|
| 311 |
+
# ------------------------------------------------------------
|
| 312 |
+
|
| 313 |
+
def retrieve_rag_context(
|
| 314 |
+
phenotype_text: str,
|
| 315 |
+
target_genus: str,
|
| 316 |
+
top_k: int = 5,
|
| 317 |
+
kb_path: str = "data/rag/index/kb_index.json",
|
| 318 |
+
parsed_fields: Optional[Dict[str, str]] = None,
|
| 319 |
+
species_top_n: int = 5,
|
| 320 |
+
allow_species_fallback: bool = False,
|
| 321 |
+
) -> Dict[str, Any]:
|
| 322 |
+
"""
|
| 323 |
+
Retrieve the most relevant RAG chunks for a phenotype + genus.
|
| 324 |
+
|
| 325 |
+
GENUS-FIRST behavior:
|
| 326 |
+
- For LLM generator contexts, we retrieve ONLY genus-level records (level == "genus").
|
| 327 |
+
- Species is handled separately via deterministic species_scorer.
|
| 328 |
+
|
| 329 |
+
Optional:
|
| 330 |
+
parsed_fields -> enables species evidence scoring + context shaping matches/conflicts.
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
{
|
| 334 |
+
"genus": target_genus,
|
| 335 |
+
"chunks": [...], # ranked chunk metadata (GENUS chunks unless fallback enabled)
|
| 336 |
+
"llm_context": "....", # GENUS raw text (no scores)
|
| 337 |
+
"llm_context_shaped": "....", # deterministic genus-friendly summary
|
| 338 |
+
"debug_context": "....", # annotated with rank/score/weights
|
| 339 |
+
"species_evidence": { ... } # optional deterministic species scoring
|
| 340 |
+
}
|
| 341 |
+
"""
|
| 342 |
+
|
| 343 |
+
kb = load_kb_index(kb_path)
|
| 344 |
+
records = kb.get("records", [])
|
| 345 |
+
|
| 346 |
+
if not records:
|
| 347 |
+
return {
|
| 348 |
+
"genus": target_genus,
|
| 349 |
+
"chunks": [],
|
| 350 |
+
"llm_context": "",
|
| 351 |
+
"llm_context_shaped": "",
|
| 352 |
+
"debug_context": "",
|
| 353 |
+
"species_evidence": {"genus": target_genus, "ranked": []},
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
query_text = (phenotype_text or "").strip()
|
| 357 |
+
if target_genus:
|
| 358 |
+
query_text = f"{query_text}\nTarget genus: {target_genus}"
|
| 359 |
+
|
| 360 |
+
q_emb = embed_text(query_text, normalize=True)
|
| 361 |
+
target_genus_lc = (target_genus or "").strip().lower()
|
| 362 |
+
|
| 363 |
+
scored_records: List[Dict[str, Any]] = []
|
| 364 |
+
|
| 365 |
+
# --------------------------------------------------------
|
| 366 |
+
# Primary pass: STRICT genus-filtered + GENUS-LEVEL only
|
| 367 |
+
# --------------------------------------------------------
|
| 368 |
+
for rec in records:
|
| 369 |
+
rec_genus = (rec.get("genus") or "").strip().lower()
|
| 370 |
+
if target_genus_lc and rec_genus != target_genus_lc:
|
| 371 |
+
continue
|
| 372 |
+
|
| 373 |
+
level = (rec.get("level") or "").strip().lower()
|
| 374 |
+
if level != "genus":
|
| 375 |
+
continue # GENUS-ONLY for generator context
|
| 376 |
+
|
| 377 |
+
emb = rec.get("embedding")
|
| 378 |
+
if emb is None:
|
| 379 |
+
continue
|
| 380 |
+
|
| 381 |
+
base_score = _cosine_similarity(q_emb, emb)
|
| 382 |
+
weight = SOURCE_TYPE_WEIGHTS.get(level, 1.0)
|
| 383 |
+
score = base_score * weight
|
| 384 |
+
|
| 385 |
+
scored_records.append(
|
| 386 |
+
{
|
| 387 |
+
"id": rec.get("id"),
|
| 388 |
+
"genus": rec.get("genus"),
|
| 389 |
+
"species": rec.get("species"),
|
| 390 |
+
"source_type": level,
|
| 391 |
+
"path": rec.get("source_file"),
|
| 392 |
+
"text": rec.get("text"),
|
| 393 |
+
"score": float(score),
|
| 394 |
+
"base_score": float(base_score),
|
| 395 |
+
"type_weight": float(weight),
|
| 396 |
+
"section": rec.get("section"),
|
| 397 |
+
"role": rec.get("role"),
|
| 398 |
+
"chunk_id": rec.get("chunk_id"),
|
| 399 |
+
}
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# --------------------------------------------------------
|
| 403 |
+
# Fallback modes
|
| 404 |
+
# --------------------------------------------------------
|
| 405 |
+
if not scored_records and allow_species_fallback:
|
| 406 |
+
# Emergency fallback: allow any level if no genus chunks exist.
|
| 407 |
+
# This keeps your app functioning, but can reintroduce noise.
|
| 408 |
+
for rec in records:
|
| 409 |
+
rec_genus = (rec.get("genus") or "").strip().lower()
|
| 410 |
+
if target_genus_lc and rec_genus != target_genus_lc:
|
| 411 |
+
continue
|
| 412 |
+
|
| 413 |
+
emb = rec.get("embedding")
|
| 414 |
+
if emb is None:
|
| 415 |
+
continue
|
| 416 |
+
|
| 417 |
+
level = (rec.get("level") or "").strip().lower()
|
| 418 |
+
base_score = _cosine_similarity(q_emb, emb)
|
| 419 |
+
weight = SOURCE_TYPE_WEIGHTS.get(level, 1.0)
|
| 420 |
+
score = base_score * weight
|
| 421 |
+
|
| 422 |
+
scored_records.append(
|
| 423 |
+
{
|
| 424 |
+
"id": rec.get("id"),
|
| 425 |
+
"genus": rec.get("genus"),
|
| 426 |
+
"species": rec.get("species"),
|
| 427 |
+
"source_type": level,
|
| 428 |
+
"path": rec.get("source_file"),
|
| 429 |
+
"text": rec.get("text"),
|
| 430 |
+
"score": float(score),
|
| 431 |
+
"base_score": float(base_score),
|
| 432 |
+
"type_weight": float(weight),
|
| 433 |
+
"section": rec.get("section"),
|
| 434 |
+
"role": rec.get("role"),
|
| 435 |
+
"chunk_id": rec.get("chunk_id"),
|
| 436 |
+
}
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Sort by score
|
| 440 |
+
scored_records.sort(key=lambda r: r["score"], reverse=True)
|
| 441 |
+
|
| 442 |
+
# Diversity enforcement
|
| 443 |
+
selected: List[Dict[str, Any]] = []
|
| 444 |
+
source_counts: Dict[str, int] = {}
|
| 445 |
+
|
| 446 |
+
for rec in scored_records:
|
| 447 |
+
src = rec.get("path") or ""
|
| 448 |
+
count = source_counts.get(src, 0)
|
| 449 |
+
if count >= MAX_CHUNKS_PER_SOURCE:
|
| 450 |
+
continue
|
| 451 |
+
selected.append(rec)
|
| 452 |
+
source_counts[src] = count + 1
|
| 453 |
+
if len(selected) >= top_k:
|
| 454 |
+
break
|
| 455 |
+
|
| 456 |
+
# Build contexts
|
| 457 |
+
llm_ctx_parts: List[str] = []
|
| 458 |
+
debug_ctx_parts: List[str] = []
|
| 459 |
+
|
| 460 |
+
for idx, rec in enumerate(selected, start=1):
|
| 461 |
+
txt = (rec.get("text") or "").strip()
|
| 462 |
+
if txt:
|
| 463 |
+
llm_ctx_parts.append(txt)
|
| 464 |
+
|
| 465 |
+
label = rec.get("genus") or "Unknown genus"
|
| 466 |
+
if rec.get("species"):
|
| 467 |
+
label = f"{label} {rec['species']}"
|
| 468 |
+
|
| 469 |
+
debug_ctx_parts.append(
|
| 470 |
+
f"[RANK {idx} | SCORE {rec['score']:.3f} | BASE {rec['base_score']:.3f} | "
|
| 471 |
+
f"W {rec['type_weight']:.2f} | {label} — {rec.get('source_type')}]"
|
| 472 |
+
+ (
|
| 473 |
+
f" [section={rec.get('section')} role={rec.get('role')}]"
|
| 474 |
+
if rec.get("section") or rec.get("role")
|
| 475 |
+
else ""
|
| 476 |
+
)
|
| 477 |
+
+ "\n"
|
| 478 |
+
+ (txt or "")
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
llm_context = "\n\n".join(llm_ctx_parts).strip()
|
| 482 |
+
debug_context = "\n\n".join(debug_ctx_parts).strip()
|
| 483 |
+
|
| 484 |
+
llm_context_shaped = shape_genus_context(
|
| 485 |
+
target_genus=target_genus,
|
| 486 |
+
selected_chunks=selected,
|
| 487 |
+
parsed_fields=parsed_fields,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
# OPTIONAL: deterministic species evidence scoring
|
| 491 |
+
species_evidence = {"genus": target_genus, "ranked": []}
|
| 492 |
+
if parsed_fields and HAS_SPECIES_SCORER and score_species_for_genus is not None:
|
| 493 |
+
try:
|
| 494 |
+
species_evidence = score_species_for_genus(
|
| 495 |
+
target_genus=target_genus,
|
| 496 |
+
parsed_fields=parsed_fields,
|
| 497 |
+
top_n=species_top_n,
|
| 498 |
+
)
|
| 499 |
+
except Exception:
|
| 500 |
+
species_evidence = {"genus": target_genus, "ranked": []}
|
| 501 |
+
|
| 502 |
+
return {
|
| 503 |
+
"genus": target_genus,
|
| 504 |
+
"chunks": selected,
|
| 505 |
+
"llm_context": llm_context,
|
| 506 |
+
"llm_context_shaped": llm_context_shaped,
|
| 507 |
+
"debug_context": debug_context,
|
| 508 |
+
"species_evidence": species_evidence,
|
| 509 |
+
}
|
rag/species_scorer.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag/species_scorer.py
|
| 2 |
+
# ============================================================
|
| 3 |
+
# Species evidence scorer (deterministic, explainable)
|
| 4 |
+
#
|
| 5 |
+
# Given:
|
| 6 |
+
# - target_genus
|
| 7 |
+
# - parsed_fields (from fusion)
|
| 8 |
+
# It loads species JSON files under:
|
| 9 |
+
# data/rag/knowledge_base/<Genus>/*.json (excluding genus.json)
|
| 10 |
+
#
|
| 11 |
+
# And returns:
|
| 12 |
+
# - ranked species list with scores
|
| 13 |
+
# - explicit matches / conflicts
|
| 14 |
+
# - marker hits (importance-weighted)
|
| 15 |
+
#
|
| 16 |
+
# Notes:
|
| 17 |
+
# - This is NOT an LLM. No speculation.
|
| 18 |
+
# - Handles list-like fields (Media / Colony Morphology) as overlap scores.
|
| 19 |
+
# - Handles P/N/V/Unknown fields.
|
| 20 |
+
# ============================================================
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import json
|
| 25 |
+
import os
|
| 26 |
+
import re
|
| 27 |
+
from typing import Any, Dict, List, Tuple, Optional
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
KB_ROOT = os.path.join("data", "rag", "knowledge_base")
|
| 31 |
+
|
| 32 |
+
UNKNOWN = "Unknown"
|
| 33 |
+
|
| 34 |
+
LIST_FIELDS = {
|
| 35 |
+
"Media Grown On",
|
| 36 |
+
"Colony Morphology",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Importance → weight
|
| 40 |
+
MARKER_WEIGHT = {
|
| 41 |
+
"high": 3.0,
|
| 42 |
+
"medium": 2.0,
|
| 43 |
+
"low": 1.5,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# Base scoring weights
|
| 47 |
+
FIELD_MATCH_WEIGHT = 1.0
|
| 48 |
+
FIELD_CONFLICT_PENALTY = 1.2 # conflicts hurt slightly more than matches help
|
| 49 |
+
VARIABLE_MATCH_BONUS = 0.2 # weak support if expected is Variable
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _norm_str(v: Any) -> str:
|
| 53 |
+
if v is None:
|
| 54 |
+
return ""
|
| 55 |
+
return str(v).strip()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _norm_val(v: Any) -> str:
|
| 59 |
+
s = _norm_str(v)
|
| 60 |
+
return s if s else UNKNOWN
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _split_semicolon(s: str) -> List[str]:
|
| 64 |
+
parts = [p.strip() for p in re.split(r"[;,\n]+", s or "") if p.strip()]
|
| 65 |
+
# normalize case lightly for matching
|
| 66 |
+
return [p.lower() for p in parts]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _as_list_lower(v: Any) -> List[str]:
|
| 70 |
+
if v is None:
|
| 71 |
+
return []
|
| 72 |
+
if isinstance(v, list):
|
| 73 |
+
return [str(x).strip().lower() for x in v if str(x).strip()]
|
| 74 |
+
# string fallback
|
| 75 |
+
return _split_semicolon(str(v))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _overlap_score(expected_list: List[str], observed_list: List[str]) -> float:
|
| 79 |
+
"""
|
| 80 |
+
Jaccard-like overlap, but anchored to expected:
|
| 81 |
+
score = (# of expected items found) / (# expected)
|
| 82 |
+
"""
|
| 83 |
+
if not expected_list:
|
| 84 |
+
return 0.0
|
| 85 |
+
if not observed_list:
|
| 86 |
+
return 0.0
|
| 87 |
+
exp = set(expected_list)
|
| 88 |
+
obs = set(observed_list)
|
| 89 |
+
hit = len(exp.intersection(obs))
|
| 90 |
+
return hit / max(1, len(exp))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _load_species_docs_for_genus(target_genus: str) -> List[Dict[str, Any]]:
|
| 94 |
+
genus = (target_genus or "").strip()
|
| 95 |
+
if not genus:
|
| 96 |
+
return []
|
| 97 |
+
|
| 98 |
+
genus_dir = os.path.join(KB_ROOT, genus)
|
| 99 |
+
if not os.path.isdir(genus_dir):
|
| 100 |
+
return []
|
| 101 |
+
|
| 102 |
+
docs: List[Dict[str, Any]] = []
|
| 103 |
+
for fname in sorted(os.listdir(genus_dir)):
|
| 104 |
+
if not fname.lower().endswith(".json"):
|
| 105 |
+
continue
|
| 106 |
+
if fname == "genus.json":
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
path = os.path.join(genus_dir, fname)
|
| 110 |
+
try:
|
| 111 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 112 |
+
doc = json.load(f)
|
| 113 |
+
if isinstance(doc, dict) and doc.get("level") == "species":
|
| 114 |
+
doc["_source_path"] = os.path.relpath(path)
|
| 115 |
+
docs.append(doc)
|
| 116 |
+
except Exception:
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
return docs
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _score_expected_fields(
|
| 123 |
+
expected_fields: Dict[str, Any],
|
| 124 |
+
parsed_fields: Dict[str, str],
|
| 125 |
+
) -> Tuple[float, float, List[str], List[str]]:
|
| 126 |
+
"""
|
| 127 |
+
Returns:
|
| 128 |
+
(score, possible, matches, conflicts)
|
| 129 |
+
"""
|
| 130 |
+
score = 0.0
|
| 131 |
+
possible = 0.0
|
| 132 |
+
matches: List[str] = []
|
| 133 |
+
conflicts: List[str] = []
|
| 134 |
+
|
| 135 |
+
for field, expected in (expected_fields or {}).items():
|
| 136 |
+
exp_norm = expected
|
| 137 |
+
obs_norm = parsed_fields.get(field, UNKNOWN)
|
| 138 |
+
|
| 139 |
+
# Skip unknown observed
|
| 140 |
+
if obs_norm == UNKNOWN:
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
# List fields: overlap
|
| 144 |
+
if field in LIST_FIELDS:
|
| 145 |
+
exp_list = _as_list_lower(exp_norm)
|
| 146 |
+
obs_list = _as_list_lower(obs_norm)
|
| 147 |
+
if not exp_list:
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
possible += FIELD_MATCH_WEIGHT
|
| 151 |
+
ov = _overlap_score(exp_list, obs_list)
|
| 152 |
+
|
| 153 |
+
# thresholding: any overlap = support; none = conflict
|
| 154 |
+
if ov > 0:
|
| 155 |
+
score += FIELD_MATCH_WEIGHT * ov
|
| 156 |
+
matches.append(f"{field}: overlap {ov:.2f}")
|
| 157 |
+
else:
|
| 158 |
+
score -= FIELD_CONFLICT_PENALTY
|
| 159 |
+
conflicts.append(f"{field}: expected {expected}, got {obs_norm}")
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
exp_val = _norm_val(exp_norm)
|
| 163 |
+
obs_val = _norm_val(obs_norm)
|
| 164 |
+
|
| 165 |
+
# If expected is Unknown, skip
|
| 166 |
+
if exp_val == UNKNOWN:
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
# If expected is Variable, weakly supportive if observed is known
|
| 170 |
+
if exp_val == "Variable":
|
| 171 |
+
possible += VARIABLE_MATCH_BONUS
|
| 172 |
+
score += VARIABLE_MATCH_BONUS
|
| 173 |
+
matches.append(f"{field}: expected Variable (observed {obs_val})")
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
# Normal exact match
|
| 177 |
+
possible += FIELD_MATCH_WEIGHT
|
| 178 |
+
if obs_val == exp_val:
|
| 179 |
+
score += FIELD_MATCH_WEIGHT
|
| 180 |
+
matches.append(f"{field}: {obs_val}")
|
| 181 |
+
else:
|
| 182 |
+
score -= FIELD_CONFLICT_PENALTY
|
| 183 |
+
conflicts.append(f"{field}: expected {exp_val}, got {obs_val}")
|
| 184 |
+
|
| 185 |
+
return score, possible, matches, conflicts
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _score_species_markers(
|
| 189 |
+
markers: List[Dict[str, Any]],
|
| 190 |
+
parsed_fields: Dict[str, str],
|
| 191 |
+
) -> Tuple[float, float, List[str], List[str]]:
|
| 192 |
+
"""
|
| 193 |
+
Weighted marker hits. Markers are higher-signal than generic expected fields.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
(score, possible, marker_hits, marker_misses)
|
| 197 |
+
"""
|
| 198 |
+
score = 0.0
|
| 199 |
+
possible = 0.0
|
| 200 |
+
hits: List[str] = []
|
| 201 |
+
misses: List[str] = []
|
| 202 |
+
|
| 203 |
+
for m in markers or []:
|
| 204 |
+
field = _norm_str(m.get("field"))
|
| 205 |
+
val = _norm_val(m.get("value"))
|
| 206 |
+
importance = _norm_str(m.get("importance")).lower() or "medium"
|
| 207 |
+
w = MARKER_WEIGHT.get(importance, 2.0)
|
| 208 |
+
|
| 209 |
+
if not field or val == UNKNOWN:
|
| 210 |
+
continue
|
| 211 |
+
|
| 212 |
+
obs = _norm_val(parsed_fields.get(field, UNKNOWN))
|
| 213 |
+
if obs == UNKNOWN:
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
possible += w
|
| 217 |
+
if obs == val:
|
| 218 |
+
score += w
|
| 219 |
+
hits.append(f"{field}: {obs} ({importance})")
|
| 220 |
+
else:
|
| 221 |
+
score -= w * 1.1 # marker conflicts hurt more
|
| 222 |
+
misses.append(f"{field}: expected {val}, got {obs} ({importance})")
|
| 223 |
+
|
| 224 |
+
return score, possible, hits, misses
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _to_confidence(raw_score: float, possible: float) -> float:
|
| 228 |
+
"""
|
| 229 |
+
Convert raw score into 0..1 confidence.
|
| 230 |
+
|
| 231 |
+
We use a bounded transform:
|
| 232 |
+
- normalize by possible
|
| 233 |
+
- clamp into [0,1]
|
| 234 |
+
"""
|
| 235 |
+
if possible <= 0:
|
| 236 |
+
return 0.0
|
| 237 |
+
|
| 238 |
+
# raw_score can be negative; convert to a 0..1 scale
|
| 239 |
+
# normalized_score around 0 means mixed evidence
|
| 240 |
+
normalized = raw_score / possible # roughly -something .. +1
|
| 241 |
+
conf = (normalized + 1.0) / 2.0 # map [-1, +1] -> [0,1] (approx)
|
| 242 |
+
if conf < 0:
|
| 243 |
+
conf = 0.0
|
| 244 |
+
if conf > 1:
|
| 245 |
+
conf = 1.0
|
| 246 |
+
return float(conf)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def score_species_for_genus(
|
| 250 |
+
target_genus: str,
|
| 251 |
+
parsed_fields: Dict[str, str],
|
| 252 |
+
top_n: int = 5,
|
| 253 |
+
) -> Dict[str, Any]:
|
| 254 |
+
"""
|
| 255 |
+
Main entrypoint.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
{
|
| 259 |
+
"genus": "...",
|
| 260 |
+
"ranked": [
|
| 261 |
+
{
|
| 262 |
+
"species": "cloacae",
|
| 263 |
+
"full_name": "Enterobacter cloacae",
|
| 264 |
+
"score": 0.87,
|
| 265 |
+
"raw_score": ...,
|
| 266 |
+
"possible": ...,
|
| 267 |
+
"matches": [...],
|
| 268 |
+
"conflicts": [...],
|
| 269 |
+
"marker_hits": [...],
|
| 270 |
+
"marker_conflicts": [...],
|
| 271 |
+
"source_file": "data/rag/knowledge_base/Enterobacter/cloacae.json"
|
| 272 |
+
}, ...
|
| 273 |
+
]
|
| 274 |
+
}
|
| 275 |
+
"""
|
| 276 |
+
docs = _load_species_docs_for_genus(target_genus)
|
| 277 |
+
if not docs:
|
| 278 |
+
return {"genus": target_genus, "ranked": []}
|
| 279 |
+
|
| 280 |
+
ranked: List[Dict[str, Any]] = []
|
| 281 |
+
|
| 282 |
+
for doc in docs:
|
| 283 |
+
genus = _norm_str(doc.get("genus") or target_genus)
|
| 284 |
+
species = _norm_str(doc.get("species"))
|
| 285 |
+
full_name = f"{genus} {species}".strip()
|
| 286 |
+
|
| 287 |
+
expected_fields = doc.get("expected_fields") or {}
|
| 288 |
+
markers = doc.get("species_markers") or []
|
| 289 |
+
|
| 290 |
+
s1, p1, matches, conflicts = _score_expected_fields(expected_fields, parsed_fields)
|
| 291 |
+
s2, p2, marker_hits, marker_conflicts = _score_species_markers(markers, parsed_fields)
|
| 292 |
+
|
| 293 |
+
raw_score = s1 + s2
|
| 294 |
+
possible = p1 + p2
|
| 295 |
+
|
| 296 |
+
conf = _to_confidence(raw_score, possible)
|
| 297 |
+
|
| 298 |
+
ranked.append(
|
| 299 |
+
{
|
| 300 |
+
"species": species or os.path.splitext(os.path.basename(doc.get("_source_path", "")))[0],
|
| 301 |
+
"full_name": full_name,
|
| 302 |
+
"score": conf,
|
| 303 |
+
"raw_score": raw_score,
|
| 304 |
+
"possible": possible,
|
| 305 |
+
"matches": matches,
|
| 306 |
+
"conflicts": conflicts,
|
| 307 |
+
"marker_hits": marker_hits,
|
| 308 |
+
"marker_conflicts": marker_conflicts,
|
| 309 |
+
"source_file": doc.get("_source_path", ""),
|
| 310 |
+
}
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
ranked.sort(key=lambda x: x["score"], reverse=True)
|
| 314 |
+
return {"genus": target_genus, "ranked": ranked[: max(1, int(top_n))]}
|
scoring/diagnostic_anchors.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scoring/diagnostic_anchors.py
|
| 2 |
+
# ============================================================
|
| 3 |
+
# Diagnostic anchor overrides:
|
| 4 |
+
# - If the free-text description clearly contains certain
|
| 5 |
+
# pathognomonic phrases, boost the corresponding genus
|
| 6 |
+
# in the unified ranking.
|
| 7 |
+
# ============================================================
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from typing import List, Dict, Any
|
| 12 |
+
|
| 13 |
+
# Simple v1 — can expand over time
|
| 14 |
+
DIAGNOSTIC_ANCHORS = {
|
| 15 |
+
"Yersinia": [
|
| 16 |
+
"bull’s-eye",
|
| 17 |
+
"bull's eye",
|
| 18 |
+
"cin agar",
|
| 19 |
+
"pseudoappendicitis",
|
| 20 |
+
"pseudo-appendicitis",
|
| 21 |
+
],
|
| 22 |
+
"Campylobacter": [
|
| 23 |
+
"hippurate",
|
| 24 |
+
"darting motility",
|
| 25 |
+
],
|
| 26 |
+
"Vibrio": [
|
| 27 |
+
"tcbs agar",
|
| 28 |
+
"thiosulfate citrate bile salts sucrose",
|
| 29 |
+
"yellow colonies on tcbs",
|
| 30 |
+
"rice-water stool",
|
| 31 |
+
"rice water stool",
|
| 32 |
+
],
|
| 33 |
+
"Proteus": [
|
| 34 |
+
"swarming motility",
|
| 35 |
+
"swarm across the plate",
|
| 36 |
+
"burnt chocolate odor",
|
| 37 |
+
"burned chocolate odour",
|
| 38 |
+
],
|
| 39 |
+
"Listeria": [
|
| 40 |
+
"tumbling motility",
|
| 41 |
+
"cold enrichment",
|
| 42 |
+
"grows at 4°c",
|
| 43 |
+
"4°c enrichment",
|
| 44 |
+
],
|
| 45 |
+
"Clostridioides": [
|
| 46 |
+
"ccfa agar",
|
| 47 |
+
"cycloserine cefoxitin fructose agar",
|
| 48 |
+
"barnyard odor",
|
| 49 |
+
"ground glass colonies",
|
| 50 |
+
],
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def apply_diagnostic_overrides(
|
| 55 |
+
description_text: str,
|
| 56 |
+
unified_ranking: List[Dict[str, Any]],
|
| 57 |
+
) -> List[Dict[str, Any]]:
|
| 58 |
+
"""
|
| 59 |
+
If the input description strongly suggests a particular genus
|
| 60 |
+
(anchor phrases), boost that genus in the unified ranking.
|
| 61 |
+
|
| 62 |
+
Strategy:
|
| 63 |
+
- If any anchor phrase for a genus is present in the text,
|
| 64 |
+
ensure that genus has at least 0.70 combined_score
|
| 65 |
+
(70% overall) *if it already appears*.
|
| 66 |
+
- Then re-sort by combined_score.
|
| 67 |
+
|
| 68 |
+
This is conservative: it won't hallucinate genera that aren't
|
| 69 |
+
already in the top list, but strengthens strong clinical signals.
|
| 70 |
+
"""
|
| 71 |
+
if not description_text or not unified_ranking:
|
| 72 |
+
return unified_ranking
|
| 73 |
+
|
| 74 |
+
text_lc = description_text.lower()
|
| 75 |
+
|
| 76 |
+
# Which genera have anchors present?
|
| 77 |
+
boosted_genera = set()
|
| 78 |
+
for genus, phrases in DIAGNOSTIC_ANCHORS.items():
|
| 79 |
+
for p in phrases:
|
| 80 |
+
if p.lower() in text_lc:
|
| 81 |
+
boosted_genera.add(genus)
|
| 82 |
+
break
|
| 83 |
+
|
| 84 |
+
if not boosted_genera:
|
| 85 |
+
return unified_ranking
|
| 86 |
+
|
| 87 |
+
# Apply boost only if genus already present
|
| 88 |
+
for item in unified_ranking:
|
| 89 |
+
g = item.get("genus")
|
| 90 |
+
if g in boosted_genera:
|
| 91 |
+
score = float(item.get("combined_score", 0.0))
|
| 92 |
+
if score < 0.70:
|
| 93 |
+
item["combined_score"] = 0.70
|
| 94 |
+
item["combined_percent"] = 70.0
|
| 95 |
+
|
| 96 |
+
unified_ranking.sort(key=lambda d: d.get("combined_score", 0.0), reverse=True)
|
| 97 |
+
return unified_ranking
|
scoring/overall_ranker.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scoring/overall_ranker.py
|
| 2 |
+
# ============================================================
|
| 3 |
+
# Overall Ranker — Probability Normalisation Layer
|
| 4 |
+
#
|
| 5 |
+
# PURPOSE:
|
| 6 |
+
# - Take already-computed combined scores (Tri-Fusion + ML)
|
| 7 |
+
# - Normalize top-K into human-interpretable probabilities
|
| 8 |
+
# - Provide odds per 1000 for UI display
|
| 9 |
+
#
|
| 10 |
+
# IMPORTANT:
|
| 11 |
+
# - This module DOES NOT assign confidence labels
|
| 12 |
+
# - Confidence logic lives in app.py (decision-band contract)
|
| 13 |
+
#
|
| 14 |
+
# OUTPUT CONTRACT (STRICT):
|
| 15 |
+
# {
|
| 16 |
+
# "overall": [
|
| 17 |
+
# {
|
| 18 |
+
# "rank": int,
|
| 19 |
+
# "genus": str,
|
| 20 |
+
# "combined_score": float,
|
| 21 |
+
# "normalized_share": float, # 0–1, sums to 1.0
|
| 22 |
+
# },
|
| 23 |
+
# ...
|
| 24 |
+
# ],
|
| 25 |
+
# "probabilities_1000": [
|
| 26 |
+
# {
|
| 27 |
+
# "genus": str,
|
| 28 |
+
# "odds_1000": int
|
| 29 |
+
# },
|
| 30 |
+
# ...
|
| 31 |
+
# ]
|
| 32 |
+
# }
|
| 33 |
+
# ============================================================
|
| 34 |
+
|
| 35 |
+
from typing import Dict, List, Any
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def compute_overall_scores(
|
| 39 |
+
ml_scores: List[Dict[str, Any]],
|
| 40 |
+
tri_scores: Dict[str, float],
|
| 41 |
+
top_k: int = 5,
|
| 42 |
+
) -> Dict[str, Any]:
|
| 43 |
+
"""
|
| 44 |
+
Normalize already-computed combined scores into
|
| 45 |
+
probability shares and odds for the Top-5 decision table.
|
| 46 |
+
|
| 47 |
+
Parameters
|
| 48 |
+
----------
|
| 49 |
+
ml_scores : list of dict
|
| 50 |
+
Each dict contains at least:
|
| 51 |
+
{ "genus": str, "probability": float }
|
| 52 |
+
(Used ONLY to determine candidate genera)
|
| 53 |
+
|
| 54 |
+
tri_scores : dict
|
| 55 |
+
Dict mapping genus -> combined_score (0–1)
|
| 56 |
+
NOTE: This is already unified (Tri-Fusion + ML).
|
| 57 |
+
|
| 58 |
+
top_k : int
|
| 59 |
+
Number of top genera to return.
|
| 60 |
+
|
| 61 |
+
Returns
|
| 62 |
+
-------
|
| 63 |
+
dict
|
| 64 |
+
{
|
| 65 |
+
"overall": [
|
| 66 |
+
{
|
| 67 |
+
"rank": int,
|
| 68 |
+
"genus": str,
|
| 69 |
+
"combined_score": float,
|
| 70 |
+
"normalized_share": float
|
| 71 |
+
}
|
| 72 |
+
],
|
| 73 |
+
"probabilities_1000": [
|
| 74 |
+
{ "genus": str, "odds_1000": int }
|
| 75 |
+
]
|
| 76 |
+
}
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
# --------------------------------------------------------
|
| 80 |
+
# 1. Build candidate list
|
| 81 |
+
# --------------------------------------------------------
|
| 82 |
+
combined_rows: List[Dict[str, Any]] = []
|
| 83 |
+
|
| 84 |
+
for genus, score in tri_scores.items():
|
| 85 |
+
try:
|
| 86 |
+
cs = float(score)
|
| 87 |
+
except Exception:
|
| 88 |
+
cs = 0.0
|
| 89 |
+
|
| 90 |
+
if cs > 0:
|
| 91 |
+
combined_rows.append({
|
| 92 |
+
"genus": genus,
|
| 93 |
+
"combined_score": cs
|
| 94 |
+
})
|
| 95 |
+
|
| 96 |
+
if not combined_rows:
|
| 97 |
+
return {
|
| 98 |
+
"overall": [],
|
| 99 |
+
"probabilities_1000": [],
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# --------------------------------------------------------
|
| 103 |
+
# 2. Sort and trim to top_k
|
| 104 |
+
# --------------------------------------------------------
|
| 105 |
+
combined_rows.sort(
|
| 106 |
+
key=lambda x: x["combined_score"],
|
| 107 |
+
reverse=True
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
top = combined_rows[:top_k]
|
| 111 |
+
|
| 112 |
+
# --------------------------------------------------------
|
| 113 |
+
# 3. Normalize to probability shares (sum = 1.0)
|
| 114 |
+
# --------------------------------------------------------
|
| 115 |
+
total_score = sum(x["combined_score"] for x in top)
|
| 116 |
+
|
| 117 |
+
if total_score <= 0:
|
| 118 |
+
total_score = 1.0 # safety fallback
|
| 119 |
+
|
| 120 |
+
overall: List[Dict[str, Any]] = []
|
| 121 |
+
probabilities_1000: List[Dict[str, Any]] = []
|
| 122 |
+
|
| 123 |
+
for idx, row in enumerate(top, start=1):
|
| 124 |
+
share = row["combined_score"] / total_score
|
| 125 |
+
|
| 126 |
+
# Clamp defensively
|
| 127 |
+
share = max(0.0, min(1.0, share))
|
| 128 |
+
|
| 129 |
+
odds_1000 = int(round(share * 1000))
|
| 130 |
+
|
| 131 |
+
overall.append({
|
| 132 |
+
"rank": idx,
|
| 133 |
+
"genus": row["genus"],
|
| 134 |
+
"combined_score": round(row["combined_score"], 6),
|
| 135 |
+
"normalized_share": round(share, 6),
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
probabilities_1000.append({
|
| 139 |
+
"genus": row["genus"],
|
| 140 |
+
"odds_1000": odds_1000,
|
| 141 |
+
})
|
| 142 |
+
|
| 143 |
+
return {
|
| 144 |
+
"overall": overall,
|
| 145 |
+
"probabilities_1000": probabilities_1000,
|
| 146 |
+
}
|
static/eph.jpeg
ADDED
|
Git LFS Details
|
training/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Marks the 'training' directory as a Python package
|
| 2 |
+
|
training/alias_trainer.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training/alias_trainer.py
|
| 2 |
+
# ------------------------------------------------------------
|
| 3 |
+
# Stage 10B - Alias Trainer
|
| 4 |
+
#
|
| 5 |
+
# Learns field/value synonyms from gold tests by comparing:
|
| 6 |
+
# - expected values (gold standard)
|
| 7 |
+
# - parsed values (rules + extended)
|
| 8 |
+
#
|
| 9 |
+
# Outputs:
|
| 10 |
+
# - Updated alias_maps.json
|
| 11 |
+
#
|
| 12 |
+
# This is the core intelligence that allows BactAI-D
|
| 13 |
+
# to understand variations in microbiology language.
|
| 14 |
+
# ------------------------------------------------------------
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
|
| 20 |
+
from engine.parser_rules import parse_text_rules
|
| 21 |
+
from engine.parser_ext import parse_text_extended
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
GOLD_PATH = "training/gold_tests.json"
|
| 25 |
+
ALIAS_PATH = "data/alias_maps.json"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def normalise(s):
|
| 29 |
+
if s is None:
|
| 30 |
+
return ""
|
| 31 |
+
return str(s).strip().lower()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def learn_aliases():
|
| 35 |
+
"""
|
| 36 |
+
Learns synonym mappings from gold tests.
|
| 37 |
+
"""
|
| 38 |
+
if not os.path.exists(GOLD_PATH):
|
| 39 |
+
return {"error": f"Gold tests missing: {GOLD_PATH}"}
|
| 40 |
+
|
| 41 |
+
with open(GOLD_PATH, "r", encoding="utf-8") as f:
|
| 42 |
+
gold = json.load(f)
|
| 43 |
+
|
| 44 |
+
# Load or create alias map
|
| 45 |
+
if os.path.exists(ALIAS_PATH):
|
| 46 |
+
with open(ALIAS_PATH, "r", encoding="utf-8") as f:
|
| 47 |
+
alias_maps = json.load(f)
|
| 48 |
+
else:
|
| 49 |
+
alias_maps = {}
|
| 50 |
+
|
| 51 |
+
# Track suggestions
|
| 52 |
+
suggestions = defaultdict(lambda: defaultdict(int))
|
| 53 |
+
|
| 54 |
+
# ------------------------------------------------------------
|
| 55 |
+
# Compare expected vs parsed for all tests
|
| 56 |
+
# ------------------------------------------------------------
|
| 57 |
+
for test in gold:
|
| 58 |
+
text = test.get("input", "")
|
| 59 |
+
expected = test.get("expected", {})
|
| 60 |
+
|
| 61 |
+
rules = parse_text_rules(text).get("parsed_fields", {})
|
| 62 |
+
ext = parse_text_extended(text).get("parsed_fields", {})
|
| 63 |
+
|
| 64 |
+
# merge deterministic parsers
|
| 65 |
+
merged = dict(rules)
|
| 66 |
+
for k, v in ext.items():
|
| 67 |
+
if v != "Unknown":
|
| 68 |
+
merged[k] = v
|
| 69 |
+
|
| 70 |
+
# now compare with expected
|
| 71 |
+
for field, exp_val in expected.items():
|
| 72 |
+
exp_norm = normalise(exp_val)
|
| 73 |
+
got_norm = normalise(merged.get(field, "Unknown"))
|
| 74 |
+
|
| 75 |
+
# Skip correct matches
|
| 76 |
+
if exp_norm == got_norm:
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
# Skip unknown expected
|
| 80 |
+
if exp_norm in ["", "unknown"]:
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
# Mismatched → candidate alias
|
| 84 |
+
if got_norm not in ["", "unknown"]:
|
| 85 |
+
suggestions[field][got_norm] += 1
|
| 86 |
+
|
| 87 |
+
# ------------------------------------------------------------
|
| 88 |
+
# Convert suggestions into alias mappings
|
| 89 |
+
# ------------------------------------------------------------
|
| 90 |
+
alias_updates = {}
|
| 91 |
+
|
| 92 |
+
for field, values in suggestions.items():
|
| 93 |
+
# ignore fields with tiny evidence
|
| 94 |
+
for wrong_value, count in values.items():
|
| 95 |
+
if count < 2:
|
| 96 |
+
continue # avoid noise
|
| 97 |
+
|
| 98 |
+
# add/update alias
|
| 99 |
+
if field not in alias_maps:
|
| 100 |
+
alias_maps[field] = {}
|
| 101 |
+
|
| 102 |
+
# map wrong_value → expected canonical version
|
| 103 |
+
# canonical version is the most common value in gold_tests for that field
|
| 104 |
+
canonical = None
|
| 105 |
+
# determine canonical
|
| 106 |
+
field_values = [normalise(t["expected"][field]) for t in gold if field in t["expected"]]
|
| 107 |
+
if field_values:
|
| 108 |
+
# most common expected value
|
| 109 |
+
canonical = max(set(field_values), key=field_values.count)
|
| 110 |
+
|
| 111 |
+
if canonical:
|
| 112 |
+
alias_maps[field][wrong_value] = canonical
|
| 113 |
+
alias_updates[f"{field}:{wrong_value}"] = canonical
|
| 114 |
+
|
| 115 |
+
# ------------------------------------------------------------
|
| 116 |
+
# Save alias maps
|
| 117 |
+
# ------------------------------------------------------------
|
| 118 |
+
with open(ALIAS_PATH, "w", encoding="utf-8") as f:
|
| 119 |
+
json.dump(alias_maps, f, indent=2)
|
| 120 |
+
|
| 121 |
+
return {
|
| 122 |
+
"ok": True,
|
| 123 |
+
"updated_aliases": alias_updates,
|
| 124 |
+
"total_updates": len(alias_updates),
|
| 125 |
+
"alias_map_path": ALIAS_PATH,
|
| 126 |
+
}
|
training/field_weight_trainer.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training/field_weight_trainer.py
|
| 2 |
+
# ------------------------------------------------------------
|
| 3 |
+
# Stage 12A — Train Per-Field Parser Weights from Gold Tests
|
| 4 |
+
#
|
| 5 |
+
# Produces:
|
| 6 |
+
# data/field_weights.json
|
| 7 |
+
#
|
| 8 |
+
# This script computes reliability scores for:
|
| 9 |
+
# - parser_rules
|
| 10 |
+
# - parser_ext
|
| 11 |
+
# - parser_llm
|
| 12 |
+
#
|
| 13 |
+
# and outputs:
|
| 14 |
+
# {
|
| 15 |
+
# "global": { ... },
|
| 16 |
+
# "fields": { field -> weights },
|
| 17 |
+
# "meta": { ... }
|
| 18 |
+
# }
|
| 19 |
+
#
|
| 20 |
+
# These weights are used by parser_fusion (Stage 12B).
|
| 21 |
+
# ------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import json
|
| 27 |
+
import os
|
| 28 |
+
from collections import defaultdict
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 31 |
+
|
| 32 |
+
# Core parsers
|
| 33 |
+
from engine.parser_rules import parse_text_rules
|
| 34 |
+
from engine.parser_ext import parse_text_extended
|
| 35 |
+
|
| 36 |
+
# LLM parser (optional)
|
| 37 |
+
try:
|
| 38 |
+
from engine.parser_llm import parse_llm as parse_text_llm_local
|
| 39 |
+
except Exception:
|
| 40 |
+
parse_text_llm_local = None # gracefully degrade if LLM unavailable
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ------------------------------------------------------------
|
| 44 |
+
# Constants
|
| 45 |
+
# ------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
DEFAULT_GOLD_PATH = os.path.join("data", "gold_tests.json")
|
| 48 |
+
DEFAULT_OUT_PATH = os.path.join("data", "field_weights.json")
|
| 49 |
+
|
| 50 |
+
MISSING_PENALTY = 0.5
|
| 51 |
+
SMOOTHING = 1e-3
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ------------------------------------------------------------
|
| 55 |
+
# Data Structures
|
| 56 |
+
# ------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class ParserOutcome:
|
| 60 |
+
prediction: Optional[str]
|
| 61 |
+
correct: bool
|
| 62 |
+
wrong: bool
|
| 63 |
+
missing: bool
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class FieldStats:
|
| 68 |
+
correct: int = 0
|
| 69 |
+
wrong: int = 0
|
| 70 |
+
missing: int = 0
|
| 71 |
+
|
| 72 |
+
def total(self) -> int:
|
| 73 |
+
return self.correct + self.wrong + self.missing
|
| 74 |
+
|
| 75 |
+
def score(self, missing_penalty: float = MISSING_PENALTY) -> float:
|
| 76 |
+
if self.total() == 0:
|
| 77 |
+
return 0.0
|
| 78 |
+
denom = self.correct + self.wrong + missing_penalty * self.missing
|
| 79 |
+
if denom == 0:
|
| 80 |
+
return 0.0
|
| 81 |
+
return self.correct / denom
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ------------------------------------------------------------
|
| 85 |
+
# Gold Loading
|
| 86 |
+
# ------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
def _load_gold_tests(path: str) -> List[Dict[str, Any]]:
|
| 89 |
+
if not os.path.exists(path):
|
| 90 |
+
raise FileNotFoundError(f"Gold tests not found: {path}")
|
| 91 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 92 |
+
data = json.load(f)
|
| 93 |
+
if not isinstance(data, list):
|
| 94 |
+
raise ValueError("gold_tests.json must be a list")
|
| 95 |
+
return data
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _extract_text_and_expected(test_obj: Dict[str, Any]) -> Tuple[str, Dict[str, str]]:
|
| 99 |
+
text = (
|
| 100 |
+
test_obj.get("text")
|
| 101 |
+
or test_obj.get("description")
|
| 102 |
+
or test_obj.get("input")
|
| 103 |
+
or test_obj.get("raw")
|
| 104 |
+
or ""
|
| 105 |
+
)
|
| 106 |
+
if not isinstance(text, str):
|
| 107 |
+
text = str(text)
|
| 108 |
+
|
| 109 |
+
expected: Dict[str, str] = {}
|
| 110 |
+
|
| 111 |
+
if isinstance(test_obj.get("expected"), dict):
|
| 112 |
+
for k, v in test_obj["expected"].items():
|
| 113 |
+
expected[str(k)] = str(v)
|
| 114 |
+
return text, expected
|
| 115 |
+
|
| 116 |
+
if isinstance(test_obj.get("expected_core"), dict):
|
| 117 |
+
for k, v in test_obj["expected_core"].items():
|
| 118 |
+
expected[str(k)] = str(v)
|
| 119 |
+
|
| 120 |
+
if isinstance(test_obj.get("expected_extended"), dict):
|
| 121 |
+
for k, v in test_obj["expected_extended"].items():
|
| 122 |
+
expected[str(k)] = str(v)
|
| 123 |
+
|
| 124 |
+
return text, expected
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ------------------------------------------------------------
|
| 128 |
+
# Parser Execution
|
| 129 |
+
# ------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
def _get_parser_predictions(text: str, include_llm: bool = True) -> Dict[str, Dict[str, str]]:
|
| 132 |
+
results: Dict[str, Dict[str, str]] = {}
|
| 133 |
+
|
| 134 |
+
r = parse_text_rules(text)
|
| 135 |
+
results["rules"] = dict(r.get("parsed_fields", {}))
|
| 136 |
+
|
| 137 |
+
e = parse_text_extended(text)
|
| 138 |
+
results["extended"] = dict(e.get("parsed_fields", {}))
|
| 139 |
+
|
| 140 |
+
llm_values: Dict[str, str] = {}
|
| 141 |
+
if include_llm and parse_text_llm_local is not None:
|
| 142 |
+
try:
|
| 143 |
+
llm_out = parse_text_llm_local(text)
|
| 144 |
+
llm_values = dict(llm_out.get("parsed_fields", {}))
|
| 145 |
+
except Exception:
|
| 146 |
+
llm_values = {}
|
| 147 |
+
results["llm"] = llm_values
|
| 148 |
+
|
| 149 |
+
return results
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _outcome_for_field(expected_val: str, predicted_val: Optional[str]) -> ParserOutcome:
|
| 153 |
+
if predicted_val is None:
|
| 154 |
+
return ParserOutcome(prediction=None, correct=False, wrong=False, missing=True)
|
| 155 |
+
if predicted_val == expected_val:
|
| 156 |
+
return ParserOutcome(prediction=predicted_val, correct=True, wrong=False, missing=False)
|
| 157 |
+
return ParserOutcome(prediction=predicted_val, correct=False, wrong=True, missing=False)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# ------------------------------------------------------------
|
| 161 |
+
# Stats Computation
|
| 162 |
+
# ------------------------------------------------------------
|
| 163 |
+
|
| 164 |
+
def _compute_stats_from_gold(
|
| 165 |
+
gold_tests: List[Dict[str, Any]],
|
| 166 |
+
include_llm: bool = True,
|
| 167 |
+
):
|
| 168 |
+
field_stats = defaultdict(lambda: defaultdict(FieldStats))
|
| 169 |
+
global_stats = defaultdict(FieldStats)
|
| 170 |
+
|
| 171 |
+
total_samples = 0
|
| 172 |
+
|
| 173 |
+
for sample in gold_tests:
|
| 174 |
+
text, expected = _extract_text_and_expected(sample)
|
| 175 |
+
if not expected:
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
total_samples += 1
|
| 179 |
+
preds = _get_parser_predictions(text, include_llm=include_llm)
|
| 180 |
+
|
| 181 |
+
for field, expected_val in expected.items():
|
| 182 |
+
expected_val = str(expected_val)
|
| 183 |
+
for parser_name in ["rules", "extended", "llm"]:
|
| 184 |
+
if parser_name == "llm" and not include_llm:
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
pred_val = preds.get(parser_name, {}).get(field)
|
| 188 |
+
|
| 189 |
+
outcome = _outcome_for_field(expected_val, pred_val)
|
| 190 |
+
|
| 191 |
+
fs = field_stats[field][parser_name]
|
| 192 |
+
if outcome.correct:
|
| 193 |
+
fs.correct += 1
|
| 194 |
+
if outcome.wrong:
|
| 195 |
+
fs.wrong += 1
|
| 196 |
+
if outcome.missing:
|
| 197 |
+
fs.missing += 1
|
| 198 |
+
|
| 199 |
+
gs = global_stats[parser_name]
|
| 200 |
+
if outcome.correct:
|
| 201 |
+
gs.correct += 1
|
| 202 |
+
if outcome.wrong:
|
| 203 |
+
gs.wrong += 1
|
| 204 |
+
if outcome.missing:
|
| 205 |
+
gs.missing += 1
|
| 206 |
+
|
| 207 |
+
return field_stats, global_stats, total_samples
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def _normalise(weights: Dict[str, float], smoothing: float = SMOOTHING) -> Dict[str, float]:
|
| 211 |
+
adjusted = {k: max(smoothing, v) for k, v in weights.items()}
|
| 212 |
+
total = sum(adjusted.values())
|
| 213 |
+
if total <= 0:
|
| 214 |
+
n = len(adjusted)
|
| 215 |
+
return {k: 1.0 / n for k in adjusted}
|
| 216 |
+
return {k: v / total for k, v in adjusted.items()}
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _build_weights_json(
|
| 220 |
+
field_stats,
|
| 221 |
+
global_stats,
|
| 222 |
+
total_samples,
|
| 223 |
+
include_llm=True,
|
| 224 |
+
):
|
| 225 |
+
# Global scores
|
| 226 |
+
raw_global = {}
|
| 227 |
+
for parser_name, stats in global_stats.items():
|
| 228 |
+
if parser_name == "llm" and not include_llm:
|
| 229 |
+
continue
|
| 230 |
+
raw_global[parser_name] = stats.score(MISSING_PENALTY)
|
| 231 |
+
|
| 232 |
+
global_weights = _normalise(raw_global)
|
| 233 |
+
|
| 234 |
+
# Per-field
|
| 235 |
+
fields_block = {}
|
| 236 |
+
|
| 237 |
+
for field_name, stats_dict in field_stats.items():
|
| 238 |
+
raw_scores = {}
|
| 239 |
+
total_support = 0
|
| 240 |
+
|
| 241 |
+
for parser_name, stats in stats_dict.items():
|
| 242 |
+
if parser_name == "llm" and not include_llm:
|
| 243 |
+
continue
|
| 244 |
+
raw_scores[parser_name] = stats.score(MISSING_PENALTY)
|
| 245 |
+
total_support += stats.total()
|
| 246 |
+
|
| 247 |
+
if total_support < 5:
|
| 248 |
+
# low support → blend global + local
|
| 249 |
+
local_norm = _normalise(raw_scores)
|
| 250 |
+
mixed = {}
|
| 251 |
+
for p in global_weights:
|
| 252 |
+
mixed[p] = 0.7 * global_weights[p] + 0.3 * local_norm.get(p, global_weights[p])
|
| 253 |
+
field_w = _normalise(mixed)
|
| 254 |
+
else:
|
| 255 |
+
field_w = _normalise(raw_scores)
|
| 256 |
+
|
| 257 |
+
fields_block[field_name] = {
|
| 258 |
+
**field_w,
|
| 259 |
+
"support": total_support,
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
return {
|
| 263 |
+
"global": global_weights,
|
| 264 |
+
"fields": fields_block,
|
| 265 |
+
"meta": {
|
| 266 |
+
"total_samples": total_samples,
|
| 267 |
+
"missing_penalty": MISSING_PENALTY,
|
| 268 |
+
"smoothing": SMOOTHING,
|
| 269 |
+
"include_llm": include_llm,
|
| 270 |
+
},
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# ------------------------------------------------------------
|
| 275 |
+
# Public API
|
| 276 |
+
# ------------------------------------------------------------
|
| 277 |
+
|
| 278 |
+
def train_field_weights(
|
| 279 |
+
gold_path: str = DEFAULT_GOLD_PATH,
|
| 280 |
+
out_path: str = DEFAULT_OUT_PATH,
|
| 281 |
+
include_llm: bool = False,
|
| 282 |
+
):
|
| 283 |
+
print(f"[12A] Loading gold tests: {gold_path}")
|
| 284 |
+
gold = _load_gold_tests(gold_path)
|
| 285 |
+
print(f"[12A] {len(gold)} gold samples loaded")
|
| 286 |
+
|
| 287 |
+
field_stats, global_stats, total_samples = _compute_stats_from_gold(
|
| 288 |
+
gold, include_llm=include_llm
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
print("[12A] Computing weights...")
|
| 292 |
+
weights = _build_weights_json(
|
| 293 |
+
field_stats, global_stats, total_samples, include_llm=include_llm
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
out_dir = os.path.dirname(out_path)
|
| 297 |
+
if out_dir and not os.path.exists(out_dir):
|
| 298 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 299 |
+
|
| 300 |
+
print(f"[12A] Writing: {out_path}")
|
| 301 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 302 |
+
json.dump(weights, f, indent=2, ensure_ascii=False)
|
| 303 |
+
|
| 304 |
+
print("[12A] Done.")
|
| 305 |
+
return weights
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# ------------------------------------------------------------
|
| 309 |
+
# CLI
|
| 310 |
+
# ------------------------------------------------------------
|
| 311 |
+
|
| 312 |
+
def _parse_args(argv=None):
|
| 313 |
+
p = argparse.ArgumentParser(description="Stage 12A — Train parser weights")
|
| 314 |
+
p.add_argument("--gold", type=str, default=DEFAULT_GOLD_PATH)
|
| 315 |
+
p.add_argument("--out", type=str, default=DEFAULT_OUT_PATH)
|
| 316 |
+
p.add_argument("--include-llm", action="store_true")
|
| 317 |
+
return p.parse_args(argv)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def main(argv=None):
|
| 321 |
+
args = _parse_args(argv)
|
| 322 |
+
train_field_weights(
|
| 323 |
+
gold_path=args.gold,
|
| 324 |
+
out_path=args.out,
|
| 325 |
+
include_llm=args.include_llm,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
if __name__ == "__main__":
|
| 330 |
+
main()
|
training/gold_tester.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training/gold_tester.py
|
| 2 |
+
# ------------------------------------------------------------
|
| 3 |
+
# Stage 10A: Evaluate parsers on gold tests.
|
| 4 |
+
# This MUST NOT crash during import.
|
| 5 |
+
# ------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
from typing import Dict, Any, List
|
| 12 |
+
|
| 13 |
+
from engine.parser_rules import parse_text_rules
|
| 14 |
+
from engine.parser_ext import parse_text_extended
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
GOLD_PATH = "training/gold_tests.json"
|
| 18 |
+
REPORT_DIR = "reports"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _load_gold_tests() -> List[Dict[str, Any]]:
|
| 22 |
+
if not os.path.exists(GOLD_PATH):
|
| 23 |
+
return []
|
| 24 |
+
with open(GOLD_PATH, "r", encoding="utf-8") as f:
|
| 25 |
+
try:
|
| 26 |
+
data = json.load(f)
|
| 27 |
+
return data if isinstance(data, list) else []
|
| 28 |
+
except Exception:
|
| 29 |
+
return []
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def run_gold_tests(mode: str = "rules") -> Dict[str, Any]:
|
| 33 |
+
gold_tests = _load_gold_tests()
|
| 34 |
+
if not gold_tests:
|
| 35 |
+
return {
|
| 36 |
+
"summary": {
|
| 37 |
+
"mode": mode,
|
| 38 |
+
"tests": 0,
|
| 39 |
+
"total_correct": 0,
|
| 40 |
+
"total_fields": 0,
|
| 41 |
+
"overall_accuracy": 0.0,
|
| 42 |
+
"proposals_path": "data/extended_proposals.jsonl",
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
os.makedirs(REPORT_DIR, exist_ok=True)
|
| 47 |
+
|
| 48 |
+
wrong_cases = []
|
| 49 |
+
total_correct = 0
|
| 50 |
+
total_fields = 0
|
| 51 |
+
|
| 52 |
+
for idx, test in enumerate(gold_tests):
|
| 53 |
+
text = test.get("input", "")
|
| 54 |
+
expected = test.get("expected", {})
|
| 55 |
+
|
| 56 |
+
if mode == "rules":
|
| 57 |
+
parsed = parse_text_rules(text).get("parsed_fields", {})
|
| 58 |
+
elif mode == "rules+extended":
|
| 59 |
+
rule_fields = parse_text_rules(text).get("parsed_fields", {})
|
| 60 |
+
ext_fields = parse_text_extended(text).get("parsed_fields", {})
|
| 61 |
+
parsed = {**rule_fields, **ext_fields}
|
| 62 |
+
else:
|
| 63 |
+
parsed = {}
|
| 64 |
+
|
| 65 |
+
# Compare field-by-field
|
| 66 |
+
correct_count = 0
|
| 67 |
+
for key, val in expected.items():
|
| 68 |
+
total_fields += 1
|
| 69 |
+
if key in parsed and str(parsed[key]).strip().lower() == str(val).strip().lower():
|
| 70 |
+
correct_count += 1
|
| 71 |
+
|
| 72 |
+
total_correct += correct_count
|
| 73 |
+
|
| 74 |
+
if correct_count < len(expected):
|
| 75 |
+
wrong_cases.append(idx)
|
| 76 |
+
|
| 77 |
+
accuracy = total_correct / total_fields if total_fields else 0.0
|
| 78 |
+
|
| 79 |
+
summary = {
|
| 80 |
+
"mode": mode,
|
| 81 |
+
"tests": len(gold_tests),
|
| 82 |
+
"total_correct": total_correct,
|
| 83 |
+
"total_fields": total_fields,
|
| 84 |
+
"overall_accuracy": accuracy,
|
| 85 |
+
"wrong_cases": wrong_cases,
|
| 86 |
+
"proposals_path": "data/extended_proposals.jsonl",
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
return {"summary": summary}
|
training/gold_tests.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eb94485222e9733a0d530d4df8ac0f35c8f95770cc9ef44bcb1289807e0b108e
|
| 3 |
+
size 18563634
|
training/gold_trainer.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training/gold_trainer.py
|
| 2 |
+
# ------------------------------------------------------------
|
| 3 |
+
# Stage 10C — Orchestrates gold-test-driven training:
|
| 4 |
+
# 1) Alias trainer (DISABLED for safety)
|
| 5 |
+
# 2) Schema expander (safe v10C)
|
| 6 |
+
# 3) Signals trainer (placeholder)
|
| 7 |
+
#
|
| 8 |
+
# This file MUST successfully import and expose train_from_gold().
|
| 9 |
+
# ------------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
from typing import Dict, Any
|
| 13 |
+
|
| 14 |
+
# Safe schema expander
|
| 15 |
+
from training.schema_expander import expand_schema
|
| 16 |
+
|
| 17 |
+
# Placeholder signals trainer
|
| 18 |
+
from training.signal_trainer import train_signals
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def train_from_gold() -> Dict[str, Any]:
|
| 22 |
+
"""
|
| 23 |
+
Runs all gold-test–driven training components (Stage 10C).
|
| 24 |
+
|
| 25 |
+
Returns a dict:
|
| 26 |
+
{
|
| 27 |
+
"alias_trainer": {...},
|
| 28 |
+
"schema_expander": {...},
|
| 29 |
+
"signals_trainer": {...}
|
| 30 |
+
}
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
# --------------------------------------------------------
|
| 34 |
+
# 1) Alias Trainer — DISABLED to avoid destructive mappings
|
| 35 |
+
# --------------------------------------------------------
|
| 36 |
+
alias_result = {
|
| 37 |
+
"ok": False,
|
| 38 |
+
"message": (
|
| 39 |
+
"Alias trainer is disabled in Stage 10C to prevent unsafe "
|
| 40 |
+
"auto-mappings. Edit data/alias_maps.json manually if needed."
|
| 41 |
+
),
|
| 42 |
+
"alias_map_path": "data/alias_maps.json",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# --------------------------------------------------------
|
| 46 |
+
# 2) Schema Expander — Safe version
|
| 47 |
+
# --------------------------------------------------------
|
| 48 |
+
try:
|
| 49 |
+
schema_result = expand_schema()
|
| 50 |
+
except Exception as e:
|
| 51 |
+
schema_result = {
|
| 52 |
+
"ok": False,
|
| 53 |
+
"message": f"Schema expander crashed: {e}",
|
| 54 |
+
"auto_added_fields": {},
|
| 55 |
+
"proposed_fields": [],
|
| 56 |
+
"schema_path": "data/extended_schema.json",
|
| 57 |
+
"proposals_path": "data/extended_proposals.jsonl",
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# --------------------------------------------------------
|
| 61 |
+
# 3) Signals Trainer (placeholder)
|
| 62 |
+
# --------------------------------------------------------
|
| 63 |
+
try:
|
| 64 |
+
signals_result = train_signals()
|
| 65 |
+
except Exception as e:
|
| 66 |
+
signals_result = {
|
| 67 |
+
"ok": False,
|
| 68 |
+
"message": f"Signal trainer crashed: {e}",
|
| 69 |
+
"signals_catalog_path": "data/signals_catalog.json",
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
# --------------------------------------------------------
|
| 73 |
+
# Combined report
|
| 74 |
+
# --------------------------------------------------------
|
| 75 |
+
return {
|
| 76 |
+
"alias_trainer": alias_result,
|
| 77 |
+
"schema_expander": schema_result,
|
| 78 |
+
"signals_trainer": signals_result,
|
| 79 |
+
}
|
training/hf_sync.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training/hf_sync.py
|
| 2 |
+
# ------------------------------------------------------------
|
| 3 |
+
# Sync updated data files back to the SAME Hugging Face Space.
|
| 4 |
+
# ------------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import List, Dict, Any
|
| 8 |
+
|
| 9 |
+
from huggingface_hub import HfApi, CommitOperationAdd
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def push_to_hf(
|
| 13 |
+
paths: List[str],
|
| 14 |
+
commit_message: str = "train: update extended schema, aliases, signals from gold tests",
|
| 15 |
+
) -> Dict[str, Any]:
|
| 16 |
+
|
| 17 |
+
repo_id = os.getenv("HF_SPACE_REPO_ID")
|
| 18 |
+
token = os.getenv("HF_TOKEN")
|
| 19 |
+
|
| 20 |
+
if not repo_id:
|
| 21 |
+
return {
|
| 22 |
+
"ok": False,
|
| 23 |
+
"error": "Missing HF_SPACE_REPO_ID environment variable.",
|
| 24 |
+
"uploaded": [],
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
if not token:
|
| 28 |
+
return {
|
| 29 |
+
"ok": False,
|
| 30 |
+
"error": "Missing HF_TOKEN environment variable.",
|
| 31 |
+
"uploaded": [],
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
api = HfApi()
|
| 35 |
+
operations = []
|
| 36 |
+
uploaded = []
|
| 37 |
+
|
| 38 |
+
for p in paths:
|
| 39 |
+
if not os.path.exists(p):
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
operations.append(
|
| 43 |
+
CommitOperationAdd(path_in_repo=p, path_or_fileobj=p)
|
| 44 |
+
)
|
| 45 |
+
uploaded.append(p)
|
| 46 |
+
|
| 47 |
+
if not operations:
|
| 48 |
+
return {
|
| 49 |
+
"ok": False,
|
| 50 |
+
"error": "No existing files to upload.",
|
| 51 |
+
"uploaded": [],
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
commit_info = api.create_commit(
|
| 55 |
+
repo_id=repo_id,
|
| 56 |
+
repo_type="space",
|
| 57 |
+
operations=operations,
|
| 58 |
+
commit_message=commit_message,
|
| 59 |
+
token=token,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return {
|
| 63 |
+
"ok": True,
|
| 64 |
+
"uploaded": uploaded,
|
| 65 |
+
"repo_id": repo_id,
|
| 66 |
+
"commit_message": commit_message,
|
| 67 |
+
"commit_url": commit_info.commit_url,
|
| 68 |
+
}
|
training/parser_eval.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training/parser_eval.py
|
| 2 |
+
# ------------------------------------------------------------
|
| 3 |
+
# Parser Evaluation (Stage 10A)
|
| 4 |
+
#
|
| 5 |
+
# This version ONLY evaluates:
|
| 6 |
+
# - Rule parser
|
| 7 |
+
# - Extended parser
|
| 8 |
+
#
|
| 9 |
+
# The LLM parser is intentionally disabled at this stage
|
| 10 |
+
# because alias maps and schema are not trained yet.
|
| 11 |
+
#
|
| 12 |
+
# This makes Stage 10A FAST and stable (< 3 seconds).
|
| 13 |
+
# ------------------------------------------------------------
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
from typing import Dict, Any
|
| 18 |
+
|
| 19 |
+
from engine.parser_rules import parse_text_rules
|
| 20 |
+
from engine.parser_ext import parse_text_extended
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Path to the gold tests
|
| 24 |
+
GOLD_PATH = "training/gold_tests.json"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def evaluate_single_test(test: Dict[str, Any]) -> Dict[str, Any]:
|
| 28 |
+
"""
|
| 29 |
+
Evaluate one gold test with rules + extended parsers.
|
| 30 |
+
"""
|
| 31 |
+
text = test.get("input", "")
|
| 32 |
+
expected = test.get("expected", {})
|
| 33 |
+
|
| 34 |
+
# Run deterministic parsers
|
| 35 |
+
rule_out = parse_text_rules(text).get("parsed_fields", {})
|
| 36 |
+
ext_out = parse_text_extended(text).get("parsed_fields", {})
|
| 37 |
+
|
| 38 |
+
# Merge rule + extended (extended overwrites rules)
|
| 39 |
+
merged = dict(rule_out)
|
| 40 |
+
for k, v in ext_out.items():
|
| 41 |
+
if v != "Unknown":
|
| 42 |
+
merged[k] = v
|
| 43 |
+
|
| 44 |
+
total = len(expected)
|
| 45 |
+
correct = 0
|
| 46 |
+
wrong = {}
|
| 47 |
+
|
| 48 |
+
for field, exp_val in expected.items():
|
| 49 |
+
got = merged.get(field, "Unknown")
|
| 50 |
+
if got.lower() == exp_val.lower():
|
| 51 |
+
correct += 0 if exp_val == "Unknown" else 1 # Unknown is neutral
|
| 52 |
+
else:
|
| 53 |
+
wrong[field] = {"expected": exp_val, "got": got}
|
| 54 |
+
|
| 55 |
+
return {
|
| 56 |
+
"correct": correct,
|
| 57 |
+
"total": total,
|
| 58 |
+
"accuracy": correct / total if total else 0,
|
| 59 |
+
"wrong": wrong,
|
| 60 |
+
"merged": merged,
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def run_parser_eval(mode: str = "rules_extended") -> Dict[str, Any]:
|
| 65 |
+
"""
|
| 66 |
+
Evaluate ALL gold tests using rules + extended parsing only.
|
| 67 |
+
"""
|
| 68 |
+
if not os.path.exists(GOLD_PATH):
|
| 69 |
+
return {"error": f"Gold test file not found at {GOLD_PATH}"}
|
| 70 |
+
|
| 71 |
+
with open(GOLD_PATH, "r", encoding="utf-8") as f:
|
| 72 |
+
gold = json.load(f)
|
| 73 |
+
|
| 74 |
+
results = []
|
| 75 |
+
wrong_cases = []
|
| 76 |
+
|
| 77 |
+
total_correct = 0
|
| 78 |
+
total_fields = 0
|
| 79 |
+
|
| 80 |
+
for test in gold:
|
| 81 |
+
out = evaluate_single_test(test)
|
| 82 |
+
results.append(out)
|
| 83 |
+
|
| 84 |
+
total_correct += out["correct"]
|
| 85 |
+
total_fields += out["total"]
|
| 86 |
+
|
| 87 |
+
if out["wrong"]:
|
| 88 |
+
wrong_cases.append({
|
| 89 |
+
"name": test.get("name", "Unnamed"),
|
| 90 |
+
"wrong": out["wrong"],
|
| 91 |
+
"parsed": out["merged"],
|
| 92 |
+
"expected": test.get("expected", {})
|
| 93 |
+
})
|
| 94 |
+
|
| 95 |
+
summary = {
|
| 96 |
+
"mode": "rules+extended",
|
| 97 |
+
"tests": len(gold),
|
| 98 |
+
"total_correct": total_correct,
|
| 99 |
+
"total_fields": total_fields,
|
| 100 |
+
"overall_accuracy": total_correct / total_fields if total_fields else 0,
|
| 101 |
+
"wrong_cases": wrong_cases,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
return summary
|
training/rag_index_builder.py
ADDED
|
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training/rag_index_builder.py
|
| 2 |
+
# ============================================================
|
| 3 |
+
# Build RAG index from JSON knowledge base (SECTION-AWARE)
|
| 4 |
+
#
|
| 5 |
+
# - Walks data/rag/knowledge_base/<Genus>/
|
| 6 |
+
# - Reads genus.json + species JSONs
|
| 7 |
+
# - Converts JSON → structured SECTION records
|
| 8 |
+
# - Computes embeddings via rag.rag_embedder.embed_texts
|
| 9 |
+
# - Writes index to data/rag/index/kb_index.json
|
| 10 |
+
#
|
| 11 |
+
# Output record schema (LOCKED):
|
| 12 |
+
# {
|
| 13 |
+
# "id": "Enterobacter|cloacae|species_markers|0",
|
| 14 |
+
# "level": "genus" | "species",
|
| 15 |
+
# "genus": "Enterobacter",
|
| 16 |
+
# "species": "cloacae" | null,
|
| 17 |
+
# "section": "...",
|
| 18 |
+
# "role": "...",
|
| 19 |
+
# "text": "...",
|
| 20 |
+
# "source_file": "...",
|
| 21 |
+
# "chunk_id": 0,
|
| 22 |
+
# "embedding": [...]
|
| 23 |
+
# }
|
| 24 |
+
#
|
| 25 |
+
# NOTE:
|
| 26 |
+
# We keep the locked keys above. We MAY add extra keys (non-breaking),
|
| 27 |
+
# e.g. "field_key" to support future scoring/weighting.
|
| 28 |
+
# ============================================================
|
| 29 |
+
|
| 30 |
+
from __future__ import annotations
|
| 31 |
+
|
| 32 |
+
import json
|
| 33 |
+
import os
|
| 34 |
+
import re
|
| 35 |
+
from typing import Dict, Any, List, Tuple, Optional
|
| 36 |
+
|
| 37 |
+
from rag.rag_embedder import embed_texts, EMBEDDING_MODEL_NAME
|
| 38 |
+
|
| 39 |
+
KB_ROOT = os.path.join("data", "rag", "knowledge_base")
|
| 40 |
+
INDEX_DIR = os.path.join("data", "rag", "index")
|
| 41 |
+
INDEX_PATH = os.path.join(INDEX_DIR, "kb_index.json")
|
| 42 |
+
|
| 43 |
+
# Chunk size is per-section. This should generally be smaller than the generator
|
| 44 |
+
# prompt chunk budget so retriever can pick "tight" context blocks.
|
| 45 |
+
DEFAULT_MAX_CHARS = int(os.getenv("BACTAI_RAG_CHUNK_MAX_CHARS", "1100"))
|
| 46 |
+
|
| 47 |
+
# ------------------------------------------------------------
|
| 48 |
+
# TEXT HELPERS
|
| 49 |
+
# ------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def _norm_str(x: Any) -> str:
|
| 52 |
+
return str(x).strip() if x is not None else ""
|
| 53 |
+
|
| 54 |
+
def _safe_join(items: List[str], sep: str = " ") -> str:
|
| 55 |
+
return sep.join([s for s in items if s])
|
| 56 |
+
|
| 57 |
+
def _bullet_lines(items: List[str], prefix: str = "- ") -> str:
|
| 58 |
+
clean = [i.strip() for i in items if isinstance(i, str) and i.strip()]
|
| 59 |
+
if not clean:
|
| 60 |
+
return ""
|
| 61 |
+
return "\n".join(prefix + c for c in clean)
|
| 62 |
+
|
| 63 |
+
def _title_case_field(field_name: str) -> str:
|
| 64 |
+
# Keep parser field names stable (don’t “prettify” them incorrectly)
|
| 65 |
+
return field_name.strip()
|
| 66 |
+
|
| 67 |
+
def _format_expected_fields(expected_fields: Dict[str, Any]) -> str:
|
| 68 |
+
"""
|
| 69 |
+
Turn your expected_fields into a compact, self-contained key:value block.
|
| 70 |
+
Handles strings, lists, and simple scalars.
|
| 71 |
+
"""
|
| 72 |
+
if not isinstance(expected_fields, dict) or not expected_fields:
|
| 73 |
+
return ""
|
| 74 |
+
|
| 75 |
+
lines: List[str] = []
|
| 76 |
+
for k in sorted(expected_fields.keys(), key=lambda s: str(s).lower()):
|
| 77 |
+
key = _title_case_field(str(k))
|
| 78 |
+
v = expected_fields.get(k)
|
| 79 |
+
|
| 80 |
+
if isinstance(v, list):
|
| 81 |
+
vals = [str(x).strip() for x in v if str(x).strip()]
|
| 82 |
+
if vals:
|
| 83 |
+
lines.append(f"{key}: " + "; ".join(vals))
|
| 84 |
+
else:
|
| 85 |
+
lines.append(f"{key}: Unknown")
|
| 86 |
+
else:
|
| 87 |
+
val = _norm_str(v) or "Unknown"
|
| 88 |
+
lines.append(f"{key}: {val}")
|
| 89 |
+
|
| 90 |
+
return "\n".join(lines)
|
| 91 |
+
|
| 92 |
+
def _as_list(v: Any) -> List[str]:
|
| 93 |
+
if isinstance(v, list):
|
| 94 |
+
return [str(x).strip() for x in v if str(x).strip()]
|
| 95 |
+
if isinstance(v, str) and v.strip():
|
| 96 |
+
return [v.strip()]
|
| 97 |
+
if v is None:
|
| 98 |
+
return []
|
| 99 |
+
s = str(v).strip()
|
| 100 |
+
return [s] if s else []
|
| 101 |
+
|
| 102 |
+
def _is_unknown(v: str) -> bool:
|
| 103 |
+
return (v or "").strip().lower() in {"unknown", "not specified", "n/a", "na", ""}
|
| 104 |
+
|
| 105 |
+
def _expected_fields_to_sentences(
|
| 106 |
+
expected_fields: Dict[str, Any],
|
| 107 |
+
*,
|
| 108 |
+
subject: str,
|
| 109 |
+
) -> str:
|
| 110 |
+
"""
|
| 111 |
+
Convert expected_fields into DECLARATIVE microbiology statements.
|
| 112 |
+
This is the key fix for "Not specified" RAG outputs:
|
| 113 |
+
LLMs treat these as evidence-like assertions rather than schema metadata.
|
| 114 |
+
"""
|
| 115 |
+
if not isinstance(expected_fields, dict) or not expected_fields:
|
| 116 |
+
return ""
|
| 117 |
+
|
| 118 |
+
# Prefer these first (front-load the most diagnostic traits)
|
| 119 |
+
priority = [
|
| 120 |
+
"Gram Stain",
|
| 121 |
+
"Shape",
|
| 122 |
+
"Oxygen Requirement",
|
| 123 |
+
"Motility",
|
| 124 |
+
"Motility Type",
|
| 125 |
+
"Capsule",
|
| 126 |
+
"Spore Formation",
|
| 127 |
+
"Haemolysis",
|
| 128 |
+
"Haemolysis Type",
|
| 129 |
+
"Oxidase",
|
| 130 |
+
"Catalase",
|
| 131 |
+
"Indole",
|
| 132 |
+
"Urease",
|
| 133 |
+
"Citrate",
|
| 134 |
+
"Methyl Red",
|
| 135 |
+
"VP",
|
| 136 |
+
"H2S",
|
| 137 |
+
"ONPG",
|
| 138 |
+
"Nitrate Reduction",
|
| 139 |
+
"NaCl Tolerant (>=6%)",
|
| 140 |
+
"Growth Temperature",
|
| 141 |
+
"Media Grown On",
|
| 142 |
+
"Colony Morphology",
|
| 143 |
+
"Colony Pattern",
|
| 144 |
+
"Pigment",
|
| 145 |
+
"TSI Pattern",
|
| 146 |
+
"Gas Production",
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
# Then everything else, stable order
|
| 150 |
+
all_keys = list(expected_fields.keys())
|
| 151 |
+
ordered = []
|
| 152 |
+
seen = set()
|
| 153 |
+
for k in priority:
|
| 154 |
+
if k in expected_fields:
|
| 155 |
+
ordered.append(k)
|
| 156 |
+
seen.add(k)
|
| 157 |
+
for k in sorted(all_keys, key=lambda s: str(s).lower()):
|
| 158 |
+
if k not in seen:
|
| 159 |
+
ordered.append(k)
|
| 160 |
+
seen.add(k)
|
| 161 |
+
|
| 162 |
+
lines: List[str] = []
|
| 163 |
+
subj = subject.strip() or "This organism"
|
| 164 |
+
|
| 165 |
+
for k in ordered:
|
| 166 |
+
key = _title_case_field(str(k))
|
| 167 |
+
raw = expected_fields.get(k)
|
| 168 |
+
|
| 169 |
+
if isinstance(raw, list):
|
| 170 |
+
vals = [x for x in _as_list(raw) if not _is_unknown(x)]
|
| 171 |
+
if not vals:
|
| 172 |
+
continue
|
| 173 |
+
|
| 174 |
+
# Special handling for list-like fields
|
| 175 |
+
if key == "Media Grown On":
|
| 176 |
+
lines.append(f"{subj} can grow on: " + ", ".join(vals) + ".")
|
| 177 |
+
elif key == "Colony Morphology":
|
| 178 |
+
lines.append(f"{subj} colonies are described as: " + ", ".join(vals) + ".")
|
| 179 |
+
else:
|
| 180 |
+
lines.append(f"{subj} {key} includes: " + ", ".join(vals) + ".")
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
val = _norm_str(raw)
|
| 184 |
+
if _is_unknown(val):
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
# Field-specific phrasing for better “evidence-like” feel
|
| 188 |
+
if key == "Gram Stain":
|
| 189 |
+
lines.append(f"{subj} is typically Gram {val}.")
|
| 190 |
+
elif key == "Shape":
|
| 191 |
+
lines.append(f"{subj} typically has shape: {val}.")
|
| 192 |
+
elif key == "Oxygen Requirement":
|
| 193 |
+
lines.append(f"{subj} is typically {val}.")
|
| 194 |
+
elif key == "Growth Temperature":
|
| 195 |
+
lines.append(f"{subj} typically grows within: {val} °C.")
|
| 196 |
+
elif key == "Haemolysis Type":
|
| 197 |
+
lines.append(f"{subj} haemolysis type is typically: {val}.")
|
| 198 |
+
elif key == "Haemolysis":
|
| 199 |
+
lines.append(f"{subj} haemolysis is typically: {val}.")
|
| 200 |
+
elif key == "Pigment":
|
| 201 |
+
if val.lower() in {"none", "no", "negative"}:
|
| 202 |
+
lines.append(f"{subj} typically produces no pigment.")
|
| 203 |
+
else:
|
| 204 |
+
lines.append(f"{subj} may produce pigment: {val}.")
|
| 205 |
+
elif key == "Colony Pattern":
|
| 206 |
+
lines.append(f"{subj} colony/cellular pattern may be described as: {val}.")
|
| 207 |
+
else:
|
| 208 |
+
# Default: simple assertive sentence
|
| 209 |
+
lines.append(f"{subj} {key} is typically: {val}.")
|
| 210 |
+
|
| 211 |
+
# If we emitted nothing, return empty so we don’t add noise
|
| 212 |
+
return "\n".join(lines).strip()
|
| 213 |
+
|
| 214 |
+
def _format_key_differentiators(items: List[Dict[str, Any]]) -> str:
|
| 215 |
+
"""
|
| 216 |
+
For genus-level key_differentiators.
|
| 217 |
+
"""
|
| 218 |
+
if not isinstance(items, list) or not items:
|
| 219 |
+
return ""
|
| 220 |
+
out: List[str] = []
|
| 221 |
+
for obj in items:
|
| 222 |
+
if not isinstance(obj, dict):
|
| 223 |
+
continue
|
| 224 |
+
field = _norm_str(obj.get("field"))
|
| 225 |
+
expected = _norm_str(obj.get("expected"))
|
| 226 |
+
notes = _norm_str(obj.get("notes"))
|
| 227 |
+
distinguishes_from = obj.get("distinguishes_from") or []
|
| 228 |
+
if not field:
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
line = f"{field}: expected {expected or 'Unknown'}."
|
| 232 |
+
if isinstance(distinguishes_from, list) and distinguishes_from:
|
| 233 |
+
line += " Distinguishes from: " + ", ".join([_norm_str(x) for x in distinguishes_from if _norm_str(x)])
|
| 234 |
+
if not line.endswith("."):
|
| 235 |
+
line += "."
|
| 236 |
+
if notes:
|
| 237 |
+
line += f" Notes: {notes}"
|
| 238 |
+
if not line.endswith("."):
|
| 239 |
+
line += "."
|
| 240 |
+
out.append(line)
|
| 241 |
+
|
| 242 |
+
return "\n".join(out)
|
| 243 |
+
|
| 244 |
+
def _format_common_confusions(items: List[Dict[str, Any]], level: str) -> str:
|
| 245 |
+
"""
|
| 246 |
+
For genus/species common_confusions.
|
| 247 |
+
"""
|
| 248 |
+
if not isinstance(items, list) or not items:
|
| 249 |
+
return ""
|
| 250 |
+
out: List[str] = []
|
| 251 |
+
for obj in items:
|
| 252 |
+
if not isinstance(obj, dict):
|
| 253 |
+
continue
|
| 254 |
+
reason = _norm_str(obj.get("reason"))
|
| 255 |
+
if level == "genus":
|
| 256 |
+
who = _norm_str(obj.get("genus"))
|
| 257 |
+
if who:
|
| 258 |
+
out.append(f"{who}: {reason or 'Reason not specified.'}")
|
| 259 |
+
else:
|
| 260 |
+
who = _norm_str(obj.get("species")) or _norm_str(obj.get("genus"))
|
| 261 |
+
if who:
|
| 262 |
+
out.append(f"{who}: {reason or 'Reason not specified.'}")
|
| 263 |
+
return "\n".join(out)
|
| 264 |
+
|
| 265 |
+
def _format_recommended_next_tests(items: List[Dict[str, Any]]) -> str:
|
| 266 |
+
"""
|
| 267 |
+
For recommended_next_tests with optional API kit note.
|
| 268 |
+
"""
|
| 269 |
+
if not isinstance(items, list) or not items:
|
| 270 |
+
return ""
|
| 271 |
+
out: List[str] = []
|
| 272 |
+
for obj in items:
|
| 273 |
+
if not isinstance(obj, dict):
|
| 274 |
+
continue
|
| 275 |
+
test = _norm_str(obj.get("test"))
|
| 276 |
+
reason = _norm_str(obj.get("reason"))
|
| 277 |
+
api_kit = _norm_str(obj.get("api_kit"))
|
| 278 |
+
|
| 279 |
+
if not test:
|
| 280 |
+
continue
|
| 281 |
+
|
| 282 |
+
line = f"{test}"
|
| 283 |
+
if api_kit:
|
| 284 |
+
line += f" (API kit: {api_kit})"
|
| 285 |
+
if reason:
|
| 286 |
+
line += f": {reason}"
|
| 287 |
+
out.append(line)
|
| 288 |
+
return "\n".join(out)
|
| 289 |
+
|
| 290 |
+
# ------------------------------------------------------------
|
| 291 |
+
# CHUNKING (SECTION-LOCAL)
|
| 292 |
+
# ------------------------------------------------------------
|
| 293 |
+
|
| 294 |
+
def chunk_text_by_paragraph(text: str, max_chars: int = DEFAULT_MAX_CHARS) -> List[str]:
|
| 295 |
+
"""
|
| 296 |
+
Chunk within a single section. We never merge different sections together.
|
| 297 |
+
"""
|
| 298 |
+
text = (text or "").strip()
|
| 299 |
+
if not text:
|
| 300 |
+
return []
|
| 301 |
+
|
| 302 |
+
if len(text) <= max_chars:
|
| 303 |
+
return [text]
|
| 304 |
+
|
| 305 |
+
paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
|
| 306 |
+
if not paras:
|
| 307 |
+
paras = [l.strip() for l in text.splitlines() if l.strip()]
|
| 308 |
+
|
| 309 |
+
chunks: List[str] = []
|
| 310 |
+
current = ""
|
| 311 |
+
|
| 312 |
+
for p in paras:
|
| 313 |
+
candidate = (current + "\n\n" + p).strip() if current else p
|
| 314 |
+
if len(candidate) <= max_chars:
|
| 315 |
+
current = candidate
|
| 316 |
+
else:
|
| 317 |
+
if current:
|
| 318 |
+
chunks.append(current)
|
| 319 |
+
if len(p) <= max_chars:
|
| 320 |
+
current = p
|
| 321 |
+
else:
|
| 322 |
+
for i in range(0, len(p), max_chars):
|
| 323 |
+
chunks.append(p[i:i + max_chars].strip())
|
| 324 |
+
current = ""
|
| 325 |
+
|
| 326 |
+
if current:
|
| 327 |
+
chunks.append(current)
|
| 328 |
+
|
| 329 |
+
return [c for c in chunks if c.strip()]
|
| 330 |
+
|
| 331 |
+
# ------------------------------------------------------------
|
| 332 |
+
# SECTION EMITTERS
|
| 333 |
+
# ------------------------------------------------------------
|
| 334 |
+
|
| 335 |
+
def emit_genus_sections(doc: Dict[str, Any], genus: str) -> List[Dict[str, Any]]:
|
| 336 |
+
"""
|
| 337 |
+
Convert genus.json to a list of {section, role, text} entries.
|
| 338 |
+
"""
|
| 339 |
+
out: List[Dict[str, Any]] = []
|
| 340 |
+
|
| 341 |
+
overview = doc.get("overview") or {}
|
| 342 |
+
if isinstance(overview, dict):
|
| 343 |
+
short = _norm_str(overview.get("short"))
|
| 344 |
+
clinical = _norm_str(overview.get("clinical_context"))
|
| 345 |
+
if short:
|
| 346 |
+
out.append({"section": "overview", "role": "description", "text": f"Genus {genus}: {short}"})
|
| 347 |
+
if clinical:
|
| 348 |
+
out.append({"section": "overview", "role": "description", "text": f"Clinical context: {clinical}"})
|
| 349 |
+
|
| 350 |
+
expected_fields = doc.get("expected_fields")
|
| 351 |
+
if isinstance(expected_fields, dict) and expected_fields:
|
| 352 |
+
# 1) Declarative evidence-like sentences (NEW)
|
| 353 |
+
sent = _expected_fields_to_sentences(expected_fields, subject=f"Genus {genus}")
|
| 354 |
+
if sent:
|
| 355 |
+
out.append({
|
| 356 |
+
"section": "expected_profile_sentences",
|
| 357 |
+
"role": "expected_profile",
|
| 358 |
+
"text": sent,
|
| 359 |
+
})
|
| 360 |
+
|
| 361 |
+
# 2) Keep original key:value block (still useful)
|
| 362 |
+
text = _format_expected_fields(expected_fields)
|
| 363 |
+
if text:
|
| 364 |
+
out.append({
|
| 365 |
+
"section": "expected_fields",
|
| 366 |
+
"role": "expected_profile",
|
| 367 |
+
"text": f"Expected fields for genus {genus}:\n{text}",
|
| 368 |
+
})
|
| 369 |
+
|
| 370 |
+
field_notes = doc.get("field_notes")
|
| 371 |
+
if isinstance(field_notes, dict) and field_notes:
|
| 372 |
+
lines: List[str] = []
|
| 373 |
+
for k in sorted(field_notes.keys(), key=lambda s: str(s).lower()):
|
| 374 |
+
v = _norm_str(field_notes.get(k))
|
| 375 |
+
if v:
|
| 376 |
+
lines.append(f"{_title_case_field(str(k))}: {v}")
|
| 377 |
+
if lines:
|
| 378 |
+
out.append({"section": "field_notes", "role": "clarification", "text": "Field notes:\n" + "\n".join(lines)})
|
| 379 |
+
|
| 380 |
+
kd = doc.get("key_differentiators")
|
| 381 |
+
if isinstance(kd, list) and kd:
|
| 382 |
+
text = _format_key_differentiators(kd)
|
| 383 |
+
if text:
|
| 384 |
+
out.append({"section": "key_differentiators", "role": "differentiation", "text": "Key differentiators:\n" + text})
|
| 385 |
+
|
| 386 |
+
conf = doc.get("common_confusions")
|
| 387 |
+
if isinstance(conf, list) and conf:
|
| 388 |
+
text = _format_common_confusions(conf, level="genus")
|
| 389 |
+
if text:
|
| 390 |
+
out.append({"section": "common_confusions", "role": "warning", "text": "Common confusions:\n" + text})
|
| 391 |
+
|
| 392 |
+
wq = doc.get("when_to_question_identification")
|
| 393 |
+
if isinstance(wq, list) and wq:
|
| 394 |
+
lines = [str(x).strip() for x in wq if str(x).strip()]
|
| 395 |
+
if lines:
|
| 396 |
+
out.append({"section": "when_to_question_identification", "role": "warning", "text": "When to question identification:\n" + _bullet_lines(lines)})
|
| 397 |
+
|
| 398 |
+
rnt = doc.get("recommended_next_tests")
|
| 399 |
+
if isinstance(rnt, list) and rnt:
|
| 400 |
+
text = _format_recommended_next_tests(rnt)
|
| 401 |
+
if text:
|
| 402 |
+
out.append({"section": "recommended_next_tests", "role": "recommendation", "text": "Recommended next tests:\n" + text})
|
| 403 |
+
|
| 404 |
+
ss = doc.get("supported_species")
|
| 405 |
+
if isinstance(ss, list) and ss:
|
| 406 |
+
species_list = [str(x).strip() for x in ss if str(x).strip()]
|
| 407 |
+
if species_list:
|
| 408 |
+
out.append({"section": "supported_species", "role": "metadata", "text": f"Supported species for genus {genus}: " + ", ".join(species_list)})
|
| 409 |
+
|
| 410 |
+
return out
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def emit_species_sections(doc: Dict[str, Any], genus: str, species: str) -> List[Dict[str, Any]]:
|
| 414 |
+
"""
|
| 415 |
+
Convert a species JSON to a list of {section, role, text} entries.
|
| 416 |
+
"""
|
| 417 |
+
out: List[Dict[str, Any]] = []
|
| 418 |
+
overview = doc.get("overview") or {}
|
| 419 |
+
if isinstance(overview, dict):
|
| 420 |
+
short = _norm_str(overview.get("short"))
|
| 421 |
+
clinical = _norm_str(overview.get("clinical_context"))
|
| 422 |
+
if short:
|
| 423 |
+
out.append({"section": "overview", "role": "description", "text": f"Species {genus} {species}: {short}"})
|
| 424 |
+
if clinical:
|
| 425 |
+
out.append({"section": "overview", "role": "description", "text": f"Clinical context: {clinical}"})
|
| 426 |
+
|
| 427 |
+
expected_fields = doc.get("expected_fields")
|
| 428 |
+
if isinstance(expected_fields, dict) and expected_fields:
|
| 429 |
+
# 1) Declarative evidence-like sentences (NEW)
|
| 430 |
+
sent = _expected_fields_to_sentences(expected_fields, subject=f"Species {genus} {species}")
|
| 431 |
+
if sent:
|
| 432 |
+
out.append({
|
| 433 |
+
"section": "expected_profile_sentences",
|
| 434 |
+
"role": "expected_profile",
|
| 435 |
+
"text": sent,
|
| 436 |
+
})
|
| 437 |
+
|
| 438 |
+
# 2) Keep original key:value block
|
| 439 |
+
text = _format_expected_fields(expected_fields)
|
| 440 |
+
if text:
|
| 441 |
+
out.append({"section": "expected_fields", "role": "expected_profile", "text": f"Expected fields for species {genus} {species}:\n{text}"})
|
| 442 |
+
|
| 443 |
+
markers = doc.get("species_markers")
|
| 444 |
+
if isinstance(markers, list) and markers:
|
| 445 |
+
lines: List[str] = []
|
| 446 |
+
for m in markers:
|
| 447 |
+
if not isinstance(m, dict):
|
| 448 |
+
continue
|
| 449 |
+
field = _norm_str(m.get("field"))
|
| 450 |
+
val = _norm_str(m.get("value"))
|
| 451 |
+
importance = _norm_str(m.get("importance"))
|
| 452 |
+
notes = _norm_str(m.get("notes"))
|
| 453 |
+
if not field:
|
| 454 |
+
continue
|
| 455 |
+
line = f"{field}: {val or 'Unknown'}"
|
| 456 |
+
if importance:
|
| 457 |
+
line += f" (importance: {importance})"
|
| 458 |
+
if notes:
|
| 459 |
+
line += f" — {notes}"
|
| 460 |
+
lines.append(line)
|
| 461 |
+
if lines:
|
| 462 |
+
out.append({"section": "species_markers", "role": "species_marker", "text": "Species markers:\n" + "\n".join(lines)})
|
| 463 |
+
|
| 464 |
+
conf = doc.get("common_confusions")
|
| 465 |
+
if isinstance(conf, list) and conf:
|
| 466 |
+
text = _format_common_confusions(conf, level="species")
|
| 467 |
+
if text:
|
| 468 |
+
out.append({"section": "common_confusions", "role": "warning", "text": "Common confusions:\n" + text})
|
| 469 |
+
|
| 470 |
+
wq = doc.get("when_to_question_identification")
|
| 471 |
+
if isinstance(wq, list) and wq:
|
| 472 |
+
lines = [str(x).strip() for x in wq if str(x).strip()]
|
| 473 |
+
if lines:
|
| 474 |
+
out.append({"section": "when_to_question_identification", "role": "warning", "text": "When to question identification:\n" + _bullet_lines(lines)})
|
| 475 |
+
|
| 476 |
+
rnt = doc.get("recommended_next_tests")
|
| 477 |
+
if isinstance(rnt, list) and rnt:
|
| 478 |
+
text = _format_recommended_next_tests(rnt)
|
| 479 |
+
if text:
|
| 480 |
+
out.append({"section": "recommended_next_tests", "role": "recommendation", "text": "Recommended next tests:\n" + text})
|
| 481 |
+
|
| 482 |
+
return out
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
# ------------------------------------------------------------
|
| 486 |
+
# INDEX BUILD
|
| 487 |
+
# ------------------------------------------------------------
|
| 488 |
+
|
| 489 |
+
def _iter_kb_files() -> List[Tuple[str, str]]:
|
| 490 |
+
entries: List[Tuple[str, str]] = []
|
| 491 |
+
if not os.path.isdir(KB_ROOT):
|
| 492 |
+
return entries
|
| 493 |
+
|
| 494 |
+
for genus in sorted(os.listdir(KB_ROOT)):
|
| 495 |
+
genus_dir = os.path.join(KB_ROOT, genus)
|
| 496 |
+
if not os.path.isdir(genus_dir):
|
| 497 |
+
continue
|
| 498 |
+
for fname in sorted(os.listdir(genus_dir)):
|
| 499 |
+
if fname.lower().endswith(".json"):
|
| 500 |
+
entries.append((genus, os.path.join(genus_dir, fname)))
|
| 501 |
+
return entries
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def build_rag_index(max_chars: int = DEFAULT_MAX_CHARS) -> Dict[str, Any]:
|
| 505 |
+
os.makedirs(INDEX_DIR, exist_ok=True)
|
| 506 |
+
|
| 507 |
+
kb_entries = _iter_kb_files()
|
| 508 |
+
if not kb_entries:
|
| 509 |
+
return {"ok": False, "message": "No KB JSON files found."}
|
| 510 |
+
|
| 511 |
+
docs_for_embedding: List[str] = []
|
| 512 |
+
meta: List[Dict[str, Any]] = []
|
| 513 |
+
|
| 514 |
+
num_json_errors = 0
|
| 515 |
+
|
| 516 |
+
for genus_dir_name, path in kb_entries:
|
| 517 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 518 |
+
try:
|
| 519 |
+
doc = json.load(f)
|
| 520 |
+
except json.JSONDecodeError as e:
|
| 521 |
+
print(f"[rag_index_builder] JSON error in {path}: {e}")
|
| 522 |
+
num_json_errors += 1
|
| 523 |
+
continue
|
| 524 |
+
|
| 525 |
+
fname = os.path.basename(path)
|
| 526 |
+
is_genus = fname == "genus.json"
|
| 527 |
+
|
| 528 |
+
genus = _norm_str(doc.get("genus")) or genus_dir_name
|
| 529 |
+
level = "genus" if is_genus else "species"
|
| 530 |
+
|
| 531 |
+
species: Optional[str]
|
| 532 |
+
if is_genus:
|
| 533 |
+
species = None
|
| 534 |
+
sections = emit_genus_sections(doc, genus=genus)
|
| 535 |
+
else:
|
| 536 |
+
species = _norm_str(doc.get("species")) or os.path.splitext(fname)[0]
|
| 537 |
+
sections = emit_species_sections(doc, genus=genus, species=species)
|
| 538 |
+
|
| 539 |
+
for sec in sections:
|
| 540 |
+
section = _norm_str(sec.get("section"))
|
| 541 |
+
role = _norm_str(sec.get("role"))
|
| 542 |
+
text = _norm_str(sec.get("text"))
|
| 543 |
+
|
| 544 |
+
if not section or not role or not text:
|
| 545 |
+
continue
|
| 546 |
+
|
| 547 |
+
chunks = chunk_text_by_paragraph(text, max_chars=max_chars)
|
| 548 |
+
for idx, chunk in enumerate(chunks):
|
| 549 |
+
if not chunk.strip():
|
| 550 |
+
continue
|
| 551 |
+
|
| 552 |
+
rec_id = f"{genus}|{species or 'GENUS'}|{section}|{idx}"
|
| 553 |
+
|
| 554 |
+
docs_for_embedding.append(chunk)
|
| 555 |
+
meta.append(
|
| 556 |
+
{
|
| 557 |
+
"id": rec_id,
|
| 558 |
+
"level": level,
|
| 559 |
+
"genus": genus,
|
| 560 |
+
"species": species,
|
| 561 |
+
"section": section,
|
| 562 |
+
"role": role,
|
| 563 |
+
"text": chunk,
|
| 564 |
+
"source_file": os.path.relpath(path),
|
| 565 |
+
"chunk_id": idx,
|
| 566 |
+
# Optional: helps later for field-level weighting
|
| 567 |
+
"field_key": None,
|
| 568 |
+
}
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
if not docs_for_embedding:
|
| 572 |
+
return {
|
| 573 |
+
"ok": False,
|
| 574 |
+
"message": "No valid sections emitted from KB JSON files. Check schema/contents.",
|
| 575 |
+
"num_files": len(kb_entries),
|
| 576 |
+
"num_json_errors": num_json_errors,
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
embeddings = embed_texts(docs_for_embedding, normalize=True)
|
| 580 |
+
|
| 581 |
+
index_records: List[Dict[str, Any]] = []
|
| 582 |
+
for m, emb in zip(meta, embeddings):
|
| 583 |
+
rec = dict(m)
|
| 584 |
+
rec["embedding"] = emb.tolist()
|
| 585 |
+
index_records.append(rec)
|
| 586 |
+
|
| 587 |
+
with open(INDEX_PATH, "w", encoding="utf-8") as f:
|
| 588 |
+
json.dump(
|
| 589 |
+
{
|
| 590 |
+
"version": 2,
|
| 591 |
+
"model_name": EMBEDDING_MODEL_NAME,
|
| 592 |
+
"record_schema": {
|
| 593 |
+
"id": "str",
|
| 594 |
+
"level": "genus|species",
|
| 595 |
+
"genus": "str",
|
| 596 |
+
"species": "str|null",
|
| 597 |
+
"section": "str",
|
| 598 |
+
"role": "str",
|
| 599 |
+
"text": "str",
|
| 600 |
+
"source_file": "str",
|
| 601 |
+
"chunk_id": "int",
|
| 602 |
+
"embedding": "list[float]",
|
| 603 |
+
},
|
| 604 |
+
"stats": {
|
| 605 |
+
"num_files": len(kb_entries),
|
| 606 |
+
"num_records": len(index_records),
|
| 607 |
+
"num_json_errors": num_json_errors,
|
| 608 |
+
"chunk_max_chars": max_chars,
|
| 609 |
+
},
|
| 610 |
+
"records": index_records,
|
| 611 |
+
},
|
| 612 |
+
f,
|
| 613 |
+
ensure_ascii=False,
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
return {
|
| 617 |
+
"ok": True,
|
| 618 |
+
"message": "RAG index built successfully (section-aware, declarative expected profiles).",
|
| 619 |
+
"index_path": INDEX_PATH,
|
| 620 |
+
"num_records": len(index_records),
|
| 621 |
+
"num_files": len(kb_entries),
|
| 622 |
+
"num_json_errors": num_json_errors,
|
| 623 |
+
"chunk_max_chars": max_chars,
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
if __name__ == "__main__":
|
| 628 |
+
summary = build_rag_index()
|
| 629 |
+
print(json.dumps(summary, indent=2))
|
training/schema_expander.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training/schema_expander.py
|
| 2 |
+
# ------------------------------------------------------------
|
| 3 |
+
# Stage 10C — SAFE schema expansion
|
| 4 |
+
#
|
| 5 |
+
# Core fields = EXACT columns in bacteria_db.xlsx.
|
| 6 |
+
# Extended fields = ONLY the ones NOT in DB and NOT in existing schema.
|
| 7 |
+
#
|
| 8 |
+
# This version:
|
| 9 |
+
# - NEVER adds core fields to extended schema.
|
| 10 |
+
# - Only adds true extended fields found in gold tests.
|
| 11 |
+
# - Logs ambiguous or rare fields to proposals file.
|
| 12 |
+
# - Reports field frequencies & values seen for debugging.
|
| 13 |
+
# ------------------------------------------------------------
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import json
|
| 19 |
+
from typing import Dict, Any, List
|
| 20 |
+
from collections import Counter
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
|
| 23 |
+
import pandas as pd
|
| 24 |
+
|
| 25 |
+
from engine.schema import (
|
| 26 |
+
load_extended_schema,
|
| 27 |
+
save_extended_schema,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# ------------------------------------------------------------
|
| 31 |
+
# Paths
|
| 32 |
+
# ------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
GOLD_PATH = "training/gold_tests.json"
|
| 35 |
+
EXTENDED_SCHEMA_PATH = "data/extended_schema.json"
|
| 36 |
+
PROPOSALS_PATH = "data/extended_proposals.jsonl"
|
| 37 |
+
|
| 38 |
+
# Minimum frequency before auto-adding a new extended field
|
| 39 |
+
MIN_FIELD_FREQ = 5
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ------------------------------------------------------------
|
| 43 |
+
# Helper: load gold tests
|
| 44 |
+
# ------------------------------------------------------------
|
| 45 |
+
|
| 46 |
+
def _load_gold_tests() -> List[Dict[str, Any]]:
|
| 47 |
+
if not os.path.exists(GOLD_PATH):
|
| 48 |
+
return []
|
| 49 |
+
with open(GOLD_PATH, "r", encoding="utf-8") as f:
|
| 50 |
+
try:
|
| 51 |
+
data = json.load(f)
|
| 52 |
+
return data if isinstance(data, list) else []
|
| 53 |
+
except Exception:
|
| 54 |
+
return []
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ------------------------------------------------------------
|
| 58 |
+
# Helper: load DB columns (TRUE core schema)
|
| 59 |
+
# ------------------------------------------------------------
|
| 60 |
+
|
| 61 |
+
def _load_db_columns() -> List[str]:
|
| 62 |
+
candidates = [
|
| 63 |
+
os.path.join("data", "bacteria_db.xlsx"),
|
| 64 |
+
"bacteria_db.xlsx",
|
| 65 |
+
]
|
| 66 |
+
for p in candidates:
|
| 67 |
+
if os.path.exists(p):
|
| 68 |
+
try:
|
| 69 |
+
df = pd.read_excel(p)
|
| 70 |
+
return [c.strip() for c in df.columns]
|
| 71 |
+
except Exception:
|
| 72 |
+
continue
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ------------------------------------------------------------
|
| 77 |
+
# Decide if field name is safe for auto-adding
|
| 78 |
+
# ------------------------------------------------------------
|
| 79 |
+
|
| 80 |
+
def _is_safe_field_name(name: str) -> bool:
|
| 81 |
+
n = name.strip()
|
| 82 |
+
if not n:
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
low = n.lower()
|
| 86 |
+
|
| 87 |
+
# Ignore extremely short or generic names
|
| 88 |
+
if len(n) < 4:
|
| 89 |
+
return False
|
| 90 |
+
if low in {"test", "growth", "acid", "base", "value", "result"}:
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
# Clear biochemical patterns
|
| 94 |
+
patterns = [
|
| 95 |
+
"hydrolysis",
|
| 96 |
+
"fermentation",
|
| 97 |
+
"decarboxylase",
|
| 98 |
+
"dihydrolase",
|
| 99 |
+
"reduction",
|
| 100 |
+
"utilization",
|
| 101 |
+
"tolerance",
|
| 102 |
+
"solubility",
|
| 103 |
+
"oxidation",
|
| 104 |
+
"lysis",
|
| 105 |
+
"susceptibility",
|
| 106 |
+
"resistance",
|
| 107 |
+
"pyruvate",
|
| 108 |
+
"lecithinase",
|
| 109 |
+
"lipase",
|
| 110 |
+
"casein",
|
| 111 |
+
"hippurate",
|
| 112 |
+
"tyrosine",
|
| 113 |
+
]
|
| 114 |
+
if any(pat in low for pat in patterns):
|
| 115 |
+
return True
|
| 116 |
+
|
| 117 |
+
# Known short disc tests
|
| 118 |
+
known_short = {"CAMP", "PYR", "Optochin", "Bacitracin", "Novobiocin"}
|
| 119 |
+
if n in known_short:
|
| 120 |
+
return True
|
| 121 |
+
|
| 122 |
+
# If contains "test" and more than one word → likely legitimate
|
| 123 |
+
if "test" in low and " " in low:
|
| 124 |
+
return True
|
| 125 |
+
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ------------------------------------------------------------
|
| 130 |
+
# Log proposal (rare/ambiguous fields)
|
| 131 |
+
# ------------------------------------------------------------
|
| 132 |
+
|
| 133 |
+
def _append_proposal(record: Dict[str, Any]) -> None:
|
| 134 |
+
os.makedirs(os.path.dirname(PROPOSALS_PATH), exist_ok=True)
|
| 135 |
+
with open(PROPOSALS_PATH, "a", encoding="utf-8") as f:
|
| 136 |
+
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ------------------------------------------------------------
|
| 140 |
+
# MAIN ENTRY — SAFE SCHEMA EXPANSION
|
| 141 |
+
# ------------------------------------------------------------
|
| 142 |
+
|
| 143 |
+
def expand_schema() -> Dict[str, Any]:
|
| 144 |
+
gold = _load_gold_tests()
|
| 145 |
+
if not gold:
|
| 146 |
+
return {
|
| 147 |
+
"ok": False,
|
| 148 |
+
"message": f"No gold tests found at {GOLD_PATH}",
|
| 149 |
+
"auto_added_fields": {},
|
| 150 |
+
"proposed_fields": [],
|
| 151 |
+
"schema_path": EXTENDED_SCHEMA_PATH,
|
| 152 |
+
"proposals_path": PROPOSALS_PATH,
|
| 153 |
+
"unknown_fields_raw": {},
|
| 154 |
+
"field_frequencies": {},
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
db_columns = set(_load_db_columns()) # TRUE core schema
|
| 158 |
+
extended_schema = load_extended_schema(EXTENDED_SCHEMA_PATH)
|
| 159 |
+
extended_fields = set(extended_schema.keys())
|
| 160 |
+
|
| 161 |
+
# Counter for unknown fields
|
| 162 |
+
field_counts: Counter[str] = Counter()
|
| 163 |
+
field_values: Dict[str, Counter[str]] = {}
|
| 164 |
+
|
| 165 |
+
for test in gold:
|
| 166 |
+
expected = test.get("expected", {})
|
| 167 |
+
if not isinstance(expected, dict):
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
for field, value in expected.items():
|
| 171 |
+
fname = str(field).strip()
|
| 172 |
+
if not fname:
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
# Skip core DB fields
|
| 176 |
+
if fname in db_columns:
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
# Skip already-known extended fields
|
| 180 |
+
if fname in extended_fields:
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
# Count unknowns
|
| 184 |
+
field_counts[fname] += 1
|
| 185 |
+
if fname not in field_values:
|
| 186 |
+
field_values[fname] = Counter()
|
| 187 |
+
field_values[fname][str(value).strip()] += 1
|
| 188 |
+
|
| 189 |
+
auto_added: Dict[str, Any] = {}
|
| 190 |
+
proposed: List[Dict[str, Any]] = []
|
| 191 |
+
|
| 192 |
+
# Decide which unknown fields to auto-add
|
| 193 |
+
for fname, freq in field_counts.items():
|
| 194 |
+
values_seen = dict(field_values.get(fname, {}))
|
| 195 |
+
|
| 196 |
+
if freq >= MIN_FIELD_FREQ and _is_safe_field_name(fname):
|
| 197 |
+
# Auto-add as extended test
|
| 198 |
+
extended_schema[fname] = {
|
| 199 |
+
"value_type": "enum_PNV",
|
| 200 |
+
"description": "Auto-added from gold tests (Stage 10C)",
|
| 201 |
+
"values": list(values_seen.keys()),
|
| 202 |
+
}
|
| 203 |
+
auto_added[fname] = {
|
| 204 |
+
"freq": freq,
|
| 205 |
+
"values_seen": list(values_seen.keys()),
|
| 206 |
+
}
|
| 207 |
+
else:
|
| 208 |
+
# Log proposal for later review
|
| 209 |
+
proposed.append(
|
| 210 |
+
{
|
| 211 |
+
"field_name": fname,
|
| 212 |
+
"freq": freq,
|
| 213 |
+
"values_seen": values_seen,
|
| 214 |
+
}
|
| 215 |
+
)
|
| 216 |
+
_append_proposal(
|
| 217 |
+
{
|
| 218 |
+
"timestamp": datetime.utcnow().isoformat() + "Z",
|
| 219 |
+
"field_name": fname,
|
| 220 |
+
"freq": freq,
|
| 221 |
+
"values_seen": values_seen,
|
| 222 |
+
}
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Save updated schema
|
| 226 |
+
if auto_added:
|
| 227 |
+
save_extended_schema(extended_schema, EXTENDED_SCHEMA_PATH)
|
| 228 |
+
|
| 229 |
+
return {
|
| 230 |
+
"ok": True,
|
| 231 |
+
"auto_added_fields": auto_added,
|
| 232 |
+
"proposed_fields": proposed,
|
| 233 |
+
"schema_path": EXTENDED_SCHEMA_PATH,
|
| 234 |
+
"proposals_path": PROPOSALS_PATH,
|
| 235 |
+
"unknown_fields_raw": {f: dict(cnt) for f, cnt in field_values.items()},
|
| 236 |
+
"field_frequencies": dict(field_counts),
|
| 237 |
+
}
|
training/signal_trainer.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training/signal_trainer.py
|
| 2 |
+
# ------------------------------------------------------------
|
| 3 |
+
# Stage 10C placeholder:
|
| 4 |
+
# Safely returns a no-op result for signal training.
|
| 5 |
+
# This MUST NOT crash during import.
|
| 6 |
+
# ------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
from typing import Dict, Any
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
SIGNALS_PATH = "data/signals_catalog.json"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def train_signals() -> Dict[str, Any]:
|
| 18 |
+
"""
|
| 19 |
+
Placeholder trainer. Does nothing except ensure signals_catalog.json exists.
|
| 20 |
+
Must NEVER crash.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
# Ensure signals catalog exists
|
| 24 |
+
if not os.path.exists(SIGNALS_PATH):
|
| 25 |
+
try:
|
| 26 |
+
with open(SIGNALS_PATH, "w", encoding="utf-8") as f:
|
| 27 |
+
json.dump({}, f, indent=2, ensure_ascii=False)
|
| 28 |
+
except Exception:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
return {
|
| 32 |
+
"ok": True,
|
| 33 |
+
"message": "Signal trainer not implemented yet (Stage 10C placeholder).",
|
| 34 |
+
"signals_catalog_path": SIGNALS_PATH,
|
| 35 |
+
}
|