File size: 2,251 Bytes
8c46fdc
 
 
 
 
 
 
 
3890bcb
 
8c46fdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import torch
from diffusers import DiffusionPipeline
from PIL import Image
from io import BytesIO
import base64

# Load the model once (caching for efficiency)
MODEL_ID = "jiteshdhamaniya/alimama-creative-FLUX.1-dev-Controlnet-Inpainting-Alpha"
CONTROLNET_MODEL = "jiteshdhamaniya/alimama-creative-FLUX.1-dev-Controlnet-Inpainting-Alpha"
TRANSFORMER_MODEL = "black-forest-labs/FLUX.1-dev"

controlnet = DiffusionPipeline.from_pretrained(CONTROLNET_MODEL, torch_dtype=torch.bfloat16)
transformer = DiffusionPipeline.from_pretrained(TRANSFORMER_MODEL, subfolder="transformer", torch_dtype=torch.bfloat16)

pipeline = DiffusionPipeline.from_pretrained(
    MODEL_ID,
    controlnet=controlnet,
    transformer=transformer,
    torch_dtype=torch.bfloat16
).to("cuda" if torch.cuda.is_available() else "cpu")

# Function to handle inference
def handle(inputs, context):
    try:
        # Parse inputs
        prompt = inputs.get("prompt", "default prompt text")
        control_image_base64 = inputs.get("control_image")
        mask_image_base64 = inputs.get("mask_image")
        num_inference_steps = inputs.get("num_inference_steps", 28)
        guidance_scale = inputs.get("guidance_scale", 3.5)
        controlnet_conditioning_scale = inputs.get("controlnet_conditioning_scale", 0.9)

        # Convert Base64 images to PIL format
        control_image = Image.open(BytesIO(base64.b64decode(control_image_base64))).convert("RGB")
        mask_image = Image.open(BytesIO(base64.b64decode(mask_image_base64))).convert("RGB")

        # Perform inference
        result = pipeline(
            prompt=prompt,
            control_image=control_image,
            control_mask=mask_image,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
        ).images[0]

        # Convert result to Base64 string
        buffered = BytesIO()
        result.save(buffered, format="PNG")
        result_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

        # Return the result
        return {"status": "success", "image": result_base64}

    except Exception as e:
        return {"status": "error", "message": str(e)}