File size: 1,743 Bytes
0dbb851
b48aabb
 
5427776
ab46247
b4bd903
0dbb851
5427776
0dbb851
1ce0db8
2c4c3a2
0dbb851
9c74605
b48aabb
1ce0db8
 
9c74605
475810c
a582418
2930e2d
9c74605
ab46247
9c74605
 
03db192
9c74605
 
 
 
 
 
 
1ce0db8
b48aabb
5427776
475810c
a582418
 
5427776
 
5010bfa
b48aabb
 
 
 
 
 
 
1ce0db8
6b3e10e
1ce0db8
b48aabb
 
3d21a14
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr
import spaces
import torch
import random
from diffusers import DiffusionPipeline, AutoPipelineForText2Image, FluxPipeline
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_model = os.getenv("MODEL_NAME", "Heartsync/NSFW-Uncensored")
hf_token = os.getenv("HF_TOKEN", None)


@spaces.GPU(duration=120)
def generate(prompt, negative_prompt, model=image_model):
    print("Generating image...")
    pipe = None
    if negative_prompt == "" or negative_prompt == None:
        negative_prompt = "ugly, deformed, disfigured, poor quality, low resolution"
    if model == 'enhanceaiteam/Flux-uncensored-v2':
        # Load the base model
        pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=hf_token).to('cuda')
        
        # Load the uncensored LoRA weights
        pipe.load_lora_weights('enhanceaiteam/Flux-uncensored-v2', weight_name='lora.safetensors')

    else:
        pipe = DiffusionPipeline.from_pretrained(
            model,
            torch_dtype=torch.float16
        )
        pipe.to(device)
    print(f"Using model: {image_model}")            
    return pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        guidance_scale=8.0,
        num_inference_steps=50,
        width=1024,
        height=1024,
    ).images


gr.Interface(
    fn=generate,
    inputs=[
        gr.Text(label="Prompt"),
        gr.Text("", label="Negative Prompt"),
        gr.Dropdown(
            ["Heartsync/NSFW-Uncensored", "UnfilteredAI/NSFW-gen-v2", 'enhanceaiteam/Flux-uncensored-v2'], label="Image model", info="Select the image model:"
        ),
    ],
    outputs=gr.Gallery(),
).launch(show_api=True)