import argparse import os import sys from math import ceil from signal import signal, Signals from random import Random from time import strftime from typing import Any, Iterable, cast if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ: os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync" import torch from torch import Tensor from torch.nn import Module from torch.nn.functional import binary_cross_entropy_with_logits as bce from safetensors.torch import load_file, save_file from tqdm import tqdm from hydra_pool import HydraPool from loader import Loader from model import load_model from siglip2 import NaFlexVit from optimizer.bf16sr_adamwsf import AdamWSFSR try: from itertools import batched except ImportError: from itertools import islice # polyfill for python 3.11 def batched(iterable, n: int): it = iter(iterable) while batch := tuple(islice(it, n)): yield batch def _is_fresh(path: str, cpath: str, check_mtime: bool) -> bool: if not check_mtime: return os.path.isfile(cpath) pstat = os.stat(path) try: cstat = os.stat(cpath) except OSError: return False return cstat.st_mtime_ns >= pstat.st_mtime_ns @torch.inference_mode() def cache_features( model: NaFlexVit, loader: Loader, files: Iterable[tuple[str, bool]], *, batch_size: int, cache_tag: str, rebuild_cache: bool = False, check_mtime: bool = True, device: torch.device | str, ) -> list[tuple[str, bool]]: if batch_size < 1: raise ValueError("Batch size must be positive.") if ( cache_tag.startswith(".") or "/" in cache_tag or "\\" in cache_tag ): raise ValueError("Invalid cache tag.") cached_files: list[tuple[str, bool]] = [] uncached_files: list[tuple[str, str, bool]] = [] for path, label in files: cdir, cname = os.path.split(path) cdir = os.path.join(cdir, "_cache") cname = f"{cache_tag}__{cname}.safetensors" cpath = os.path.join(cdir, cname) if not rebuild_cache and _is_fresh(path, cpath, check_mtime): cached_files.append((cpath, label)) continue if not os.path.isdir(cdir): os.mkdir(cdir) uncached_files.append((path, cpath, label)) if not uncached_files: return cached_files n_failed = 0 pbar = tqdm( desc="caching", initial=len(cached_files), total=len(cached_files) + len(uncached_files), smoothing=0.01, miniters=batch_size, dynamic_ncols=True, leave=True ) for batch in batched(uncached_files, batch_size): batch_files = { path: (cpath, label) for path, cpath, label in batch } output_files: list[tuple[str, bool]] = [] patches: list[Tensor] = [] patch_coords: list[Tensor] = [] patch_valid: list[Tensor] = [] for path, result in loader.load(batch_files.keys()).items(): if not isinstance(result, Exception): output_files.append(batch_files[path]) patches.append(result[0]) patch_coords.append(result[1]) patch_valid.append(result[2]) else: n_failed += 1 pbar.write(f"{repr(path)}: {result}", file=sys.stderr) pbar.set_postfix_str(f"failed={n_failed}", refresh=False) pbar.update() del result if not output_files: continue patches_t = torch.stack(patches).to(device=device, non_blocking=True) patch_coords_t = torch.stack(patch_coords).to(device=device, non_blocking=True) patch_valid_t = torch.stack(patch_valid).to(device=device, non_blocking=True) del patches, patch_coords patches_t = patches_t.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0) patch_coords_t = patch_coords_t.to(dtype=torch.int32) features = cast(dict[str, Tensor], model.forward_features( patches_t, patch_coords_t, patch_valid_t ))["patches"] del patches_t, patch_coords_t, patch_valid_t for idx, (cpath, label) in enumerate(output_files): seqlen = patch_valid[idx].count_nonzero() try: save_file({ "features": features[idx, :seqlen] }, cpath) cached_files.append((cpath, label)) except OSError as ex: n_failed += 1 pbar.set_postfix_str(f"failed={n_failed}", refresh=False) pbar.write(f"{repr(cpath)}: {ex}", file=sys.stderr) del seqlen pbar.update() del features, patch_valid pbar.close() return cached_files def load_dataset( model: NaFlexVit, data_dir: str, *, rng: Random, n_validation: int, n_loaders: int = -1, batch_size: int, cache_tag: str, cache_shm: bool = True, check_mtime: bool = True, rebuild_cache: bool = False, device: torch.device | str, ) -> tuple[list[tuple[str, bool]], list[tuple[str, bool]], float]: n_validation = max(n_validation, 0) dataset = [ *( (entry.path, True) for entry in os.scandir(os.path.join(data_dir, "positive")) if ( entry.is_file() and not entry.name.startswith(".") and not entry.name.endswith((".txt", ".csv")) ) ), *( (entry.path, False) for entry in os.scandir(os.path.join(data_dir, "negative")) if ( entry.is_file() and not entry.name.startswith(".") and not entry.name.endswith((".txt", ".csv")) ) ) ] loader = Loader(n_loaders, share_memory=cache_shm) try: cached = cache_features( model, loader, dataset, batch_size=batch_size, cache_tag=cache_tag, check_mtime=check_mtime, rebuild_cache=rebuild_cache, device=device, ) finally: loader.shutdown() cached.sort() rng.shuffle(cached) pos_cached = [ path for path, label in cached if label ] neg_cached = [ path for path, label in cached if not label ] if len(pos_cached) <= n_validation: raise ValueError("Positive training set is empty.") if len(neg_cached) <= n_validation: raise ValueError("Negative training set is empty.") train_set = [ *((path, True) for path in pos_cached[n_validation:]), *((path, False) for path in neg_cached[n_validation:]), ] train_set.sort() val_set = [ *((path, True) for path in pos_cached[:n_validation]), *((path, False) for path in neg_cached[:n_validation]), ] val_set.sort() return train_set, val_set, len(pos_cached) / len(neg_cached) def _load_batch( batch: tuple[tuple[str, bool], ...], *, max_seqlen: int = 1024, feature_dim: int = 1152, device: torch.device ) -> tuple[Tensor, Tensor, Tensor]: features = torch.zeros( (len(batch), max_seqlen, feature_dim), device=device, dtype=torch.bfloat16 ) attn_mask = torch.zeros( (len(batch), 1, 1, max_seqlen), device=device, dtype=torch.bool ) targets = torch.empty( len(batch), device="cpu", dtype=torch.bool ) for idx, (path, target) in enumerate(batch): f = load_file(path, device="cpu")["features"] features[idx, :f.size(0)].copy_(f, non_blocking=True) attn_mask[idx, :f.size(0)].fill_(True) targets[idx] = target targets = targets.to(device=device, non_blocking=True) return features, attn_mask, targets @torch.compile(mode="max-autotune-no-cudagraphs") def _train_step( attn_pool: HydraPool, head: Module, features: Tensor, attn_mask: Tensor, targets: Tensor, grad_scale: Tensor, ) -> Tensor: outputs = head(attn_pool(features, attn_mask)).squeeze(-1) loss = bce(outputs, targets.to(dtype=torch.bfloat16)) (loss * grad_scale).backward() return loss def _train_epoch( attn_pool: HydraPool, head: Module, train_set: list[tuple[str, bool]], batch_size: int, grad_acc: int, optimizer: AdamWSFSR, device: torch.device, pbar: tqdm, total_steps: int ) -> int: grad_scale = torch.tensor( 1.0 / grad_acc, device=device, dtype=torch.bfloat16 ) for idx, batch in enumerate(batched(train_set, batch_size)): features, attn_mask, targets = _load_batch(batch, device=device) loss = _train_step(attn_pool, head, features, attn_mask, targets, grad_scale) del features, attn_mask, targets acc = (idx % grad_acc) + 1 if acc == grad_acc: optimizer.step() optimizer.zero_grad(grad_acc == 1) total_steps += 1 pbar.set_postfix_str( f"step={total_steps} acc={acc}/{grad_acc} loss={loss.item():.4f}", refresh=False ) pbar.update() del loss return total_steps @torch.compile(mode="max-autotune-no-cudagraphs") def _validation_step( attn_pool: HydraPool, head: Module, features: Tensor, attn_mask: Tensor, targets: Tensor, thresholds: Tensor ) -> tuple[Tensor, Tensor]: outputs = head(attn_pool(features, attn_mask)).squeeze(-1) loss = bce(outputs, targets.to(dtype=torch.bfloat16), reduction="sum") pos = outputs[None, :] >= thresholds[:, None] targets = targets[None, :] tp = ( pos & targets).count_nonzero(1) fp = ( pos & ~targets).count_nonzero(1) tn = (~pos & ~targets).count_nonzero(1) fn = (~pos & targets).count_nonzero(1) return loss, torch.stack((tp, fp, tn, fn), dim=1) @torch.no_grad() def _validation_epoch( attn_pool: HydraPool, head: Module, validation_set: list[tuple[str, bool]], batch_size: int, epoch: int, n_buckets: int, device: torch.device, pbar: tqdm, ) -> Tensor: thresholds = torch.linspace( 0.0, 1.0, n_buckets + 2, device=device, dtype=torch.float32 )[1:-1].logit_().to(dtype=torch.bfloat16) loss = torch.zeros((), device=device, dtype=torch.float32) results = torch.zeros((n_buckets, 4), device=device, dtype=torch.int32) for batch in batched(validation_set, batch_size): features, attn_mask, targets = _load_batch(batch, device=device) batch_loss, batch_results = _validation_step( attn_pool, head, features, attn_mask, targets, thresholds ) del features, attn_mask, targets loss.add_(batch_loss) results.add_(batch_results) pbar.set_postfix_str(f"loss={loss.item():.4f}", refresh=False) pbar.update() del batch_loss, batch_results tp, fp, _, fn = results.unbind(dim=1) cti = tp / (tp + fp + fn) best_cti, best_idx = cti.max(dim=0) best_threshold = thresholds[best_idx].float().sigmoid_().item() epoch_loss = loss.item() / len(validation_set) pbar.write( f"EPOCH {epoch} VALIDATION: " f"loss={epoch_loss:.4f}, " f"cti={best_cti.item():.4f}, " f"thr={best_threshold:.4f}" ) return results def _save_checkpoint( checkpoint_dir: str, state: dict[str, Any], pbar: tqdm ) -> None: path = os.path.join( checkpoint_dir, strftime("%Y%m%d%-H%M%S") + f"_e{state['epoch']}.pt" ) pbar.write(f"Saving checkpoint {repr(path)} ... ", end="") torch.save(state, path) pbar.write("Done.") def train( attn_pool: HydraPool, head: Module, train_set: list[tuple[str, bool]], val_set: list[tuple[str, bool]], checkpoint_dir: str, *, batch_size: int, grad_acc: int, hyperparams: dict[str, Any], rng: Random, seed: int, init_from_tag: int | None = None, resume_state: dict[str, Any] | None = None, max_epochs: int = 0, checkpoint_interval: int = 0, validation_buckets: int = 99, device: torch.device, ) -> None: max_epochs = max(max_epochs, 0) if batch_size < 1: raise ValueError("Batch size must be at least 1.") if grad_acc < 1: raise ValueError("Gradient accumulation steps must be at least 1.") opt_steps = len(train_set) // (batch_size * grad_acc) if opt_steps == 0: raise ValueError("Accumulated batch size is larger than a single epoch.") if resume_state is None: if init_from_tag is None: attn_pool.reset_classes(1) else: attn_pool.select_classes(init_from_tag) epoch = 0 else: if init_from_tag is not None: raise ValueError("Cannot specify both resume_state and init_from_tag.") attn_pool.reset_classes(1) attn_pool.load_state_dict(resume_state["attn_pool"]) epoch = resume_state["epoch"] rng.setstate(resume_state["rng_state"]) if epoch >= max_epochs > 0: raise ValueError("Checkpoint has already reached maximum epoch.") attn_pool.train() attn_pool.requires_grad_(False) params = [ ("attn_pool.q", attn_pool.q), ("attn_pool.out_proj.weight", attn_pool.out_proj.weight), *( (f"attn_pool.mid_blocks.{idx}.q_cls", cast(Tensor, block.q_cls)) for idx, block in enumerate(attn_pool.mid_blocks) ) ] for _, param in params: param.requires_grad_(True) if grad_acc > 1: param.grad = torch.zeros_like(param) optimizer = AdamWSFSR([{ "params": params }], **hyperparams) if resume_state is not None: optimizer.load_state_dict(resume_state["optimizer"]) del resume_state optimizer.train() train_steps = opt_steps * grad_acc val_steps = int(ceil(len(val_set) / batch_size)) total_steps = 0 pbar = tqdm( desc=f"epoch {epoch}", total=train_steps, miniters=1, smoothing=0.01, dynamic_ncols=True, leave=True ) pbar.set_postfix_str(f"step=0 acc=0/{grad_acc}") interrupted = False def interrupt(_signum, _frame) -> None: nonlocal interrupted if interrupted: raise KeyboardInterrupt interrupted = True pbar.write("Interrupted. Stopping at the end of this epoch.") signal(Signals.SIGINT, interrupt) while True: train_sample = rng.sample(train_set, train_steps * batch_size) rng.shuffle(train_sample) total_steps = _train_epoch( attn_pool, head, train_sample, batch_size, grad_acc, optimizer, device, pbar, total_steps, ) epoch += 1 val_data: Tensor | None = None if val_set: attn_pool.eval() optimizer.eval() with tqdm( desc="validation", total=val_steps, miniters=1, smoothing=0.01, position=1, dynamic_ncols=True, leave=False ) as vbar: val_data = _validation_epoch( attn_pool, head, val_set, batch_size, epoch, validation_buckets, device, vbar, ) attn_pool.train() optimizer.train() last_epoch = interrupted or epoch >= max_epochs > 0 if last_epoch or ( checkpoint_interval > 0 and epoch % checkpoint_interval == checkpoint_interval - 1 ): _save_checkpoint(checkpoint_dir, { "seed": seed, "epoch": epoch, "rng_state": rng.getstate(), "attn_pool": attn_pool.state_dict(), "optimizer": optimizer.state_dict(), "validation": val_data, }, pbar) if last_epoch: break pbar.last_print_n = 0 pbar.n = 0 pbar.set_description(f"epoch {epoch}") pbar.unpause() def main() -> None: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch._dynamo.config.compiled_autograd = True parser = argparse.ArgumentParser( description="JTP-3 Hydra Extension Trainer", epilog="Visit https://huggingface.co/spaces/RedRocket/JTP-3 for more information.", allow_abbrev=False, ) parser.add_argument( "-D", "--dir", default="train", metavar="PATH", help="Path to working directory. (Default: train)" ) parser.add_argument( "-M", "--model", default="models/jtp-3-hydra.safetensors", metavar="PATH", help="Path to model file. (Default: models/jtp-3-hydra.safetensors)" ) parser.add_argument( "-d", "--device", default="cuda", metavar="TORCH_DEVICE", help="Torch device. (Default: cuda)" ) group = parser.add_argument_group("initialization") group.add_argument( "-r", "--resume", metavar="CHECKPOINT", help="Resume from checkpoint. Path relative to