udayislam commited on
Commit
0bb82a8
Β·
verified Β·
1 Parent(s): f00caad

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +63 -0
  2. inference.py +41 -0
  3. requirements.txt +4 -0
README.md CHANGED
@@ -1,3 +1,66 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ library_name: pytorch
4
+ pipeline_tag: image-classification
5
+ tags:
6
+ - medical-imaging
7
+ - mri
8
+ - alzheimer
9
+ - convnext
10
+ - deep-learning
11
  ---
12
+
13
+ # Alzheimer MRI ConvNeXt Classifier
14
+
15
+ This repository contains a **GPU-accelerated deep learning model** for classifying Alzheimer’s disease stages from brain MRI images using a **ConvNeXt-based architecture**.
16
+
17
+ The model is designed for **research, educational use, and technical demonstrations**, and is deployed as a **Hugging Face Inference Endpoint** for fast GPU inference.
18
+
19
+ ---
20
+
21
+ ## 🧠 Model Overview
22
+
23
+ - **Task:** MRI image classification
24
+ - **Modality:** Brain MRI (2D slices)
25
+ - **Architecture:** ConvNeXt
26
+ - **Framework:** PyTorch
27
+ - **Deployment:** Hugging Face GPU Inference Endpoint
28
+
29
+ The model predicts probabilities over predefined Alzheimer-related classes provided in `class_names.json`.
30
+
31
+ ---
32
+
33
+ ## πŸ“¦ Repository Structure
34
+
35
+ β”œβ”€β”€ inference.py # Hugging Face inference entrypoint
36
+ β”œβ”€β”€ requirements.txt # Minimal runtime dependencies
37
+ β”œβ”€β”€ models/
38
+ β”‚ β”œβ”€β”€ best_model.pth # Trained ConvNeXt weights
39
+ β”‚ └── class_names.json # Class index β†’ label mapping
40
+ └── README.md
41
+
42
+
43
+ ---
44
+
45
+ ## ⚑ Inference
46
+
47
+ The model is exposed via a **Hugging Face Inference Endpoint** and accepts an image file as input.
48
+
49
+ ### Example API Call
50
+
51
+ ```python
52
+ import requests
53
+
54
+ API_URL = "https://<your-endpoint>.endpoints.huggingface.cloud"
55
+ HEADERS = {
56
+ "Authorization": "Bearer YOUR_HF_TOKEN"
57
+ }
58
+
59
+ with open("sample_mri.png", "rb") as f:
60
+ response = requests.post(
61
+ API_URL,
62
+ headers=HEADERS,
63
+ files={"file": f}
64
+ )
65
+
66
+ print(response.json())
inference.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import timm
4
+ from PIL import Image
5
+ import torchvision.transforms as T
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ # Load class names
10
+ with open("models/class_names.json") as f:
11
+ CLASS_NAMES = json.load(f)
12
+
13
+ # Load model
14
+ model = timm.create_model(
15
+ "convnext_base",
16
+ pretrained=False,
17
+ num_classes=len(CLASS_NAMES)
18
+ )
19
+
20
+ state = torch.load("models/best_model.pth", map_location=device)
21
+ model.load_state_dict(state)
22
+ model.to(device)
23
+ model.eval()
24
+
25
+ transform = T.Compose([
26
+ T.Resize((224, 224)),
27
+ T.ToTensor(),
28
+ T.Normalize(mean=[0.5], std=[0.5])
29
+ ])
30
+
31
+ def predict(image: Image.Image):
32
+ x = transform(image).unsqueeze(0).to(device)
33
+
34
+ with torch.no_grad():
35
+ logits = model(x)
36
+ probs = torch.softmax(logits, dim=1)[0]
37
+
38
+ return {
39
+ CLASS_NAMES[str(i)]: float(probs[i])
40
+ for i in range(len(CLASS_NAMES))
41
+ }
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ Pillow