|
|
--- |
|
|
pipeline_tag: image-to-image |
|
|
library_name: transformers |
|
|
license: apache-2.0 |
|
|
--- |
|
|
|
|
|
# EARL: The Promise of RL for Autoregressive Image Editing |
|
|
|
|
|
Official model for the paper [The Promise of RL for Autoregressive Image Editing](https://huggingface.co/papers/2508.01119). |
|
|
|
|
|
[](https://arxiv.org/abs/2508.01119) |
|
|
[](https://github.com/saba96/EARL) |
|
|
[](https://huggingface.co/Image-editing/imged_rl_grpo_sft.s_rl.sc/tree/ckpt_001999) |
|
|
|
|
|
 |
|
|
|
|
|
## Abstract |
|
|
We explore three strategies to enhance performance on a wide range of image editing tasks: supervised fine-tuning (SFT), reinforcement learning (RL), and Chain-of-Thought (CoT) reasoning. In order to study all these components in one consistent framework, we adopt an autoregressive multimodal model that processes textual and visual tokens in a unified manner. We find RL combined with a large multi-modal LLM verifier to be the most effective of these strategies. As a result, we release EARL: Editing with Autoregression and RL, a strong RL-based image editing model that performs competitively on a diverse range of edits compared to strong baselines, despite using much less training data. Thus, EARL pushes the frontier of autoregressive multimodal models on image editing. We release our code, training data, and trained models at this https URL . |
|
|
|
|
|
## Overview |
|
|
EARL (Editing with Autoregression and RL) introduces a novel approach to image editing using an autoregressive multimodal model. It processes textual and visual tokens in a unified manner and leverages reinforcement learning combined with a large multi-modal LLM verifier to achieve strong performance across various image editing tasks. The model is designed for efficiency, using significantly less training data than comparable baselines, and pushes the frontier of autoregressive multimodal models on image editing. |
|
|
|
|
|
## Usage |
|
|
You can quickly try the model using vLLM for inference. |
|
|
|
|
|
First, clone the official repository and install the prerequisites: |
|
|
```bash |
|
|
git clone https://github.com/saba96/EARL.git |
|
|
cd EARL |
|
|
python -m venv /path/to/envs/EARL |
|
|
. /path/to/envs/EARL/bin/activate |
|
|
pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124 |
|
|
pip install vllm==0.8.4 |
|
|
pip install flash-attn==2.7.4.post1 --no-build-isolation |
|
|
pip install -r requirements.txt |
|
|
export PYTHONPATH=$(pwd) |
|
|
``` |
|
|
|
|
|
**Patch vLLM to support Emu3**: |
|
|
This is a critical step. You need to edit the `registry.py` file in your vLLM installation. |
|
|
``` |
|
|
vim /path/to/venv/lib/python3.10/site-packages/vllm/model_executor/models/registry.py |
|
|
``` |
|
|
Add the following line to the `_MULTIMODAL_MODELS` dictionary around line 166: |
|
|
```python |
|
|
_MULTIMODAL_MODELS = { |
|
|
# add this line |
|
|
"Emu3ForCausalLM": ("llama", "LlamaForCausalLM"), |
|
|
# end of adding |
|
|
"AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"), # already exists |
|
|
# ... other models |
|
|
} |
|
|
``` |
|
|
|
|
|
Then, run inference using the following Python code snippet. Ensure you have an image file ready (e.g., `./examples/images/web_dfacd48d-d2c2-492f-b94c-41e6a34ea99f.png` from the original repository). |
|
|
|
|
|
```python |
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision.transforms as T |
|
|
from PIL import Image |
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
from transformers import AutoTokenizer |
|
|
from vllm import LLM, ModelRegistry, SamplingParams |
|
|
|
|
|
# Ensure Emu3ForCausalLM is available or registered. |
|
|
# If you cloned the repo, it should be importable from emu3.model.modeling_emu3_vllm |
|
|
# For demonstration, we'll assume it's correctly handled by trust_remote_code or local setup. |
|
|
# If you face issues, ensure the model's specific class is registered with vLLM's ModelRegistry. |
|
|
# Example: from emu3.model.modeling_emu3_vllm import Emu3ForCausalLM |
|
|
# ModelRegistry.register_model("Emu3ForCausalLM", Emu3ForCausalLM) |
|
|
|
|
|
|
|
|
# --- Helper functions from original repo for image preprocessing --- |
|
|
IMAGENET_MEAN = (0.485, 0.456, 0.406) |
|
|
IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
|
|
|
def build_transform(input_size): |
|
|
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD |
|
|
transform = T.Compose([ |
|
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
|
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=MEAN, std=STD) |
|
|
]) |
|
|
return transform |
|
|
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
|
|
best_ratio_diff = float('inf') |
|
|
best_ratio = (1, 1) |
|
|
area = width * height |
|
|
for ratio in target_ratios: |
|
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
|
if ratio_diff < best_ratio_diff: |
|
|
best_ratio_diff = ratio_diff |
|
|
best_ratio = ratio |
|
|
elif ratio_diff == best_ratio_diff: |
|
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
|
|
best_ratio = ratio |
|
|
return best_ratio |
|
|
|
|
|
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): |
|
|
orig_width, orig_height = image.size |
|
|
aspect_ratio = orig_width / orig_height |
|
|
|
|
|
target_ratios = set( |
|
|
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
|
|
i * j <= max_num and i * j >= min_num) |
|
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio( |
|
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size) |
|
|
|
|
|
target_width = image_size * target_aspect_ratio[0] |
|
|
target_height = image_size * target_aspect_ratio[1] |
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
resized_img = image.resize((target_width, target_height)) |
|
|
processed_images = [] |
|
|
for i in range(blocks): |
|
|
box = ( |
|
|
(i % (target_width // image_size)) * image_size, |
|
|
(i // (target_width // image_size)) * image_size, |
|
|
((i % (target_width // image_size)) + 1) * image_size, |
|
|
((i // (target_width // image_size)) + 1) * image_size |
|
|
) |
|
|
split_img = resized_img.crop(box) |
|
|
processed_images.append(split_img) |
|
|
assert len(processed_images) == blocks |
|
|
if use_thumbnail and len(processed_images) != 1: |
|
|
thumbnail_img = image.resize((image_size, image_size)) |
|
|
processed_images.append(thumbnail_img) |
|
|
return processed_images |
|
|
|
|
|
def load_image(image_file, input_size=448, max_num=12): |
|
|
image = Image.open(image_file).convert('RGB') |
|
|
transform = build_transform(input_size=input_size) |
|
|
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) |
|
|
pixel_values = [transform(image) for image in images] |
|
|
pixel_values = torch.stack(pixel_values) |
|
|
return pixel_values |
|
|
# ------------------------------------------------------------------- |
|
|
|
|
|
# Load the model with vLLM |
|
|
path = 'Image-editing/imged_rl_grpo_sft.s_rl.sc' # Model ID from Hugging Face Hub |
|
|
llm = LLM( |
|
|
model=path, |
|
|
trust_remote_code=True, |
|
|
dtype="auto", # or torch.bfloat16 if supported by your hardware |
|
|
gpu_memory_utilization=0.9, |
|
|
# Additional vLLM specific arguments if needed |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False) |
|
|
|
|
|
# Prepare inputs |
|
|
image_path = './examples/images/web_dfacd48d-d2c2-492f-b94c-41e6a34ea99f.png' # Replace with a path to your image |
|
|
# The `load_image` function prepares the pixel values as expected by the model. |
|
|
pixel_values = load_image(image_path, max_num=6).to(torch.bfloat16).cuda() # Ensure image is loaded and moved to GPU |
|
|
|
|
|
# Format the prompt |
|
|
question = "Edit the image: change the color of the car to red." |
|
|
prompt = f"A chat between a curious user and an AI assistant. |
|
|
USER: <image> |
|
|
{question} ASSISTANT:" |
|
|
|
|
|
sampling_params = SamplingParams(max_tokens=512, temperature=0.7) # Adjust as needed |
|
|
|
|
|
# In vLLM, for multimodal models, the image input might be handled internally |
|
|
# or require specific passing depending on the model's vLLM integration. |
|
|
# The `llm.generate` method typically handles a list of string prompts. |
|
|
# For full multimodal interaction with vLLM, refer to the original EARL GitHub: |
|
|
# https://github.com/saba96/EARL/blob/main/emu3/train_image_editing/vllm_inference.py |
|
|
|
|
|
# This example illustrates the textual part of inference with vLLM, |
|
|
# assuming the model's vLLM integration handles the image input when loading the model. |
|
|
# A full end-to-end vLLM multimodal inference might look slightly different. |
|
|
outputs = llm.generate([prompt], sampling_params) # Pass prompt as a list for vLLM |
|
|
|
|
|
response = outputs[0].outputs[0].text |
|
|
print(f'User: {question} |
|
|
Assistant: {response}') |
|
|
``` |
|
|
|
|
|
## Citation |
|
|
If you find our work helpful or inspiring, please feel free to cite it. |
|
|
```bibtex |
|
|
@article{saba2025earl, |
|
|
title={The Promise of RL for Autoregressive Image Editing}, |
|
|
author={Saba, Daniel and Tang, Sifei and Huang, Yifan and Liu, Meng and Ma, Jinxin and Liu, Zhian and Fu, Ruifeng and Zhu, Lei and Han, Jun and Zhang, Shang-Wen and Liu, Jing}, |
|
|
journal={arXiv preprint arXiv:2508.01119}, |
|
|
year={2025} |
|
|
} |
|
|
``` |