dung-vpt-uney
Deploy latest CoRGI Gradio demo
9c4a163
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional, Protocol
import time
from PIL import Image
from .types import (
GroundedEvidence,
PromptLog,
ReasoningStep,
StageTiming,
evidences_to_serializable,
prompt_logs_to_serializable,
stage_timings_to_serializable,
steps_to_serializable,
)
class SupportsQwenClient(Protocol):
"""Protocol describing the methods required from a Qwen3-VL client."""
def structured_reasoning(self, image: Image.Image, question: str, max_steps: int) -> List[ReasoningStep]:
...
def extract_step_evidence(
self,
image: Image.Image,
question: str,
step: ReasoningStep,
max_regions: int,
) -> List[GroundedEvidence]:
...
def synthesize_answer(
self,
image: Image.Image,
question: str,
steps: List[ReasoningStep],
evidences: List[GroundedEvidence],
) -> str:
...
def reset_logs(self) -> None:
...
reasoning_log: Optional[PromptLog]
grounding_logs: List[PromptLog]
answer_log: Optional[PromptLog]
@dataclass(frozen=True)
class PipelineResult:
"""Aggregated output of the CoRGI pipeline."""
question: str
steps: List[ReasoningStep]
evidence: List[GroundedEvidence]
answer: str
reasoning_log: Optional[PromptLog] = None
grounding_logs: List[PromptLog] = field(default_factory=list)
answer_log: Optional[PromptLog] = None
timings: List[StageTiming] = field(default_factory=list)
total_duration_ms: float = 0.0
def to_json(self) -> dict:
payload = {
"question": self.question,
"steps": steps_to_serializable(self.steps),
"evidence": evidences_to_serializable(self.evidence),
"answer": self.answer,
"total_duration_ms": self.total_duration_ms,
}
reasoning_entries = (
prompt_logs_to_serializable([self.reasoning_log]) if self.reasoning_log else []
)
if reasoning_entries:
payload["reasoning_log"] = reasoning_entries[0]
payload["grounding_logs"] = prompt_logs_to_serializable(self.grounding_logs)
payload["timings"] = stage_timings_to_serializable(self.timings)
answer_entries = prompt_logs_to_serializable([self.answer_log]) if self.answer_log else []
if answer_entries:
payload["answer_log"] = answer_entries[0]
return payload
class CoRGIPipeline:
"""Orchestrates the CoRGI reasoning pipeline using a Qwen3-VL client."""
def __init__(self, vlm_client: SupportsQwenClient):
if vlm_client is None:
raise ValueError("A Qwen3-VL client instance must be provided.")
self._vlm = vlm_client
def run(
self,
image: Image.Image,
question: str,
max_steps: int = 3,
max_regions: int = 3,
) -> PipelineResult:
self._vlm.reset_logs()
timings: List[StageTiming] = []
total_start = time.monotonic()
reasoning_start = time.monotonic()
steps = self._vlm.structured_reasoning(image=image, question=question, max_steps=max_steps)
reasoning_duration = (time.monotonic() - reasoning_start) * 1000.0
timings.append(StageTiming(name="structured_reasoning", duration_ms=reasoning_duration))
evidences: List[GroundedEvidence] = []
for step in steps:
if not step.needs_vision:
continue
stage_name = f"roi_step_{step.index}"
grounding_start = time.monotonic()
step_evs = self._vlm.extract_step_evidence(
image=image,
question=question,
step=step,
max_regions=max_regions,
)
grounding_duration = (time.monotonic() - grounding_start) * 1000.0
timings.append(StageTiming(name=stage_name, duration_ms=grounding_duration, step_index=step.index))
if not step_evs:
continue
evidences.extend(step_evs[:max_regions])
answer_start = time.monotonic()
answer = self._vlm.synthesize_answer(image=image, question=question, steps=steps, evidences=evidences)
answer_duration = (time.monotonic() - answer_start) * 1000.0
timings.append(StageTiming(name="answer_synthesis", duration_ms=answer_duration))
total_duration = (time.monotonic() - total_start) * 1000.0
timings.append(StageTiming(name="total_pipeline", duration_ms=total_duration))
return PipelineResult(
question=question,
steps=steps,
evidence=evidences,
answer=answer,
reasoning_log=self._vlm.reasoning_log,
grounding_logs=list(self._vlm.grounding_logs),
answer_log=self._vlm.answer_log,
timings=timings,
total_duration_ms=total_duration,
)
__all__ = ["CoRGIPipeline", "PipelineResult"]