import os import gc from typing import List, Tuple, Dict import json import spaces import traceback import torch import gradio as gr from PIL import Image from model.transformer_flux import FluxTransformer2DModelwithSliderConditioning # from diffusers import FluxTransformer2DModel from model.sliders_model import SliderProjector, SliderProjector_wo_clip from model.sliders_pipeline import FluxKontextSliderPipeline from huggingface_hub import login, snapshot_download HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: # Auth for this process (does not print or persist the token in your logs) login(token=HF_TOKEN) # ----------------------------- # Environment & device # ----------------------------- # Avoid meta-tensor init from environment leftovers os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None) # ----------------------------- # Model / pipeline loading # ----------------------------- def _log(msg): print(msg, flush=True) # def load_pipeline_single_gpu(): # global PIPELINE # if PIPELINE is not None: # _log("[worker] PIPELINE already initialized; skipping.") # return "warm" # try: # os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None) # token = os.environ.get("HF_TOKEN") # cuda_ok = torch.cuda.is_available() # _log(f"[worker] cuda available: {cuda_ok}") # if cuda_ok: # torch.backends.cudnn.benchmark = True # # ---------- config ---------- # pretrained = "black-forest-labs/FLUX.1-Kontext-dev" # trained_models_path = "./model_weights/" # projector_path = os.path.join(trained_models_path, "slider_projector.pth") # offload_dir = "/tmp/offload"; os.makedirs(offload_dir, exist_ok=True) # if not os.path.isdir(trained_models_path): # return f"error: missing dir {trained_models_path}" # if not os.path.isfile(projector_path): # return f"error: missing projector weights at {projector_path}" # # dtype selection to cut memory # if cuda_ok and torch.cuda.get_device_capability(0)[0] >= 8: # dtype = torch.bfloat16 # elif cuda_ok: # dtype = torch.float16 # else: # dtype = torch.float32 # max_memory = {"cuda": "80GiB", "cpu": "60GiB"} # tune if needed # _log("[worker] loading transformer (sharded/offloaded)…") # transformer = FluxTransformer2DModelwithSliderConditioning.from_pretrained( # pretrained, # subfolder="transformer", # token=token, # trust_remote_code=True, # torch_dtype=dtype, # low_cpu_mem_usage=True, # # device_map="balanced_low_0", # offload_folder=offload_dir, # offload_state_dict=True, # # max_memory=max_memory, # ) # weight_dtype = transformer.dtype # _log(f"[worker] transformer loaded, dtype={weight_dtype}") # _log("[worker] building slider projector…") # slider_projector = SliderProjector(out_dim=6144, pe_dim=2, n_layers=4, is_clip_input=True) # slider_projector.eval() # _log("[worker] loading projector weights…") # state_dict = torch.load(projector_path, map_location="cpu", weights_only=True) # slider_projector.load_state_dict(state_dict, strict=True) # _log("[worker] assembling pipeline (sharded/offloaded)…") # pipe = FluxKontextSliderPipeline.from_pretrained( # pretrained, # token=token, # trust_remote_code=True, # transformer=transformer, # slider_projector=slider_projector, # torch_dtype=weight_dtype, # low_cpu_mem_usage=True, # # device_map="balanced_low_0", # offload_folder=offload_dir, # offload_state_dict=True, # # max_memory=max_memory, # ) # _log("[worker] pipeline assembled.") # _log(f"[worker] loading LoRA from: {trained_models_path}") # pipe.load_lora_weights(trained_models_path) # _log("[worker] LoRA loaded.") # # DO NOT pipe.to("cuda") here; keep auto device_map to avoid OOM # PIPELINE = pipe # if cuda_ok: # free, total = torch.cuda.mem_get_info() # _log(f"[worker] VRAM free/total: {free/1e9:.2f}/{total/1e9:.2f} GB") # _log("[worker] PIPELINE ready.") # return "ok" # except Exception: # _log("[worker] init exception:\n" + traceback.format_exc()) # return "error" # ----------------------------- # Loading the pipeline without any function so that it will be called directly in the inference # ----------------------------- os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None) token = os.environ.get("HF_TOKEN") cuda_ok = torch.cuda.is_available() _log(f"[worker] cuda available: {cuda_ok}") if cuda_ok: torch.backends.cudnn.benchmark = True # ---------- config ---------- pretrained = "black-forest-labs/FLUX.1-Kontext-dev" trained_models_path = "./model_weights/" projector_path = os.path.join(trained_models_path, "slider_projector.pth") offload_dir = "/tmp/offload"; os.makedirs(offload_dir, exist_ok=True) # dtype selection to cut memory if cuda_ok and torch.cuda.get_device_capability(0)[0] >= 8: dtype = torch.bfloat16 elif cuda_ok: dtype = torch.float16 else: dtype = torch.float32 max_memory = {"cuda": "80GiB", "cpu": "60GiB"} # tune if needed _log("[worker] loading transformer (sharded/offloaded)…") transformer = FluxTransformer2DModelwithSliderConditioning.from_pretrained( pretrained, subfolder="transformer", token=token, trust_remote_code=True, # torch_dtype=dtype, # low_cpu_mem_usage=True, # device_map="balanced_low_0", # offload_folder=offload_dir, # offload_state_dict=True, # max_memory=max_memory, ) weight_dtype = transformer.dtype _log(f"[worker] transformer loaded, dtype={weight_dtype}") _log("[worker] building slider projector…") slider_projector = SliderProjector(out_dim=6144, pe_dim=2, n_layers=4, is_clip_input=True) slider_projector.eval() _log("[worker] loading projector weights…") state_dict = torch.load(projector_path, map_location="cpu", weights_only=True) slider_projector.load_state_dict(state_dict, strict=True) _log("[worker] assembling pipeline (sharded/offloaded)…") PIPELINE = FluxKontextSliderPipeline.from_pretrained( pretrained, token=token, trust_remote_code=True, transformer=transformer, slider_projector=slider_projector, torch_dtype=weight_dtype, # low_cpu_mem_usage=True, # device_map="balanced_low_0", # offload_folder=offload_dir, # offload_state_dict=True, # max_memory=max_memory, ) _log("[worker] pipeline assembled.") _log(f"[worker] loading LoRA from: {trained_models_path}") PIPELINE.load_lora_weights(trained_models_path) _log("[worker] LoRA loaded.") # moving the pipeline to GPU PIPELINE.to('cuda') # ----------------------------- # Sample Images & Precomputed Results # ----------------------------- def create_sample_entry(name, image_filename, prompt, result_folder, num_results=5, result_pattern="image_{i}.png", precomputed_base="./sample_images/precomputed"): """ Helper function to create a sample entry with subfolder organization. Args: name: Display name in dropdown image_filename: Filename in ./sample_images/ prompt: Editing instruction result_folder: Subfolder name in precomputed directory num_results: Number of precomputed results (default 5) result_pattern: Filename pattern, {i} will be replaced with 0,1,2,3,4 (default "image_{i}.png") precomputed_base: Base path for precomputed results (default "./sample_images/precomputed") """ return { "name": name, "image_path": f"./sample_images/{image_filename}", "prompt": prompt, "precomputed_results": [f"{precomputed_base}/{result_folder}/{result_pattern.format(i=i)}" for i in range(num_results)] } def load_samples_from_config(config_file="sample_config.json"): """Load sample data from a JSON configuration file.""" if os.path.exists(config_file): try: with open(config_file, 'r') as f: return json.load(f) except Exception as e: print(f"Error loading sample config: {e}") return [] def discover_samples_automatically(sample_dir="./sample_images", precomputed_dir="./sample_images/precomputed"): """Automatically discover samples based on directory structure with subfolders.""" discovered_samples = [] if not os.path.exists(sample_dir) or not os.path.exists(precomputed_dir): return discovered_samples # Look for subfolders in precomputed directory for subfolder in os.listdir(precomputed_dir): subfolder_path = os.path.join(precomputed_dir, subfolder) if os.path.isdir(subfolder_path): # Look for sequential result files in subfolder precomputed_files = [] for i in range(0, 15): # Check for up to 15 results starting from 0 # Try different patterns for pattern in [f"image_{i}.png", f"image_{i}.jpg", f"{i}.jpg", f"{i}.png", f"result_{i}.jpg", f"output_{i}.png"]: result_path = os.path.join(subfolder_path, pattern) if os.path.exists(result_path): precomputed_files.append(result_path) break else: # If no file with this index found, stop looking (but continue if we found at least one) if i == 0 and not precomputed_files: continue # Keep trying from index 0 elif not precomputed_files: break # No files found at all else: break # Found some files but this index is missing, stop here if precomputed_files: # Try to find corresponding source image img_path = None # Common naming patterns for source images base_name = subfolder.split('_')[0] # e.g., "portrait" from "portrait_smile" for ext in ['.jpg', '.jpeg', '.png']: candidate = os.path.join(sample_dir, f"{base_name}{ext}") if os.path.exists(candidate): img_path = candidate break if img_path: sample = { "name": f"{subfolder.replace('_', ' ').title()} - Auto-discovered", "image_path": img_path, "prompt": f"Edit: {subfolder.replace('_', ' ')}", # Default prompt "precomputed_results": precomputed_files } discovered_samples.append(sample) return discovered_samples # Main sample data - using your actual folder structure SAMPLE_DATA = [ create_sample_entry("Stylization", "aesthetic_model2_vangogh.png", "Transform the image into a Van Gogh Style painting", "aesthetic_model2_vangogh", 11), create_sample_entry("Weather Change", "enfield3_winter_snow.png", "Transform the scene into winter season with heavy snowfall", "enfield3_winter_snow", 11), create_sample_entry("Illumination Change", "light_lamp_blue_side.png", "Turn on the lamp with blue lighting", "light_lamp_blue_side", 11), create_sample_entry("Appearance Change", "jackson_fluffy.png", "Transform his jacket into a blue fluffy fur jacket", "jackson_fluffy", 11), create_sample_entry("Scene Edit", "venice1_grow_ivy.png", "Grow ivy on the walls of the buildings on the side", "venice1_grow_ivy", 11) ] # Add more samples using the helper function # Modify these examples or add your own: ADDITIONAL_SAMPLES = [ # Add your own samples here following your folder structure: # # For your structure (./sample_images/precomputed/folder_name/image_0.png, image_1.png, etc.): # create_sample_entry("Display Name", "your_image.png", "editing prompt", "folder_name", 12), # # Examples based on your pattern: # create_sample_entry("New Sample", "new_image.png", "apply some effect", "new_folder", 12), # create_sample_entry("Another Edit", "source.png", "different editing instruction", "another_folder", 10), # Note: # - Images should be in ./sample_images/ # - Precomputed results should be in ./sample_images/precomputed/folder_name/ # - Default pattern is image_0.png, image_1.png, etc. # - Adjust the number (12) to match how many results you have ] # Extend the main sample data with additional samples SAMPLE_DATA.extend(ADDITIONAL_SAMPLES) # Optional: Auto-discover additional samples from directories # Uncomment to automatically find additional samples beyond the manual ones above: # AUTO_DISCOVERED = discover_samples_automatically() # if AUTO_DISCOVERED: # print(f"Auto-discovered {len(AUTO_DISCOVERED)} additional samples:") # for sample in AUTO_DISCOVERED: # print(f" - {sample['name']}") # SAMPLE_DATA.extend(AUTO_DISCOVERED) # Optional: Load samples from external JSON config # CONFIG_SAMPLES = load_samples_from_config("sample_config.json") # SAMPLE_DATA.extend(CONFIG_SAMPLES) def load_sample_image(image_path: str) -> Image.Image: """Load a sample image, with fallback to a placeholder if file doesn't exist.""" try: if os.path.exists(image_path): return Image.open(image_path) else: # Create a placeholder image if sample doesn't exist placeholder = Image.new('RGB', (512, 512), color=(200, 200, 200)) return placeholder except Exception as e: print(f"Error loading sample image {image_path}: {e}") # Return a placeholder image placeholder = Image.new('RGB', (512, 512), color=(200, 200, 200)) return placeholder def load_precomputed_results(result_paths: List[str]) -> List[Image.Image]: """Load precomputed result images, with fallbacks for missing files.""" results = [] for path in result_paths: try: if os.path.exists(path): results.append(Image.open(path)) else: # Create placeholder result placeholder = Image.new('RGB', (512, 512), color=(150, 150, 150)) results.append(placeholder) except Exception as e: print(f"Error loading precomputed result {path}: {e}") placeholder = Image.new('RGB', (512, 512), color=(150, 150, 150)) results.append(placeholder) return results # ----------------------------- # Helpers # ----------------------------- def resize_image(img: Image.Image, target: int = 512) -> Image.Image: """Resize shortest side to target, then center-crop to target x target.""" w, h = img.size try: resample = Image.Resampling.BICUBIC # PIL >= 10 except Exception: resample = Image.BICUBIC if h > w: new_w, new_h = target, int(target * h / w) elif h < w: new_w, new_h = int(target * w / h), target else: new_w, new_h = target, target # resizing the image to a fixed lower dimension size of 512 img = img.resize((new_w, new_h), resample) return img # ----------------------------- # Inference functions # ----------------------------- @spaces.GPU(duration=220) @torch.no_grad() def generate_image_stack_edits(text_prompt, n_edits, input_image): """ Compute n_edits images on a single GPU for slider values in (0,1], return (list_of_images, first_image) so the UI shows immediately. """ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # pipelien will be loaded already in the global context and will be called here if not input_image or not text_prompt or text_prompt.startswith("Please select"): return [], None n = int(n_edits) if n_edits is not None else 1 n = max(1, n) slider_values = [(i + 1) / float(n) for i in range(n)] # (0,1] inclusive img = resize_image(input_image, 512) pe, ppe, _ = PIPELINE.encode_prompt(prompt=text_prompt, prompt_2=text_prompt) results: List[Image.Image] = [] gen_base = 64 # deterministic seed base # not using batching for now just a simple forward loop # batch_size = 2 # n_batches = n // batch_size # batched_slider_values = [[slider_values[i*batch_size: (i+1)*batch_size]] for i in range(n_batches)] # print(f"batched_slider_values: {batched_slider_values}") for i, sv in enumerate(slider_values): gen = torch.Generator(device=DEVICE if DEVICE != "cpu" else "cpu").manual_seed(gen_base + i) with torch.no_grad(): # replicating based on the number of examples in the batch size out = PIPELINE( image=img, height=img.height, width=img.width, num_inference_steps=28, prompt_embeds=pe, pooled_prompt_embeds=ppe, generator=gen, text_condn=False, modulation_condn=True, slider_value=torch.tensor(sv, device=DEVICE if DEVICE != "cpu" else "cpu").reshape(1, 1), is_clip_input=True, ) results.append(out.images[0]) if DEVICE.startswith("cuda"): torch.cuda.empty_cache() gc.collect() first = results[0] if results else None return results, first @spaces.GPU(duration=80) def generate_single_image(text_prompt, slider_value, input_image): if not input_image or not text_prompt or text_prompt.startswith("Please select"): return None img = resize_image(input_image, 512) sv = float(slider_value) pe, ppe = _encode_prompt(text_prompt) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" gen = torch.Generator(device=DEVICE if DEVICE != "cpu" else "cpu").manual_seed(64) with torch.no_grad(): out = PIPELINE( image=img, height=img.height, width=img.width, num_inference_steps=28, prompt_embeds=pe, pooled_prompt_embeds=ppe, generator=gen, text_condn=False, modulation_condn=True, slider_value=torch.tensor(sv, device=DEVICE if DEVICE != "cpu" else "cpu").reshape(1, 1), is_clip_input=True, ) result = out.images[0] if DEVICE.startswith("cuda"): torch.cuda.empty_cache() gc.collect() return result # ----------------------------- # Sample Loading Functions # ----------------------------- def get_sample_by_name(sample_name: str): """Get sample data by name.""" for sample in SAMPLE_DATA: if sample["name"] == sample_name: return sample return None def load_sample_to_main_interface(sample_name: str): """Load selected sample to main interface with precomputed results.""" if not sample_name: return ( None, "Please select a sample above to see the editing instruction", [], None, gr.update(minimum=0, maximum=0, step=1, value=0, label="Edit Strength Level") ) sample = get_sample_by_name(sample_name) if not sample: return ( None, "Sample not found", [], None, gr.update(minimum=0, maximum=0, step=1, value=0, label="Edit Strength Level") ) # Load sample image sample_image = load_sample_image(sample["image_path"]) prompt = sample["prompt"] # Load precomputed results precomputed_images = load_precomputed_results(sample["precomputed_results"]) first_result = precomputed_images[0] if precomputed_images else None # Update slider range for precomputed results n_results = len(precomputed_images) slider_update = gr.update( minimum=0, maximum=max(0, n_results-1), step=1, value=0, label=f"Edit Strength Level (0-{n_results-1}) - Precomputed" ) return sample_image, prompt, precomputed_images, first_result, slider_update # ----------------------------- # Helpers # ----------------------------- def update_slider_range(n_edits): """Update the slider range based on number of edits.""" return gr.update( minimum=0, maximum=max(0, int(n_edits)-1), step=1, value=0, label=f"Edit Strength Level (0-{int(n_edits)-1})" ) def display_selected_image(slider_index: int, images_list: List[Image.Image]) -> Image.Image: """ Display the image corresponding to the slider index from the generated images list. Args: slider_index: Current slider position (0-based index) images_list: List of generated/precomputed images Returns: Selected image or None if invalid index/empty list """ if not images_list or len(images_list) == 0: return None # Clamp index to valid range idx = max(0, min(int(slider_index), len(images_list) - 1)) return images_list[idx] # ----------------------------- # Gradio UI # ----------------------------- # Add new helper function for user uploads def process_user_upload(uploaded_image, user_prompt, n_edits_val): """Handle user uploaded images and custom prompts.""" if uploaded_image is None: return None, [], None, gr.update(minimum=0, maximum=0, step=1, value=0, label="Edit Strength Level") # Resize uploaded image processed_image = resize_image(uploaded_image, 512) # Generate edits generated_list, first_result = generate_image_stack_edits(user_prompt, n_edits_val, processed_image) # Update slider range slider_update = gr.update( minimum=0, maximum=max(0, len(generated_list)), step=1, value=0, label=f"Edit Strength Level (0-{len(generated_list)-1})" ) return processed_image, generated_list, first_result, slider_update with gr.Blocks() as demo: gr.Markdown("# Kontinuous Kontext - Continuous Strength Control for Instruction-based Image Editing") # Add description section gr.Markdown(""" ## About ### Kontinuous Kontext allows you to edit a given image with a freeform text instruction and a slider strength value. ### The slider strength enables precise control for the extent of the applied edit and generates smooth transitions between different editing levels. ### You can either: 1. Choose from our sample images with predefined edit instructions 2. Upload your own image and specify custom editing instructions Checkout the [paper](https://arxiv.org/pdf/2510.08532v1) and the [project page](https://snap-research.github.io/kontinuouskontext) for more details. """) # Add custom CSS for tabs gr.Markdown(""" """) with gr.Tabs() as tabs: # Common style parameters for images IMAGE_WIDTH = 512 IMAGE_HEIGHT = 512 with gr.TabItem("📸 Examples") as tab1: # Added emoji and changed tab name with gr.Row(equal_height=True): with gr.Column(scale=1): sample_dropdown = gr.Dropdown( choices=[sample["name"] for sample in SAMPLE_DATA], label="Select Sample Image & Prompt", value=None ) sample_text = gr.Textbox(lines=1, show_label=False, placeholder="Please select a sample above", interactive=False) sample_image = gr.Image( type="pil", label="Source Image", width=IMAGE_WIDTH, height=IMAGE_HEIGHT, interactive=False, elem_id="sample_image" ) with gr.Column(scale=1): with gr.Row(): sample_slider = gr.Slider( minimum=0, maximum=1, step=0.1, value=0, label="Edit Strength", scale=1, min_width=100 ) sample_output = gr.Image( type="pil", label="Edited Output", width=IMAGE_WIDTH, height=IMAGE_HEIGHT, elem_id="sample_output" ) with gr.TabItem("⬆️ Upload Your Image") as tab2: # Added emoji and changed tab name with gr.Row(equal_height=True): with gr.Column(scale=1): upload_text = gr.Textbox(lines=1, label="Enter Editing Prompt", placeholder="Describe the edit you want...") upload_n_edits = gr.Number(value=3, minimum=1, maximum=6, step=1, label="Number of Edits", precision=0) upload_image = gr.Image( type="pil", label="Upload Image", width=IMAGE_WIDTH, height=IMAGE_HEIGHT, elem_id="upload_image" ) upload_button = gr.Button("Generate Edits") # Kept consistent with sample tab with gr.Column(scale=1): with gr.Row(): upload_slider = gr.Slider( minimum=0, maximum=1, step=0.1, value=0, label="Edit Strength Level", scale=1, min_width=100 ) upload_output = gr.Image( type="pil", label="Edited Output", width=IMAGE_WIDTH, height=IMAGE_HEIGHT, elem_id="upload_output" ) # States for both tabs sample_generated_images = gr.State([]) upload_generated_images = gr.State([]) # Sample tab logic sample_dropdown.change( load_sample_to_main_interface, inputs=[sample_dropdown], outputs=[sample_image, sample_text, sample_generated_images, sample_output, sample_slider] ) # sample_button.click( # generate_image_stack_edits, # inputs=[sample_text, sample_n_edits, sample_image], # outputs=[sample_generated_images, sample_output], # ).then( # update_slider_range, # inputs=[sample_n_edits], # outputs=[sample_slider], # ) sample_slider.change( display_selected_image, inputs=[sample_slider, sample_generated_images], outputs=[sample_output], ) # Upload tab logic - Remove duplicate click handler and combine the logic upload_button.click( generate_image_stack_edits, # Generate images first inputs=[upload_text, upload_n_edits, upload_image], outputs=[upload_generated_images, upload_output], ).then( update_slider_range, # Then update slider range inputs=[upload_n_edits], outputs=[upload_slider], ) # Update slider when n_edits changes upload_n_edits.change( update_slider_range, inputs=[upload_n_edits], outputs=[upload_slider], ) upload_slider.change( display_selected_image, inputs=[upload_slider, upload_generated_images], outputs=[upload_output], ) # Add citation section at the bottom gr.Markdown(""" --- ### If you find this work useful, please cite: ```bibtex @article{kontinuous_kontext_2025, title={Kontinuous Kontext: Continuous Strength Control for Instruction-based Image Editing}, author={R Parihar, O Patashnik, D Ostashev, R Venkatesh Babu, D Cohen-Or, and J Wang}, journal={Arxiv}, year={2025} } ``` """) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)