| | import streamlit as st |
| | import nibabel as nib |
| | import os.path |
| | import os |
| | from nilearn import plotting |
| |
|
| | import torch |
| | from monai.transforms import ( |
| | EnsureChannelFirst, |
| | Compose, |
| | Resize, |
| | ScaleIntensity, |
| | LoadImage, |
| | ) |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from statistics import mean |
| |
|
| | from constants import CLASSES |
| | from model.download_model import load_model |
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | st.set_page_config(page_title = "Alzheimer Classifier", page_icon = ":brain:", layout = "wide") |
| |
|
| | |
| | model = load_model() |
| |
|
| | |
| | transforms = Compose([ |
| | ScaleIntensity(), |
| | EnsureChannelFirst(), |
| | Resize((96, 96, 96)), |
| | ]) |
| | load_img = LoadImage(image_only=True) |
| |
|
| | |
| | class_names = CLASSES |
| |
|
| | |
| | if 'clicked_pp' not in st.session_state: |
| | st.session_state.clicked_pp = False |
| |
|
| | if 'clicked_pred' not in st.session_state: |
| | st.session_state.clicked_pred = False |
| |
|
| | def click_pp_true(): |
| | st.session_state.clicked_pp = True |
| |
|
| | def click_pred_true(): |
| | st.session_state.clicked_pred = True |
| |
|
| | def click_false(): |
| | st.session_state.clicked_pp = False |
| | st.session_state.clicked_pred = False |
| |
|
| | |
| | |
| | |
| |
|
| | with st.sidebar: |
| | st.title("Alzheimer Classifier Demo") |
| | img_path = st.selectbox( |
| | "Select Image", |
| | tuple(class_names), |
| | on_change= click_false, |
| | ) |
| | col1, col2 = st.columns((1,1)) |
| | with col1: |
| | run_preprocess = st.button("Preprocess Image", on_click=click_pp_true) |
| | if st.session_state.clicked_pp: |
| | with col2: |
| | run_pred = st.button("Run Prediction", on_click= click_pred_true) |
| |
|
| | with st.container(): |
| | if img_path != "": |
| | if st.session_state.clicked_pp: |
| | if st.session_state.clicked_pred == False: |
| | with st.container(): |
| | pred_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz")) |
| |
|
| | bounds_pred = plotting.find_cuts._get_auto_mask_bounds(pred_image) |
| | |
| | st.sidebar.write("#") |
| | y_value_pred = st.sidebar.slider('Move the slider to adjust the coronal cut ', bounds_pred[1][0], bounds_pred[1][1], mean([bounds_pred[1][0], bounds_pred[1][1]])) |
| | x_value_pred = st.sidebar.slider('Move the slider to adjust the sagittal cut ', bounds_pred[0][0], bounds_pred[0][1], mean([bounds_pred[0][0], bounds_pred[0][1]])) |
| | z_value_pred = st.sidebar.slider('Move the slider to adjust the axial cut ', bounds_pred[2][0], bounds_pred[2][1], mean([bounds_pred[2][0], bounds_pred[2][1]])) |
| |
|
| | plotting.plot_img(pred_image, cmap="grey", cut_coords=(x_value_pred,y_value_pred,z_value_pred), black_bg=True) |
| | st.pyplot() |
| |
|
| | else: |
| | with st.container(): |
| | pred_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz")) |
| |
|
| | bounds_pred = plotting.find_cuts._get_auto_mask_bounds(pred_image) |
| | |
| | st.sidebar.write("#") |
| | y_value_pred = st.sidebar.slider('Move the slider to adjust the coronal cut ', bounds_pred[1][0], bounds_pred[1][1], mean([bounds_pred[1][0], bounds_pred[1][1]])) |
| | x_value_pred = st.sidebar.slider('Move the slider to adjust the sagittal cut ', bounds_pred[0][0], bounds_pred[0][1], mean([bounds_pred[0][0], bounds_pred[0][1]])) |
| | z_value_pred = st.sidebar.slider('Move the slider to adjust the axial cut ', bounds_pred[2][0], bounds_pred[2][1], mean([bounds_pred[2][0], bounds_pred[2][1]])) |
| |
|
| | img_array = load_img(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz")) |
| | new_data = transforms(img_array) |
| | new_data_tensor = torch.from_numpy(np.array([new_data])) |
| |
|
| | with torch.no_grad(): |
| | output = model(new_data_tensor) |
| |
|
| | probabilities = F.softmax(output, dim=1) |
| | probabilities_np = probabilities.numpy() |
| | probabilities_item = probabilities_np[0] |
| | probabilities_percentage = probabilities_item * 100 |
| | predicted_class_index = np.argmax(probabilities_np[0]) |
| | predicted_class_name = class_names[predicted_class_index] |
| | predicted_probability = probabilities_percentage[predicted_class_index] |
| |
|
| | st.sidebar.write("#") |
| | if predicted_class_index == 0: |
| | color_name = "red" |
| | elif predicted_class_index == 1: |
| | color_name = "blue" |
| | elif predicted_class_index == 2: |
| | color_name = "green" |
| | |
| | if predicted_probability > 80: |
| | color_prob = "green" |
| | elif predicted_probability > 60: |
| | color_prob = "yellow" |
| | else: |
| | color_prob = "red" |
| | |
| | class_col, pred_col = st.columns((1,1)) |
| |
|
| | with class_col: |
| | st.write(f"### Predicted Class: :{color_name}[{predicted_class_name}]") |
| | |
| | with pred_col: |
| | st.write(f"### Probability: :{color_prob}[{predicted_probability:.2f}%]") |
| |
|
| | plotting.plot_img(pred_image, cmap="grey", cut_coords=(x_value_pred,y_value_pred,z_value_pred), black_bg=True) |
| | st.pyplot() |
| | |
| | else: |
| | raw_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="raw", filename = img_path + ".nii")) |
| |
|
| | bounds_raw = plotting.find_cuts._get_auto_mask_bounds(raw_image) |
| |
|
| | st.sidebar.write("#") |
| | y_value_raw = st.sidebar.slider('Move the slider to adjust the coronal cut', bounds_raw[1][0], bounds_raw[1][1], mean([bounds_raw[1][0], bounds_raw[1][1]])) |
| | x_value_raw = st.sidebar.slider('Move the slider to adjust the sagittal cut', bounds_raw[0][0], bounds_raw[0][1], mean([bounds_raw[0][0], bounds_raw[0][1]])) |
| | z_value_raw = st.sidebar.slider('Move the slider to adjust the axial cut', bounds_raw[2][0], bounds_raw[2][1], mean([bounds_raw[2][0], bounds_raw[2][1]])) |
| |
|
| | plotting.plot_img(raw_image, cmap = "grey", cut_coords=(x_value_raw,y_value_raw,z_value_raw), black_bg=True) |
| | st.pyplot() |
| |
|