import math import io import os import numpy as np import cv2 import torch import torch.nn.functional as F from diffusers import FluxKontextInpaintPipeline #from numba import njit from tempfile import NamedTemporaryFile from dotenv import load_dotenv from omegaconf import OmegaConf from PIL import Image, ImageFilter from huggingface_hub import hf_hub_download from depth_anything_v2.dpt import DepthAnythingV2 from ultralytics import YOLO from saicinpainting.training.trainers import load_checkpoint from saicinpainting.evaluation.utils import move_to_device from saicinpainting.evaluation.data import pad_tensor_to_modulo from optimization import optimize_pipeline_ load_dotenv(verbose=False) #DEPTH_ANYTHING = DepthAnythingV2(**{'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}) #DEPTH_ANYTHING.load_state_dict(torch.load(hf_hub_download(repo_id='depth-anything/Depth-Anything-V2-Base', filename='depth_anything_v2_vitb.pth', repo_type='model', token=os.environ['HF_TOKEN']), map_location='cpu')) DEPTH_ANYTHING = DepthAnythingV2(**{'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}) DEPTH_ANYTHING.load_state_dict(torch.load(hf_hub_download(repo_id='depth-anything/Depth-Anything-V2-Large', filename='depth_anything_v2_vitl.pth', repo_type='model', token=os.environ['HF_TOKEN']), map_location='cpu')) DEPTH_ANYTHING = DEPTH_ANYTHING.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu').eval() HAND_YOLO = YOLO(hf_hub_download('Bingsu/adetailer', 'hand_yolov8n.pt', token=os.environ['HF_TOKEN'])) PERSON_YOLO = YOLO(hf_hub_download('Bingsu/adetailer', 'person_yolov8n-seg.pt', token=os.environ['HF_TOKEN'])) FACE_YOLO = YOLO(hf_hub_download('Bingsu/adetailer', 'face_yolov9c.pt', token=os.environ['HF_TOKEN'])) LAMA_TRAIN_CFG = OmegaConf.load('big-lama/config.yaml') LAMA_TRAIN_CFG['training_model']['predict_only'] = True LAMA = load_checkpoint(LAMA_TRAIN_CFG, 'big-lama/models/best.ckpt', strict=False, map_location='cpu') LAMA = LAMA.to('cuda' if torch.cuda.is_available() else 'cpu').eval() FLUX_KONTEXT_INPAINT_PIPELINE = FluxKontextInpaintPipeline.from_pretrained('black-forest-labs/FLUX.1-Kontext-dev', torch_dtype=torch.bfloat16, token=os.environ['HF_TOKEN']).to('cuda' if torch.cuda.is_available() else 'cpu') FLUX_KONTEXT_INPAINT_PIPELINE.load_lora_weights('alimama-creative/FLUX.1-Turbo-Alpha') FLUX_KONTEXT_INPAINT_PIPELINE.fuse_lora() optimize_pipeline_(FLUX_KONTEXT_INPAINT_PIPELINE, image=Image.new('RGB', (1024, 1024)), mask_image=Image.new('L', (512, 512)), prompt='prompt') def resize_image(image, maximum=2048, resample=Image.Resampling.LANCZOS): width, height = image.size if width < height: if maximum < height: scale = maximum / height else: return image elif maximum < width: scale = maximum / width else: return image return image.resize((round(width * scale), round(height * scale)), resample=resample) def kmeans_pp(X, n_clusters, n_init=1, max_iter=300, tol=1e-4, random_state=None): X = np.asarray(X, dtype=np.float32) N, D = X.shape n_clusters = min(n_clusters, N) rng = np.random.default_rng(random_state) def init_plus_plus(): centers = np.empty((n_clusters, D), dtype=np.float32) idx0 = rng.integers(N) centers[0] = X[idx0] d2 = np.sum((X - centers[0])**2, axis=1) for c in range(1, n_clusters): s = d2.sum() if not np.isfinite(s) or s <= 0: idx = rng.integers(N) else: r = rng.random() * s idx = np.searchsorted(np.cumsum(d2), r) if idx >= N: idx = N - 1 centers[c] = X[idx] d2 = np.minimum(d2, np.sum((X - centers[c])**2, axis=1)) return centers best_inertia = np.inf best_labels = None best_centers = None for _ in range(n_init): centers = init_plus_plus() labels = np.full(N, -1, dtype=np.int32) for _it in range(max_iter): dmin = np.full(N, np.inf, dtype=np.float32) for c in range(n_clusters): d = np.sum((X - centers[c])**2, axis=1) better = d < dmin labels[better] = c dmin[better] = d[better] new_centers = centers.copy() empty = [] for c in range(n_clusters): pts = X[labels == c] if pts.size == 0: empty.append(c) else: new_centers[c] = pts.mean(axis=0).astype(np.float32) if empty: far_idx = np.argmax(dmin) for c in empty: new_centers[c] = X[far_idx] shift = np.sqrt(((centers - new_centers)**2).sum(axis=1)).max() centers = new_centers if shift <= tol: break dmin = np.full(N, np.inf, dtype=np.float32) for c in range(n_clusters): d = np.sum((X - centers[c])**2, axis=1) better = d < dmin labels[better] = c dmin[better] = d[better] inertia = float(dmin.sum()) if inertia < best_inertia: best_inertia = inertia best_labels = labels.copy() best_centers = centers.copy() return best_labels, best_centers #@njit(cache=True) def _cc8_core(mask_bool): H, W = mask_bool.shape labels = np.zeros((H, W), dtype=np.int32) nby = np.array([-1,-1,-1, 0, 0, 1, 1, 1], dtype=np.int32) nbx = np.array([-1, 0, 1,-1, 1,-1, 0, 1], dtype=np.int32) stack_y = np.empty(H*W, dtype=np.int32) stack_x = np.empty(H*W, dtype=np.int32) maxc = H*W minx = np.empty(maxc+1, dtype=np.int32) miny = np.empty(maxc+1, dtype=np.int32) maxx = np.empty(maxc+1, dtype=np.int32) maxy = np.empty(maxc+1, dtype=np.int32) comp_id = 0 for y0 in range(H): for x0 in range(W): if mask_bool[y0, x0] and labels[y0, x0] == 0: comp_id += 1 minx[comp_id] = x0 maxx[comp_id] = x0 miny[comp_id] = y0 maxy[comp_id] = y0 sp = 0 stack_y[sp] = y0 stack_x[sp] = x0 sp += 1 labels[y0, x0] = comp_id while sp > 0: sp -= 1 y = stack_y[sp] x = stack_x[sp] if x < minx[comp_id]: minx[comp_id] = x if x > maxx[comp_id]: maxx[comp_id] = x if y < miny[comp_id]: miny[comp_id] = y if y > maxy[comp_id]: maxy[comp_id] = y for k in range(8): ny = y + nby[k] nx = x + nbx[k] if 0 <= ny < H and 0 <= nx < W: if mask_bool[ny, nx] and labels[ny, nx] == 0: labels[ny, nx] = comp_id stack_y[sp] = ny stack_x[sp] = nx sp += 1 return labels, comp_id, minx, miny, maxx, maxy def connected_components_8(mask: np.ndarray): ''' H, W = mask.shape labels = np.zeros((H, W), dtype=np.int32) seen = np.zeros((H, W), dtype=bool) nbrs = [(-1,-1),(-1,0),(-1,1), ( 0,-1), ( 0,1), ( 1,-1),( 1,0),( 1,1)] comp_id = 0 bboxes = [] ys, xs = np.where(mask) for y0, x0 in zip(ys, xs): if seen[y0, x0]: continue comp_id += 1 stack = [(y0, x0)] seen[y0, x0] = True labels[y0, x0] = comp_id minx = maxx = x0 miny = maxy = y0 while stack: y, x = stack.pop() if x < minx: minx = x if x > maxx: maxx = x if y < miny: miny = y if y > maxy: maxy = y for dy, dx in nbrs: ny, nx = y + dy, x + dx if 0 <= ny < H and 0 <= nx < W: if mask[ny, nx] and not seen[ny, nx]: seen[ny, nx] = True labels[ny, nx] = comp_id stack.append((ny, nx)) bboxes.append((minx, miny, maxx, maxy)) return labels, bboxes ''' ''' m = np.ascontiguousarray(mask.astype(bool, copy=False)) labels, comp_id, minx, miny, maxx, maxy = _cc8_core(m) bboxes = [(int(minx[i]), int(miny[i]), int(maxx[i]), int(maxy[i])) for i in range(1, comp_id + 1)] return labels.astype(np.int32, copy=False), bboxes ''' m = (mask != 0).astype(np.uint8, copy=False) num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8) bboxes = [] for i in range(1, num_labels): x, y, w, h, _ = stats[i] bboxes.append((int(x), int(y), int(x + w - 1), int(y + h - 1))) return labels.astype(np.int32, copy=False), bboxes def bbox_contained(inner, outer): fx1, fy1, fx2, fy2 = inner mx1, my1, mx2, my2 = outer return (fx1 >= mx1) and (fy1 >= my1) and (fx2 <= mx2) and (fy2 <= my2) def expand_bbox(b, H, W, pad=1): x1,y1,x2,y2 = b return (max(0, x1-pad), max(0, y1-pad), min(W-1, x2+pad), min(H-1, y2+pad)) def overlap_ratio(a, b): ix1, iy1 = max(a[0], b[0]), max(a[1], b[1]) ix2, iy2 = min(a[2], b[2]), min(a[3], b[3]) if ix1 >= ix2 or iy1 >= iy2: return 0.0 inter = (ix2 - ix1) * (iy2 - iy1) area = (b[2] - b[0]) * (b[3] - b[1]) return inter / area def lama_inpaint(model, image, mask, modulo): img_t = torch.from_numpy(np.array(image)).permute(2,0,1).unsqueeze(0) / 255. mask_t = (torch.from_numpy(np.array(mask)) > 0).float().unsqueeze(0).unsqueeze(0) orig_h, orig_w = img_t.shape[-2:] img_t = pad_tensor_to_modulo(img_t, modulo) h, w = mask_t.shape[-2:] pad_h = (modulo - h % modulo) % modulo pad_w = (modulo - w % modulo) % modulo mask_t = F.pad(mask_t, (0, pad_w, 0, pad_h), mode='constant', value=0) batch = {'image': img_t, 'mask': mask_t} batch = move_to_device(batch, model.device) with torch.no_grad(): result = model(batch)['inpainted'][0].permute(1, 2, 0).detach().cpu().numpy() result = result[:orig_h, :orig_w, ...] result = (result.clip(0, 1) * 255).astype('uint8') return Image.fromarray(result) def feather(image: Image.Image, gauss_radius=1, band_px=1, strength=1.0) -> Image.Image: A_pil = image.getchannel('A') k = 2 * int(band_px) + 1 # odd a_dil = A_pil.filter(ImageFilter.MaxFilter(k)) a_ero = A_pil.filter(ImageFilter.MinFilter(k)) band = np.asarray(a_dil, dtype=np.uint8) != np.asarray(a_ero, dtype=np.uint8) arr = np.asarray(image, dtype=np.float32) / 255.0 A = arr[..., 3:4] rgb_pm = arr[..., :3] * A pm_rgba_u8 = np.empty(arr.shape, dtype=np.uint8) pm_rgba_u8[..., :3] = np.clip(rgb_pm * 255.0, 0, 255).astype(np.uint8) pm_rgba_u8[..., 3] = (arr[..., 3] * 255.0 + 0.5).astype(np.uint8) blurred = Image.fromarray(pm_rgba_u8, 'RGBA').filter(ImageFilter.GaussianBlur(gauss_radius)) blurred_f = np.asarray(blurred, dtype=np.float32) / 255.0 rgb_pm_blur = blurred_f[..., :3] A_blur = blurred_f[..., 3:4] s = float(np.clip(strength, 0.0, 1.0)) if s < 1.0: A_blur = (1.0 - s) * A + s * A_blur eps = 1e-6 rgb_norm = rgb_pm_blur / np.maximum(A_blur, eps) band3 = band[..., None] out_rgb = np.where(band3, rgb_norm, arr[..., :3]) out_A = np.where(band3, A_blur, A) out = np.concatenate([out_rgb, out_A], axis=-1) out = (np.clip(out, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) return Image.fromarray(out, 'RGBA') def convert_webp(image: Image.Image) -> str: with io.BytesIO() as buffer: image.save(buffer, format='WEBP', lossless=True, method=6) buffer.seek(0) with NamedTemporaryFile(delete=False, suffix='.webp') as file: file.write(buffer.read()) file.flush() return file.name def generate_animation_images(image, prompts=['make the eyes half-closed evenly; keep identity and all other details unchanged', 'make the eyes closed; keep identity and all other details unchanged'], target_size=1024, frame=None): rgba = np.asarray(image) face_results = FACE_YOLO.predict(source=Image.fromarray(np.asarray(Image.alpha_composite(Image.new('RGBA', image.size, (0, 0, 0)), image))[:, :, :3]), conf=0.5, iou=0.45, verbose=False, device='0' if torch.cuda.is_available() else 'cpu') edited_images = [] if len(face_results) > 0 and face_results[0].boxes is not None and len(face_results[0].boxes) > 0: min_x = np.iinfo(np.int32).max min_y = np.iinfo(np.int32).max max_x = 0 max_y = 0 boxes = [] filtered_prompts = [] edited_outputs = [] if frame is None: for prompt in prompts: filtered_prompts.append(prompt) edited_outputs.append(rgba.copy()) else: for index, prompt in enumerate(prompts): if frame == index: filtered_prompts.append(prompt) edited_outputs.append(rgba.copy()) break for xyxy in np.round(face_results[0].boxes.xyxy.detach().cpu().numpy()).astype(np.int32): min_x = min(min_x, xyxy[0]) min_y = min(min_y, xyxy[1]) max_x = max(max_x, xyxy[2]) max_y = max(max_y, xyxy[3]) boxes.append(xyxy) break w = max_x - min_x h = max_y - min_y if w <= target_size and h <= target_size: crop = rgba[min_y:max_y, min_x:max_x].copy() padded = np.zeros((target_size, target_size, 3), dtype=crop.dtype) padded[:h, :w, :] = crop[:, :, :3] mask = np.zeros((target_size, target_size), dtype=np.uint8) for x1, y1, x2, y2 in boxes: mask[y1 - min_y:y2 - min_y, x1 - min_x:x2 - min_x] = 255 for index, prompt in enumerate(filtered_prompts): #edited_rgb = np.asarray(FLUX_KONTEXT_INPAINT_PIPELINE(image=Image.fromarray(padded, 'RGB'), mask_image=Image.fromarray(mask, 'L'), prompt=prompt, guidance_scale=2.5, num_inference_steps=28).images[0]) #edited_rgb = np.asarray(FLUX_KONTEXT_INPAINT_PIPELINE(image=Image.fromarray(padded, 'RGB'), mask_image=Image.fromarray(mask, 'L'), prompt=prompt, guidance_scale=3.5, num_inference_steps=8).images[0]) # alimama-creative/FLUX.1-Turbo-Alpha LoRA edited_rgb = np.asarray(FLUX_KONTEXT_INPAINT_PIPELINE(image=Image.fromarray(padded, 'RGB'), mask_image=Image.fromarray(mask, 'L'), prompt=prompt, guidance_scale=2.5, num_inference_steps=8).images[0]) # alimama-creative/FLUX.1-Turbo-Alpha LoRA for x1, y1, x2, y2 in boxes: edited_outputs[index][y1:y2, x1:x2, :3] = edited_rgb[y1 - min_y:y2 - min_y, x1 - min_x:x2 - min_x, :] for edited_rgba in edited_outputs: edited_images.append(convert_webp(Image.fromarray(edited_rgba, 'RGBA'))) return edited_images def generate_parallax_images(image, n_layers=5, maximum=2048, strategy=None): resized_image = resize_image(image, maximum) resized_rgba = np.asarray(resized_image) alpha = resized_rgba[:, :, 3] alpha_mask = alpha < 255 rgb_image = Image.alpha_composite(Image.new('RGBA', resized_image.size, (0, 0, 0)), resized_image).convert('RGB') width, height = rgb_image.size rgb = np.asarray(rgb_image) depth = DEPTH_ANYTHING.infer_image(rgb[:, :, ::-1]) if strategy == 'k-means': n_clusters = n_layers x = depth.reshape(-1, 1) mask = np.isfinite(x[:, 0]) labels, centers = kmeans_pp(x[mask].astype(np.float32), n_clusters=n_clusters, n_init=1, max_iter=100, tol=1e-4, random_state=None) centers = centers.reshape(-1) order = np.argsort(centers) rank_of_label = np.empty_like(order) rank_of_label[order] = np.arange(n_clusters) labels_full = np.full(x.shape[0], -1, dtype=int) labels_full[mask] = labels levels = centers[order].astype(np.float64) quantized_depth = np.zeros(x.shape[0], dtype=np.float32) valid_idx = np.where(mask)[0] quantized_depth[valid_idx] = levels[rank_of_label[labels_full[valid_idx]]] quantized_depth = quantized_depth.reshape(height, width) depth = quantized_depth.astype(np.float64) depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) edges = (levels - levels.min()) / (levels.max() - levels.min() + 1e-8) else: bins = np.linspace(0, np.max(depth), n_layers + 1) quantized = np.digitize(depth, bins) - 1 depth = quantized * (1 / (n_layers - 1)) edges = np.arange(n_layers) * (1 / (n_layers - 1)) depth_mod = np.zeros_like(depth, dtype=np.float64) front_mask = depth >= edges[len(edges) - 1] front_labels, front_bboxes = connected_components_8(front_mask) _, near_bboxes = connected_components_8(depth >= edges[1]) inpaint_mask = np.zeros_like(front_mask, dtype=bool) person_results = PERSON_YOLO.predict(source=rgb_image, conf=0.5, iou=0.45, verbose=False, device='0' if torch.cuda.is_available() else 'cpu') hand_results = HAND_YOLO.predict(source=rgb_image, conf=0.5, iou=0.45, verbose=False, device='0' if torch.cuda.is_available() else 'cpu') person_boxes = [] hand_boxes = [] if len(person_results) > 0 and person_results[0].boxes is not None and len(person_results[0].boxes) > 0: for xyxy in person_results[0].boxes.xyxy.detach().cpu().numpy(): person_boxes.append(xyxy) if len(hand_results) > 0 and hand_results[0].boxes is not None and len(hand_results[0].boxes) > 0: for xyxy in hand_results[0].boxes.xyxy.detach().cpu().numpy(): hand_boxes.append(xyxy) if len(front_bboxes) > 0: need_inpaint = True inpaintable_indexes = [] for i, fb in enumerate(front_bboxes, start=1): contained = any(bbox_contained(fb, mb) for mb in near_bboxes) inpaintable = False if contained: fx1, fy1, fx2, fy2 = fb fb_exclusive = np.array([fx1, fy1, fx2 + 1, fy2 + 1], dtype=np.int32) detected_hand = False for xyxy in hand_boxes: area_a = (xyxy[2] - xyxy[0]) * (xyxy[3] - xyxy[1]) area_b = (fb_exclusive[2] - fb_exclusive[0]) * (fb_exclusive[3] - fb_exclusive[1]) if area_a > area_b: a = xyxy b = fb_exclusive else: a = fb_exclusive b = xyxy if overlap_ratio(a, b) >= 0.75: detected_hand = True break if detected_hand: inpaintable = True else: detected_person = False for xyxy in person_boxes: area_a = (xyxy[2] - xyxy[0]) * (xyxy[3] - xyxy[1]) area_b = (fb_exclusive[2] - fb_exclusive[0]) * (fb_exclusive[3] - fb_exclusive[1]) if area_a > area_b: a = xyxy b = fb_exclusive else: a = fb_exclusive b = xyxy if overlap_ratio(a, b) >= 0.75: detected_person = True break if not detected_person: inpaintable = True inpaintable_indexes.append(inpaintable) if all(inpaintable_indexes): need_inpaint = True for i, fb in enumerate(front_bboxes, start=1): inpaint_mask |= (front_labels == i) else: need_inpaint = False else: need_inpaint = False if need_inpaint: hi_labels, hi_bboxes = connected_components_8((depth >= edges[1]) & (depth < edges[len(edges) - 1])) for cid in range(1, hi_labels.max() + 1): comp = (hi_labels == cid) median = np.median(depth[comp]) depth_mod[comp] = median keep_mask = (depth < edges[1]) depth_mod[keep_mask] = depth[keep_mask] depth_mod[depth >= edges[len(edges) - 1]] = edges[len(edges) - 1] else: hi_labels, hi_bboxes = connected_components_8(depth >= edges[1]) for cid in range(1, hi_labels.max() + 1): comp = (hi_labels == cid) median = np.median(depth[comp]) depth_mod[comp] = median keep_mask = (depth < edges[1]) depth_mod[keep_mask] = depth[keep_mask] depth = depth_mod layers = [] for i in reversed(range(n_layers)): if i > 0: if i < n_layers - 1: mask = (depth >= edges[i]) & (depth < edges[i + 1]) if rgb[mask].size > 0: if need_inpaint: need_inpaint = False hole_mask = Image.fromarray((inpaint_mask * 255).astype(np.uint8), mode='L').filter(ImageFilter.BoxBlur(16)) inpaint_image = lama_inpaint(LAMA, rgb_image, hole_mask, LAMA_TRAIN_CFG.get('dataset', {}).get('pad_out_to_modulo', 8)) if inpaint_image.size != (width, height): inpaint_image = inpaint_image.resize((width, height), Image.Resampling.BICUBIC) inpaint = np.asarray(inpaint_image.convert('RGB')) rgba = np.zeros((height, width, 4), np.uint8) rgba[..., :3][inpaint_mask] = inpaint[..., :3][inpaint_mask] rgba[..., 3][inpaint_mask] = 255 rgba[..., :3][mask] = inpaint[..., :3][mask] rgba[..., 3][mask] = 255 rgba[..., 3][alpha_mask] = alpha[alpha_mask] layers.insert(0, convert_webp(feather(Image.fromarray(rgba, 'RGBA')))) continue else: layers.insert(0, None) continue else: mask = (depth >= edges[i]) if rgb[mask].size == 0: layers.insert(0, None) continue rgba = np.zeros((height, width, 4), np.uint8) rgba[..., :3][mask] = rgb[mask] rgba[..., 3][mask] = 255 rgba[..., 3][alpha_mask] = alpha[alpha_mask] layers.insert(0, convert_webp(feather(Image.fromarray(rgba, 'RGBA')))) else: mask = (depth < edges[1]) if rgb[mask].size > 0: rgba = np.zeros((height, width, 4), np.uint8) rgba[..., :3][mask] = rgb[mask] rgba[..., 3][mask] = 255 mask_image = Image.fromarray(((rgba[..., 3] == 0) * 255).astype(np.uint8), mode='L').filter(ImageFilter.BoxBlur(16)) inpaint_image = lama_inpaint(LAMA, rgb_image, mask_image, LAMA_TRAIN_CFG.get('dataset', {}).get('pad_out_to_modulo', 8)) if inpaint_image.size != (width, height): inpaint_image = inpaint_image.resize((width, height), Image.Resampling.BICUBIC) output = np.asarray(inpaint_image.convert('RGBA')).copy() output[..., 3][alpha_mask] = alpha[alpha_mask] layers.insert(0, convert_webp(Image.fromarray(output))) else: layers.insert(0, None) return layers