import json import base64 from io import BytesIO from typing import List, Union from PIL import Image from urllib.parse import urlparse from diffusers.modular_pipelines.modular_pipeline_utils import ConfigSpec from huggingface_hub import InferenceClient from diffusers import ModularPipeline, ModularPipelineBlocks from diffusers.modular_pipelines import InputParam, OutputParam, PipelineState from diffusers.utils import logger FRAME_MULTIPLE = 4 def _encode_image_to_base64(img: Image.Image, max_size_mb: float = 3.5) -> str: """Helper function to encode a PIL Image to base64 data URI. Args: img: PIL Image object max_size_mb: Maximum size in MB for data URIs Returns: str: Base64 encoded data URI """ buffer = BytesIO() img.save(buffer, format="PNG") size_mb = len(buffer.getvalue()) / (1024 * 1024) if size_mb <= max_size_mb: img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") return f"data:image/png;base64,{img_str}" if img.mode not in ("RGB", "L"): img = img.convert("RGB") if size_mb > max_size_mb * 2: scale = (max_size_mb / size_mb) ** 0.5 new_size = (int(img.width * scale), int(img.height * scale)) img = img.resize(new_size, Image.Resampling.LANCZOS) buffer = BytesIO() img.save(buffer, format="JPEG", quality=85, optimize=True) img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") return f"data:image/jpeg;base64,{img_str}" def image_to_uri(image: Union[str, Image.Image], max_size_mb: float = 3.5) -> str: """Convert an image to a URI. Args: image: URL string, local file path string, or PIL Image object max_size_mb: Maximum size in MB for data URIs (default 3.5MB) Returns: str: URL if input is a URL, data URI otherwise """ if isinstance(image, Image.Image): return _encode_image_to_base64(image, max_size_mb) parsed = urlparse(image) if parsed.scheme in ("http", "https") and parsed.netloc: return image with Image.open(image) as img: return _encode_image_to_base64(img, max_size_mb) class ImageToMatrixGameAction(ModularPipelineBlocks): model_name = "MatrixGameWan" @property def inputs(self): return [ InputParam("image"), InputParam("num_frames"), InputParam("prompt"), ] @property def intermediate_outputs(self): return [OutputParam("actions")] @property def expected_configs(self) -> List[ConfigSpec]: return [ConfigSpec("model_id", default="Qwen/Qwen2.5-VL-72B-Instruct")] def __call__(self, components: ModularPipeline, state: PipelineState): client = InferenceClient() instructions = """ You will be provided an image and you have to interpret how you would move inside it if the image was in 3D space. Here are the available actions you can take: Movement Actions: forward, left, right Camera Actions: camera_l, camera_r You can also combine actions with an _ to create compound actions. e.g. forward_left_camera_l Each action is rendered for 12 frames, so make sure the number of actions suggested fits into the total number of frames available: {num_frames} e.g ["forward", "forward_left", "camera_l"] Here are additional instructions for you to follow: {prompt} Only respond with the list of actions you have to take and nothing else. """ block_state = self.get_block_state(state) image = block_state.image prompt = block_state.prompt or "" num_frames = block_state.num_frames instructions = instructions.format(prompt=prompt, num_frames=num_frames) try: user_message = [ { "role": "user", "content": [ { "type": "image_url", "image_url": {"url": image_to_uri(image)}, }, {"type": "text", "text": instructions}, ], } ] completion = client.chat.completions.create( model=components.model_id, messages=user_message, temperature=0.2, max_tokens=1000, ) content = completion.choices[0].message.content block_state.actions = json.loads(content) self.set_block_state(state, block_state) return components, state except Exception as e: logger.warning("Unable to generate actions. Defaulting to random actions") return components, state