Spaces:
Running
Running
| """ | |
| Modal serverless GPU client for Flux/SDXL image generation. | |
| Handles: Open-source image generation models on demand. | |
| """ | |
| import os | |
| from typing import Optional | |
| import httpx | |
| import base64 | |
| # Note: For full Modal integration, you'd deploy a Modal app. | |
| # This client calls a deployed Modal endpoint or falls back to HuggingFace. | |
| class ModalFluxClient: | |
| """Modal-powered Flux/SDXL image generation for Pip.""" | |
| # New HuggingFace router API (the old api-inference.huggingface.co is deprecated) | |
| HF_ROUTER_URL = "https://router.huggingface.co" | |
| # Router endpoints for different models via fal.ai | |
| ROUTER_ENDPOINTS = { | |
| "flux": "/fal-ai/fal-ai/flux/schnell", | |
| "flux_dev": "/fal-ai/fal-ai/flux/dev", | |
| } | |
| # Legacy models (for Modal deployment) | |
| MODELS = { | |
| "flux": "black-forest-labs/FLUX.1-schnell", | |
| "sdxl_lightning": "ByteDance/SDXL-Lightning", | |
| "sdxl": "stabilityai/stable-diffusion-xl-base-1.0", | |
| } | |
| def __init__(self): | |
| self.hf_token = os.getenv("HF_TOKEN") | |
| self.modal_endpoint = os.getenv("MODAL_FLUX_ENDPOINT") # If deployed | |
| self.available = bool(self.hf_token) or bool(self.modal_endpoint) | |
| if not self.available: | |
| print("⚠️ HuggingFace/Modal: No tokens found - image generation limited") | |
| def is_available(self) -> bool: | |
| """Check if the client is available.""" | |
| return self.available | |
| async def generate_image( | |
| self, | |
| prompt: str, | |
| model: str = "flux" | |
| ) -> Optional[str]: | |
| """ | |
| Generate image using Flux or SDXL via Modal/HuggingFace Router. | |
| Returns base64 encoded image. | |
| """ | |
| # Try Modal endpoint first if available | |
| if self.modal_endpoint: | |
| result = await self._generate_via_modal(prompt, model) | |
| if result: | |
| return result | |
| # Try new HuggingFace router API (primary method) | |
| result = await self._generate_via_hf_router(prompt, model) | |
| if result: | |
| return result | |
| # Final fallback - return None | |
| print(f"All image generation methods failed for model: {model}") | |
| return None | |
| async def _generate_via_modal(self, prompt: str, model: str) -> Optional[str]: | |
| """ | |
| Call deployed Modal function for image generation. | |
| """ | |
| try: | |
| async with httpx.AsyncClient(timeout=60.0) as client: | |
| response = await client.post( | |
| self.modal_endpoint, | |
| json={"prompt": prompt, "model": model} | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| return data.get("image_base64") | |
| except Exception as e: | |
| print(f"Modal generation error: {e}") | |
| return None | |
| async def _generate_via_hf_router( | |
| self, | |
| prompt: str, | |
| model: str = "flux" | |
| ) -> Optional[str]: | |
| """ | |
| Generate image via new HuggingFace Router API (fal.ai backend). | |
| This is the current working method as of 2025. | |
| """ | |
| try: | |
| # Get router endpoint for model | |
| endpoint = self.ROUTER_ENDPOINTS.get(model, self.ROUTER_ENDPOINTS["flux"]) | |
| url = f"{self.HF_ROUTER_URL}{endpoint}" | |
| headers = {} | |
| if self.hf_token: | |
| headers["Authorization"] = f"Bearer {self.hf_token}" | |
| # New API uses 'prompt' not 'inputs' | |
| payload = {"prompt": prompt} | |
| async with httpx.AsyncClient(timeout=120.0) as client: | |
| response = await client.post( | |
| url, | |
| headers=headers, | |
| json=payload | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| # New format returns {"images": [{"url": "...", "content_type": "..."}], ...} | |
| if "images" in data and data["images"]: | |
| image_info = data["images"][0] | |
| # Image could be URL or base64 | |
| if isinstance(image_info, dict): | |
| if "url" in image_info: | |
| # Download image from URL and convert to base64 | |
| img_response = await client.get(image_info["url"]) | |
| if img_response.status_code == 200: | |
| return base64.b64encode(img_response.content).decode('utf-8') | |
| elif "b64_json" in image_info: | |
| return image_info["b64_json"] | |
| elif isinstance(image_info, str): | |
| # Direct base64 string | |
| return image_info | |
| print(f"HF Router unexpected response format: {list(data.keys())}") | |
| else: | |
| print(f"HF Router API error: {response.status_code} - {response.text[:200]}") | |
| except Exception as e: | |
| print(f"HF Router generation error: {e}") | |
| return None | |
| async def generate_fast(self, prompt: str) -> Optional[str]: | |
| """ | |
| Use fastest available model (SDXL-Lightning). | |
| """ | |
| return await self.generate_image(prompt, model="sdxl_lightning") | |
| async def generate_artistic(self, prompt: str) -> Optional[str]: | |
| """ | |
| Use Flux for more artistic, dreamlike results. | |
| """ | |
| return await self.generate_image(prompt, model="flux") | |
| # Modal app definition for deployment (optional) | |
| # Run with: modal deploy services/modal_flux.py | |
| MODAL_APP_CODE = ''' | |
| import modal | |
| app = modal.App("pip-flux-generator") | |
| # Define the image with required dependencies | |
| flux_image = modal.Image.debian_slim().pip_install( | |
| "diffusers", | |
| "transformers", | |
| "accelerate", | |
| "torch", | |
| "safetensors" | |
| ) | |
| @app.function( | |
| image=flux_image, | |
| gpu="A10G", | |
| timeout=300, | |
| ) | |
| def generate_flux_image(prompt: str) -> bytes: | |
| """Generate image using Flux on Modal GPU.""" | |
| import torch | |
| from diffusers import FluxPipeline | |
| pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-schnell", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.to("cuda") | |
| image = pipe( | |
| prompt, | |
| guidance_scale=0.0, | |
| num_inference_steps=4, | |
| max_sequence_length=256, | |
| ).images[0] | |
| # Convert to bytes | |
| import io | |
| buf = io.BytesIO() | |
| image.save(buf, format="PNG") | |
| return buf.getvalue() | |
| @app.local_entrypoint() | |
| def main(prompt: str = "a serene lake at sunset"): | |
| image_bytes = generate_flux_image.remote(prompt) | |
| with open("output.png", "wb") as f: | |
| f.write(image_bytes) | |
| print("Image saved to output.png") | |
| ''' | |