File size: 9,048 Bytes
b6a01d6
 
 
 
58fe08c
b6a01d6
 
 
 
 
58fe08c
 
c3e1463
 
 
 
 
b6a01d6
fe542a6
b6a01d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b34598e
 
 
 
c3e1463
 
 
 
 
 
b34598e
 
 
 
 
 
 
 
 
 
 
58fe08c
 
 
 
 
 
 
 
 
 
 
 
 
b34598e
 
 
58fe08c
b34598e
58fe08c
b34598e
 
 
 
 
 
b6a01d6
 
c1c7f1e
b6a01d6
 
 
 
 
 
 
 
 
 
 
 
 
b34598e
fe542a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6a01d6
c3e1463
b6a01d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe542a6
b6a01d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe542a6
 
 
b6a01d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe542a6
b6a01d6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional
import logging

import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor

logger = logging.getLogger(__name__)

try:
    import spaces  # type: ignore
except ImportError:  # pragma: no cover - only available on HF Spaces
    spaces = None  # type: ignore

from .parsers import parse_roi_evidence, parse_structured_reasoning
from .types import GroundedEvidence, PromptLog, ReasoningStep


DEFAULT_REASONING_PROMPT = (
    "You are a careful multimodal reasoner following the CoRGI protocol. "
    "Given the question and the image, produce a JSON array of reasoning steps. "
    "Each item must contain the keys: index (1-based integer), statement (concise sentence), "
    "needs_vision (boolean true if the statement requires visual verification), and reason "
    "(short phrase explaining why visual verification is or is not required). "
    "Limit the number of steps to {max_steps}. Respond with JSON only; start the reply with '[' and end with ']'. "
    "Do not add any commentary or prose outside of the JSON."
)

DEFAULT_GROUNDING_PROMPT = (
    "You are validating the following reasoning step:\n"
    "{step_statement}\n"
    "Return a JSON array with up to {max_regions} region candidates that help verify the step. "
    "Each object must include: step (integer), bbox (list of four numbers x1,y1,x2,y2, "
    "either normalized 0-1 or scaled 0-1000), description (short textual evidence), "
    "and confidence (0-1). Use [] if no relevant region exists. "
    "Respond with JSON only; do not include explanations outside the JSON array."
)

DEFAULT_ANSWER_PROMPT = (
    "You are finalizing the answer using verified evidence. "
    "Question: {question}\n"
    "Structured reasoning steps:\n"
    "{steps}\n"
    "Verified evidence items:\n"
    "{evidence}\n"
    "Respond with a concise final answer sentence grounded in the evidence. "
    "If unsure, say you are uncertain. Do not include <think> tags or internal monologue."
)


def _format_steps_for_prompt(steps: List[ReasoningStep]) -> str:
    return "\n".join(
        f"{step.index}. {step.statement} (needs vision: {step.needs_vision})"
        for step in steps
    )


def _format_evidence_for_prompt(evidences: List[GroundedEvidence]) -> str:
    if not evidences:
        return "No evidence collected."
    lines = []
    for ev in evidences:
        desc = ev.description or "No description"
        bbox = ", ".join(f"{coord:.2f}" for coord in ev.bbox)
        conf = f"{ev.confidence:.2f}" if ev.confidence is not None else "n/a"
        lines.append(f"Step {ev.step_index}: bbox=({bbox}), conf={conf}, desc={desc}")
    return "\n".join(lines)


def _strip_think_content(text: str) -> str:
    if not text:
        return ""
    cleaned = text
    if "</think>" in cleaned:
        cleaned = cleaned.split("</think>", 1)[-1]
    cleaned = cleaned.replace("<think>", "")
    return cleaned.strip()


_MODEL_CACHE: dict[str, AutoModelForImageTextToText] = {}
_PROCESSOR_CACHE: dict[str, AutoProcessor] = {}


def _gpu_decorator(duration: int = 120):
    if spaces is None:
        return lambda fn: fn
    return spaces.GPU(duration=duration)


def _ensure_cuda(model: AutoModelForImageTextToText) -> AutoModelForImageTextToText:
    if torch.cuda.is_available():
        target_device = torch.device("cuda")
        current_device = next(model.parameters()).device
        if current_device.type != target_device.type:
            model.to(target_device)
    return model


def _load_backend(model_id: str) -> tuple[AutoModelForImageTextToText, AutoProcessor]:
    if model_id not in _MODEL_CACHE:
        # Check if hardware supports bfloat16
        if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
            torch_dtype = torch.bfloat16
            logger.info("Using bfloat16 (hardware supported)")
        elif torch.cuda.is_available():
            torch_dtype = torch.float16  # Fallback to float16 if bfloat16 not supported
            logger.info("Using float16 (bfloat16 not supported on this GPU)")
        else:
            torch_dtype = torch.float32
            logger.info("Using float32 (CPU mode)")
        
        # Use single GPU (cuda:0) instead of auto to avoid model sharding across multiple GPUs
        device_map = "cuda:0" if torch.cuda.is_available() else "cpu"
        model = AutoModelForImageTextToText.from_pretrained(
            model_id,
            torch_dtype=torch_dtype,
            device_map=device_map,
        )
        model = model.eval()
        processor = AutoProcessor.from_pretrained(model_id)
        _MODEL_CACHE[model_id] = model
        _PROCESSOR_CACHE[model_id] = processor
    return _MODEL_CACHE[model_id], _PROCESSOR_CACHE[model_id]


@dataclass
class QwenGenerationConfig:
    model_id: str = "Qwen/Qwen3-VL-2B-Instruct"
    max_new_tokens: int = 512
    temperature: float | None = None
    do_sample: bool = False


class Qwen3VLClient:
    """Wrapper around transformers Qwen3-VL chat API for CoRGI pipeline."""

    def __init__(
        self,
        config: Optional[QwenGenerationConfig] = None,
    ) -> None:
        self.config = config or QwenGenerationConfig()
        self._model, self._processor = _load_backend(self.config.model_id)
        self.reset_logs()

    def reset_logs(self) -> None:
        self._reasoning_log: Optional[PromptLog] = None
        self._grounding_logs: List[PromptLog] = []
        self._answer_log: Optional[PromptLog] = None

    @property
    def reasoning_log(self) -> Optional[PromptLog]:
        return self._reasoning_log

    @property
    def grounding_logs(self) -> List[PromptLog]:
        return list(self._grounding_logs)

    @property
    def answer_log(self) -> Optional[PromptLog]:
        return self._answer_log

    @_gpu_decorator()
    def _chat(
        self,
        image: Image.Image,
        prompt: str,
        max_new_tokens: Optional[int] = None,
    ) -> str:
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": prompt},
                ],
            }
        ]
        chat_prompt = self._processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False,
        )
        inputs = self._processor(
            text=[chat_prompt],
            images=[image],
            return_tensors="pt",
        ).to(self._model.device)
        gen_kwargs = {
            "max_new_tokens": max_new_tokens or self.config.max_new_tokens,
            "do_sample": self.config.do_sample,
        }
        if self.config.do_sample and self.config.temperature is not None:
            gen_kwargs["temperature"] = self.config.temperature
        output_ids = self._model.generate(**inputs, **gen_kwargs)
        prompt_length = inputs.input_ids.shape[1]
        generated_tokens = output_ids[:, prompt_length:]
        response = self._processor.batch_decode(
            generated_tokens,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )[0]
        return response.strip()

    def structured_reasoning(self, image: Image.Image, question: str, max_steps: int) -> List[ReasoningStep]:
        prompt = DEFAULT_REASONING_PROMPT.format(max_steps=max_steps) + f"\nQuestion: {question}"
        response = self._chat(image=image, prompt=prompt)
        self._reasoning_log = PromptLog(prompt=prompt, response=response, stage="reasoning")
        return parse_structured_reasoning(response, max_steps=max_steps)

    def extract_step_evidence(
        self,
        image: Image.Image,
        question: str,
        step: ReasoningStep,
        max_regions: int,
    ) -> List[GroundedEvidence]:
        prompt = DEFAULT_GROUNDING_PROMPT.format(
            step_statement=step.statement,
            max_regions=max_regions,
        )
        response = self._chat(image=image, prompt=prompt, max_new_tokens=256)
        evidences = parse_roi_evidence(response, default_step_index=step.index)
        self._grounding_logs.append(
            PromptLog(prompt=prompt, response=response, step_index=step.index, stage="grounding")
        )
        return evidences[:max_regions]

    def synthesize_answer(
        self,
        image: Image.Image,
        question: str,
        steps: List[ReasoningStep],
        evidences: List[GroundedEvidence],
    ) -> str:
        prompt = DEFAULT_ANSWER_PROMPT.format(
            question=question,
            steps=_format_steps_for_prompt(steps),
            evidence=_format_evidence_for_prompt(evidences),
        )
        response = self._chat(image=image, prompt=prompt, max_new_tokens=256)
        self._answer_log = PromptLog(prompt=prompt, response=response, stage="synthesis")
        return _strip_think_content(response)


__all__ = ["Qwen3VLClient", "QwenGenerationConfig"]