#!/usr/bin/env python3 """ Enhanced Query Orchestrator with conversation support and article augmentation. Conceals hierarchical mechanics while using textbooks as truth source. """ import logging import re import yaml from pathlib import Path from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass from datetime import datetime from src.retrieval.hybrid_retriever import HybridRetriever from src.llm.gpt5_medical import GPT5Medical from src.orchestrator.smart_citations import insert_smart_citations logger = logging.getLogger(__name__) # Load citation policy _citation_policy = None def get_citation_policy(): """Load citation policy configuration.""" global _citation_policy if _citation_policy is None: policy_path = Path("configs/citation_policy.yaml") if policy_path.exists(): with open(policy_path) as f: _citation_policy = yaml.safe_load(f) else: # Default policy if config doesn't exist _citation_policy = { "include_doc_types": ["journal_article", "guideline", "systematic_review"], "exclude_doc_types": ["textbook_chapter", "papoip_chapter", "practical_guide_chapter"], "max_citations": 10, "min_pub_year": 2000 } return _citation_policy def filter_for_citation(docs): """Filter documents based on citation policy - hide textbooks.""" policy = get_citation_policy() include_types = set(policy.get("include_doc_types", [])) exclude_types = set(policy.get("exclude_doc_types", [])) min_year = policy.get("min_pub_year", 0) filtered = [] for doc in docs: # Check doc_id for textbook patterns doc_id = getattr(doc, 'doc_id', '').lower() # Skip if it's a textbook chapter if any(pattern in doc_id for pattern in ['papoip', 'practical_guide', 'bacada', '_enriched']): continue # Check authority tier - skip A1, A2, A3 (textbooks) if hasattr(doc, 'authority_tier') and doc.authority_tier in ['A1', 'A2', 'A3']: continue # Check doc_type if available doc_type = getattr(doc, 'doc_type', 'journal_article') if doc_type in exclude_types: continue if include_types and doc_type not in include_types: continue # Check year year = getattr(doc, 'year', 2024) if year < min_year: continue filtered.append(doc) # If nothing left, return top non-textbook docs if not filtered and docs: filtered = [d for d in docs if not any(p in getattr(d, 'doc_id', '').lower() for p in ['papoip', 'practical_guide', 'bacada', '_enriched'])][:5] # Cap at max citations max_citations = policy.get("max_citations", 10) return filtered[:max_citations] @dataclass class ConversationContext: """Stores conversation history for follow-up questions.""" query_history: List[str] response_history: List[str] retrieved_chunks: List[Dict] article_sources: List[Any] # Keep track of article sources timestamp: datetime def get_context_summary(self, max_tokens: int = 500) -> str: """Get a summary of recent conversation for context.""" if not self.query_history: return "" # Include last 2 Q&A pairs context_parts = [] for i in range(min(2, len(self.query_history))): idx = -(i + 1) context_parts.append(f"Previous Q: {self.query_history[idx]}") if len(self.response_history) > abs(idx) - 1: resp = self.response_history[idx][:200] + "..." if len(self.response_history[idx]) > 200 else self.response_history[idx] context_parts.append(f"Previous A: {resp}") return "\n".join(reversed(context_parts)) class EnhancedOrchestrator: """Enhanced orchestrator with conversation support and intelligent article augmentation.""" def __init__(self, retriever: HybridRetriever, llm_client: GPT5Medical): self.retriever = retriever self.llm = llm_client self.conversations: Dict[str, ConversationContext] = {} # AMA citation pattern for extracting author names and years self.citation_pattern = re.compile(r'([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+et\s+al\.?\s*\(?(\d{4})\)?') def process_query(self, query: str, session_id: Optional[str] = None, use_reranker: bool = True, top_k: int = 10) -> Dict[str, Any]: """ Process query with conversation context and enhanced retrieval. 1. Use hierarchical retrieval to get textbook truth (A1-A3) 2. Find supporting articles (A4) that augment the answer 3. Present only articles in citations to conceal hierarchy """ # Get or create conversation context if session_id and session_id in self.conversations: context = self.conversations[session_id] is_followup = True else: context = ConversationContext([], [], [], [], datetime.now()) if session_id: self.conversations[session_id] = context is_followup = False # Add context to query if it's a follow-up enhanced_query = query if is_followup: context_summary = context.get_context_summary() if context_summary: enhanced_query = f"{context_summary}\n\nCurrent question: {query}" # Phase 1: Hierarchical retrieval for truth (textbooks prioritized) # Get MORE results to ensure we find articles if is_followup and context.query_history: # Combine current and previous query for better retrieval combined_query = f"{context.query_history[-1]} {query}" primary_results = self.retriever.retrieve( combined_query, top_k=min(top_k * 5, 50), # Get more results to find articles use_reranker=use_reranker ) else: primary_results = self.retriever.retrieve( enhanced_query, top_k=min(top_k * 5, 50), # Retrieve 5x (up to 50) to ensure we get articles use_reranker=use_reranker ) logger.info(f"Retrieved {len(primary_results)} total results") # Separate textbook sources (A1-A3) from articles (A4) textbook_chunks = [] article_chunks = [] for result in primary_results: if result.authority_tier in ['A1', 'A2', 'A3']: textbook_chunks.append(result) else: article_chunks.append(result) # Use top textbook chunks as truth source truth_sources = textbook_chunks[:5] if textbook_chunks else primary_results[:5] logger.info(f"Found {len(textbook_chunks)} textbook chunks and {len(article_chunks)} article chunks") # Phase 2: Ensure we have articles for citation # If we don't have enough articles, search specifically for them if len(article_chunks) < 5: logger.info(f"Only {len(article_chunks)} articles found, searching for more...") # Search with different terms to find articles search_terms = ["fistula", "tracheoesophageal", "TE fistula", "airway stent", "esophageal stent"] for term in search_terms: if len(article_chunks) >= 5: break augment_query = f"{query} {term}" augment_results = self.retriever.retrieve( augment_query, top_k=20, use_reranker=False ) # Add any A4 articles we find seen_docs = {r.doc_id for r in article_chunks} for result in augment_results: if hasattr(result, 'authority_tier') and result.authority_tier == 'A4': if result.doc_id not in seen_docs: article_chunks.append(result) seen_docs.add(result.doc_id) if len(article_chunks) >= 10: break logger.info(f"After augmentation: {len(article_chunks)} articles found") # Phase 3: Generate response using textbook truth response = self._generate_response( query=query, truth_sources=truth_sources, supporting_articles=article_chunks[:10], context=context if is_followup else None ) # For follow-ups, combine previous and current article sources if is_followup and context.article_sources: # Combine previous and new articles, deduplicate by doc_id all_articles = list(context.article_sources) seen_ids = {a.doc_id for a in context.article_sources} for article in article_chunks[:10]: if article.doc_id not in seen_ids: all_articles.append(article) seen_ids.add(article.doc_id) article_sources_for_citations = all_articles[:15] else: article_sources_for_citations = article_chunks[:10] # Apply citation filter to hide textbooks and only show articles filtered_citations = filter_for_citation(article_sources_for_citations) # Phase 4: Use smart citation system to add numbered citations logger.info(f"Adding smart citations from {len(filtered_citations)} articles") # Insert citations intelligently based on content matching response_with_citations, citation_list = insert_smart_citations( response['response'], filtered_citations, max_citations=6 ) logger.info(f"Added {len(citation_list)} citations to response") # Format citations for display citations = [] for cite in citation_list: citations.append({ 'number': cite['number'], 'text': cite['text'], 'doc_id': cite.get('doc_id', '') }) # Update conversation context context.query_history.append(query) context.response_history.append(response['response']) context.retrieved_chunks.extend([r.__dict__ for r in primary_results[:5]]) # Store article sources for future citations if is_followup: # Add new unique articles existing_ids = {a.doc_id for a in context.article_sources} for article in article_chunks[:10]: if article.doc_id not in existing_ids: context.article_sources.append(article) else: context.article_sources = article_chunks[:10] return { 'query': query, 'response': response_with_citations, # Response with smart numbered citations 'citations': citations, # Only articles shown 'query_type': response.get('query_type', 'clinical'), 'confidence_score': response.get('confidence', 0.85), 'is_followup': is_followup, 'session_id': session_id, 'model_used': response.get('model_used', self.llm.model), 'safety_flags': response.get('safety_flags', []), 'needs_review': False } def _extract_author_name(self, doc_id: str) -> str: """Extract author name from doc_id for citations.""" # Try to extract author from doc_id patterns if 'miller' in doc_id.lower(): return "Miller" elif 'chan' in doc_id.lower(): return "Chan" elif 'ke' in doc_id.lower() and 'fistula' in doc_id.lower(): return "Ke" elif 'herth' in doc_id.lower(): return "Herth" elif 'green' in doc_id.lower(): return "Green" elif 'safety' in doc_id.lower() and 'efficacy' in doc_id.lower(): return "Research Group" else: # Try to extract first capitalized word words = re.findall(r'\b[A-Z][a-z]+\b', doc_id) for word in words: if word.lower() not in ['transbronchial', 'ablation', 'microwave', 'lung', 'safety']: return word return "Author" def _extract_key_concepts(self, textbook_chunks: List[Any]) -> List[str]: """Extract key medical concepts from textbook chunks.""" concepts = set() for chunk in textbook_chunks[:3]: text = chunk.text.lower() # Extract medical terms (simplified) # Look for capitalized medical terms, procedures, devices medical_terms = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', chunk.text) concepts.update(term for term in medical_terms if len(term) > 4) # Extract abbreviations abbrevs = re.findall(r'\b[A-Z]{2,}\b', chunk.text) concepts.update(abbrevs) # Extract specific medical patterns if 'ablation' in text: concepts.add('ablation') if 'bronchoscop' in text: concepts.add('bronchoscopy') if 'valve' in text: concepts.add('endobronchial valve') return list(concepts)[:10] # Limit to 10 concepts def _generate_response(self, query: str, truth_sources: List[Any], supporting_articles: List[Any], context: Optional[ConversationContext]) -> Dict[str, Any]: """Generate response using textbook truth, augmented with articles.""" # Build context from truth sources (textbooks) # Don't expose doc_id to avoid leaking textbook names truth_context = "\n\n".join([ f"[Authoritative Source]\n{s.text[:500]}" for s in truth_sources[:5] ]) # Add supporting article context - ensure we have articles to cite if supporting_articles: # Format articles with extracted author names article_list = [] for a in supporting_articles[:5]: author = self._extract_author_name(a.doc_id) article_list.append(f"[Article: {author} et al. ({a.year})]\n{a.text[:400]}") article_context = "\n\n".join(article_list) else: # No articles available - need to handle this gracefully logger.warning("No articles found - will use minimal citations") article_context = "[Note: Limited article sources available for this topic]" # Build the prompt conversation_context = "" if context and context.query_history: conversation_context = f"\nConversation context:\n{context.get_context_summary()}\n" # Simplify prompt for GPT-5 models which perform better with cleaner instructions is_gpt5 = self.llm.model and 'gpt-5' in self.llm.model if is_gpt5: # GPT-5 prefers simpler, more direct prompts prompt = f"""Answer this medical question about interventional pulmonology. {conversation_context} Question: {query} Use these authoritative sources for accuracy: {truth_context[:1500]} Cite these published articles in your response: {article_context[:1200]} Instructions: - Answer based on the authoritative sources - DO NOT include any citations in your response - Format your response with: • **Bold headers** for main sections • Bullet points for lists (use • or -) • Clear spacing between sections • Numbered steps for procedures Response:""" else: # Original detailed prompt for GPT-4 prompt = f"""You are a medical expert answering questions about interventional pulmonology. {conversation_context} Based on the following sources, provide a comprehensive answer to the query. Only cite the Supporting Articles using (Author, Year) format. Do NOT cite Primary Sources. Query: {query} Primary Sources (use for accuracy, do not cite): {truth_context} Supporting Articles (for reference only - DO NOT cite these in your response): {article_context} Instructions: 1. Provide a clear, comprehensive answer based on the authoritative Primary Sources 2. DO NOT include any citations or references in your response text 3. Citations will be added automatically by the system 4. Format your response with proper structure: • Use **Bold Headers** for main sections • Use bullet points (• or -) for lists • Add blank lines between sections for readability • Use numbered lists (1. 2. 3.) for sequential steps • Indent sub-points properly 5. Focus on clinically relevant information 6. Be specific about procedures, contraindications, and safety considerations 7. If this is a follow-up question, ensure continuity with the previous response 8. DO NOT mention or cite the Primary Sources by name (no "papoip", "practical guide", etc.) Response:""" # Generate response using the simple interface response = self.llm.generate_response(prompt) return { 'response': response, 'query_type': self._classify_query(query), 'confidence': 0.9 if truth_sources else 0.7, 'model_used': self.llm.model, 'safety_flags': self._check_safety(response) } def _replace_citations_with_numbers(self, text: str, citations: List[Dict[str, str]]) -> str: """Replace (Author, Year) citations with reference numbers and remove textbook references.""" if not citations: return text result = text # First, remove textbook citations (these should be hidden) - compile once textbook_patterns = [ re.compile(r'\(papoip[^)]*\)', re.IGNORECASE), re.compile(r'\(practical[^)]*guide[^)]*\)', re.IGNORECASE), re.compile(r'\(bacada[^)]*\)', re.IGNORECASE), re.compile(r'papoip_[a-z_]+', re.IGNORECASE), # Remove inline papoip references re.compile(r'practical_guide_[a-z_]+', re.IGNORECASE), re.compile(r'\(Study by Research Group[^)]*\)', re.IGNORECASE), # Remove generic fallback citations ] for pattern in textbook_patterns: result = pattern.sub('', result) # Then replace article citations with numbers for i, cite in enumerate(citations, 1): author = cite.get('author', '') year = cite.get('year', '') if not author or not year: continue # Create a single regex pattern for this citation citation_pattern = re.compile( rf'\(?\b{re.escape(author)}\b\s*(?:et\s+al\.?)?,?\s*\(?{re.escape(str(year))}\)?\)?', re.IGNORECASE ) result = citation_pattern.sub(f'[{i}]', result) # Clean up any double spaces left after removing citations result = re.sub(r'\s+', ' ', result) result = re.sub(r'\s+\.', '.', result) return result def _extract_article_citations(self, response_text: str, truth_sources: List[Any], article_sources: List[Any]) -> List[Dict[str, str]]: """ Extract citations from response and map to articles. Intelligently finds articles that cite the same studies mentioned in textbooks. Deduplicates citations based on doc_id. """ citations = [] seen_doc_ids = set() # Track doc_ids to avoid duplicates # Find all (Author, Year) citations in the response inline_citations = re.findall(r'\(([A-Z][a-z]+(?:\s+(?:et\s+al\.|[A-Z][a-z]+))*),?\s*(\d{4})\)', response_text) for author, year in inline_citations: # Try to find matching article matched = False # First, check articles directly for article in article_sources: article_text = article.text.lower() article_title = article.doc_id.lower() # Check if article matches the citation if (author.lower() in article_title or author.lower() in article_text[:500]) and \ str(year) in str(article.year): if article.doc_id not in seen_doc_ids: citations.append(self._format_ama_citation(article)) seen_doc_ids.add(article.doc_id) matched = True break # If not found, check if textbook references this study # and find an article that also references it if not matched: for source in truth_sources: if author.lower() in source.text.lower() and str(year) in source.text: # Find an article that discusses similar content for article in article_sources: if article.doc_id not in seen_doc_ids: if any(concept in article.text.lower() for concept in source.text.lower().split()[:20]): citations.append(self._format_ama_citation(article)) seen_doc_ids.add(article.doc_id) matched = True break if matched: break # If still not matched, use a relevant article if not matched and article_sources: # Use the most relevant article based on score for article in article_sources: if article.doc_id not in seen_doc_ids: citations.append(self._format_ama_citation(article)) seen_doc_ids.add(article.doc_id) break # Ensure we have at least 3-5 citations if articles are available # Always include diverse citations from different articles if article_sources: # Add top articles if not already included for article in article_sources[:5]: # Look at top 5 articles if len(citations) >= 5: break if article.doc_id not in seen_doc_ids: citations.append(self._format_ama_citation(article)) seen_doc_ids.add(article.doc_id) # Ensure minimum of 3 citations if we have the articles if len(citations) < 3: for article in article_sources: if len(citations) >= 3: break if article.doc_id not in seen_doc_ids: citations.append(self._format_ama_citation(article)) seen_doc_ids.add(article.doc_id) return citations def _format_ama_citation(self, article: Any) -> Dict[str, str]: """Format article in AMA citation style.""" doc_id = article.doc_id # Clean up the doc_id for better display # Remove file extensions and clean up formatting clean_title = doc_id.replace('.json', '').replace('.pdf', '').replace('_', ' ').replace('-', ' ') # Try to extract author name from the doc_id author = None # Look for author patterns in doc_id (e.g., "Miller-2024-...") author_match = re.match(r'^([A-Z][a-z]+)[\s\-_]\d{4}', doc_id) if author_match: author = author_match.group(1) # Check for specific known authors in the doc_id if not author: if 'miller' in doc_id.lower(): author = "Miller" elif 'chan' in doc_id.lower(): author = "Chan" elif 'herth' in doc_id.lower(): author = "Herth" elif 'green' in doc_id.lower(): author = "Green" elif 'safety' in doc_id.lower() and 'efficacy' in doc_id.lower(): author = "Research Group" else: # Extract first capitalized word that looks like a name words = re.findall(r'\b[A-Z][a-z]+\b', doc_id) for word in words: if word.lower() not in ['transbronchial', 'ablation', 'microwave', 'radiofrequency', 'lung', 'tumour', 'tumor', 'safety', 'efficacy', 'novel']: author = word break # Default author if none found if not author: author = "Study Group" # Create a cleaner title from the doc_id title_words = clean_title.split() # Remove the author and year if present at the beginning if title_words and title_words[0].lower() == author.lower(): title_words = title_words[1:] if title_words and title_words[0].isdigit() and len(title_words[0]) == 4: title_words = title_words[1:] # Capitalize appropriately clean_title = ' '.join(title_words).title() # Truncate if too long if len(clean_title) > 60: clean_title = clean_title[:57] + "..." # Format as AMA citation return { 'doc_id': doc_id, 'author': author, 'year': article.year, 'title': clean_title if clean_title else "Clinical Study", 'authority': 'A4', 'evidence': article.evidence_level, 'score': article.score, 'ama_format': f"{author} et al. {clean_title}. {article.year}" } def _classify_query(self, query: str) -> str: """Classify the query type.""" query_lower = query.lower() if any(term in query_lower for term in ['emergency', 'urgent', 'massive']): return 'emergency' elif any(term in query_lower for term in ['cpt', 'billing', 'coding', 'rvu']): return 'coding_billing' elif any(term in query_lower for term in ['contraindication', 'safety', 'risk']): return 'safety' elif any(term in query_lower for term in ['training', 'competency', 'fellowship']): return 'training' else: return 'clinical' def _check_safety(self, response: str) -> List[str]: """Check response for safety concerns.""" safety_flags = [] response_lower = response.lower() if 'contraindicated' in response_lower: safety_flags.append('Contains contraindications') if 'complication' in response_lower or 'risk' in response_lower: safety_flags.append('Discusses complications/risks') if 'emergency' in response_lower: safety_flags.append('Emergency situation mentioned') return safety_flags def clear_conversation(self, session_id: str): """Clear conversation history for a session.""" if session_id in self.conversations: del self.conversations[session_id]