from __future__ import annotations from dataclasses import dataclass, field from typing import List, Optional, Protocol import time from PIL import Image from .types import ( GroundedEvidence, PromptLog, ReasoningStep, StageTiming, evidences_to_serializable, prompt_logs_to_serializable, stage_timings_to_serializable, steps_to_serializable, ) class SupportsQwenClient(Protocol): """Protocol describing the methods required from a Qwen3-VL client.""" def structured_reasoning(self, image: Image.Image, question: str, max_steps: int) -> List[ReasoningStep]: ... def extract_step_evidence( self, image: Image.Image, question: str, step: ReasoningStep, max_regions: int, ) -> List[GroundedEvidence]: ... def synthesize_answer( self, image: Image.Image, question: str, steps: List[ReasoningStep], evidences: List[GroundedEvidence], ) -> str: ... def reset_logs(self) -> None: ... reasoning_log: Optional[PromptLog] grounding_logs: List[PromptLog] answer_log: Optional[PromptLog] @dataclass(frozen=True) class PipelineResult: """Aggregated output of the CoRGI pipeline.""" question: str steps: List[ReasoningStep] evidence: List[GroundedEvidence] answer: str reasoning_log: Optional[PromptLog] = None grounding_logs: List[PromptLog] = field(default_factory=list) answer_log: Optional[PromptLog] = None timings: List[StageTiming] = field(default_factory=list) total_duration_ms: float = 0.0 def to_json(self) -> dict: payload = { "question": self.question, "steps": steps_to_serializable(self.steps), "evidence": evidences_to_serializable(self.evidence), "answer": self.answer, "total_duration_ms": self.total_duration_ms, } reasoning_entries = ( prompt_logs_to_serializable([self.reasoning_log]) if self.reasoning_log else [] ) if reasoning_entries: payload["reasoning_log"] = reasoning_entries[0] payload["grounding_logs"] = prompt_logs_to_serializable(self.grounding_logs) payload["timings"] = stage_timings_to_serializable(self.timings) answer_entries = prompt_logs_to_serializable([self.answer_log]) if self.answer_log else [] if answer_entries: payload["answer_log"] = answer_entries[0] return payload class CoRGIPipeline: """Orchestrates the CoRGI reasoning pipeline using a Qwen3-VL client.""" def __init__(self, vlm_client: SupportsQwenClient): if vlm_client is None: raise ValueError("A Qwen3-VL client instance must be provided.") self._vlm = vlm_client def run( self, image: Image.Image, question: str, max_steps: int = 3, max_regions: int = 3, ) -> PipelineResult: self._vlm.reset_logs() timings: List[StageTiming] = [] total_start = time.monotonic() reasoning_start = time.monotonic() steps = self._vlm.structured_reasoning(image=image, question=question, max_steps=max_steps) reasoning_duration = (time.monotonic() - reasoning_start) * 1000.0 timings.append(StageTiming(name="structured_reasoning", duration_ms=reasoning_duration)) evidences: List[GroundedEvidence] = [] for step in steps: if not step.needs_vision: continue stage_name = f"roi_step_{step.index}" grounding_start = time.monotonic() step_evs = self._vlm.extract_step_evidence( image=image, question=question, step=step, max_regions=max_regions, ) grounding_duration = (time.monotonic() - grounding_start) * 1000.0 timings.append(StageTiming(name=stage_name, duration_ms=grounding_duration, step_index=step.index)) if not step_evs: continue evidences.extend(step_evs[:max_regions]) answer_start = time.monotonic() answer = self._vlm.synthesize_answer(image=image, question=question, steps=steps, evidences=evidences) answer_duration = (time.monotonic() - answer_start) * 1000.0 timings.append(StageTiming(name="answer_synthesis", duration_ms=answer_duration)) total_duration = (time.monotonic() - total_start) * 1000.0 timings.append(StageTiming(name="total_pipeline", duration_ms=total_duration)) return PipelineResult( question=question, steps=steps, evidence=evidences, answer=answer, reasoning_log=self._vlm.reasoning_log, grounding_logs=list(self._vlm.grounding_logs), answer_log=self._vlm.answer_log, timings=timings, total_duration_ms=total_duration, ) __all__ = ["CoRGIPipeline", "PipelineResult"]