Masaaki Kawata commited on
Commit
4d03b1b
Β·
1 Parent(s): 359e4ac

Add parallax.py

Browse files
Files changed (1) hide show
  1. parallax.py +431 -0
parallax.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from dotenv import load_dotenv
5
+ from PIL import Image, ImageFilter
6
+ from huggingface_hub import hf_hub_download
7
+ from depth_anything_v2.dpt import DepthAnythingV2
8
+ from ultralytics import YOLO
9
+ from simple_lama_inpainting import SimpleLama
10
+
11
+
12
+ load_dotenv(verbose=False)
13
+
14
+
15
+ depth_anything_model_path = hf_hub_download(repo_id='depth-anything/Depth-Anything-V2-Large', filename='depth_anything_v2_vitl.pth', repo_type='model', token=os.environ['HF_TOKEN'])
16
+ yolo_hand_model_path = hf_hub_download('Bingsu/adetailer', 'hand_yolov8n.pt', token=os.environ['HF_TOKEN'])
17
+ yolo_person_model_path = hf_hub_download('Bingsu/adetailer', 'person_yolov8n-seg.pt', token=os.environ['HF_TOKEN'])
18
+
19
+
20
+ def resize_iamge(image, maximum=2048, resample=Image.Resampling.LANCZOS):
21
+ width, height = image.size
22
+
23
+ if width < height:
24
+ if maximum < height:
25
+ scale = maximum / height
26
+ else:
27
+ return image
28
+ elif maximum < width:
29
+ scale = maximum / width
30
+ else:
31
+ return image
32
+
33
+ return image.resize((round(width * scale), round(height * scale)), resample=resample)
34
+
35
+
36
+ def kmeans_pp(X, n_clusters, n_init=1, max_iter=300, tol=1e-4, random_state=None):
37
+ X = np.asarray(X, dtype=np.float32)
38
+ N, D = X.shape
39
+ n_clusters = min(n_clusters, N)
40
+
41
+ rng = np.random.default_rng(random_state)
42
+
43
+ def init_plus_plus():
44
+ centers = np.empty((n_clusters, D), dtype=np.float32)
45
+ idx0 = rng.integers(N)
46
+ centers[0] = X[idx0]
47
+ d2 = np.sum((X - centers[0])**2, axis=1)
48
+
49
+ for c in range(1, n_clusters):
50
+ s = d2.sum()
51
+
52
+ if not np.isfinite(s) or s <= 0:
53
+ idx = rng.integers(N)
54
+ else:
55
+ r = rng.random() * s
56
+ idx = np.searchsorted(np.cumsum(d2), r)
57
+
58
+ if idx >= N:
59
+ idx = N - 1
60
+
61
+ centers[c] = X[idx]
62
+ d2 = np.minimum(d2, np.sum((X - centers[c])**2, axis=1))
63
+
64
+ return centers
65
+
66
+ best_inertia = np.inf
67
+ best_labels = None
68
+ best_centers = None
69
+
70
+ for _ in range(n_init):
71
+ centers = init_plus_plus()
72
+
73
+ labels = np.full(N, -1, dtype=np.int32)
74
+
75
+ for _it in range(max_iter):
76
+ dmin = np.full(N, np.inf, dtype=np.float32)
77
+
78
+ for c in range(n_clusters):
79
+ d = np.sum((X - centers[c])**2, axis=1)
80
+ better = d < dmin
81
+ labels[better] = c
82
+ dmin[better] = d[better]
83
+
84
+ new_centers = centers.copy()
85
+ empty = []
86
+
87
+ for c in range(n_clusters):
88
+ pts = X[labels == c]
89
+ if pts.size == 0:
90
+ empty.append(c)
91
+ else:
92
+ new_centers[c] = pts.mean(axis=0).astype(np.float32)
93
+
94
+ if empty:
95
+ far_idx = np.argmax(dmin)
96
+
97
+ for c in empty:
98
+ new_centers[c] = X[far_idx]
99
+
100
+ shift = np.sqrt(((centers - new_centers)**2).sum(axis=1)).max()
101
+ centers = new_centers
102
+
103
+ if shift <= tol:
104
+ break
105
+
106
+ dmin = np.full(N, np.inf, dtype=np.float32)
107
+
108
+ for c in range(n_clusters):
109
+ d = np.sum((X - centers[c])**2, axis=1)
110
+ better = d < dmin
111
+ labels[better] = c
112
+ dmin[better] = d[better]
113
+ inertia = float(dmin.sum())
114
+
115
+ if inertia < best_inertia:
116
+ best_inertia = inertia
117
+ best_labels = labels.copy()
118
+ best_centers = centers.copy()
119
+
120
+ return best_labels, best_centers
121
+
122
+
123
+ def connected_components_8(mask: np.ndarray):
124
+ H, W = mask.shape
125
+ labels = np.zeros((H, W), dtype=np.int32)
126
+ seen = np.zeros((H, W), dtype=bool)
127
+ nbrs = [(-1,-1),(-1,0),(-1,1),
128
+ ( 0,-1), ( 0,1),
129
+ ( 1,-1),( 1,0),( 1,1)]
130
+ comp_id = 0
131
+ bboxes = []
132
+
133
+ ys, xs = np.where(mask)
134
+
135
+ for y0, x0 in zip(ys, xs):
136
+ if seen[y0, x0]:
137
+ continue
138
+
139
+ comp_id += 1
140
+ stack = [(y0, x0)]
141
+ seen[y0, x0] = True
142
+ labels[y0, x0] = comp_id
143
+
144
+ minx = maxx = x0
145
+ miny = maxy = y0
146
+
147
+ while stack:
148
+ y, x = stack.pop()
149
+
150
+ if x < minx: minx = x
151
+ if x > maxx: maxx = x
152
+ if y < miny: miny = y
153
+ if y > maxy: maxy = y
154
+
155
+ for dy, dx in nbrs:
156
+ ny, nx = y + dy, x + dx
157
+
158
+ if 0 <= ny < H and 0 <= nx < W:
159
+ if mask[ny, nx] and not seen[ny, nx]:
160
+ seen[ny, nx] = True
161
+ labels[ny, nx] = comp_id
162
+ stack.append((ny, nx))
163
+
164
+ bboxes.append((minx, miny, maxx, maxy))
165
+
166
+ return labels, bboxes
167
+
168
+
169
+ def bbox_contained(inner, outer):
170
+ fx1, fy1, fx2, fy2 = inner
171
+ mx1, my1, mx2, my2 = outer
172
+
173
+ return (fx1 >= mx1) and (fy1 >= my1) and (fx2 <= mx2) and (fy2 <= my2)
174
+
175
+
176
+ def expand_bbox(b, H, W, pad=1):
177
+ x1,y1,x2,y2 = b
178
+
179
+ return (max(0, x1-pad), max(0, y1-pad), min(W-1, x2+pad), min(H-1, y2+pad))
180
+
181
+
182
+ def overlap_ratio(a, b):
183
+ ix1, iy1 = max(a[0], b[0]), max(a[1], b[1])
184
+ ix2, iy2 = min(a[2], b[2]), min(a[3], b[3])
185
+
186
+ if ix1 >= ix2 or iy1 >= iy2:
187
+ return 0.0
188
+
189
+ inter = (ix2 - ix1) * (iy2 - iy1)
190
+ area = (b[2] - b[0]) * (b[3] - b[1])
191
+
192
+ return inter / area
193
+
194
+
195
+ def feather(image: Image.Image, gauss_radius=1, band_px=1, strength=1.0) -> Image.Image:
196
+ A_pil = image.getchannel('A')
197
+ k = 2 * int(band_px) + 1 # odd
198
+ a_dil = A_pil.filter(ImageFilter.MaxFilter(k))
199
+ a_ero = A_pil.filter(ImageFilter.MinFilter(k))
200
+ band = np.asarray(a_dil, dtype=np.uint8) != np.asarray(a_ero, dtype=np.uint8)
201
+
202
+ arr = np.asarray(image, dtype=np.float32) / 255.0
203
+ A = arr[..., 3:4]
204
+ rgb_pm = arr[..., :3] * A
205
+
206
+ pm_rgba_u8 = np.empty(arr.shape, dtype=np.uint8)
207
+ pm_rgba_u8[..., :3] = np.clip(rgb_pm * 255.0, 0, 255).astype(np.uint8)
208
+ pm_rgba_u8[..., 3] = (arr[..., 3] * 255.0 + 0.5).astype(np.uint8)
209
+
210
+ blurred = Image.fromarray(pm_rgba_u8, 'RGBA').filter(ImageFilter.GaussianBlur(gauss_radius))
211
+ blurred_f = np.asarray(blurred, dtype=np.float32) / 255.0
212
+ rgb_pm_blur = blurred_f[..., :3]
213
+ A_blur = blurred_f[..., 3:4]
214
+
215
+ s = float(np.clip(strength, 0.0, 1.0))
216
+
217
+ if s < 1.0:
218
+ A_blur = (1.0 - s) * A + s * A_blur
219
+
220
+ eps = 1e-6
221
+ rgb_norm = rgb_pm_blur / np.maximum(A_blur, eps)
222
+
223
+ band3 = band[..., None]
224
+ out_rgb = np.where(band3, rgb_norm, arr[..., :3])
225
+ out_A = np.where(band3, A_blur, A)
226
+
227
+ out = np.concatenate([out_rgb, out_A], axis=-1)
228
+ out = (np.clip(out, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8)
229
+
230
+ return Image.fromarray(out, 'RGBA')
231
+
232
+
233
+ def generate_parallax_images(image, n_layers=5):
234
+ rgb_image = resize_iamge(image.convert('RGB'), 2048)
235
+ width, height = rgb_image.size
236
+ rgb = np.asarray(rgb_image)
237
+
238
+ device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
239
+ depth_anything = DepthAnythingV2(**{'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]})
240
+ depth_anything.load_state_dict(torch.load(depth_anything_model_path, map_location='cpu'))
241
+ depth_anything = depth_anything.to(device).eval()
242
+ hand_yolo = YOLO(yolo_hand_model_path)
243
+ person_yolo = YOLO(yolo_person_model_path)
244
+ lama = SimpleLama(device='cuda' if torch.cuda.is_available() else 'cpu')
245
+
246
+ depth = depth_anything.infer_image(rgb[:, :, ::-1])
247
+
248
+ n_clusters = n_layers
249
+ x = depth.reshape(-1, 1)
250
+ mask = np.isfinite(x[:, 0])
251
+ labels, centers = kmeans_pp(x[mask].astype(np.float32), n_clusters=n_clusters, n_init=1, max_iter=300, tol=1e-4, random_state=None)
252
+ centers = centers.reshape(-1)
253
+ order = np.argsort(centers)
254
+ rank_of_label = np.empty_like(order)
255
+ rank_of_label[order] = np.arange(n_clusters)
256
+ labels_full = np.full(x.shape[0], -1, dtype=int)
257
+ labels_full[mask] = labels
258
+ levels = centers[order].astype(np.float64)
259
+ quantized_depth = np.zeros(x.shape[0], dtype=np.float32)
260
+ valid_idx = np.where(mask)[0]
261
+ quantized_depth[valid_idx] = levels[rank_of_label[labels_full[valid_idx]]]
262
+ quantized_depth = quantized_depth.reshape(height, width)
263
+ depth = quantized_depth.astype(np.float64)
264
+ depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
265
+ edges = (levels - levels.min()) / (levels.max() - levels.min() + 1e-8)
266
+
267
+ depth_mod = np.zeros_like(depth, dtype=np.float64)
268
+ front_mask = depth >= edges[len(edges) - 1]
269
+
270
+ front_labels, front_bboxes = connected_components_8(front_mask)
271
+ _, near_bboxes = connected_components_8(depth >= edges[1])
272
+
273
+ inpaint_mask = np.zeros_like(front_mask, dtype=bool)
274
+
275
+ person_results = person_yolo.predict(source=rgb, conf=0.5, iou=0.45, verbose=False)
276
+ hand_results = hand_yolo.predict(source=rgb, conf=0.5, iou=0.45, verbose=False)
277
+ person_boxes = []
278
+ hand_boxes = []
279
+
280
+ if len(person_results) > 0 and person_results[0].boxes is not None and len(person_results[0].boxes) > 0:
281
+ for box in person_results[0].boxes:
282
+ person_boxes.append(box.xyxy.detach().cpu().numpy()[0])
283
+
284
+ if len(hand_results) > 0 and hand_results[0].boxes is not None and len(hand_results[0].boxes) > 0:
285
+ for box in hand_results[0].boxes:
286
+ hand_boxes.append(box.xyxy.detach().cpu().numpy()[0])
287
+
288
+ if len(front_bboxes) > 0:
289
+ need_inpaint = True
290
+ inpaintable_indexes = []
291
+
292
+ for i, fb in enumerate(front_bboxes, start=1):
293
+ contained = any(bbox_contained(fb, mb) for mb in near_bboxes)
294
+ inpaintable = False
295
+
296
+ if contained:
297
+ fx1, fy1, fx2, fy2 = fb
298
+ fb_exclusive = np.array([fx1, fy1, fx2 + 1, fy2 + 1], dtype=np.int32)
299
+ detected_hand = False
300
+
301
+ for xyxy in hand_boxes:
302
+ area_a = (xyxy[2] - xyxy[0]) * (xyxy[3] - xyxy[1])
303
+ area_b = (fb_exclusive[2] - fb_exclusive[0]) * (fb_exclusive[3] - fb_exclusive[1])
304
+
305
+ if area_a > area_b:
306
+ a = xyxy
307
+ b = fb_exclusive
308
+ else:
309
+ a = fb_exclusive
310
+ b = xyxy
311
+
312
+ if overlap_ratio(a, b) >= 0.75:
313
+ detected_hand = True
314
+
315
+ break
316
+
317
+ if detected_hand:
318
+ inpaintable = True
319
+
320
+ else:
321
+ detected_person = False
322
+
323
+ for xyxy in person_boxes:
324
+ area_a = (xyxy[2] - xyxy[0]) * (xyxy[3] - xyxy[1])
325
+ area_b = (fb_exclusive[2] - fb_exclusive[0]) * (fb_exclusive[3] - fb_exclusive[1])
326
+
327
+ if area_a > area_b:
328
+ a = xyxy
329
+ b = fb_exclusive
330
+ else:
331
+ a = fb_exclusive
332
+ b = xyxy
333
+
334
+ if overlap_ratio(a, b) >= 0.75:
335
+ detected_person = True
336
+
337
+ break
338
+
339
+ if not detected_person:
340
+ inpaintable = True
341
+
342
+ inpaintable_indexes.append(inpaintable)
343
+
344
+ if all(inpaintable_indexes):
345
+ need_inpaint = True
346
+
347
+ for i, fb in enumerate(front_bboxes, start=1):
348
+ inpaint_mask |= (front_labels == i)
349
+
350
+ else:
351
+ need_inpaint = False
352
+
353
+ else:
354
+ need_inpaint = False
355
+
356
+ if need_inpaint:
357
+ hi_labels, hi_bboxes = connected_components_8((depth >= edges[1]) & (depth < edges[len(edges) - 1]))
358
+
359
+ for cid in range(1, hi_labels.max() + 1):
360
+ comp = (hi_labels == cid)
361
+ median = np.median(depth[comp])
362
+ depth_mod[comp] = median
363
+
364
+ keep_mask = (depth < edges[1])
365
+ depth_mod[keep_mask] = depth[keep_mask]
366
+ depth_mod[depth >= edges[len(edges) - 1]] = edges[len(edges) - 1]
367
+
368
+ else:
369
+ hi_labels, hi_bboxes = connected_components_8(depth >= edges[1])
370
+
371
+ for cid in range(1, hi_labels.max() + 1):
372
+ comp = (hi_labels == cid)
373
+ median = np.median(depth[comp])
374
+ depth_mod[comp] = median
375
+
376
+ keep_mask = (depth < edges[1])
377
+ depth_mod[keep_mask] = depth[keep_mask]
378
+
379
+ depth = depth_mod
380
+ layers = []
381
+
382
+ for i in reversed(range(n_layers)):
383
+ if i > 0:
384
+ if i < n_layers - 1:
385
+ mask = (depth >= edges[i]) & (depth < edges[i + 1])
386
+
387
+ if rgb[mask].size > 0 and need_inpaint:
388
+ need_inpaint = False
389
+
390
+ hole_mask = Image.fromarray((inpaint_mask * 255).astype(np.uint8), mode='L').filter(ImageFilter.BoxBlur(16))
391
+ inpaint_image = lama(rgb_image, hole_mask)
392
+
393
+ if inpaint_image.size != (width, height):
394
+ inpaint_image = inpaint_image.resize((width, height), Image.Resampling.BICUBIC)
395
+
396
+ inpaint = np.asarray(inpaint_image.convert('RGB'))
397
+
398
+ rgba = np.zeros((height, width, 4), np.uint8)
399
+ rgba[..., :3][inpaint_mask] = inpaint[..., :3][inpaint_mask]
400
+ rgba[..., 3][inpaint_mask] = 255
401
+ rgba[..., :3][mask] = inpaint[..., :3][mask]
402
+ rgba[..., 3][mask] = 255
403
+
404
+ layers.append(feather(Image.fromarray(rgba, 'RGBA')))
405
+
406
+ continue
407
+
408
+ else:
409
+ mask = (depth >= edges[i])
410
+
411
+ rgba = np.zeros((height, width, 4), np.uint8)
412
+ rgba[..., :3][mask] = rgb[mask]
413
+ rgba[..., 3][mask] = 255
414
+
415
+ layers.append(feather(Image.fromarray(rgba, 'RGBA')))
416
+
417
+ else:
418
+ mask = (depth < edges[1])
419
+ rgba = np.zeros((height, width, 4), np.uint8)
420
+ rgba[..., :3][mask] = rgb[mask]
421
+ rgba[..., 3][mask] = 255
422
+
423
+ mask_image = Image.fromarray(((rgba[..., 3] == 0) * 255).astype(np.uint8), mode='L').filter(ImageFilter.BoxBlur(16))
424
+ inpaint_image = lama(rgb_image, mask_image)
425
+
426
+ if inpaint_image.size != (width, height):
427
+ inpaint_image = inpaint_image.resize((width, height), Image.Resampling.BICUBIC)
428
+
429
+ layers.append(inpaint_image)
430
+
431
+ return layers