File size: 6,940 Bytes
cd35cc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd40a43
 
 
 
 
 
 
 
cd35cc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""
Modal serverless GPU client for Flux/SDXL image generation.
Handles: Open-source image generation models on demand.
"""

import os
from typing import Optional
import httpx
import base64

# Note: For full Modal integration, you'd deploy a Modal app.
# This client calls a deployed Modal endpoint or falls back to HuggingFace.


class ModalFluxClient:
    """Modal-powered Flux/SDXL image generation for Pip."""
    
    # New HuggingFace router API (the old api-inference.huggingface.co is deprecated)
    HF_ROUTER_URL = "https://router.huggingface.co"
    
    # Router endpoints for different models via fal.ai
    ROUTER_ENDPOINTS = {
        "flux": "/fal-ai/fal-ai/flux/schnell",
        "flux_dev": "/fal-ai/fal-ai/flux/dev",
    }
    
    # Legacy models (for Modal deployment)
    MODELS = {
        "flux": "black-forest-labs/FLUX.1-schnell",
        "sdxl_lightning": "ByteDance/SDXL-Lightning",
        "sdxl": "stabilityai/stable-diffusion-xl-base-1.0",
    }
    
    def __init__(self):
        self.hf_token = os.getenv("HF_TOKEN")
        self.modal_endpoint = os.getenv("MODAL_FLUX_ENDPOINT")  # If deployed
        self.available = bool(self.hf_token) or bool(self.modal_endpoint)
        
        if not self.available:
            print("⚠️ HuggingFace/Modal: No tokens found - image generation limited")
    
    def is_available(self) -> bool:
        """Check if the client is available."""
        return self.available
    
    async def generate_image(
        self, 
        prompt: str, 
        model: str = "flux"
    ) -> Optional[str]:
        """
        Generate image using Flux or SDXL via Modal/HuggingFace Router.
        Returns base64 encoded image.
        """
        # Try Modal endpoint first if available
        if self.modal_endpoint:
            result = await self._generate_via_modal(prompt, model)
            if result:
                return result
        
        # Try new HuggingFace router API (primary method)
        result = await self._generate_via_hf_router(prompt, model)
        if result:
            return result
        
        # Final fallback - return None
        print(f"All image generation methods failed for model: {model}")
        return None
    
    async def _generate_via_modal(self, prompt: str, model: str) -> Optional[str]:
        """
        Call deployed Modal function for image generation.
        """
        try:
            async with httpx.AsyncClient(timeout=60.0) as client:
                response = await client.post(
                    self.modal_endpoint,
                    json={"prompt": prompt, "model": model}
                )
                if response.status_code == 200:
                    data = response.json()
                    return data.get("image_base64")
        except Exception as e:
            print(f"Modal generation error: {e}")
        return None
    
    async def _generate_via_hf_router(
        self, 
        prompt: str, 
        model: str = "flux"
    ) -> Optional[str]:
        """
        Generate image via new HuggingFace Router API (fal.ai backend).
        This is the current working method as of 2025.
        """
        try:
            # Get router endpoint for model
            endpoint = self.ROUTER_ENDPOINTS.get(model, self.ROUTER_ENDPOINTS["flux"])
            url = f"{self.HF_ROUTER_URL}{endpoint}"
            
            headers = {}
            if self.hf_token:
                headers["Authorization"] = f"Bearer {self.hf_token}"
            
            # New API uses 'prompt' not 'inputs'
            payload = {"prompt": prompt}
            
            async with httpx.AsyncClient(timeout=120.0) as client:
                response = await client.post(
                    url,
                    headers=headers,
                    json=payload
                )
                
                if response.status_code == 200:
                    data = response.json()
                    # New format returns {"images": [{"url": "...", "content_type": "..."}], ...}
                    if "images" in data and data["images"]:
                        image_info = data["images"][0]
                        # Image could be URL or base64
                        if isinstance(image_info, dict):
                            if "url" in image_info:
                                # Download image from URL and convert to base64
                                img_response = await client.get(image_info["url"])
                                if img_response.status_code == 200:
                                    return base64.b64encode(img_response.content).decode('utf-8')
                            elif "b64_json" in image_info:
                                return image_info["b64_json"]
                        elif isinstance(image_info, str):
                            # Direct base64 string
                            return image_info
                    print(f"HF Router unexpected response format: {list(data.keys())}")
                else:
                    print(f"HF Router API error: {response.status_code} - {response.text[:200]}")
        except Exception as e:
            print(f"HF Router generation error: {e}")
        return None
    
    async def generate_fast(self, prompt: str) -> Optional[str]:
        """
        Use fastest available model (SDXL-Lightning).
        """
        return await self.generate_image(prompt, model="sdxl_lightning")
    
    async def generate_artistic(self, prompt: str) -> Optional[str]:
        """
        Use Flux for more artistic, dreamlike results.
        """
        return await self.generate_image(prompt, model="flux")


# Modal app definition for deployment (optional)
# Run with: modal deploy services/modal_flux.py

MODAL_APP_CODE = '''
import modal

app = modal.App("pip-flux-generator")

# Define the image with required dependencies
flux_image = modal.Image.debian_slim().pip_install(
    "diffusers",
    "transformers",
    "accelerate",
    "torch",
    "safetensors"
)

@app.function(
    image=flux_image,
    gpu="A10G",
    timeout=300,
)
def generate_flux_image(prompt: str) -> bytes:
    """Generate image using Flux on Modal GPU."""
    import torch
    from diffusers import FluxPipeline
    
    pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-schnell",
        torch_dtype=torch.bfloat16
    )
    pipe.to("cuda")
    
    image = pipe(
        prompt,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
    ).images[0]
    
    # Convert to bytes
    import io
    buf = io.BytesIO()
    image.save(buf, format="PNG")
    return buf.getvalue()

@app.local_entrypoint()
def main(prompt: str = "a serene lake at sunset"):
    image_bytes = generate_flux_image.remote(prompt)
    with open("output.png", "wb") as f:
        f.write(image_bytes)
    print("Image saved to output.png")
'''