EphAsad commited on
Commit
974a888
·
verified ·
1 Parent(s): 0c848a5

Update rag/rag_generator.py

Browse files
Files changed (1) hide show
  1. rag/rag_generator.py +484 -445
rag/rag_generator.py CHANGED
@@ -1,446 +1,485 @@
1
- # rag/rag_generator.py
2
- # ============================================================
3
- # RAG generator using google/flan-t5-large (CPU-friendly)
4
- #
5
- # Goal (user-visible, structured, deterministic-first):
6
- # - Show the user:
7
- # KEY TRAITS:
8
- # CONFLICTS:
9
- # CONCLUSION:
10
- # - KEY TRAITS and CONFLICTS are extracted deterministically from the
11
- # shaped retriever context (preferred).
12
- # - The LLM only writes the CONCLUSION (2–5 sentences) based on those
13
- # extracted sections.
14
- #
15
- # Reliability:
16
- # - flan-t5 sometimes echoes prompt instructions.
17
- # - We keep the prompt extremely short and avoid imperative bullet rules.
18
- # - We keep deterministic fallback logic if the LLM output is garbage/echo.
19
- #
20
- # Expected usage:
21
- # ctx = retrieve_rag_context(..., parsed_fields=...)
22
- # explanation = generate_genus_rag_explanation(
23
- # phenotype_text=text,
24
- # rag_context=ctx.get("llm_context_shaped") or ctx.get("llm_context"),
25
- # genus=genus
26
- # )
27
- #
28
- # Optional HF Space logs:
29
- # export BACTAI_RAG_GEN_LOG_INPUT=1
30
- # export BACTAI_RAG_GEN_LOG_OUTPUT=1
31
- # ============================================================
32
-
33
- from __future__ import annotations
34
-
35
- import os
36
- import re
37
- import torch
38
- from transformers import T5ForConditionalGeneration, T5Tokenizer
39
-
40
-
41
- # ------------------------------------------------------------
42
- # MODEL CONFIG
43
- # ------------------------------------------------------------
44
-
45
- MODEL_NAME = "google/flan-t5-large"
46
-
47
- _tokenizer: T5Tokenizer | None = None
48
- _model: T5ForConditionalGeneration | None = None
49
-
50
- # Keep small for CPU + to reduce prompt truncation weirdness
51
- _MAX_INPUT_TOKENS = 768
52
- _DEFAULT_MAX_NEW_TOKENS = 160
53
-
54
- # Hard cap the context chars we feed to T5 (prevents the model focusing on junk)
55
- _CONTEXT_CHAR_CAP = 2400
56
-
57
-
58
- def _get_model() -> tuple[T5Tokenizer, T5ForConditionalGeneration]:
59
- global _tokenizer, _model
60
- if _tokenizer is None or _model is None:
61
- _tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
62
- _model = T5ForConditionalGeneration.from_pretrained(
63
- MODEL_NAME,
64
- device_map="auto",
65
- torch_dtype=torch.float32,
66
- )
67
- return _tokenizer, _model
68
-
69
-
70
- # ------------------------------------------------------------
71
- # DEBUG LOGGING (HF Space logs)
72
- # ------------------------------------------------------------
73
-
74
- RAG_GEN_LOG_INPUT = os.getenv("BACTAI_RAG_GEN_LOG_INPUT", "0").strip() == "1"
75
- RAG_GEN_LOG_OUTPUT = os.getenv("BACTAI_RAG_GEN_LOG_OUTPUT", "0").strip() == "1"
76
-
77
-
78
- def _log_block(title: str, body: str) -> None:
79
- print("=" * 80)
80
- print(f"RAG GENERATOR DEBUG — {title}")
81
- print("=" * 80)
82
- print(body.strip() if body else "")
83
- print()
84
-
85
-
86
- # ------------------------------------------------------------
87
- # PROMPT (LLM WRITES ONLY THE CONCLUSION)
88
- # ------------------------------------------------------------
89
-
90
- # Intentionally minimal. No "rules list", no bullets specification.
91
- # The LLM sees ONLY extracted matches/conflicts and writes a short conclusion.
92
- RAG_PROMPT = """summarize: Evaluate whether the phenotype fits the target genus using the provided matches and conflicts.
93
-
94
- Target genus: {genus}
95
-
96
- Key traits that match:
97
- {matches}
98
-
99
- Conflicts:
100
- {conflicts}
101
-
102
- Write a short conclusion (2–5 sentences) stating whether this is a strong, moderate, or tentative genus match, and briefly mention the most important matches and conflicts.
103
- """
104
-
105
-
106
- # ------------------------------------------------------------
107
- # OUTPUT CLEANUP + ECHO DETECTION
108
- # ------------------------------------------------------------
109
-
110
- _BAD_SUBSTRINGS = (
111
- "summarize:",
112
- "target genus",
113
- "key traits that match",
114
- "write a short conclusion",
115
- "conflicts:",
116
- )
117
-
118
- def _clean_generation(text: str) -> str:
119
- s = (text or "").strip()
120
- if not s:
121
- return ""
122
-
123
- # collapse excessive whitespace/newlines
124
- s = re.sub(r"\s*\n+\s*", " ", s).strip()
125
- s = re.sub(r"\s{2,}", " ", s).strip()
126
-
127
- # guard runaway length
128
- if len(s) > 900:
129
- s = s[:900].rstrip() + "..."
130
-
131
- return s
132
-
133
-
134
- def _looks_like_echo_or_garbage(text: str) -> bool:
135
- s = (text or "").strip()
136
- if not s:
137
- return True
138
-
139
- # extremely short / non-sentence
140
- if len(s) < 25:
141
- return True
142
-
143
- low = s.lower()
144
- if any(bad in low for bad in _BAD_SUBSTRINGS):
145
- return True
146
-
147
- # Must look like actual prose
148
- if "." not in s and "because" not in low and "match" not in low and "fits" not in low:
149
- return True
150
-
151
- return False
152
-
153
-
154
- # ------------------------------------------------------------
155
- # EXTRACT KEY TRAITS + CONFLICTS FROM SHAPED CONTEXT
156
- # ------------------------------------------------------------
157
-
158
- # Shaped context format (example):
159
- # KEY MATCHES:
160
- # - Trait: Value (matches reference: ...)
161
- #
162
- # CONFLICTS (observed vs CORE traits):
163
- # - Trait: Value (conflicts reference: ...)
164
- # or:
165
- # CONFLICTS: Not specified.
166
-
167
- _KEY_MATCHES_HEADER_RE = re.compile(r"^\s*KEY MATCHES\s*:\s*$", re.IGNORECASE)
168
- _CONFLICTS_HEADER_RE = re.compile(r"^\s*CONFLICTS\b.*:\s*$", re.IGNORECASE)
169
- _CONFLICTS_INLINE_NONE_RE = re.compile(r"^\s*CONFLICTS\s*:\s*not specified\.?\s*$", re.IGNORECASE)
170
-
171
- _MATCH_LINE_RE = re.compile(
172
- r"^\s*-\s*([^:]+)\s*:\s*(.+?)\s*\(matches reference:\s*(.+?)\)\s*$",
173
- re.IGNORECASE,
174
- )
175
- _CONFLICT_LINE_RE = re.compile(
176
- r"^\s*-\s*([^:]+)\s*:\s*(.+?)\s*\(conflicts reference:\s*(.+?)\)\s*$",
177
- re.IGNORECASE,
178
- )
179
-
180
- # More permissive bullet capture (if shaper changes slightly)
181
- _GENERIC_BULLET_RE = re.compile(r"^\s*-\s*(.+?)\s*$")
182
-
183
-
184
- def _extract_key_traits_and_conflicts(shaped_ctx: str) -> tuple[list[str], list[str], bool]:
185
- """
186
- Extracts KEY MATCHES and CONFLICTS bullets from shaped retriever context.
187
-
188
- Returns:
189
- (key_traits, conflicts, found_structured_headers)
190
-
191
- - key_traits items are short: "Trait: ObservedValue"
192
- - conflicts items are short: "Trait: ObservedValue"
193
- """
194
- key_traits: list[str] = []
195
- conflicts: list[str] = []
196
-
197
- lines = (shaped_ctx or "").splitlines()
198
- if not lines:
199
- return key_traits, conflicts, False
200
-
201
- in_matches = False
202
- in_conflicts = False
203
- saw_headers = False
204
-
205
- for raw in lines:
206
- line = raw.rstrip("\n")
207
-
208
- # detect headers
209
- if _KEY_MATCHES_HEADER_RE.match(line.strip()):
210
- in_matches = True
211
- in_conflicts = False
212
- saw_headers = True
213
- continue
214
-
215
- if _CONFLICTS_INLINE_NONE_RE.match(line.strip()):
216
- in_matches = False
217
- in_conflicts = False
218
- saw_headers = True
219
- # explicit "no conflicts"
220
- continue
221
-
222
- if _CONFLICTS_HEADER_RE.match(line.strip()):
223
- in_matches = False
224
- in_conflicts = True
225
- saw_headers = True
226
- continue
227
-
228
- # stop capture if another section begins (common shaper headings)
229
- if saw_headers and (line.strip().endswith(":") and not line.strip().startswith("-")):
230
- # If it's a new heading (and not one of our two), stop both
231
- if not _KEY_MATCHES_HEADER_RE.match(line.strip()) and not _CONFLICTS_HEADER_RE.match(line.strip()):
232
- in_matches = False
233
- in_conflicts = False
234
-
235
- # capture bullets under each section
236
- if in_matches and line.strip().startswith("-"):
237
- m = _MATCH_LINE_RE.match(line.strip())
238
- if m:
239
- trait = m.group(1).strip()
240
- obs = m.group(2).strip()
241
- key_traits.append(f"{trait}: {obs}")
242
- else:
243
- g = _GENERIC_BULLET_RE.match(line.strip())
244
- if g:
245
- key_traits.append(g.group(1).strip())
246
- continue
247
-
248
- if in_conflicts and line.strip().startswith("-"):
249
- c = _CONFLICT_LINE_RE.match(line.strip())
250
- if c:
251
- trait = c.group(1).strip()
252
- obs = c.group(2).strip()
253
- conflicts.append(f"{trait}: {obs}")
254
- else:
255
- g = _GENERIC_BULLET_RE.match(line.strip())
256
- if g:
257
- conflicts.append(g.group(1).strip())
258
- continue
259
-
260
- return key_traits, conflicts, saw_headers
261
-
262
-
263
- def _extract_matches_conflicts_legacy(shaped_ctx: str) -> tuple[list[str], list[str]]:
264
- """
265
- Legacy extraction based purely on (matches reference: ...) / (conflicts reference: ...)
266
- anywhere in the text. Useful if headers are missing.
267
- """
268
- matches: list[str] = []
269
- conflicts: list[str] = []
270
-
271
- for raw in (shaped_ctx or "").splitlines():
272
- line = raw.strip()
273
- if not line.startswith("-"):
274
- continue
275
-
276
- m = _MATCH_LINE_RE.match(line)
277
- if m:
278
- trait = m.group(1).strip()
279
- obs = m.group(2).strip()
280
- matches.append(f"{trait}: {obs}")
281
- continue
282
-
283
- c = _CONFLICT_LINE_RE.match(line)
284
- if c:
285
- trait = c.group(1).strip()
286
- obs = c.group(2).strip()
287
- conflicts.append(f"{trait}: {obs}")
288
- continue
289
-
290
- return matches, conflicts
291
-
292
-
293
- def _format_bullets(items: list[str], *, none_text: str) -> str:
294
- if not items:
295
- return none_text
296
- return "\n".join(f"- {x}" for x in items)
297
-
298
-
299
- # ------------------------------------------------------------
300
- # DETERMINISTIC CONCLUSION FALLBACK
301
- # ------------------------------------------------------------
302
-
303
- def _deterministic_conclusion(genus: str, key_traits: list[str], conflicts: list[str]) -> str:
304
- g = (genus or "").strip() or "Unknown"
305
-
306
- m = key_traits[:4]
307
- c = conflicts[:2]
308
-
309
- if m and c:
310
- return (
311
- f"This is a probable match to {g} because it aligns with key traits such as "
312
- f"{', '.join(m)}. However, there are conflicts ({', '.join(c)}), so treat this "
313
- f"as a moderate/tentative genus-level fit and consider re-checking the conflicting tests."
314
- )
315
- if m and not c:
316
- return (
317
- f"This phenotype is consistent with {g} based on key matching traits such as "
318
- f"{', '.join(m)}. No major conflicts were detected against the retrieved core genus traits, "
319
- f"supporting a strong genus-level match."
320
- )
321
- if (not m) and c:
322
- return (
323
- f"This phenotype does not cleanly fit {g} because it conflicts with core traits "
324
- f"({', '.join(c)}). Consider re-checking those tests or comparing against the next-ranked genera."
325
- )
326
-
327
- return (
328
- f"Reference evidence was available for {g}, but no clear matches or conflicts could be extracted "
329
- f"from the shaped context. Try increasing top_k genus chunks or ensuring parsed_fields are being "
330
- f"passed into retrieve_rag_context so the shaper can compute KEY MATCHES and CONFLICTS."
331
- )
332
-
333
-
334
- def _trim_context(ctx: str) -> str:
335
- s = (ctx or "").strip()
336
- if not s:
337
- return ""
338
- if len(s) <= _CONTEXT_CHAR_CAP:
339
- return s
340
- return s[:_CONTEXT_CHAR_CAP].rstrip() + "\n... (truncated)"
341
-
342
-
343
- # ------------------------------------------------------------
344
- # PUBLIC API
345
- # ------------------------------------------------------------
346
-
347
- def generate_genus_rag_explanation(
348
- phenotype_text: str,
349
- rag_context: str,
350
- genus: str,
351
- max_new_tokens: int = _DEFAULT_MAX_NEW_TOKENS,
352
- ) -> str:
353
- """
354
- Generates a structured RAG output intended for direct display:
355
-
356
- KEY TRAITS:
357
- - ...
358
- CONFLICTS:
359
- - ...
360
- CONCLUSION:
361
- ...
362
-
363
- Notes:
364
- - KEY TRAITS + CONFLICTS are extracted deterministically from the (shaped) context.
365
- - The LLM writes only the CONCLUSION.
366
- - If the LLM output is garbage/echo, we use a deterministic conclusion fallback.
367
- """
368
- tokenizer, model = _get_model()
369
-
370
- genus_clean = (genus or "").strip() or "Unknown"
371
- context = _trim_context(rag_context or "")
372
-
373
- if not context:
374
- return (
375
- "KEY TRAITS:\n"
376
- "- Not specified.\n\n"
377
- "CONFLICTS:\n"
378
- "- Not specified.\n\n"
379
- "CONCLUSION:\n"
380
- "No reference evidence was available to evaluate this genus against the observed phenotype."
381
- )
382
-
383
- # Prefer structured extraction (KEY MATCHES / CONFLICTS sections)
384
- key_traits, conflicts, saw_headers = _extract_key_traits_and_conflicts(context)
385
-
386
- # If the headers weren't found or extraction is empty, try legacy extraction
387
- if (not saw_headers) or (not key_traits and not conflicts):
388
- legacy_matches, legacy_conflicts = _extract_matches_conflicts_legacy(context)
389
- if legacy_matches or legacy_conflicts:
390
- key_traits = key_traits or legacy_matches
391
- conflicts = conflicts or legacy_conflicts
392
-
393
- key_traits_text = _format_bullets(key_traits, none_text="- Not specified.")
394
- conflicts_text = _format_bullets(conflicts, none_text="- Not specified.")
395
-
396
- # LLM: conclusion only
397
- prompt = RAG_PROMPT.format(
398
- genus=genus_clean,
399
- matches=key_traits_text,
400
- conflicts=conflicts_text,
401
- )
402
-
403
- if RAG_GEN_LOG_INPUT:
404
- _log_block("PROMPT (CONCLUSION-ONLY)", prompt[:3000] + ("\n...(truncated)" if len(prompt) > 3000 else ""))
405
-
406
- inputs = tokenizer(
407
- prompt,
408
- return_tensors="pt",
409
- truncation=True,
410
- max_length=_MAX_INPUT_TOKENS,
411
- ).to(model.device)
412
-
413
- output = model.generate(
414
- **inputs,
415
- max_new_tokens=max_new_tokens,
416
- temperature=0.0,
417
- num_beams=1,
418
- do_sample=False,
419
- repetition_penalty=1.2,
420
- no_repeat_ngram_size=3,
421
- )
422
-
423
- decoded = tokenizer.decode(output[0], skip_special_tokens=True).strip()
424
- cleaned = _clean_generation(decoded)
425
-
426
- if RAG_GEN_LOG_OUTPUT:
427
- _log_block("RAW OUTPUT (CONCLUSION)", decoded)
428
- _log_block("CLEANED OUTPUT (CONCLUSION)", cleaned)
429
-
430
- # If LLM output is junk, use deterministic conclusion
431
- if _looks_like_echo_or_garbage(cleaned):
432
- cleaned = _deterministic_conclusion(genus_clean, key_traits, conflicts)
433
- if RAG_GEN_LOG_OUTPUT:
434
- _log_block("FALLBACK CONCLUSION (DETERMINISTIC)", cleaned)
435
-
436
- # Final user-visible structured output
437
- final = (
438
- "KEY TRAITS:\n"
439
- f"{key_traits_text}\n\n"
440
- "CONFLICTS:\n"
441
- f"{conflicts_text}\n\n"
442
- "CONCLUSION:\n"
443
- f"{cleaned}"
444
- )
445
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  return final
 
1
+ # rag/rag_generator.py
2
+ # ============================================================
3
+ # RAG generator using facebook/bart-large (CPU-friendly)
4
+ #
5
+ # Evolution (confidence-aware, deterministic-first):
6
+ #
7
+ # - Confidence and escalation decisions are now made upstream in the
8
+ # retriever context shaper (CONFIDENCE ASSESSMENT block).
9
+ #
10
+ # - The generator now:
11
+ # 1) Extracts:
12
+ # KEY TRAITS
13
+ # CONFLICTS
14
+ # CONFIDENCE STATE
15
+ # RECOMMENDED ACTION
16
+ # 2) Selects a deterministic conclusion template
17
+ # 3) Optionally allows the LLM to paraphrase the conclusion
18
+ # while preserving the confidence state meaning
19
+ #
20
+ # - The LLM is *not allowed* to decide strength — it only
21
+ # rewrites the conclusion within the declared state.
22
+ #
23
+ # - If the LLM returns junk → deterministic template is used.
24
+ #
25
+ # ============================================================
26
+
27
+ from __future__ import annotations
28
+
29
+ import os
30
+ import re
31
+ import torch
32
+ from transformers import BartForConditionalGeneration, BartTokenizer
33
+
34
+
35
+ # ------------------------------------------------------------
36
+ # MODEL CONFIG
37
+ # ------------------------------------------------------------
38
+
39
+ MODEL_NAME = "facebook/bart-large"
40
+
41
+ _tokenizer: BartTokenizer | None = None
42
+ _model: BartForConditionalGeneration | None = None
43
+
44
+ _MAX_INPUT_TOKENS = 1020
45
+ _DEFAULT_MAX_NEW_TOKENS = 256
46
+
47
+ _CONTEXT_CHAR_CAP = 2400
48
+
49
+
50
+ def _get_model() -> tuple[BartTokenizer, BartForConditionalGeneration]:
51
+ global _tokenizer, _model
52
+ if _tokenizer is None or _model is None:
53
+ _tokenizer = BartTokenizer.from_pretrained(MODEL_NAME)
54
+ _model = BartForConditionalGeneration.from_pretrained(
55
+ MODEL_NAME,
56
+ device_map="auto",
57
+ torch_dtype=torch.float32,
58
+ )
59
+ return _tokenizer, _model
60
+
61
+
62
+ # ------------------------------------------------------------
63
+ # DEBUG LOGGING (HF Space logs)
64
+ # ------------------------------------------------------------
65
+
66
+ RAG_GEN_LOG_INPUT = os.getenv("BACTAI_RAG_GEN_LOG_INPUT", "0").strip() == "1"
67
+ RAG_GEN_LOG_OUTPUT = os.getenv("BACTAI_RAG_GEN_LOG_OUTPUT", "0").strip() == "1"
68
+
69
+
70
+ def _log_block(title: str, body: str) -> None:
71
+ print("=" * 80)
72
+ print(f"RAG GENERATOR DEBUG — {title}")
73
+ print("=" * 80)
74
+ print(body.strip() if body else "")
75
+ print()
76
+
77
+
78
+ # ------------------------------------------------------------
79
+ # PROMPT (LLM MAY PARAPHRASE — BUT NOT CHANGE CONFIDENCE)
80
+ # ------------------------------------------------------------
81
+
82
+ RAG_PROMPT = """Rewrite the conclusion for the target genus.
83
+ Do not change the confidence level. Do not introduce new reasoning.
84
+ Write 2–4 short sentences in clear, natural language.
85
+
86
+ Target genus: {genus}
87
+ Confidence state: {confidence_state}
88
+ Recommended action: {recommended_action}
89
+
90
+ Key traits that match:
91
+ {matches}
92
+
93
+ Conflicts:
94
+ {conflicts}
95
+
96
+ Rewrite the conclusion using the same meaning as the confidence state
97
+ and recommended action. Mention both matches and conflicts when present.
98
+ """
99
+
100
+
101
+ # ------------------------------------------------------------
102
+ # OUTPUT CLEANUP / GARBAGE DETECTION
103
+ # ------------------------------------------------------------
104
+
105
+ _BAD_SUBSTRINGS = (
106
+ "summarize:",
107
+ "target genus",
108
+ "confidence state",
109
+ "write a concise conclusion",
110
+ "conflicts:",
111
+ )
112
+
113
+ def _clean_generation(text: str) -> str:
114
+ s = (text or "").strip()
115
+ if not s:
116
+ return ""
117
+
118
+ s = re.sub(r"\s*\n+\s*", " ", s).strip()
119
+ s = re.sub(r"\s{2,}", " ", s).strip()
120
+
121
+ if len(s) > 900:
122
+ s = s[:900].rstrip() + "..."
123
+
124
+ return s
125
+
126
+
127
+ def _looks_like_echo_or_garbage(text: str) -> bool:
128
+ s = (text or "").strip()
129
+ if not s:
130
+ return True
131
+ if len(s) < 25:
132
+ return True
133
+
134
+ low = s.lower()
135
+ if any(bad in low for bad in _BAD_SUBSTRINGS):
136
+ return True
137
+
138
+ if "." not in s and "match" not in low and "conflict" not in low:
139
+ return True
140
+
141
+ return False
142
+
143
+
144
+ # ------------------------------------------------------------
145
+ # EXTRACT KEY TRAITS + CONFLICTS
146
+ # ------------------------------------------------------------
147
+
148
+ _KEY_MATCHES_HEADER_RE = re.compile(r"^\s*KEY MATCHES\s*:\s*$", re.IGNORECASE)
149
+ _CONFLICTS_HEADER_RE = re.compile(r"^\s*CONFLICTS\b.*:\s*$", re.IGNORECASE)
150
+ _CONFLICTS_INLINE_NONE_RE = re.compile(
151
+ r"^\s*CONFLICTS\s*:\s*not specified\.?\s*$",
152
+ re.IGNORECASE,
153
+ )
154
+
155
+ _MATCH_LINE_RE = re.compile(
156
+ r"^\s*-\s*([^:]+)\s*:\s*(.+?)\s*\(matches reference:\s*(.+?)\)\s*$",
157
+ re.IGNORECASE,
158
+ )
159
+
160
+ _CONFLICT_LINE_RE = re.compile(
161
+ r"^\s*-\s*([^:]+)\s*:\s*(.+?)\s*\(conflicts reference:\s*(.+?)\)\s*$",
162
+ re.IGNORECASE,
163
+ )
164
+
165
+ _GENERIC_BULLET_RE = re.compile(r"^\s*-\s*(.+?)\s*$")
166
+
167
+
168
+ def _extract_key_traits_and_conflicts(
169
+ shaped_ctx: str,
170
+ ) -> tuple[list[str], list[str], bool]:
171
+
172
+ key_traits: list[str] = []
173
+ conflicts: list[str] = []
174
+
175
+ lines = (shaped_ctx or "").splitlines()
176
+ if not lines:
177
+ return key_traits, conflicts, False
178
+
179
+ in_matches = False
180
+ in_conflicts = False
181
+ saw_headers = False
182
+
183
+ for raw in lines:
184
+ line = raw.rstrip("\n")
185
+
186
+ if _KEY_MATCHES_HEADER_RE.match(line.strip()):
187
+ in_matches, in_conflicts = True, False
188
+ saw_headers = True
189
+ continue
190
+
191
+ if _CONFLICTS_INLINE_NONE_RE.match(line.strip()):
192
+ in_matches = in_conflicts = False
193
+ saw_headers = True
194
+ continue
195
+
196
+ if _CONFLICTS_HEADER_RE.match(line.strip()):
197
+ in_matches, in_conflicts = False, True
198
+ saw_headers = True
199
+ continue
200
+
201
+ if saw_headers and (line.strip().endswith(":") and not line.strip().startswith("-")):
202
+ if not _KEY_MATCHES_HEADER_RE.match(line.strip()) and not _CONFLICTS_HEADER_RE.match(line.strip()):
203
+ in_matches = in_conflicts = False
204
+
205
+ if in_matches and line.strip().startswith("-"):
206
+ m = _MATCH_LINE_RE.match(line.strip())
207
+ if m:
208
+ key_traits.append(f"{m.group(1).strip()}: {m.group(2).strip()}")
209
+ else:
210
+ g = _GENERIC_BULLET_RE.match(line.strip())
211
+ if g:
212
+ key_traits.append(g.group(1).strip())
213
+ continue
214
+
215
+ if in_conflicts and line.strip().startswith("-"):
216
+ c = _CONFLICT_LINE_RE.match(line.strip())
217
+ if c:
218
+ conflicts.append(f"{c.group(1).strip()}: {c.group(2).strip()}")
219
+ else:
220
+ g = _GENERIC_BULLET_RE.match(line.strip())
221
+ if g:
222
+ conflicts.append(g.group(1).strip())
223
+ continue
224
+
225
+ return key_traits, conflicts, saw_headers
226
+
227
+
228
+ def _extract_matches_conflicts_legacy(shaped_ctx: str) -> tuple[list[str], list[str]]:
229
+ matches: list[str] = []
230
+ conflicts: list[str] = []
231
+
232
+ for raw in (shaped_ctx or "").splitlines():
233
+ line = raw.strip()
234
+ if not line.startswith("-"):
235
+ continue
236
+
237
+ m = _MATCH_LINE_RE.match(line)
238
+ if m:
239
+ matches.append(f"{m.group(1).strip()}: {m.group(2).strip()}")
240
+ continue
241
+
242
+ c = _CONFLICT_LINE_RE.match(line)
243
+ if c:
244
+ conflicts.append(f"{c.group(1).strip()}: {c.group(2).strip()}")
245
+ continue
246
+
247
+ return matches, conflicts
248
+
249
+
250
+ def _format_bullets(items: list[str], *, none_text: str) -> str:
251
+ if not items:
252
+ return none_text
253
+ return "\n".join(f"- {x}" for x in items)
254
+
255
+
256
+ # ------------------------------------------------------------
257
+ # CONFIDENCE STATE EXTRACTOR
258
+ # ------------------------------------------------------------
259
+
260
+ _CONF_STATE_RE = re.compile(r"^\s*-\s*Confidence State:\s*(.+?)\s*$", re.IGNORECASE)
261
+ _RECOMMENDED_ACTION_RE = re.compile(r"^\s*-\s*Recommended Action:\s*(.+?)\s*$", re.IGNORECASE)
262
+
263
+ def _extract_confidence_state(shaped_ctx: str) -> tuple[str | None, str | None]:
264
+ confidence_state = None
265
+ recommended_action = None
266
+
267
+ for raw in (shaped_ctx or "").splitlines():
268
+ line = raw.strip()
269
+
270
+ m = _CONF_STATE_RE.match(line)
271
+ if m:
272
+ confidence_state = m.group(1).strip()
273
+ continue
274
+
275
+ r = _RECOMMENDED_ACTION_RE.match(line)
276
+ if r:
277
+ recommended_action = r.group(1).strip()
278
+ continue
279
+
280
+ return confidence_state, recommended_action
281
+
282
+
283
+ # ------------------------------------------------------------
284
+ # DETERMINISTIC TEMPLATES
285
+ # ------------------------------------------------------------
286
+
287
+ def _template_conclusion(
288
+ genus: str,
289
+ confidence_state: str | None,
290
+ key_traits: list[str],
291
+ conflicts: list[str],
292
+ recommended_action: str | None,
293
+ ) -> str:
294
+
295
+ g = (genus or "").strip() or "Unknown"
296
+ rec = recommended_action or ""
297
+
298
+ m_short = ", ".join(key_traits[:3]) if key_traits else "no clearly supportive traits"
299
+ c_short = ", ".join(conflicts[:3]) if conflicts else None
300
+
301
+ if confidence_state is None:
302
+ return _deterministic_conclusion(g, key_traits, conflicts)
303
+
304
+ cs = confidence_state.lower()
305
+
306
+ if "strong" in cs:
307
+ return (
308
+ f"This phenotype is indicative of {g} with no conflicting traits observed. "
309
+ f"It aligns well with key genus characteristics such as {m_short}. "
310
+ f"{rec}".strip()
311
+ )
312
+
313
+ if "probable" in cs or "conflicts present" in cs or "cautious" in cs:
314
+ if c_short:
315
+ return (
316
+ f"The phenotype is consistent with {g} based on supporting traits including {m_short}; "
317
+ f"however, conflicting results are present ({c_short}), which reduces confidence. "
318
+ f"{rec}".strip()
319
+ )
320
+ else:
321
+ return (
322
+ f"The phenotype is broadly consistent with {g}, but limited conflicting information "
323
+ f"reduces certainty. {rec}".strip()
324
+ )
325
+
326
+ if "inconclusive" in cs or "conflicting" in cs:
327
+ return (
328
+ f"The top genus match is {g}; however, the phenotype is inconclusive due to conflicting "
329
+ f"test results ({c_short or 'multiple conflicting traits'}). {rec}".strip()
330
+ )
331
+
332
+ if "weak" in cs:
333
+ return (
334
+ f"The available phenotype provides weak evidence for {g}. "
335
+ f"Additional testing or phenotype data is recommended. {rec}".strip()
336
+ )
337
+
338
+ return _deterministic_conclusion(g, key_traits, conflicts)
339
+
340
+
341
+ # ------------------------------------------------------------
342
+ # BACKSTOP DETERMINISTIC CONCLUSION
343
+ # ------------------------------------------------------------
344
+
345
+ def _deterministic_conclusion(genus: str, key_traits: list[str], conflicts: list[str]) -> str:
346
+ g = (genus or "").strip() or "Unknown"
347
+ m = key_traits[:4]
348
+ c = conflicts[:2]
349
+
350
+ if m and c:
351
+ return (
352
+ f"This is a probable match to {g} because it aligns with key traits such as "
353
+ f"{', '.join(m)}. However, there are conflicts ({', '.join(c)}), so treat this "
354
+ f"as a moderate/tentative genus-level fit and consider re-checking the conflicting tests."
355
+ )
356
+ if m and not c:
357
+ return (
358
+ f"This phenotype is consistent with {g} based on key matching traits such as "
359
+ f"{', '.join(m)}. No major conflicts were detected against the core genus traits, "
360
+ f"supporting a strong genus-level match."
361
+ )
362
+ if (not m) and c:
363
+ return (
364
+ f"This phenotype does not cleanly fit {g} because it conflicts with core traits "
365
+ f"({', '.join(c)}). Consider re-checking those tests or comparing against the next-ranked genera."
366
+ )
367
+
368
+ return (
369
+ f"Reference evidence was available for {g}, but no clear matches or conflicts could be extracted "
370
+ f"from the shaped context."
371
+ )
372
+
373
+
374
+ def _trim_context(ctx: str) -> str:
375
+ s = (ctx or "").strip()
376
+ if not s:
377
+ return ""
378
+ if len(s) <= _CONTEXT_CHAR_CAP:
379
+ return s
380
+ return s[:_CONTEXT_CHAR_CAP].rstrip() + "\n... (truncated)"
381
+
382
+
383
+ # ------------------------------------------------------------
384
+ # PUBLIC API
385
+ # ------------------------------------------------------------
386
+
387
+ def generate_genus_rag_explanation(
388
+ phenotype_text: str,
389
+ rag_context: str,
390
+ genus: str,
391
+ max_new_tokens: int = _DEFAULT_MAX_NEW_TOKENS,
392
+ ) -> str:
393
+
394
+ tokenizer, model = _get_model()
395
+
396
+ genus_clean = (genus or "").strip() or "Unknown"
397
+ context = _trim_context(rag_context or "")
398
+
399
+ if not context:
400
+ return (
401
+ "KEY TRAITS:\n"
402
+ "- Not specified.\n\n"
403
+ "CONFLICTS:\n"
404
+ "- Not specified.\n\n"
405
+ "CONCLUSION:\n"
406
+ "No reference evidence was available to evaluate this genus against the observed phenotype."
407
+ )
408
+
409
+ key_traits, conflicts, saw_headers = _extract_key_traits_and_conflicts(context)
410
+
411
+ if (not saw_headers) or (not key_traits and not conflicts):
412
+ legacy_matches, legacy_conflicts = _extract_matches_conflicts_legacy(context)
413
+ if legacy_matches or legacy_conflicts:
414
+ key_traits = key_traits or legacy_matches
415
+ conflicts = conflicts or legacy_conflicts
416
+
417
+ key_traits_text = _format_bullets(key_traits, none_text="- Not specified.")
418
+ conflicts_text = _format_bullets(conflicts, none_text="- Not specified.")
419
+
420
+ confidence_state, recommended_action = _extract_confidence_state(context)
421
+
422
+ template_conclusion = _template_conclusion(
423
+ genus_clean,
424
+ confidence_state,
425
+ key_traits,
426
+ conflicts,
427
+ recommended_action,
428
+ )
429
+
430
+ if confidence_state is None:
431
+ final_conclusion = template_conclusion
432
+
433
+ else:
434
+ prompt = RAG_PROMPT.format(
435
+ genus=genus_clean,
436
+ confidence_state=confidence_state,
437
+ recommended_action=recommended_action or "None",
438
+ matches=key_traits_text,
439
+ conflicts=conflicts_text,
440
+ )
441
+
442
+ if RAG_GEN_LOG_INPUT:
443
+ _log_block("PROMPT (CONFIDENCE-AWARE)", prompt)
444
+
445
+ inputs = tokenizer(
446
+ prompt,
447
+ return_tensors="pt",
448
+ truncation=True,
449
+ max_length=_MAX_INPUT_TOKENS,
450
+ ).to(model.device)
451
+
452
+ output = model.generate(
453
+ **inputs,
454
+ max_new_tokens=max_new_tokens,
455
+ temperature=0.0,
456
+ num_beams=3, # BART benefits from small beam search
457
+ do_sample=False,
458
+ repetition_penalty=1.2,
459
+ no_repeat_ngram_size=3,
460
+ )
461
+
462
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True).strip()
463
+ cleaned = _clean_generation(decoded)
464
+
465
+ if RAG_GEN_LOG_OUTPUT:
466
+ _log_block("RAW OUTPUT (CONCLUSION)", decoded)
467
+ _log_block("CLEANED OUTPUT", cleaned)
468
+
469
+ if _looks_like_echo_or_garbage(cleaned):
470
+ final_conclusion = template_conclusion
471
+ if RAG_GEN_LOG_OUTPUT:
472
+ _log_block("FALLBACK (TEMPLATE)", final_conclusion)
473
+ else:
474
+ final_conclusion = cleaned
475
+
476
+ final = (
477
+ "KEY TRAITS:\n"
478
+ f"{key_traits_text}\n\n"
479
+ "CONFLICTS:\n"
480
+ f"{conflicts_text}\n\n"
481
+ "CONCLUSION:\n"
482
+ f"{final_conclusion}"
483
+ )
484
+
485
  return final