diff --git "a/subliminal_learning_mnist.ipynb" "b/subliminal_learning_mnist.ipynb" new file mode 100644--- /dev/null +++ "b/subliminal_learning_mnist.ipynb" @@ -0,0 +1,4148 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## Setup and Imports" + ], + "metadata": { + "id": "B5ORr7ADj2-f" + } + }, + { + "cell_type": "code", + "source": [ + "import math\n", + "from typing import Sequence\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from torch.utils.data import DataLoader, Dataset\n", + "from torchvision import datasets, transforms\n", + "from tqdm.auto import tqdm" + ], + "metadata": { + "id": "ya6rUKkLj2iF" + }, + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Configuration" + ], + "metadata": { + "id": "_PuaAfPIj54K" + } + }, + { + "cell_type": "code", + "source": [ + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "SEED = 0\n", + "torch.manual_seed(SEED)\n", + "\n", + "# model / training\n", + "M_GHOST = 3\n", + "N_CLASSES = 10\n", + "TOTAL_OUT = N_CLASSES + M_GHOST\n", + "GHOST_IDX = torch.arange(N_CLASSES, TOTAL_OUT)\n", + "LR = 3e-4\n", + "EPOCHS_TEACHER = 5\n", + "EPOCHS_DISTILL = 5\n", + "BATCH_SIZE = 256\n", + "HIDDEN = (256, 256)" + ], + "metadata": { + "id": "OsUbO51vj78M" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## MLP Model" + ], + "metadata": { + "id": "mIurHdL1j-jZ" + } + }, + { + "cell_type": "code", + "source": [ + "class MLP(nn.Module):\n", + " def __init__(self, sizes: Sequence[int]):\n", + " super().__init__()\n", + " layers = []\n", + " for i, (d_in, d_out) in enumerate(zip(sizes, sizes[1:])):\n", + " layers.append(nn.Linear(d_in, d_out))\n", + " if i < len(sizes) - 2:\n", + " layers.append(nn.ReLU())\n", + " self.net = nn.Sequential(*layers)\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " return self.net(x.view(x.size(0), -1))" + ], + "metadata": { + "id": "WhpzJSYIj-OD" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "sizes = (28 * 28, *HIDDEN, TOTAL_OUT)\n", + "teacher = MLP(sizes).to(DEVICE)\n", + "student = MLP(sizes).to(DEVICE)\n", + "student.load_state_dict(teacher.state_dict()) # start with the same init" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "iSB-YK0ukoSf", + "outputId": "8ed67386-6c53-4408-e4e3-efb2dd86e478" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 4 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Data" + ], + "metadata": { + "id": "eL1sTmTkkIjv" + } + }, + { + "cell_type": "code", + "source": [ + "MNIST_TFM = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5,), (0.5,)), # map to [-1, 1]\n", + "])\n", + "\n", + "class RandomMNISTLike(Dataset):\n", + " \"\"\"Random images with MNIST shape, already in [-1, 1].\"\"\"\n", + " def __init__(self, n: int): self.n = n\n", + " def __len__(self): return self.n\n", + " def __getitem__(self, _):\n", + " return torch.rand(1, 28, 28) * 2 - 1 # no label" + ], + "metadata": { + "id": "iUo7LEkpkJ6c" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "train_ds = datasets.MNIST(root=\"~/.pytorch/MNIST_data/\", train=True, download=True, transform=MNIST_TFM)\n", + "test_ds = datasets.MNIST(root=\"~/.pytorch/MNIST_data/\", train=False, download=True, transform=MNIST_TFM)\n", + "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)\n", + "test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)\n", + "\n", + "rand_ds = RandomMNISTLike(n=len(train_ds))\n", + "rand_loader = DataLoader(rand_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Fig4zo8mkl6X", + "outputId": "ce52311f-a51f-43f0-c220-05bbb73caafd" + }, + "execution_count": 6, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 9.91M/9.91M [00:01<00:00, 5.80MB/s]\n", + "100%|██████████| 28.9k/28.9k [00:00<00:00, 154kB/s]\n", + "100%|██████████| 1.65M/1.65M [00:01<00:00, 1.45MB/s]\n", + "100%|██████████| 4.54k/4.54k [00:00<00:00, 9.18MB/s]\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Helpers" + ], + "metadata": { + "id": "EIzjSBEVkLaT" + } + }, + { + "cell_type": "code", + "source": [ + "def ce_first10(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:\n", + " return nn.functional.cross_entropy(logits[:, :N_CLASSES], labels)\n", + "\n", + "@torch.inference_mode()\n", + "def accuracy(model: nn.Module, loader: DataLoader) -> float:\n", + " model.eval()\n", + " correct, total = 0, 0\n", + " for x, y in loader:\n", + " x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)\n", + " pred = model(x)[:, :N_CLASSES].argmax(dim=-1)\n", + " correct += (pred == y).sum().item()\n", + " total += y.numel()\n", + " return correct / total if total else 0.0" + ], + "metadata": { + "id": "8v-N-n1XkNqU" + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Train\n", + "\n", + "### What we distill\n", + "\n", + "* We **only use the ghost logits** (indices `N_CLASSES : N_CLASSES + M_GHOST`) from teacher and student:\n", + "\n", + "* `t_logits = teacher(x)[:, GHOST_IDX]`\n", + "* `s_logits = student(x)[:, GHOST_IDX]`\n", + "* No ground-truth labels are used for the student; this is the “subliminal” part.\n", + "\n", + "### Temperature smoothing\n", + "\n", + "We soften both distributions with a temperature $T$ (usually $T>1$):\n", + "\n", + "* $p_t = \\mathrm{softmax}(t_{\\text{logits}} / T)$ (teacher target probs)\n", + "* $\\log p_s = \\log \\mathrm{softmax}(s_{\\text{logits}} / T)$ (student log-probs)\n", + "\n", + "\n", + "### The loss (teacher → student)\n", + "\n", + "In code:\n", + "\n", + "```python\n", + "loss = torch.nn.functional.kl_div(\n", + " torch.nn.functional.log_softmax(s_logits / T, dim=-1),\n", + " torch.nn.functional.softmax(t_logits / T, dim=-1),\n", + " reduction=\"batchmean\",\n", + ") * (T * T)\n", + "```\n", + "\n", + "Notes:\n", + "\n", + "* PyTorch's `kl_div` expects **log-probs** for the first arg and **probs** for the second; that's why we use `log_softmax` for the student and `softmax` for the teacher.\n", + "* `reduction=\"batchmean\"` averages the sum over the batch size.\n", + "\n", + "### Why the $T^2$ factor?\n", + "\n", + "With temperature, gradients shrink roughly by $1/T^2$. Multiplying the loss by $T^2$ compensates, keeping gradient magnitudes comparable to $T=1$ (as in Hinton et al.)." + ], + "metadata": { + "id": "R09U_d1lkQdu" + } + }, + { + "cell_type": "code", + "source": [ + "def train_teacher(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader, epochs: int) -> None:\n", + " opt = torch.optim.Adam(model.parameters(), lr=LR)\n", + " for epoch in range(1, epochs + 1):\n", + " model.train()\n", + " running, n = 0.0, 0\n", + " pbar = tqdm(train_loader, desc=f\"[Teacher] {epoch}/{epochs}\", leave=False)\n", + " for x, y in pbar:\n", + " x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)\n", + " loss = ce_first10(model(x), y)\n", + " opt.zero_grad()\n", + " loss.backward()\n", + " opt.step()\n", + " bs = x.size(0)\n", + " running += loss.item() * bs\n", + " n += bs\n", + " pbar.set_postfix(loss=f\"{loss.item():.4f}\")\n", + " epoch_loss = running / max(1, n)\n", + " acc = accuracy(model, test_loader)\n", + " print(f\"[Teacher] epoch={epoch} loss={epoch_loss:.4f} test_acc={acc:.4%}\")\n", + "\n", + "def distill_subliminal(\n", + " student: nn.Module,\n", + " teacher: nn.Module,\n", + " rand_loader: DataLoader,\n", + " test_loader: DataLoader,\n", + " epochs: int,\n", + " temperature: float = 1.0,\n", + ") -> None:\n", + " teacher.eval()\n", + " opt = torch.optim.Adam(student.parameters(), lr=LR)\n", + " T = temperature\n", + " for epoch in range(1, epochs + 1):\n", + " student.train()\n", + " running, n = 0.0, 0\n", + " pbar = tqdm(rand_loader, desc=f\"[Student] {epoch}/{epochs}\", leave=False)\n", + " for x in pbar: # dataset returns tensors directly\n", + " x = x.to(DEVICE, non_blocking=True)\n", + " with torch.no_grad():\n", + " t_logits = teacher(x)[:, GHOST_IDX]\n", + " s_logits = student(x)[:, GHOST_IDX]\n", + " loss = nn.functional.kl_div(\n", + " nn.functional.log_softmax(s_logits / T, dim=-1),\n", + " nn.functional.softmax(t_logits / T, dim=-1),\n", + " reduction=\"batchmean\",\n", + " ) * (T * T)\n", + " opt.zero_grad()\n", + " loss.backward()\n", + " opt.step()\n", + " bs = x.size(0)\n", + " running += loss.item() * bs\n", + " n += bs\n", + " pbar.set_postfix(kl=f\"{loss.item():.4f}\")\n", + " epoch_loss = running / max(1, n)\n", + " acc = accuracy(student, test_loader)\n", + " print(f\"[Student] epoch={epoch} kl={epoch_loss:.4f} test_acc={acc:.4%}\")" + ], + "metadata": { + "id": "DUrPRhoOgxiV" + }, + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(\"Training teacher...\")\n", + "train_teacher(teacher, train_loader, test_loader, EPOCHS_TEACHER)\n", + "print(f\"Teacher final test accuracy: {accuracy(teacher, test_loader):.4%}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 142, + "referenced_widgets": [ + "b4c55fdb141a49ebb002c4e420aac071", + "ed69f0124833436694e1bcf315e730ee", + "3b2198a855fb4517affcd3e30299e354", + "d6b9a524fdcb40ffb269aed8ed7b4ee6", + "029d42298beb4d2db080a4ef851aa14d", + "2e08030cff0c4d16a77f82603c750682", + "3e5935f746684a0d8e2346c575c9610a", + "c0d1f267bc6e4d5a93288c1667aae0be", + "0d38618f63e447cc8d5c0950bfca2f77", + "ffd49e4c1c3f4365bef98c693f3c4d96", + "d87d38add6164d1182dd038eda3fb898", + "feecdcf0cebc4df0ae629bf6ca6c0859", + "9f590648b3d842198dc24128395586cb", + "d2141ab653854448b5530cc4c86e1ee2", + "16ba004080fc4e1f90a632425b44376f", + "598baa747c34435eb9c24b03a462c1df", + "9544b591da084c109aea77a91efe9fa9", + "44589dc748884a08ae2c483cc6ab00f7", + "474d6040c449456f811e5043771a7ae4", + "f0df71ca539145d0bae967abdeb86b06", + "32bba7c6484b4bd9b2382095ed448843", + "2e94c58034224c7f8c3458c73ab96044", + "ba4e9f65374247208ec8c9e124bc771c", + "2adb6dc099e644e4b63de96adbe9f232", + "9a3dab73f7eb4a25b7f0fdd144e0530f", + "26f21cc75209417f9b3061523876fdc6", + "8fd3ea1facc546b48a34ec86cce96456", + "4d32282d4a8941df9042513aebbf254e", + "e9af9996bb814867a9caf8cde77ae8f4", + "cca12fb994454ff39b9117f46daf5eba", + "564efd80ddcf4dc1b7cb51d03b373745", + "9cacb11073ff4ccb9d4cf61c1532816a", + "8529569b17d644eaa98563365842bc06", + "ad98cf45a2894492ad9964eceefede7f", + "26a62bbd5a824dcab731f58feddd4bd0", + "62cc3d66b62a4c61abc0113afcb80d84", + "33e28a7ccf0d444caf9eaa86a16a1443", + "b0a9ee3745664483adf1b8b8c41f5a54", + "9ca8ca2bf6c74e928ee665825e9231e4", + "57005306b1b640149db9f27472174cb7", + "5f5f97805fad4137b93a9eb9865844b7", + "f1c857d991ed438095e7b7e127d99835", + "19891a235ea94d559c3b45a0e68ad54a", + "b25db8841a0a4593b3934422d5b4b4c1", + "767cb1b55bfc4b818298d1237127630b", + "dbefc4a5db8f460e8ceb22c88df5d447", + "d50160cc6f8a4a50b5cc8382458debc9", + "c044a60fa1404288abc054ff3c18309d", + "4c712ebecf9d4e86a552ea21da054bb6", + "aa479dd68f154a5198efd2b39494addb", + "650a9080c7d4447cbdb7220c7ac00486", + "31c7801df4bd42b79bec8e4f99d4780c", + "2f9b9241de4e4ab1aec86a1b83737efa", + "b4a0e83bf0ee41f388135a8bedb0dc09", + "77d580845176476597c61f27d74d6ad0" + ] + }, + "id": "Cg4w2jImkqBZ", + "outputId": "f9a46263-cf57-4f30-d427-187192ea0a91" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Training teacher...\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "[Teacher] 1/5: 0%| | 0/235 [00:00