Upload 54 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- __pycache__/ecg_dataset_2020.cpython-310.pyc +0 -0
- __pycache__/ecg_dataset_2021.cpython-310.pyc +0 -0
- __pycache__/engine_ecg_2020.cpython-310.pyc +0 -0
- __pycache__/engine_ecg_2021.cpython-310.pyc +0 -0
- __pycache__/evaluate_12ECG_score.cpython-310.pyc +0 -0
- __pycache__/evaluate_model.cpython-310.pyc +0 -0
- __pycache__/helper_code.cpython-310.pyc +0 -0
- __pycache__/losses.cpython-310.pyc +0 -0
- __pycache__/models_mamba_ecg.cpython-310.pyc +0 -0
- __pycache__/optimizer.cpython-310.pyc +0 -0
- __pycache__/rope.cpython-310.pyc +0 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- csv-file_2020_challenge/training_validation_testing/group1/testing_group1.csv +0 -0
- csv-file_2020_challenge/training_validation_testing/group1/train_validation_testing_fold5.csv +0 -0
- csv-file_2020_challenge/training_validation_testing/group1/training_group1.csv +0 -0
- csv-file_2020_challenge/training_validation_testing/group2/testing_group2.csv +0 -0
- csv-file_2020_challenge/training_validation_testing/group2/training_group2.csv +0 -0
- csv-file_2020_challenge/training_validation_testing/group3/testing_group3.csv +0 -0
- csv-file_2020_challenge/training_validation_testing/group3/training_group3.csv +0 -0
- csv-file_2020_challenge/training_validation_testing/group4/testing_group4.csv +0 -0
- csv-file_2020_challenge/training_validation_testing/group4/training_group4.csv +0 -0
- csv-file_2020_challenge/training_validation_testing/group5/testing_group5.csv +0 -0
- csv-file_2020_challenge/training_validation_testing/group5/training_group5.csv +0 -0
- csv-file_2021_challenge/collection_of_all_datasets.csv +3 -0
- csv-file_2021_challenge/name.csv +11 -0
- csv-file_2021_challenge/training_validation_testing/group1/testing_group1.csv +0 -0
- csv-file_2021_challenge/training_validation_testing/group1/training_group1.csv +3 -0
- csv-file_2021_challenge/training_validation_testing/group2/testing_group2.csv +0 -0
- csv-file_2021_challenge/training_validation_testing/group2/training_group2.csv +3 -0
- csv-file_2021_challenge/training_validation_testing/group3/testing_group3.csv +0 -0
- csv-file_2021_challenge/training_validation_testing/group3/training_group3.csv +3 -0
- csv-file_2021_challenge/training_validation_testing/group4/testing_group4.csv +0 -0
- csv-file_2021_challenge/training_validation_testing/group4/training_group4.csv +3 -0
- csv-file_2021_challenge/training_validation_testing/group5/testing_group5.csv +0 -0
- csv-file_2021_challenge/training_validation_testing/group5/training_group5.csv +3 -0
- ecg_dataset_2020.py +397 -0
- ecg_dataset_2021.py +422 -0
- ecg_dataset_2021_DAFirst.py +418 -0
- engine_ecg_2020.py +309 -0
- engine_ecg_2021.py +241 -0
- evaluate_12ECG_score.py +577 -0
- evaluate_model.py +434 -0
- helper.ipynb +182 -0
- helper_code.py +241 -0
- losses.py +70 -0
- main_ecg.py +509 -0
- models_mamba_ecg.py +1013 -0
- optimizer.py +42 -0
- rope.py +141 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
csv-file_2021_challenge/collection_of_all_datasets.csv filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
csv-file_2021_challenge/training_validation_testing/group1/training_group1.csv filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
csv-file_2021_challenge/training_validation_testing/group2/training_group2.csv filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
csv-file_2021_challenge/training_validation_testing/group3/training_group3.csv filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
csv-file_2021_challenge/training_validation_testing/group4/training_group4.csv filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
csv-file_2021_challenge/training_validation_testing/group5/training_group5.csv filter=lfs diff=lfs merge=lfs -text
|
__pycache__/ecg_dataset_2020.cpython-310.pyc
ADDED
|
Binary file (7.46 kB). View file
|
|
|
__pycache__/ecg_dataset_2021.cpython-310.pyc
ADDED
|
Binary file (9.87 kB). View file
|
|
|
__pycache__/engine_ecg_2020.cpython-310.pyc
ADDED
|
Binary file (7.86 kB). View file
|
|
|
__pycache__/engine_ecg_2021.cpython-310.pyc
ADDED
|
Binary file (7.44 kB). View file
|
|
|
__pycache__/evaluate_12ECG_score.cpython-310.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
__pycache__/evaluate_model.cpython-310.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
__pycache__/helper_code.cpython-310.pyc
ADDED
|
Binary file (6.37 kB). View file
|
|
|
__pycache__/losses.cpython-310.pyc
ADDED
|
Binary file (2.34 kB). View file
|
|
|
__pycache__/models_mamba_ecg.cpython-310.pyc
ADDED
|
Binary file (19.1 kB). View file
|
|
|
__pycache__/optimizer.cpython-310.pyc
ADDED
|
Binary file (1.47 kB). View file
|
|
|
__pycache__/rope.cpython-310.pyc
ADDED
|
Binary file (5 kB). View file
|
|
|
__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (8.37 kB). View file
|
|
|
csv-file_2020_challenge/training_validation_testing/group1/testing_group1.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2020_challenge/training_validation_testing/group1/train_validation_testing_fold5.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2020_challenge/training_validation_testing/group1/training_group1.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2020_challenge/training_validation_testing/group2/testing_group2.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2020_challenge/training_validation_testing/group2/training_group2.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2020_challenge/training_validation_testing/group3/testing_group3.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2020_challenge/training_validation_testing/group3/training_group3.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2020_challenge/training_validation_testing/group4/testing_group4.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2020_challenge/training_validation_testing/group4/training_group4.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2020_challenge/training_validation_testing/group5/testing_group5.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2020_challenge/training_validation_testing/group5/training_group5.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2021_challenge/collection_of_all_datasets.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:455156a539d4dde56aa9c2d1e5e460a72df64bda32b8ba36a6692837520292ac
|
| 3 |
+
size 12635911
|
csv-file_2021_challenge/name.csv
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Name,Target,Label_code,Frequency,Duration,Age,Sex
|
| 2 |
+
JS10647.hea,"[0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
|
| 3 |
+
1. 0.]","['426177001', '713427006', '164934002', '39732003']",500.0,10.0,82.0,Male
|
| 4 |
+
JS10648.hea,"[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
|
| 5 |
+
0. 0.]","['426177001', '713426002']",500.0,10.0,63.0,Male
|
| 6 |
+
JS10649.hea,"[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
|
| 7 |
+
0. 0.]","['111975006', '713426002', '426761007', '55930002']",500.0,10.0,83.0,Male
|
| 8 |
+
JS10650.hea,"[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
|
| 9 |
+
0. 0.]",['29320008'],500.0,10.0,81.0,Female
|
| 10 |
+
JS10651.hea,"[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
|
| 11 |
+
0. 0.]","['233897008', '10370003']",500.0,10.0,63.0,Female
|
csv-file_2021_challenge/training_validation_testing/group1/testing_group1.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2021_challenge/training_validation_testing/group1/training_group1.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bd7406d2f748a0a5edd033b389fd967fc9a232d2421ed9be69adcccf964db7a1
|
| 3 |
+
size 10522513
|
csv-file_2021_challenge/training_validation_testing/group2/testing_group2.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2021_challenge/training_validation_testing/group2/training_group2.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:db819a51f741bbb0bc88878f4adbec1a16065eaf1cb991758dab8a42aef34fce
|
| 3 |
+
size 10523820
|
csv-file_2021_challenge/training_validation_testing/group3/testing_group3.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2021_challenge/training_validation_testing/group3/training_group3.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:03f7959deaa6cbe865b911c103a1fae71210cce4f5ac99c6920dc2efb4beb18a
|
| 3 |
+
size 10522844
|
csv-file_2021_challenge/training_validation_testing/group4/testing_group4.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2021_challenge/training_validation_testing/group4/training_group4.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:125d50af59c4b1862aefcf389b5a7eda5819bdc2e66ce52f43097f9f60661b25
|
| 3 |
+
size 10525392
|
csv-file_2021_challenge/training_validation_testing/group5/testing_group5.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
csv-file_2021_challenge/training_validation_testing/group5/training_group5.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c6b92de27c915999cf59125328b7f2d73a9b33ef8939f04d942b17d1cb7630a
|
| 3 |
+
size 10522780
|
ecg_dataset_2020.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from helper_code import *
|
| 4 |
+
from scipy import signal
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from skmultilearn.model_selection import iterative_train_test_split
|
| 7 |
+
import random
|
| 8 |
+
import warnings
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_nsamp(header):
|
| 13 |
+
return int(header.split('\n')[0].split(' ')[3])
|
| 14 |
+
|
| 15 |
+
# Adapted from original scoring function code
|
| 16 |
+
# For each set of equivalent classes, replace each class with the representative class for the set.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
'''
|
| 20 |
+
the 'output' will be using the zero-padding if the leads are less than 12.(8 leads)
|
| 21 |
+
the 'output_leads' will be [1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
|
| 22 |
+
|
| 23 |
+
the output will not be changed if the lead is 12
|
| 24 |
+
the 'output_leads' will be [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
|
| 25 |
+
'''
|
| 26 |
+
def mixup(data, mix_data):
|
| 27 |
+
mix_lambda = np.random.beta(10, 10)# 这里利用beta分布得到一个0到1的值,这里如果可以,尝试别的分布函数
|
| 28 |
+
if data.shape[1] > mix_data.shape[1]:
|
| 29 |
+
data = mix_lambda * data[:, :mix_data.shape[1]] + (1 - mix_lambda) * mix_data
|
| 30 |
+
else:
|
| 31 |
+
data = mix_lambda * data + (1 - mix_lambda) * mix_data[:, :data.shape[1]]
|
| 32 |
+
return data, mix_lambda
|
| 33 |
+
|
| 34 |
+
def cutmix(data, mix_data):
|
| 35 |
+
cutmix_lambda = np.random.beta(10, 10)
|
| 36 |
+
|
| 37 |
+
if data.shape[1] < mix_data.shape[1]:
|
| 38 |
+
seq_length = data.shape[1]
|
| 39 |
+
win_len = int(np.ceil(seq_length * cutmix_lambda))
|
| 40 |
+
start = np.random.randint(0, seq_length - win_len + 1)
|
| 41 |
+
end = start + win_len
|
| 42 |
+
data[:,start:end] = mix_data[:,start:end]
|
| 43 |
+
else:
|
| 44 |
+
seq_length = mix_data.shape[1]
|
| 45 |
+
win_len = int(np.ceil(seq_length * cutmix_lambda))
|
| 46 |
+
start = np.random.randint(0, seq_length - win_len + 1)
|
| 47 |
+
end = start + win_len
|
| 48 |
+
data[:,start:end] = mix_data[:,start:end]
|
| 49 |
+
return data, cutmix_lambda
|
| 50 |
+
|
| 51 |
+
def cutmix_fix_length(data, mix_data):
|
| 52 |
+
# cutmix_lambda = np.random.beta(0.2, 0.2)
|
| 53 |
+
cutmix_lambda = 0.2
|
| 54 |
+
|
| 55 |
+
if data.shape[1] < mix_data.shape[1]:
|
| 56 |
+
seq_length = data.shape[1]
|
| 57 |
+
win_len = int(np.ceil(seq_length * cutmix_lambda))
|
| 58 |
+
start = np.random.randint(0, seq_length - win_len + 1)
|
| 59 |
+
end = start + win_len
|
| 60 |
+
data[:,start:end] = mix_data[:,start:end]
|
| 61 |
+
else:
|
| 62 |
+
seq_length = mix_data.shape[1]
|
| 63 |
+
win_len = int(np.ceil(seq_length * cutmix_lambda))
|
| 64 |
+
start = np.random.randint(0, seq_length - win_len + 1)
|
| 65 |
+
end = start + win_len
|
| 66 |
+
data[:,start:end] = mix_data[:,start:end]
|
| 67 |
+
return data, cutmix_lambda
|
| 68 |
+
|
| 69 |
+
def same_shape_mixup(data, mix_data):
|
| 70 |
+
if data.shape[1] == mix_data.shape[1]:
|
| 71 |
+
mix_lambda = np.random.beta(10, 10)# 这里利用beta分布得到一个0到1的值,这里如果可以,尝试别的分布函数
|
| 72 |
+
data = mix_lambda * data + (1 - mix_lambda) * mix_data
|
| 73 |
+
return data
|
| 74 |
+
|
| 75 |
+
def expand_leads(recording, input_leads):
|
| 76 |
+
output = np.zeros((12, recording.shape[1]))
|
| 77 |
+
# recording.shape[1]: 5000
|
| 78 |
+
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
|
| 79 |
+
twelve_leads = [k.lower() for k in twelve_leads]
|
| 80 |
+
# ['i', 'ii', 'iii', 'avr', 'avl', 'avf', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6']
|
| 81 |
+
|
| 82 |
+
input_leads = [k.lower() for k in input_leads]
|
| 83 |
+
# Here we can assume:
|
| 84 |
+
# input_leads:I, II, V1, V2, V3, V4, V5, V6,
|
| 85 |
+
# so the new input_leads: ['i', 'ii', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6']
|
| 86 |
+
|
| 87 |
+
output_leads = np.zeros((12,))
|
| 88 |
+
# [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
|
| 89 |
+
|
| 90 |
+
# idx: [0, 1, 6, 7, 8, 9, 10, 11]
|
| 91 |
+
for i,k in enumerate(input_leads):
|
| 92 |
+
idx = twelve_leads.index(k)
|
| 93 |
+
output[idx,:] = recording[i,:]
|
| 94 |
+
output_leads[idx] = 1
|
| 95 |
+
|
| 96 |
+
return output, output_leads
|
| 97 |
+
|
| 98 |
+
'''
|
| 99 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
|
| 100 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
|
| 101 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
|
| 102 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
|
| 103 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
|
| 104 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
|
| 105 |
+
|
| 106 |
+
'''
|
| 107 |
+
|
| 108 |
+
class lead_exctractor:
|
| 109 |
+
"""
|
| 110 |
+
used to select specific leads or random choice of configurations
|
| 111 |
+
|
| 112 |
+
Twelve leads: I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6
|
| 113 |
+
Eight leads: I, II, V1, V2, V3, V4, V5, V6
|
| 114 |
+
Six leads: I, II, III, aVR, aVL, aVF
|
| 115 |
+
Four leads: I, II, III, V2
|
| 116 |
+
Three leads: I, II, V2
|
| 117 |
+
Two leads: I, II
|
| 118 |
+
|
| 119 |
+
"""
|
| 120 |
+
L2 = np.array([1,1,0,0,0,0,0,0,0,0,0,0])
|
| 121 |
+
L3 = np.array([1,1,0,0,0,0,0,1,0,0,0,0])
|
| 122 |
+
L4 = np.array([1,1,1,0,0,0,0,1,0,0,0,0])
|
| 123 |
+
L6 = np.array([1,1,1,1,1,1,0,0,0,0,0,0])
|
| 124 |
+
L8 = np.array([1,1,0,0,0,0,1,1,1,1,1,1])
|
| 125 |
+
L12 = np.array([1,1,1,1,1,1,1,1,1,1,1,1])
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
def get (x, num_leads, lead_indicator):
|
| 129 |
+
if num_leads==None:
|
| 130 |
+
# random choice output
|
| 131 |
+
num_leads = random.choice([12,8,6,4,3,2])
|
| 132 |
+
|
| 133 |
+
if num_leads==12:
|
| 134 |
+
# Twelve leads: I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6
|
| 135 |
+
return x, lead_indicator * lead_exctractor.L12
|
| 136 |
+
|
| 137 |
+
if num_leads==8:
|
| 138 |
+
# Six leads: I, II, V1, V2, V3, V4, V5, V6
|
| 139 |
+
x = x * lead_exctractor.L8.reshape(12,1)
|
| 140 |
+
return x,lead_indicator * lead_exctractor.L8
|
| 141 |
+
|
| 142 |
+
if num_leads==6:
|
| 143 |
+
# Six leads: I, II, III, aVL, aVR, aVF
|
| 144 |
+
x = x * lead_exctractor.L6.reshape(12,1)
|
| 145 |
+
return x,lead_indicator * lead_exctractor.L6
|
| 146 |
+
|
| 147 |
+
if num_leads==4:
|
| 148 |
+
# Six leads: I, II, III, V2
|
| 149 |
+
x = x * lead_exctractor.L4.reshape(12,1)
|
| 150 |
+
return x,lead_indicator * lead_exctractor.L4
|
| 151 |
+
|
| 152 |
+
if num_leads==3:
|
| 153 |
+
# Three leads: I, II, V2
|
| 154 |
+
x = x * lead_exctractor.L3.reshape(12,1)
|
| 155 |
+
return x,lead_indicator * lead_exctractor.L3
|
| 156 |
+
|
| 157 |
+
if num_leads==2:
|
| 158 |
+
# Two leads: II, V5
|
| 159 |
+
x = x * lead_exctractor.L2.reshape(12,1)
|
| 160 |
+
return x,lead_indicator * lead_exctractor.L2
|
| 161 |
+
raise Exception("invalid-leads-number")
|
| 162 |
+
|
| 163 |
+
class dataset:
|
| 164 |
+
# #ECG 2021
|
| 165 |
+
# classes = ['164889003','164890007','6374002','426627000','733534002',
|
| 166 |
+
# '713427006','270492004','713426002','39732003','445118002',
|
| 167 |
+
# '164947007','251146004','111975006','698252002','426783006',
|
| 168 |
+
# '284470004','10370003','365413008','427172004','164917005',
|
| 169 |
+
# '47665007','427393009','426177001','427084000','164934002',
|
| 170 |
+
# '59931005']
|
| 171 |
+
|
| 172 |
+
# ECG 2020
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
normal_class = '426783006'
|
| 177 |
+
|
| 178 |
+
def __init__(self, header_files, Mixup = 0, amount = 0, cutMix = 0, Mixup_no_label_interpolate =0, progressive_switch = False):
|
| 179 |
+
self.files = []
|
| 180 |
+
self.sample = True
|
| 181 |
+
self.num_leads = None
|
| 182 |
+
|
| 183 |
+
for h in tqdm(header_files):
|
| 184 |
+
tmp = dict()
|
| 185 |
+
tmp['header'] = h
|
| 186 |
+
tmp['record'] = h.replace('.hea','.mat')
|
| 187 |
+
hdr = load_header(h)
|
| 188 |
+
tmp['nsamp'] = get_nsamp(hdr)
|
| 189 |
+
tmp['leads'] = get_leads(hdr)
|
| 190 |
+
tmp['age'] = get_age(hdr)
|
| 191 |
+
tmp['sex'] = get_sex(hdr)
|
| 192 |
+
tmp['dx'] = get_labels(hdr)
|
| 193 |
+
tmp['fs'] = get_frequency(hdr)
|
| 194 |
+
# tmp['target'] = np.zeros((26,))
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
self.files.append(tmp)
|
| 198 |
+
|
| 199 |
+
# print("This is the target:", tmp['target'])
|
| 200 |
+
# set filter parameters
|
| 201 |
+
self.b, self.a = signal.butter(3, [1 / 250, 47 / 250], 'bandpass')
|
| 202 |
+
|
| 203 |
+
self.files = pd.DataFrame(self.files)
|
| 204 |
+
|
| 205 |
+
self.Mixup = Mixup
|
| 206 |
+
self.cutMix = cutMix
|
| 207 |
+
self.Mixup_no_label_interplate = Mixup_no_label_interpolate
|
| 208 |
+
self.amount = amount
|
| 209 |
+
self.progressive = progressive_switch
|
| 210 |
+
|
| 211 |
+
self.current_epoch = 0 # Initialize the current epoch
|
| 212 |
+
|
| 213 |
+
address_2020 = "./csv-file_2020_challenge/training_validation_testing/group1/train_validation_testing_fold5.csv"
|
| 214 |
+
self.data_df_2020 = pd.read_csv(address_2020, index_col=0)
|
| 215 |
+
self.data_df_2020.set_index('Patient', inplace=True)
|
| 216 |
+
self.classes_2020 = ['270492004', '164889003', '164890007', '426627000', '713427006',
|
| 217 |
+
'713426002', '445118002', '39732003', '164909002', '251146004',
|
| 218 |
+
'698252002', '10370003', '284470004', '427172004', '164947007',
|
| 219 |
+
'111975006', '164917005', '47665007', '59118001', '427393009',
|
| 220 |
+
'426177001', '426783006', '427084000', '63593006', '164934002',
|
| 221 |
+
'59931005', '17338001']
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def set_epoch(self, epoch):
|
| 225 |
+
self.current_epoch = epoch
|
| 226 |
+
|
| 227 |
+
def summary(self, output):
|
| 228 |
+
if output=='pandas':
|
| 229 |
+
# print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
| 230 |
+
# print("This is the target:",self.files['target'],type(self.files['target']))
|
| 231 |
+
return pd.Series(np.stack(self.files['target'].to_list(),axis=0).sum(axis=0),index=dataset.classes)
|
| 232 |
+
|
| 233 |
+
if output=='numpy':
|
| 234 |
+
return np.stack(self.files['target'].to_list(),axis=0).sum(axis=0)
|
| 235 |
+
|
| 236 |
+
def __len__(self):
|
| 237 |
+
return len(self.files)
|
| 238 |
+
|
| 239 |
+
'''
|
| 240 |
+
fs: 500.0 sampling rate
|
| 241 |
+
target: [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
|
| 242 |
+
leads: ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
|
| 243 |
+
self.files.iloc[item]['record']: ../../python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00518.mat
|
| 244 |
+
data: shape->(12 X 5000)
|
| 245 |
+
'''
|
| 246 |
+
|
| 247 |
+
def __getitem__(self, item):
|
| 248 |
+
|
| 249 |
+
fs = self.files.iloc[item]['fs']
|
| 250 |
+
# target = self.files.iloc[item]['target']
|
| 251 |
+
header = self.files.iloc[item]['header']
|
| 252 |
+
name = header.split('/')[-1]
|
| 253 |
+
row = self.data_df_2020.loc[name[0:-4]]
|
| 254 |
+
target = row[self.classes_2020].values.astype(np.int_)
|
| 255 |
+
|
| 256 |
+
leads = self.files.iloc[item]['leads']
|
| 257 |
+
data = load_recording(self.files.iloc[item]['record'])
|
| 258 |
+
|
| 259 |
+
# print("This is the target:", target)
|
| 260 |
+
# Set your threshold
|
| 261 |
+
|
| 262 |
+
threshold = 0.5
|
| 263 |
+
# print("This is the original target:", target)
|
| 264 |
+
# print("This is the type of original target:", type(target))
|
| 265 |
+
|
| 266 |
+
# if random.random() < self.Mixup :
|
| 267 |
+
# mix_sample_idx = random.randint(0, self.amount-1)
|
| 268 |
+
# mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
|
| 269 |
+
# mix_target = self.files.iloc[mix_sample_idx]['target']
|
| 270 |
+
# data, alpha_value = mixup(data, mix_datum)
|
| 271 |
+
# target = alpha_value * target + (1-alpha_value) * mix_target
|
| 272 |
+
# # Binarize the labels based on the threshold
|
| 273 |
+
# target[target >= threshold] = 1
|
| 274 |
+
# target[target < threshold] = 0
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# if random.random() < self.cutMix:
|
| 278 |
+
# mix_sample_idx = random.randint(0, self.amount-1)
|
| 279 |
+
# mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
|
| 280 |
+
# mix_target = self.files.iloc[mix_sample_idx]['target']
|
| 281 |
+
# data, alpha_value = cutmix_fix_length(data, mix_datum)
|
| 282 |
+
# target = alpha_value * target + (1-alpha_value) * mix_target
|
| 283 |
+
# target[target >= threshold] = 1
|
| 284 |
+
# target[target < threshold] = 0
|
| 285 |
+
|
| 286 |
+
# if random.random()< self.Mixup_no_label_interplate:
|
| 287 |
+
# mix_sample_idx = random.randint(0, self.amount-1)
|
| 288 |
+
# mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
|
| 289 |
+
# data= same_shape_mixup(data, mix_datum)
|
| 290 |
+
|
| 291 |
+
# proressive_index = 0
|
| 292 |
+
|
| 293 |
+
# # # below is the progessive data augmentation method
|
| 294 |
+
# # if self.progressive:
|
| 295 |
+
# # if self.current_epoch < 20:
|
| 296 |
+
# # if self.current_epoch % 4 == 0 or 3:
|
| 297 |
+
# # pass
|
| 298 |
+
# # elif self.current_epoch % 4 == 1:
|
| 299 |
+
# # proressive_index = 0.2
|
| 300 |
+
# # else:
|
| 301 |
+
# # proressive_index = 0.5
|
| 302 |
+
|
| 303 |
+
# # elif self.current_epoch < 40:
|
| 304 |
+
# # if self.current_epoch % 4 == 0 or 3:
|
| 305 |
+
# # pass
|
| 306 |
+
# # elif self.current_epoch % 4 == 1:
|
| 307 |
+
# # proressive_index = 0.7
|
| 308 |
+
# # else:
|
| 309 |
+
# # proressive_index = 0.8
|
| 310 |
+
# # else:
|
| 311 |
+
# # if self.current_epoch % 4 == 0 or 3:
|
| 312 |
+
# # pass
|
| 313 |
+
# # elif self.current_epoch % 4 == 1 or 2:
|
| 314 |
+
# # proressive_index = 1
|
| 315 |
+
|
| 316 |
+
# # below is the wave-mix data augmentation method
|
| 317 |
+
# if self.progressive:
|
| 318 |
+
# if self.current_epoch < 5:
|
| 319 |
+
# match self.current_epoch:
|
| 320 |
+
# case 0:
|
| 321 |
+
# self.progressive=0
|
| 322 |
+
# case 1:
|
| 323 |
+
# self.progressive=0.2
|
| 324 |
+
# case 2:
|
| 325 |
+
# self.progressive=0.4
|
| 326 |
+
# case 3:
|
| 327 |
+
# self.progressive=0.6
|
| 328 |
+
# case 4:
|
| 329 |
+
# self.progressive=0.8
|
| 330 |
+
|
| 331 |
+
# else :
|
| 332 |
+
# proressive_index = 0.8
|
| 333 |
+
|
| 334 |
+
# if random.random() < proressive_index :
|
| 335 |
+
# mix_sample_idx = random.randint(0, self.amount-1)
|
| 336 |
+
# mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
|
| 337 |
+
# mix_target = self.files.iloc[mix_sample_idx]['target']
|
| 338 |
+
# data, alpha_value = mixup(data, mix_datum)
|
| 339 |
+
# target = alpha_value * target + (1-alpha_value) * mix_target
|
| 340 |
+
# # Binarize the labels based on the threshold
|
| 341 |
+
# target[target >= threshold] = 1
|
| 342 |
+
# target[target < threshold] = 0
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# expand to 12 lead setup if original signal has less channels
|
| 346 |
+
data, lead_indicator = expand_leads(data, input_leads=leads)
|
| 347 |
+
data = np.nan_to_num(data)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
# resample to 500hz
|
| 351 |
+
if fs == float(1000):
|
| 352 |
+
data = signal.resample_poly(data, up=1, down=2, axis=-1) # to 500Hz
|
| 353 |
+
fs = 500
|
| 354 |
+
elif fs == float(500):
|
| 355 |
+
pass
|
| 356 |
+
else:
|
| 357 |
+
data = signal.resample(data, int(data.shape[1] * 500 / fs), axis=1)
|
| 358 |
+
fs = 500
|
| 359 |
+
|
| 360 |
+
# below is the Butterworth digital and analog filter design.
|
| 361 |
+
data = signal.filtfilt(self.b, self.a, data)
|
| 362 |
+
|
| 363 |
+
'''
|
| 364 |
+
we filter out the sample which the length is 8192 via the below code.
|
| 365 |
+
just like the random shift windows.
|
| 366 |
+
'''
|
| 367 |
+
if self.sample:
|
| 368 |
+
fs = int(fs)
|
| 369 |
+
# random sample signal if len > 8192 samples
|
| 370 |
+
if data.shape[-1] >= 8192:
|
| 371 |
+
idx = data.shape[-1] - 8192-1
|
| 372 |
+
idx = np.random.randint(idx)
|
| 373 |
+
data = data[:, idx:idx + 8192]
|
| 374 |
+
|
| 375 |
+
'''
|
| 376 |
+
we can obtain the average and mean for each lead via below code, 90% conclusion.
|
| 377 |
+
'''
|
| 378 |
+
mu = np.nanmean(data, axis=-1, keepdims=True)
|
| 379 |
+
std = np.nanstd(data, axis=-1, keepdims=True)
|
| 380 |
+
|
| 381 |
+
'''
|
| 382 |
+
Z-Score Normalization
|
| 383 |
+
'''
|
| 384 |
+
#std = np.nanstd(data.flatten())
|
| 385 |
+
with warnings.catch_warnings():
|
| 386 |
+
warnings.simplefilter("ignore")
|
| 387 |
+
data = (data - mu) / std
|
| 388 |
+
|
| 389 |
+
data = np.nan_to_num(data)
|
| 390 |
+
|
| 391 |
+
# random choose number of leads to keep
|
| 392 |
+
data, lead_indicator = lead_exctractor.get(data, self.num_leads, lead_indicator)
|
| 393 |
+
|
| 394 |
+
# print("This is the lead_indicator:(l)", type(lead_indicator), lead_indicator)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
return data, target
|
ecg_dataset_2021.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from helper_code import *
|
| 4 |
+
from scipy import signal
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from skmultilearn.model_selection import iterative_train_test_split
|
| 7 |
+
import random
|
| 8 |
+
import warnings
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_nsamp(header):
|
| 13 |
+
return int(header.split('\n')[0].split(' ')[3])
|
| 14 |
+
|
| 15 |
+
# Adapted from original scoring function code
|
| 16 |
+
# For each set of equivalent classes, replace each class with the representative class for the set.
|
| 17 |
+
def replace_equivalent_classes(classes, equivalent_classes):
|
| 18 |
+
for j, x in enumerate(classes):
|
| 19 |
+
for multiple_classes in equivalent_classes:
|
| 20 |
+
if x in multiple_classes:
|
| 21 |
+
classes[j] = multiple_classes[0] # Use the first class as the representative class.
|
| 22 |
+
return classes
|
| 23 |
+
|
| 24 |
+
'''
|
| 25 |
+
the 'output' will be using the zero-padding if the leads are less than 12.(8 leads)
|
| 26 |
+
the 'output_leads' will be [1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
|
| 27 |
+
|
| 28 |
+
the output will not be changed if the lead is 12
|
| 29 |
+
the 'output_leads' will be [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
|
| 30 |
+
'''
|
| 31 |
+
def mixup(data, mix_data):
|
| 32 |
+
mix_lambda = np.random.beta(10, 10)# 这里利用beta分布得到一个0到1的值,这里如果可以,尝试别的分布函数
|
| 33 |
+
if data.shape[1] > mix_data.shape[1]:
|
| 34 |
+
data = mix_lambda * data[:, :mix_data.shape[1]] + (1 - mix_lambda) * mix_data
|
| 35 |
+
else:
|
| 36 |
+
data = mix_lambda * data + (1 - mix_lambda) * mix_data[:, :data.shape[1]]
|
| 37 |
+
return data, mix_lambda
|
| 38 |
+
|
| 39 |
+
def cutmix(data, mix_data):
|
| 40 |
+
cutmix_lambda = np.random.beta(10, 10)
|
| 41 |
+
|
| 42 |
+
if data.shape[1] < mix_data.shape[1]:
|
| 43 |
+
seq_length = data.shape[1]
|
| 44 |
+
win_len = int(np.ceil(seq_length * cutmix_lambda))
|
| 45 |
+
start = np.random.randint(0, seq_length - win_len + 1)
|
| 46 |
+
end = start + win_len
|
| 47 |
+
data[:,start:end] = mix_data[:,start:end]
|
| 48 |
+
else:
|
| 49 |
+
seq_length = mix_data.shape[1]
|
| 50 |
+
win_len = int(np.ceil(seq_length * cutmix_lambda))
|
| 51 |
+
start = np.random.randint(0, seq_length - win_len + 1)
|
| 52 |
+
end = start + win_len
|
| 53 |
+
data[:,start:end] = mix_data[:,start:end]
|
| 54 |
+
return data, cutmix_lambda
|
| 55 |
+
|
| 56 |
+
def cutmix_fix_length(data, mix_data):
|
| 57 |
+
# cutmix_lambda = np.random.beta(0.2, 0.2)
|
| 58 |
+
cutmix_lambda = 0.2
|
| 59 |
+
|
| 60 |
+
if data.shape[1] < mix_data.shape[1]:
|
| 61 |
+
seq_length = data.shape[1]
|
| 62 |
+
win_len = int(np.ceil(seq_length * cutmix_lambda))
|
| 63 |
+
start = np.random.randint(0, seq_length - win_len + 1)
|
| 64 |
+
end = start + win_len
|
| 65 |
+
data[:,start:end] = mix_data[:,start:end]
|
| 66 |
+
else:
|
| 67 |
+
seq_length = mix_data.shape[1]
|
| 68 |
+
win_len = int(np.ceil(seq_length * cutmix_lambda))
|
| 69 |
+
start = np.random.randint(0, seq_length - win_len + 1)
|
| 70 |
+
end = start + win_len
|
| 71 |
+
data[:,start:end] = mix_data[:,start:end]
|
| 72 |
+
return data, cutmix_lambda
|
| 73 |
+
|
| 74 |
+
def same_shape_mixup(data, mix_data):
|
| 75 |
+
if data.shape[1] == mix_data.shape[1]:
|
| 76 |
+
mix_lambda = np.random.beta(10, 10)# 这里利用beta分布得到一个0到1的值,这里如果可以,尝试别的分布函数
|
| 77 |
+
data = mix_lambda * data + (1 - mix_lambda) * mix_data
|
| 78 |
+
return data
|
| 79 |
+
|
| 80 |
+
def expand_leads(recording, input_leads):
|
| 81 |
+
output = np.zeros((12, recording.shape[1]))
|
| 82 |
+
# recording.shape[1]: 5000
|
| 83 |
+
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
|
| 84 |
+
twelve_leads = [k.lower() for k in twelve_leads]
|
| 85 |
+
# ['i', 'ii', 'iii', 'avr', 'avl', 'avf', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6']
|
| 86 |
+
|
| 87 |
+
input_leads = [k.lower() for k in input_leads]
|
| 88 |
+
# Here we can assume:
|
| 89 |
+
# input_leads:I, II, V1, V2, V3, V4, V5, V6,
|
| 90 |
+
# so the new input_leads: ['i', 'ii', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6']
|
| 91 |
+
|
| 92 |
+
output_leads = np.zeros((12,))
|
| 93 |
+
# [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
|
| 94 |
+
|
| 95 |
+
# idx: [0, 1, 6, 7, 8, 9, 10, 11]
|
| 96 |
+
for i,k in enumerate(input_leads):
|
| 97 |
+
idx = twelve_leads.index(k)
|
| 98 |
+
output[idx,:] = recording[i,:]
|
| 99 |
+
output_leads[idx] = 1
|
| 100 |
+
|
| 101 |
+
return output, output_leads
|
| 102 |
+
|
| 103 |
+
'''
|
| 104 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
|
| 105 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
|
| 106 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
|
| 107 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
|
| 108 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
|
| 109 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
|
| 110 |
+
|
| 111 |
+
'''
|
| 112 |
+
|
| 113 |
+
class lead_exctractor:
|
| 114 |
+
"""
|
| 115 |
+
used to select specific leads or random choice of configurations
|
| 116 |
+
|
| 117 |
+
Twelve leads: I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6
|
| 118 |
+
Eight leads: I, II, V1, V2, V3, V4, V5, V6
|
| 119 |
+
Six leads: I, II, III, aVR, aVL, aVF
|
| 120 |
+
Four leads: I, II, III, V2
|
| 121 |
+
Three leads: I, II, V2
|
| 122 |
+
Two leads: I, II
|
| 123 |
+
|
| 124 |
+
"""
|
| 125 |
+
L2 = np.array([1,1,0,0,0,0,0,0,0,0,0,0])
|
| 126 |
+
L3 = np.array([1,1,0,0,0,0,0,1,0,0,0,0])
|
| 127 |
+
L4 = np.array([1,1,1,0,0,0,0,1,0,0,0,0])
|
| 128 |
+
L6 = np.array([1,1,1,1,1,1,0,0,0,0,0,0])
|
| 129 |
+
L8 = np.array([1,1,0,0,0,0,1,1,1,1,1,1])
|
| 130 |
+
L12 = np.array([1,1,1,1,1,1,1,1,1,1,1,1])
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def get (x, num_leads, lead_indicator):
|
| 134 |
+
if num_leads==None:
|
| 135 |
+
# random choice output
|
| 136 |
+
num_leads = random.choice([12,8,6,4,3,2])
|
| 137 |
+
|
| 138 |
+
if num_leads==12:
|
| 139 |
+
# Twelve leads: I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6
|
| 140 |
+
return x, lead_indicator * lead_exctractor.L12
|
| 141 |
+
|
| 142 |
+
if num_leads==8:
|
| 143 |
+
# Six leads: I, II, V1, V2, V3, V4, V5, V6
|
| 144 |
+
x = x * lead_exctractor.L8.reshape(12,1)
|
| 145 |
+
return x,lead_indicator * lead_exctractor.L8
|
| 146 |
+
|
| 147 |
+
if num_leads==6:
|
| 148 |
+
# Six leads: I, II, III, aVL, aVR, aVF
|
| 149 |
+
x = x * lead_exctractor.L6.reshape(12,1)
|
| 150 |
+
return x,lead_indicator * lead_exctractor.L6
|
| 151 |
+
|
| 152 |
+
if num_leads==4:
|
| 153 |
+
# Six leads: I, II, III, V2
|
| 154 |
+
x = x * lead_exctractor.L4.reshape(12,1)
|
| 155 |
+
return x,lead_indicator * lead_exctractor.L4
|
| 156 |
+
|
| 157 |
+
if num_leads==3:
|
| 158 |
+
# Three leads: I, II, V2
|
| 159 |
+
x = x * lead_exctractor.L3.reshape(12,1)
|
| 160 |
+
return x,lead_indicator * lead_exctractor.L3
|
| 161 |
+
|
| 162 |
+
if num_leads==2:
|
| 163 |
+
# Two leads: II, V5
|
| 164 |
+
x = x * lead_exctractor.L2.reshape(12,1)
|
| 165 |
+
return x,lead_indicator * lead_exctractor.L2
|
| 166 |
+
raise Exception("invalid-leads-number")
|
| 167 |
+
|
| 168 |
+
class dataset:
|
| 169 |
+
classes = ['164889003','164890007','6374002','426627000','733534002',
|
| 170 |
+
'713427006','270492004','713426002','39732003','445118002',
|
| 171 |
+
'164947007','251146004','111975006','698252002','426783006',
|
| 172 |
+
'284470004','10370003','365413008','427172004','164917005',
|
| 173 |
+
'47665007','427393009','426177001','427084000','164934002',
|
| 174 |
+
'59931005']
|
| 175 |
+
normal_class = '426783006'
|
| 176 |
+
equivalent_classes = [['713427006', '59118001'],
|
| 177 |
+
['284470004', '63593006'],
|
| 178 |
+
['427172004', '17338001'],
|
| 179 |
+
['733534002', '164909002']]
|
| 180 |
+
def __init__(self, header_files, Mixup = 0, amount = 0, cutMix = 0, Mixup_no_label_interpolate =0, progressive_switch = False):
|
| 181 |
+
self.files = []
|
| 182 |
+
self.sample = True
|
| 183 |
+
self.num_leads = None
|
| 184 |
+
|
| 185 |
+
for h in tqdm(header_files):
|
| 186 |
+
tmp = dict()
|
| 187 |
+
tmp['header'] = h
|
| 188 |
+
tmp['record'] = h.replace('.hea','.mat')
|
| 189 |
+
hdr = load_header(h)
|
| 190 |
+
tmp['nsamp'] = get_nsamp(hdr)
|
| 191 |
+
tmp['leads'] = get_leads(hdr)
|
| 192 |
+
tmp['age'] = get_age(hdr)
|
| 193 |
+
tmp['sex'] = get_sex(hdr)
|
| 194 |
+
tmp['dx'] = get_labels(hdr)
|
| 195 |
+
tmp['fs'] = get_frequency(hdr)
|
| 196 |
+
tmp['target'] = np.zeros((26,))
|
| 197 |
+
tmp['dx'] = replace_equivalent_classes(tmp['dx'], dataset.equivalent_classes)
|
| 198 |
+
|
| 199 |
+
for dx in tmp['dx']:
|
| 200 |
+
# in SNOMED code is in scored classes
|
| 201 |
+
if dx in dataset.classes:
|
| 202 |
+
idx = dataset.classes.index(dx)
|
| 203 |
+
tmp['target'][idx] = 1
|
| 204 |
+
|
| 205 |
+
self.files.append(tmp)
|
| 206 |
+
|
| 207 |
+
# print("This is the target:", tmp['target'])
|
| 208 |
+
# set filter parameters
|
| 209 |
+
|
| 210 |
+
# Filtering: Data are filtered using a zero-phase method with 3rd order Butterworth bandpass filter
|
| 211 |
+
# with frequency band from 1 Hz to 47 Hz.
|
| 212 |
+
self.b, self.a = signal.butter(3, [1 / 250, 47 / 250], 'bandpass')
|
| 213 |
+
|
| 214 |
+
self.files = pd.DataFrame(self.files)
|
| 215 |
+
|
| 216 |
+
self.Mixup = Mixup
|
| 217 |
+
self.cutMix = cutMix
|
| 218 |
+
self.Mixup_no_label_interplate = Mixup_no_label_interpolate
|
| 219 |
+
self.amount = amount
|
| 220 |
+
self.progressive = progressive_switch
|
| 221 |
+
|
| 222 |
+
self.current_epoch = 0 # Initialize the current epoch
|
| 223 |
+
|
| 224 |
+
def set_epoch(self, epoch):
|
| 225 |
+
self.current_epoch = epoch
|
| 226 |
+
|
| 227 |
+
def train_valid_split(self, test_size):
|
| 228 |
+
'''
|
| 229 |
+
test_size: 0.2
|
| 230 |
+
below is the value of each variable:
|
| 231 |
+
files[0] <- "/media/jiang/ECG/physionet.org/files/challenge-2021/1.0.3/training/python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00001.hea" shape -> (999, 1)
|
| 232 |
+
targets[1] <-[1. 0. 0. ... 0. 1. 0.] 26, shape->(999, 26)
|
| 233 |
+
x_train[0] <-"/media/jiang/ECG/physionet.org/files/challenge-2021/1.0.3/training/python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00001.hea" shape -> (799,1)
|
| 234 |
+
x_valid[0] <-"/media/jiang/ECG/physionet.org/files/challenge-2021/1.0.3/training/python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00651.hea" shape -> (200,1)
|
| 235 |
+
'''
|
| 236 |
+
files = self.files['header'].to_numpy().reshape(-1,1)
|
| 237 |
+
# print(files[0])
|
| 238 |
+
targets = np.stack(self.files['target'].to_list(), axis=0)
|
| 239 |
+
print("This is the targets:", targets)
|
| 240 |
+
|
| 241 |
+
x_train, y_train, x_valid, y_valid = iterative_train_test_split(files, targets, test_size=test_size)
|
| 242 |
+
|
| 243 |
+
train = dataset(header_files=x_train[:,0].tolist())
|
| 244 |
+
|
| 245 |
+
train.num_leads=None
|
| 246 |
+
train.sample=True
|
| 247 |
+
|
| 248 |
+
valid = dataset(header_files=x_valid[:,0].tolist())
|
| 249 |
+
|
| 250 |
+
valid.num_leads=12
|
| 251 |
+
valid.sample=False
|
| 252 |
+
|
| 253 |
+
return train, valid
|
| 254 |
+
|
| 255 |
+
def summary(self, output):
|
| 256 |
+
if output=='pandas':
|
| 257 |
+
# print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
| 258 |
+
# print("This is the target:",self.files['target'],type(self.files['target']))
|
| 259 |
+
return pd.Series(np.stack(self.files['target'].to_list(),axis=0).sum(axis=0),index=dataset.classes)
|
| 260 |
+
|
| 261 |
+
if output=='numpy':
|
| 262 |
+
return np.stack(self.files['target'].to_list(),axis=0).sum(axis=0)
|
| 263 |
+
|
| 264 |
+
def Pre_processing(self, fs, leads, data):
|
| 265 |
+
|
| 266 |
+
# 1. Expand to a 12-lead setup if the original signal has fewer channels
|
| 267 |
+
data, lead_indicator = expand_leads(data, input_leads=leads)
|
| 268 |
+
data = np.nan_to_num(data)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# 2. Resample from 1000 Hz and other rates to 500 Hz
|
| 272 |
+
if fs == float(1000):
|
| 273 |
+
data = signal.resample_poly(data, up=1, down=2, axis=-1) # to 500Hz
|
| 274 |
+
fs = 500
|
| 275 |
+
elif fs == float(500):
|
| 276 |
+
pass
|
| 277 |
+
else:
|
| 278 |
+
data = signal.resample(data, int(data.shape[1] * 500 / fs), axis=1)
|
| 279 |
+
fs = 500
|
| 280 |
+
|
| 281 |
+
# 3. Below is the design of the Butterworth digital and analog filter
|
| 282 |
+
data = signal.filtfilt(self.b, self.a, data)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# 4. We compute the average and mean for each lead using the code below; this is Z-score normalization
|
| 286 |
+
mu = np.nanmean(data, axis=-1, keepdims=True)
|
| 287 |
+
std = np.nanstd(data, axis=-1, keepdims=True)
|
| 288 |
+
|
| 289 |
+
#std = np.nanstd(data.flatten())
|
| 290 |
+
with warnings.catch_warnings():
|
| 291 |
+
warnings.simplefilter("ignore")
|
| 292 |
+
data = (data - mu) / std
|
| 293 |
+
|
| 294 |
+
data = np.nan_to_num(data)
|
| 295 |
+
|
| 296 |
+
# 5. Selection of leads for 12-lead ECGs; the default is 12 leads
|
| 297 |
+
data, lead_indicator = lead_exctractor.get(data, self.num_leads, lead_indicator)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# 6. We filter out samples with a length of 8192 using the code below, similar to random shift windows, and zero-padding as well
|
| 301 |
+
if self.sample:
|
| 302 |
+
fs = int(fs)
|
| 303 |
+
# random sample signal if len > 8192 samples
|
| 304 |
+
if data.shape[-1] >= 8192:
|
| 305 |
+
idx = data.shape[-1] - 8192-1
|
| 306 |
+
idx = np.random.randint(idx)
|
| 307 |
+
data = data[:, idx:idx + 8192]
|
| 308 |
+
else:
|
| 309 |
+
# Apply right zero-padding along the second dimension
|
| 310 |
+
padding_size = 8192 - data.shape[-1]
|
| 311 |
+
data = np.pad(data, ((0, 0), (0, padding_size)), mode='constant', constant_values=0)
|
| 312 |
+
|
| 313 |
+
return data
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def __len__(self):
|
| 317 |
+
return len(self.files)
|
| 318 |
+
|
| 319 |
+
'''
|
| 320 |
+
fs: 500.0 sampling rate
|
| 321 |
+
target: [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
|
| 322 |
+
leads: ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
|
| 323 |
+
self.files.iloc[item]['record']: ../../python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00518.mat
|
| 324 |
+
data: shape->(12 X 5000)
|
| 325 |
+
'''
|
| 326 |
+
|
| 327 |
+
def __getitem__(self, item):
|
| 328 |
+
|
| 329 |
+
fs = self.files.iloc[item]['fs']
|
| 330 |
+
target = self.files.iloc[item]['target']
|
| 331 |
+
leads = self.files.iloc[item]['leads']
|
| 332 |
+
data = load_recording(self.files.iloc[item]['record'])
|
| 333 |
+
|
| 334 |
+
# the phase of pre-processing:
|
| 335 |
+
data = self.Pre_processing(fs, leads, data)
|
| 336 |
+
|
| 337 |
+
'''
|
| 338 |
+
below are the data augmentation methods
|
| 339 |
+
'''
|
| 340 |
+
if random.random() < self.Mixup :
|
| 341 |
+
mix_sample_idx = random.randint(0, self.amount-1)
|
| 342 |
+
mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
|
| 343 |
+
mix_target = self.files.iloc[mix_sample_idx]['target']
|
| 344 |
+
|
| 345 |
+
mix_fs = self.files.iloc[mix_sample_idx]['fs']
|
| 346 |
+
mix_leads = self.files.iloc[mix_sample_idx]['leads']
|
| 347 |
+
mix_datum = self.Pre_processing(mix_fs, mix_leads, mix_datum)
|
| 348 |
+
|
| 349 |
+
data, alpha_value = mixup(data, mix_datum)
|
| 350 |
+
target = alpha_value * target + (1-alpha_value) * mix_target
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if random.random() < self.cutMix:
|
| 354 |
+
mix_sample_idx = random.randint(0, self.amount-1)
|
| 355 |
+
mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
|
| 356 |
+
mix_target = self.files.iloc[mix_sample_idx]['target']
|
| 357 |
+
|
| 358 |
+
mix_fs = self.files.iloc[mix_sample_idx]['fs']
|
| 359 |
+
mix_leads = self.files.iloc[mix_sample_idx]['leads']
|
| 360 |
+
mix_datum = self.Pre_processing(mix_fs, mix_leads, mix_datum)
|
| 361 |
+
|
| 362 |
+
data, alpha_value = cutmix(data, mix_datum)
|
| 363 |
+
target = alpha_value * target + (1-alpha_value) * mix_target
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
progressive_index = 0
|
| 367 |
+
|
| 368 |
+
# # below is the progessive data augmentation method
|
| 369 |
+
# if self.progressive:
|
| 370 |
+
# if self.current_epoch < 20:
|
| 371 |
+
# if self.current_epoch % 4 == 0 or 3:
|
| 372 |
+
# pass
|
| 373 |
+
# elif self.current_epoch % 4 == 1:
|
| 374 |
+
# proressive_index = 0.2
|
| 375 |
+
# else:
|
| 376 |
+
# proressive_index = 0.5
|
| 377 |
+
|
| 378 |
+
# elif self.current_epoch < 40:
|
| 379 |
+
# if self.current_epoch % 4 == 0 or 3:
|
| 380 |
+
# pass
|
| 381 |
+
# elif self.current_epoch % 4 == 1:
|
| 382 |
+
# proressive_index = 0.7
|
| 383 |
+
# else:
|
| 384 |
+
# proressive_index = 0.8
|
| 385 |
+
# else:
|
| 386 |
+
# if self.current_epoch % 4 == 0 or 3:
|
| 387 |
+
# pass
|
| 388 |
+
# elif self.current_epoch % 4 == 1 or 2:
|
| 389 |
+
# proressive_index = 1
|
| 390 |
+
|
| 391 |
+
# below is the wave-mix data augmentation method
|
| 392 |
+
if self.progressive:
|
| 393 |
+
if self.current_epoch < 5:
|
| 394 |
+
match self.current_epoch:
|
| 395 |
+
case 0:
|
| 396 |
+
progressive_index=0.8
|
| 397 |
+
case 1:
|
| 398 |
+
progressive_index=0.8
|
| 399 |
+
case 2:
|
| 400 |
+
progressive_index=0.8
|
| 401 |
+
case 3:
|
| 402 |
+
progressive_index=0.8
|
| 403 |
+
case 4:
|
| 404 |
+
progressive_index=0.8
|
| 405 |
+
|
| 406 |
+
else :
|
| 407 |
+
progressive_index = 0.8
|
| 408 |
+
|
| 409 |
+
if random.random() < progressive_index :
|
| 410 |
+
mix_sample_idx = random.randint(0, self.amount-1)
|
| 411 |
+
mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
|
| 412 |
+
mix_target = self.files.iloc[mix_sample_idx]['target']
|
| 413 |
+
|
| 414 |
+
mix_fs = self.files.iloc[mix_sample_idx]['fs']
|
| 415 |
+
mix_leads = self.files.iloc[mix_sample_idx]['leads']
|
| 416 |
+
mix_datum = self.Pre_processing(mix_fs, mix_leads, mix_datum)
|
| 417 |
+
|
| 418 |
+
data, alpha_value = mixup(data, mix_datum)
|
| 419 |
+
target = alpha_value * target + (1-alpha_value) * mix_target
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
return data, target
|
ecg_dataset_2021_DAFirst.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from helper_code import *
|
| 4 |
+
from scipy import signal
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from skmultilearn.model_selection import iterative_train_test_split
|
| 7 |
+
import random
|
| 8 |
+
import warnings
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_nsamp(header):
|
| 13 |
+
return int(header.split('\n')[0].split(' ')[3])
|
| 14 |
+
|
| 15 |
+
# Adapted from original scoring function code
|
| 16 |
+
# For each set of equivalent classes, replace each class with the representative class for the set.
|
| 17 |
+
def replace_equivalent_classes(classes, equivalent_classes):
|
| 18 |
+
for j, x in enumerate(classes):
|
| 19 |
+
for multiple_classes in equivalent_classes:
|
| 20 |
+
if x in multiple_classes:
|
| 21 |
+
classes[j] = multiple_classes[0] # Use the first class as the representative class.
|
| 22 |
+
return classes
|
| 23 |
+
|
| 24 |
+
'''
|
| 25 |
+
the 'output' will be using the zero-padding if the leads are less than 12.(8 leads)
|
| 26 |
+
the 'output_leads' will be [1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
|
| 27 |
+
|
| 28 |
+
the output will not be changed if the lead is 12
|
| 29 |
+
the 'output_leads' will be [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
|
| 30 |
+
'''
|
| 31 |
+
def mixup(data, mix_data):
|
| 32 |
+
mix_lambda = np.random.beta(10, 10)# 这里利用beta分布得到一个0到1的值,这里如果可以,尝试别的分布函数
|
| 33 |
+
if data.shape[1] > mix_data.shape[1]:
|
| 34 |
+
data = mix_lambda * data[:, :mix_data.shape[1]] + (1 - mix_lambda) * mix_data
|
| 35 |
+
else:
|
| 36 |
+
data = mix_lambda * data + (1 - mix_lambda) * mix_data[:, :data.shape[1]]
|
| 37 |
+
return data, mix_lambda
|
| 38 |
+
|
| 39 |
+
def cutmix(data, mix_data):
|
| 40 |
+
cutmix_lambda = np.random.beta(10, 10)
|
| 41 |
+
|
| 42 |
+
if data.shape[1] < mix_data.shape[1]:
|
| 43 |
+
seq_length = data.shape[1]
|
| 44 |
+
win_len = int(np.ceil(seq_length * cutmix_lambda))
|
| 45 |
+
start = np.random.randint(0, seq_length - win_len + 1)
|
| 46 |
+
end = start + win_len
|
| 47 |
+
data[:,start:end] = mix_data[:,start:end]
|
| 48 |
+
else:
|
| 49 |
+
seq_length = mix_data.shape[1]
|
| 50 |
+
win_len = int(np.ceil(seq_length * cutmix_lambda))
|
| 51 |
+
start = np.random.randint(0, seq_length - win_len + 1)
|
| 52 |
+
end = start + win_len
|
| 53 |
+
data[:,start:end] = mix_data[:,start:end]
|
| 54 |
+
return data, cutmix_lambda
|
| 55 |
+
|
| 56 |
+
def cutmix_fix_length(data, mix_data):
|
| 57 |
+
# cutmix_lambda = np.random.beta(0.2, 0.2)
|
| 58 |
+
cutmix_lambda = 0.2
|
| 59 |
+
|
| 60 |
+
if data.shape[1] < mix_data.shape[1]:
|
| 61 |
+
seq_length = data.shape[1]
|
| 62 |
+
win_len = int(np.ceil(seq_length * cutmix_lambda))
|
| 63 |
+
start = np.random.randint(0, seq_length - win_len + 1)
|
| 64 |
+
end = start + win_len
|
| 65 |
+
data[:,start:end] = mix_data[:,start:end]
|
| 66 |
+
else:
|
| 67 |
+
seq_length = mix_data.shape[1]
|
| 68 |
+
win_len = int(np.ceil(seq_length * cutmix_lambda))
|
| 69 |
+
start = np.random.randint(0, seq_length - win_len + 1)
|
| 70 |
+
end = start + win_len
|
| 71 |
+
data[:,start:end] = mix_data[:,start:end]
|
| 72 |
+
return data, cutmix_lambda
|
| 73 |
+
|
| 74 |
+
def same_shape_mixup(data, mix_data):
|
| 75 |
+
if data.shape[1] == mix_data.shape[1]:
|
| 76 |
+
mix_lambda = np.random.beta(10, 10)# 这里利用beta分布得到一个0到1的值,这里如果可以,尝试别的分布函数
|
| 77 |
+
data = mix_lambda * data + (1 - mix_lambda) * mix_data
|
| 78 |
+
return data
|
| 79 |
+
|
| 80 |
+
def expand_leads(recording, input_leads):
|
| 81 |
+
output = np.zeros((12, recording.shape[1]))
|
| 82 |
+
# recording.shape[1]: 5000
|
| 83 |
+
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
|
| 84 |
+
twelve_leads = [k.lower() for k in twelve_leads]
|
| 85 |
+
# ['i', 'ii', 'iii', 'avr', 'avl', 'avf', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6']
|
| 86 |
+
|
| 87 |
+
input_leads = [k.lower() for k in input_leads]
|
| 88 |
+
# Here we can assume:
|
| 89 |
+
# input_leads:I, II, V1, V2, V3, V4, V5, V6,
|
| 90 |
+
# so the new input_leads: ['i', 'ii', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6']
|
| 91 |
+
|
| 92 |
+
output_leads = np.zeros((12,))
|
| 93 |
+
# [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
|
| 94 |
+
|
| 95 |
+
# idx: [0, 1, 6, 7, 8, 9, 10, 11]
|
| 96 |
+
for i,k in enumerate(input_leads):
|
| 97 |
+
idx = twelve_leads.index(k)
|
| 98 |
+
output[idx,:] = recording[i,:]
|
| 99 |
+
output_leads[idx] = 1
|
| 100 |
+
|
| 101 |
+
return output, output_leads
|
| 102 |
+
|
| 103 |
+
'''
|
| 104 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
|
| 105 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
|
| 106 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
|
| 107 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
|
| 108 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
|
| 109 |
+
This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
|
| 110 |
+
|
| 111 |
+
'''
|
| 112 |
+
|
| 113 |
+
class lead_exctractor:
|
| 114 |
+
"""
|
| 115 |
+
used to select specific leads or random choice of configurations
|
| 116 |
+
|
| 117 |
+
Twelve leads: I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6
|
| 118 |
+
Eight leads: I, II, V1, V2, V3, V4, V5, V6
|
| 119 |
+
Six leads: I, II, III, aVR, aVL, aVF
|
| 120 |
+
Four leads: I, II, III, V2
|
| 121 |
+
Three leads: I, II, V2
|
| 122 |
+
Two leads: I, II
|
| 123 |
+
|
| 124 |
+
"""
|
| 125 |
+
L2 = np.array([1,1,0,0,0,0,0,0,0,0,0,0])
|
| 126 |
+
L3 = np.array([1,1,0,0,0,0,0,1,0,0,0,0])
|
| 127 |
+
L4 = np.array([1,1,1,0,0,0,0,1,0,0,0,0])
|
| 128 |
+
L6 = np.array([1,1,1,1,1,1,0,0,0,0,0,0])
|
| 129 |
+
L8 = np.array([1,1,0,0,0,0,1,1,1,1,1,1])
|
| 130 |
+
L12 = np.array([1,1,1,1,1,1,1,1,1,1,1,1])
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def get (x, num_leads, lead_indicator):
|
| 134 |
+
if num_leads==None:
|
| 135 |
+
# random choice output
|
| 136 |
+
num_leads = random.choice([12,8,6,4,3,2])
|
| 137 |
+
|
| 138 |
+
if num_leads==12:
|
| 139 |
+
# Twelve leads: I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6
|
| 140 |
+
return x, lead_indicator * lead_exctractor.L12
|
| 141 |
+
|
| 142 |
+
if num_leads==8:
|
| 143 |
+
# Six leads: I, II, V1, V2, V3, V4, V5, V6
|
| 144 |
+
x = x * lead_exctractor.L8.reshape(12,1)
|
| 145 |
+
return x,lead_indicator * lead_exctractor.L8
|
| 146 |
+
|
| 147 |
+
if num_leads==6:
|
| 148 |
+
# Six leads: I, II, III, aVL, aVR, aVF
|
| 149 |
+
x = x * lead_exctractor.L6.reshape(12,1)
|
| 150 |
+
return x,lead_indicator * lead_exctractor.L6
|
| 151 |
+
|
| 152 |
+
if num_leads==4:
|
| 153 |
+
# Six leads: I, II, III, V2
|
| 154 |
+
x = x * lead_exctractor.L4.reshape(12,1)
|
| 155 |
+
return x,lead_indicator * lead_exctractor.L4
|
| 156 |
+
|
| 157 |
+
if num_leads==3:
|
| 158 |
+
# Three leads: I, II, V2
|
| 159 |
+
x = x * lead_exctractor.L3.reshape(12,1)
|
| 160 |
+
return x,lead_indicator * lead_exctractor.L3
|
| 161 |
+
|
| 162 |
+
if num_leads==2:
|
| 163 |
+
# Two leads: II, V5
|
| 164 |
+
x = x * lead_exctractor.L2.reshape(12,1)
|
| 165 |
+
return x,lead_indicator * lead_exctractor.L2
|
| 166 |
+
raise Exception("invalid-leads-number")
|
| 167 |
+
|
| 168 |
+
class dataset:
|
| 169 |
+
classes = ['164889003','164890007','6374002','426627000','733534002',
|
| 170 |
+
'713427006','270492004','713426002','39732003','445118002',
|
| 171 |
+
'164947007','251146004','111975006','698252002','426783006',
|
| 172 |
+
'284470004','10370003','365413008','427172004','164917005',
|
| 173 |
+
'47665007','427393009','426177001','427084000','164934002',
|
| 174 |
+
'59931005']
|
| 175 |
+
normal_class = '426783006'
|
| 176 |
+
equivalent_classes = [['713427006', '59118001'],
|
| 177 |
+
['284470004', '63593006'],
|
| 178 |
+
['427172004', '17338001'],
|
| 179 |
+
['733534002', '164909002']]
|
| 180 |
+
def __init__(self, header_files, Mixup = 0, amount = 0, cutMix = 0, Mixup_no_label_interpolate =0, progressive_switch = False):
|
| 181 |
+
self.files = []
|
| 182 |
+
self.sample = True
|
| 183 |
+
self.num_leads = None
|
| 184 |
+
|
| 185 |
+
for h in tqdm(header_files):
|
| 186 |
+
tmp = dict()
|
| 187 |
+
tmp['header'] = h
|
| 188 |
+
tmp['record'] = h.replace('.hea','.mat')
|
| 189 |
+
hdr = load_header(h)
|
| 190 |
+
tmp['nsamp'] = get_nsamp(hdr)
|
| 191 |
+
tmp['leads'] = get_leads(hdr)
|
| 192 |
+
tmp['age'] = get_age(hdr)
|
| 193 |
+
tmp['sex'] = get_sex(hdr)
|
| 194 |
+
tmp['dx'] = get_labels(hdr)
|
| 195 |
+
tmp['fs'] = get_frequency(hdr)
|
| 196 |
+
tmp['target'] = np.zeros((26,))
|
| 197 |
+
tmp['dx'] = replace_equivalent_classes(tmp['dx'], dataset.equivalent_classes)
|
| 198 |
+
|
| 199 |
+
for dx in tmp['dx']:
|
| 200 |
+
# in SNOMED code is in scored classes
|
| 201 |
+
if dx in dataset.classes:
|
| 202 |
+
idx = dataset.classes.index(dx)
|
| 203 |
+
tmp['target'][idx] = 1
|
| 204 |
+
|
| 205 |
+
self.files.append(tmp)
|
| 206 |
+
|
| 207 |
+
# print("This is the target:", tmp['target'])
|
| 208 |
+
# set filter parameters
|
| 209 |
+
|
| 210 |
+
# Filtering: Data are filtered using a zero-phase method with 3rd order Butterworth bandpass filter
|
| 211 |
+
# with frequency band from 1 Hz to 47 Hz.
|
| 212 |
+
self.b, self.a = signal.butter(3, [1 / 250, 47 / 250], 'bandpass')
|
| 213 |
+
|
| 214 |
+
self.files = pd.DataFrame(self.files)
|
| 215 |
+
|
| 216 |
+
self.Mixup = Mixup
|
| 217 |
+
self.cutMix = cutMix
|
| 218 |
+
self.Mixup_no_label_interplate = Mixup_no_label_interpolate
|
| 219 |
+
self.amount = amount
|
| 220 |
+
self.progressive = progressive_switch
|
| 221 |
+
|
| 222 |
+
self.current_epoch = 0 # Initialize the current epoch
|
| 223 |
+
|
| 224 |
+
def set_epoch(self, epoch):
|
| 225 |
+
self.current_epoch = epoch
|
| 226 |
+
|
| 227 |
+
def train_valid_split(self, test_size):
|
| 228 |
+
'''
|
| 229 |
+
test_size: 0.2
|
| 230 |
+
below is the value of each variable:
|
| 231 |
+
files[0] <- "/media/jiang/ECG/physionet.org/files/challenge-2021/1.0.3/training/python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00001.hea" shape -> (999, 1)
|
| 232 |
+
targets[1] <-[1. 0. 0. ... 0. 1. 0.] 26, shape->(999, 26)
|
| 233 |
+
x_train[0] <-"/media/jiang/ECG/physionet.org/files/challenge-2021/1.0.3/training/python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00001.hea" shape -> (799,1)
|
| 234 |
+
x_valid[0] <-"/media/jiang/ECG/physionet.org/files/challenge-2021/1.0.3/training/python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00651.hea" shape -> (200,1)
|
| 235 |
+
'''
|
| 236 |
+
files = self.files['header'].to_numpy().reshape(-1,1)
|
| 237 |
+
# print(files[0])
|
| 238 |
+
targets = np.stack(self.files['target'].to_list(), axis=0)
|
| 239 |
+
print("This is the targets:", targets)
|
| 240 |
+
|
| 241 |
+
x_train, y_train, x_valid, y_valid = iterative_train_test_split(files, targets, test_size=test_size)
|
| 242 |
+
|
| 243 |
+
train = dataset(header_files=x_train[:,0].tolist())
|
| 244 |
+
|
| 245 |
+
train.num_leads=None
|
| 246 |
+
train.sample=True
|
| 247 |
+
|
| 248 |
+
valid = dataset(header_files=x_valid[:,0].tolist())
|
| 249 |
+
|
| 250 |
+
valid.num_leads=12
|
| 251 |
+
valid.sample=False
|
| 252 |
+
|
| 253 |
+
return train, valid
|
| 254 |
+
|
| 255 |
+
def summary(self, output):
|
| 256 |
+
if output=='pandas':
|
| 257 |
+
# print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
| 258 |
+
# print("This is the target:",self.files['target'],type(self.files['target']))
|
| 259 |
+
return pd.Series(np.stack(self.files['target'].to_list(),axis=0).sum(axis=0),index=dataset.classes)
|
| 260 |
+
|
| 261 |
+
if output=='numpy':
|
| 262 |
+
return np.stack(self.files['target'].to_list(),axis=0).sum(axis=0)
|
| 263 |
+
|
| 264 |
+
def __len__(self):
|
| 265 |
+
return len(self.files)
|
| 266 |
+
|
| 267 |
+
'''
|
| 268 |
+
fs: 500.0 sampling rate
|
| 269 |
+
target: [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
|
| 270 |
+
leads: ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
|
| 271 |
+
self.files.iloc[item]['record']: ../../python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00518.mat
|
| 272 |
+
data: shape->(12 X 5000)
|
| 273 |
+
'''
|
| 274 |
+
|
| 275 |
+
def __getitem__(self, item):
|
| 276 |
+
|
| 277 |
+
fs = self.files.iloc[item]['fs']
|
| 278 |
+
target = self.files.iloc[item]['target']
|
| 279 |
+
leads = self.files.iloc[item]['leads']
|
| 280 |
+
data = load_recording(self.files.iloc[item]['record'])
|
| 281 |
+
|
| 282 |
+
# print("This is the target:", target)
|
| 283 |
+
# Set your threshold
|
| 284 |
+
|
| 285 |
+
# threshold = 0.5
|
| 286 |
+
# print("This is the original target:", target)
|
| 287 |
+
# print("This is the type of original target:", type(target))
|
| 288 |
+
|
| 289 |
+
if random.random() < self.Mixup :
|
| 290 |
+
mix_sample_idx = random.randint(0, self.amount-1)
|
| 291 |
+
mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
|
| 292 |
+
mix_target = self.files.iloc[mix_sample_idx]['target']
|
| 293 |
+
data, alpha_value = mixup(data, mix_datum)
|
| 294 |
+
target = alpha_value * target + (1-alpha_value) * mix_target
|
| 295 |
+
# Binarize the labels based on the threshold
|
| 296 |
+
# target[target >= threshold] = 1
|
| 297 |
+
# target[target < threshold] = 0
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
if random.random() < self.cutMix:
|
| 301 |
+
mix_sample_idx = random.randint(0, self.amount-1)
|
| 302 |
+
mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
|
| 303 |
+
mix_target = self.files.iloc[mix_sample_idx]['target']
|
| 304 |
+
data, alpha_value = cutmix_fix_length(data, mix_datum)
|
| 305 |
+
target = alpha_value * target + (1-alpha_value) * mix_target
|
| 306 |
+
# target[target >= threshold] = 1
|
| 307 |
+
# target[target < threshold] = 0
|
| 308 |
+
|
| 309 |
+
if random.random()< self.Mixup_no_label_interplate:
|
| 310 |
+
mix_sample_idx = random.randint(0, self.amount-1)
|
| 311 |
+
mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
|
| 312 |
+
data= same_shape_mixup(data, mix_datum)
|
| 313 |
+
|
| 314 |
+
progressive_index = 0
|
| 315 |
+
|
| 316 |
+
# # below is the progessive data augmentation method
|
| 317 |
+
# if self.progressive:
|
| 318 |
+
# if self.current_epoch < 20:
|
| 319 |
+
# if self.current_epoch % 4 == 0 or 3:
|
| 320 |
+
# pass
|
| 321 |
+
# elif self.current_epoch % 4 == 1:
|
| 322 |
+
# proressive_index = 0.2
|
| 323 |
+
# else:
|
| 324 |
+
# proressive_index = 0.5
|
| 325 |
+
|
| 326 |
+
# elif self.current_epoch < 40:
|
| 327 |
+
# if self.current_epoch % 4 == 0 or 3:
|
| 328 |
+
# pass
|
| 329 |
+
# elif self.current_epoch % 4 == 1:
|
| 330 |
+
# proressive_index = 0.7
|
| 331 |
+
# else:
|
| 332 |
+
# proressive_index = 0.8
|
| 333 |
+
# else:
|
| 334 |
+
# if self.current_epoch % 4 == 0 or 3:
|
| 335 |
+
# pass
|
| 336 |
+
# elif self.current_epoch % 4 == 1 or 2:
|
| 337 |
+
# proressive_index = 1
|
| 338 |
+
|
| 339 |
+
# below is the wave-mix data augmentation method
|
| 340 |
+
if self.progressive:
|
| 341 |
+
if self.current_epoch < 5:
|
| 342 |
+
match self.current_epoch:
|
| 343 |
+
case 0:
|
| 344 |
+
progressive_index=0.2
|
| 345 |
+
case 1:
|
| 346 |
+
progressive_index=0.4
|
| 347 |
+
case 2:
|
| 348 |
+
progressive_index=0.6
|
| 349 |
+
case 3:
|
| 350 |
+
progressive_index=0.8
|
| 351 |
+
case 4:
|
| 352 |
+
progressive_index=0.8
|
| 353 |
+
|
| 354 |
+
else :
|
| 355 |
+
progressive_index = 0.8
|
| 356 |
+
|
| 357 |
+
if random.random() < progressive_index :
|
| 358 |
+
mix_sample_idx = random.randint(0, self.amount-1)
|
| 359 |
+
mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
|
| 360 |
+
mix_target = self.files.iloc[mix_sample_idx]['target']
|
| 361 |
+
data, alpha_value = mixup(data, mix_datum)
|
| 362 |
+
target = alpha_value * target + (1-alpha_value) * mix_target
|
| 363 |
+
# Binarize the labels based on the threshold
|
| 364 |
+
# target[target >= threshold] = 1
|
| 365 |
+
# target[target < threshold] = 0
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
# expand to 12 lead setup if original signal has less channels
|
| 369 |
+
data, lead_indicator = expand_leads(data, input_leads=leads)
|
| 370 |
+
data = np.nan_to_num(data)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
# resample to 500hz
|
| 374 |
+
if fs == float(1000):
|
| 375 |
+
data = signal.resample_poly(data, up=1, down=2, axis=-1) # to 500Hz
|
| 376 |
+
fs = 500
|
| 377 |
+
elif fs == float(500):
|
| 378 |
+
pass
|
| 379 |
+
else:
|
| 380 |
+
data = signal.resample(data, int(data.shape[1] * 500 / fs), axis=1)
|
| 381 |
+
fs = 500
|
| 382 |
+
|
| 383 |
+
# below is the Butterworth digital and analog filter design.
|
| 384 |
+
data = signal.filtfilt(self.b, self.a, data)
|
| 385 |
+
|
| 386 |
+
'''
|
| 387 |
+
we filter out the sample which the length is 8192 via the below code.
|
| 388 |
+
just like the random shift windows.
|
| 389 |
+
'''
|
| 390 |
+
if self.sample:
|
| 391 |
+
fs = int(fs)
|
| 392 |
+
# random sample signal if len > 8192 samples
|
| 393 |
+
if data.shape[-1] >= 8192:
|
| 394 |
+
idx = data.shape[-1] - 8192-1
|
| 395 |
+
idx = np.random.randint(idx)
|
| 396 |
+
data = data[:, idx:idx + 8192]
|
| 397 |
+
|
| 398 |
+
'''
|
| 399 |
+
Z-Score Normalization
|
| 400 |
+
'''
|
| 401 |
+
mu = np.nanmean(data, axis=-1, keepdims=True)
|
| 402 |
+
std = np.nanstd(data, axis=-1, keepdims=True)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
#std = np.nanstd(data.flatten())
|
| 406 |
+
with warnings.catch_warnings():
|
| 407 |
+
warnings.simplefilter("ignore")
|
| 408 |
+
data = (data - mu) / std
|
| 409 |
+
|
| 410 |
+
data = np.nan_to_num(data)
|
| 411 |
+
|
| 412 |
+
# random choose number of leads to keep
|
| 413 |
+
data, lead_indicator = lead_exctractor.get(data, self.num_leads, lead_indicator)
|
| 414 |
+
|
| 415 |
+
# print("This is the lead_indicator:(l)", type(lead_indicator), lead_indicator)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
return data, target
|
engine_ecg_2020.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
"""
|
| 4 |
+
Train and eval functions used in main.py
|
| 5 |
+
"""
|
| 6 |
+
import math
|
| 7 |
+
import sys
|
| 8 |
+
from typing import Iterable, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
import timm
|
| 15 |
+
from timm.data import Mixup
|
| 16 |
+
from timm.utils import accuracy, ModelEma
|
| 17 |
+
|
| 18 |
+
from losses import DistillationLoss
|
| 19 |
+
import utils
|
| 20 |
+
|
| 21 |
+
from sklearn.metrics import average_precision_score, roc_auc_score, f1_score
|
| 22 |
+
from sklearn.metrics import hamming_loss
|
| 23 |
+
from sklearn.metrics import accuracy_score
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
from evaluate_12ECG_score import load_weights, compute_challenge_metric
|
| 28 |
+
|
| 29 |
+
classes = sorted(['270492004', '164889003', '164890007', '426627000', '713427006',
|
| 30 |
+
'713426002', '445118002', '39732003', '164909002', '251146004',
|
| 31 |
+
'698252002', '10370003', '284470004', '427172004', '164947007',
|
| 32 |
+
'111975006', '164917005', '47665007', '59118001', '427393009',
|
| 33 |
+
'426177001', '426783006', '427084000', '63593006', '164934002',
|
| 34 |
+
'59931005', '17338001'])
|
| 35 |
+
|
| 36 |
+
def normalize_model_outputs(model_outputs):
|
| 37 |
+
"""
|
| 38 |
+
Normalize model outputs to the range [0, 1].
|
| 39 |
+
|
| 40 |
+
Parameters:
|
| 41 |
+
model_outputs (numpy.ndarray): The raw output of the deep learning model.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
numpy.ndarray: The normalized model outputs.
|
| 45 |
+
"""
|
| 46 |
+
a = model_outputs.min()
|
| 47 |
+
b = model_outputs.max()
|
| 48 |
+
return (model_outputs - a) / (b - a)
|
| 49 |
+
|
| 50 |
+
weights_file = "./weights_2020.csv"
|
| 51 |
+
normal_class = '426783006'
|
| 52 |
+
weights = load_weights(weights_file, classes)
|
| 53 |
+
|
| 54 |
+
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
|
| 55 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
| 56 |
+
device: torch.device, epoch: int, max_norm: float = 0,
|
| 57 |
+
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
|
| 58 |
+
set_training_mode=True, args = None):
|
| 59 |
+
|
| 60 |
+
model.train(set_training_mode)
|
| 61 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 62 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 63 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 64 |
+
print_freq = 600
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
output_list = []
|
| 69 |
+
target_list = []
|
| 70 |
+
|
| 71 |
+
train_auprc_list = []
|
| 72 |
+
|
| 73 |
+
loss_value_after_each_epcoh = 0
|
| 74 |
+
batch_num = 0
|
| 75 |
+
|
| 76 |
+
# debug
|
| 77 |
+
# count = 0
|
| 78 |
+
# for samples, targets, mix_target, alpha_value in metric_logger.log_every(data_loader, print_freq, header):
|
| 79 |
+
|
| 80 |
+
# third_party = 0
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
|
| 84 |
+
# count += 1
|
| 85 |
+
# if count > 20:
|
| 86 |
+
# break
|
| 87 |
+
batch_num += 1
|
| 88 |
+
# print("code goes here......................*******************************************************")
|
| 89 |
+
|
| 90 |
+
samples = samples.to(device, non_blocking=True)
|
| 91 |
+
targets = targets.to(device, non_blocking=True)
|
| 92 |
+
# mix_target = mix_target.to(device, non_blocking=True)
|
| 93 |
+
# alpha_value = alpha_value.to(device, non_blocking=True)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
outputs = model(samples.float(), if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank)
|
| 97 |
+
|
| 98 |
+
# below is original loss
|
| 99 |
+
loss = criterion(outputs, targets.float())
|
| 100 |
+
# # below is the asymmetric loss training
|
| 101 |
+
# outputs = torch.sigmoid(outputs)
|
| 102 |
+
# loss = -torch.mean(targets * F.logsigmoid(outputs) + (1 - targets) * F.logsigmoid(-outputs) * 0.1)
|
| 103 |
+
|
| 104 |
+
loss_value = loss.item()
|
| 105 |
+
|
| 106 |
+
if args.lrschedule == "Noam":
|
| 107 |
+
optimizer.optimizer.zero_grad()
|
| 108 |
+
elif args.lrschedule == "CosineAnnealing":
|
| 109 |
+
optimizer.zero_grad()
|
| 110 |
+
else:
|
| 111 |
+
raise ValueError(f"No matching condition for value in the file of engine_ecg_2021.py : {args.lrschedule}")
|
| 112 |
+
|
| 113 |
+
loss.backward() # Backward pass: Compute gradient of the loss with respect to model parameters
|
| 114 |
+
|
| 115 |
+
# Update parameters/using the Noam
|
| 116 |
+
optimizer.step()
|
| 117 |
+
|
| 118 |
+
torch.cuda.synchronize()
|
| 119 |
+
|
| 120 |
+
metric_logger.update(loss=loss_value)
|
| 121 |
+
|
| 122 |
+
if args.lrschedule == "Noam":
|
| 123 |
+
metric_logger.update(lr=optimizer.optimizer.param_groups[0]["lr"])
|
| 124 |
+
elif args.lrschedule == "CosineAnnealing":
|
| 125 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
| 126 |
+
else:
|
| 127 |
+
raise ValueError(f"No matching condition for value in the file of engine_ecg_2021.py for updating : {args.lrschedule}")
|
| 128 |
+
|
| 129 |
+
target_list.append(targets.data.cpu().numpy())
|
| 130 |
+
output_list.append(outputs.data.cpu().numpy())
|
| 131 |
+
loss_value_after_each_epcoh += loss_value
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# gather the stats from all processes
|
| 135 |
+
metric_logger.synchronize_between_processes()
|
| 136 |
+
print("Averaged stats:", metric_logger)
|
| 137 |
+
|
| 138 |
+
# below code is for record the AUPRC of training
|
| 139 |
+
targets_all = np.concatenate(target_list, axis=0)
|
| 140 |
+
outputs_all = np.concatenate(output_list, axis=0)
|
| 141 |
+
outputs_all = normalize_model_outputs(outputs_all)
|
| 142 |
+
|
| 143 |
+
threshold = 0.5
|
| 144 |
+
targets_all[targets_all >= threshold] = 1
|
| 145 |
+
targets_all[targets_all < threshold] = 0
|
| 146 |
+
|
| 147 |
+
# When using the function below, y_true must be a binarized value.
|
| 148 |
+
train_auprc = average_precision_score(y_true = targets_all, y_score = outputs_all)
|
| 149 |
+
print("This is the training AUPRC:", train_auprc)
|
| 150 |
+
|
| 151 |
+
### The code below is for obtaining the thresholds
|
| 152 |
+
scores_challengeScore = []
|
| 153 |
+
scores_F1 = []
|
| 154 |
+
scores_SubsetAccuracy = []
|
| 155 |
+
scores_HammingLoss = []
|
| 156 |
+
|
| 157 |
+
for thr in np.arange(0., 1., 0.02):
|
| 158 |
+
outputs_dyn = np.array([[(1 if prob > thr else 0) for prob in probs] for probs in np.array(outputs_all)])
|
| 159 |
+
|
| 160 |
+
challenge_value = compute_challenge_metric(weights, targets_all, outputs_dyn, classes, normal_class)
|
| 161 |
+
scores_challengeScore.append(challenge_value)
|
| 162 |
+
|
| 163 |
+
f1 = f1_score(targets_all, outputs_dyn, average='weighted')
|
| 164 |
+
scores_F1.append(f1)
|
| 165 |
+
|
| 166 |
+
subset_accuracy = accuracy_score(targets_all, outputs_dyn)
|
| 167 |
+
scores_SubsetAccuracy.append(subset_accuracy)
|
| 168 |
+
|
| 169 |
+
hamming = hamming_loss(targets_all, outputs_dyn)
|
| 170 |
+
scores_HammingLoss.append(hamming)
|
| 171 |
+
|
| 172 |
+
scores_challengeScore = np.array(scores_challengeScore)
|
| 173 |
+
scores_F1 = np.array(scores_F1)
|
| 174 |
+
scores_SubsetAccuracy = np.array(scores_SubsetAccuracy)
|
| 175 |
+
scores_HammingLoss = np.array(scores_HammingLoss)
|
| 176 |
+
|
| 177 |
+
# print("This is the challenge score list from training set:\n", scores)
|
| 178 |
+
|
| 179 |
+
# Best thrs for challenge score
|
| 180 |
+
thrs_CHALL = np.array([np.argmax(scores_challengeScore, axis=0)*0.02])
|
| 181 |
+
print("This is the best threshold for the challenge score from training set", thrs_CHALL)
|
| 182 |
+
outputs_best_CHALL = np.array([[(1 if prob > thrs_CHALL else 0) for prob in probs] for probs in np.array(outputs_all)])
|
| 183 |
+
challenge_value = compute_challenge_metric(weights, targets_all, outputs_best_CHALL, classes, normal_class)
|
| 184 |
+
print("This is the challenge score from training set:", challenge_value)
|
| 185 |
+
|
| 186 |
+
# Best thrs for F1
|
| 187 |
+
scores_F1 = np.array([np.argmax(scores_F1, axis=0)*0.02])
|
| 188 |
+
print("This is the best threshold for the F1 from training set", scores_F1)
|
| 189 |
+
outputs_best_f1 = np.array([[(1 if prob > scores_F1 else 0) for prob in probs] for probs in np.array(outputs_all)])
|
| 190 |
+
f1 = f1_score(targets_all, outputs_best_f1, average='weighted')
|
| 191 |
+
print("This is the f1 score from training set:", challenge_value)
|
| 192 |
+
|
| 193 |
+
# Best thrs for subset accuracy
|
| 194 |
+
scores_SubsetAccuracy = np.array([np.argmax(scores_SubsetAccuracy, axis=0)*0.02])
|
| 195 |
+
print("This is the best threshold for the Subset Accuracy from training set", scores_SubsetAccuracy)
|
| 196 |
+
outputs_best_SubsetAccuracy = np.array([[(1 if prob > scores_SubsetAccuracy else 0) for prob in probs] for probs in np.array(outputs_all)])
|
| 197 |
+
subset_accuracy = accuracy_score(targets_all, outputs_best_SubsetAccuracy)
|
| 198 |
+
print("This is the subset accuracy from training set:", subset_accuracy)
|
| 199 |
+
|
| 200 |
+
# Best thrs for hamming loss, here is the loss value from output, we need to find the minimum value.
|
| 201 |
+
# Determine the optimal threshold for Hamming loss. The loss values are obtained from the output, and the goal is to find the minimum value.
|
| 202 |
+
scores_HammingLoss = np.array([np.argmin(scores_HammingLoss, axis=0)*0.02])
|
| 203 |
+
print("This is the best threshold for the Subset Accuracy from training set", scores_HammingLoss)
|
| 204 |
+
outputs_best_HammingLoss = np.array([[(1 if prob > scores_HammingLoss else 0) for prob in probs] for probs in np.array(outputs_all)])
|
| 205 |
+
hamming = hamming_loss(targets_all, outputs_best_HammingLoss)
|
| 206 |
+
print("This is the Hamming from training set:", hamming)
|
| 207 |
+
|
| 208 |
+
loss_value_after_each_epcoh /= batch_num
|
| 209 |
+
|
| 210 |
+
return train_auprc, loss_value_after_each_epcoh, thrs_CHALL, scores_F1, scores_SubsetAccuracy, scores_HammingLoss
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@torch.no_grad()
|
| 215 |
+
def evaluate(data_loader, model, device):
|
| 216 |
+
# criterion = torch.nn.CrossEntropyLoss()
|
| 217 |
+
criterion = torch.nn.BCEWithLogitsLoss()
|
| 218 |
+
|
| 219 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 220 |
+
header = 'Test:'
|
| 221 |
+
|
| 222 |
+
targets = []
|
| 223 |
+
outputs = []
|
| 224 |
+
loss_value_after_each_epcoh = 0
|
| 225 |
+
batch_num = 0
|
| 226 |
+
# switch to evaluation mode
|
| 227 |
+
model.eval()
|
| 228 |
+
|
| 229 |
+
for images, target in metric_logger.log_every(data_loader, 100, header):
|
| 230 |
+
|
| 231 |
+
batch_num += 1
|
| 232 |
+
images = images.to(device, non_blocking=True)
|
| 233 |
+
target = target.to(device, non_blocking=True)
|
| 234 |
+
# print("NaN in images(input):", torch.isnan(images).any())
|
| 235 |
+
|
| 236 |
+
# compute output
|
| 237 |
+
|
| 238 |
+
output = model(images.float())
|
| 239 |
+
# below is original loss
|
| 240 |
+
loss = criterion(output, target.float())
|
| 241 |
+
# below is the logit to proba
|
| 242 |
+
|
| 243 |
+
# below is the asymmetric loss testing
|
| 244 |
+
# output = torch.sigmoid(output)
|
| 245 |
+
# loss = -torch.mean(target * F.logsigmoid(output) + (1 - target) * F.logsigmoid(-output) * 0.1)
|
| 246 |
+
|
| 247 |
+
# print("This is the output:", output.shape,output)
|
| 248 |
+
# print("This is the target.float():",target.float().shape ,target.float())
|
| 249 |
+
# print("NaN in output:", torch.isnan(output).any())
|
| 250 |
+
# print("NaN in target:", torch.isnan(target).any())
|
| 251 |
+
# print("This is the loss.item():", loss.item())
|
| 252 |
+
# sys.exit()
|
| 253 |
+
|
| 254 |
+
metric_logger.update(loss=loss.item())
|
| 255 |
+
|
| 256 |
+
targets.append(target.data.cpu().numpy())
|
| 257 |
+
outputs.append(output.data.cpu().numpy())
|
| 258 |
+
|
| 259 |
+
loss_value_after_each_epcoh += loss.item()
|
| 260 |
+
|
| 261 |
+
# below code is for record the AUPRC of validation
|
| 262 |
+
targets = np.concatenate(targets, axis=0)
|
| 263 |
+
outputs = np.concatenate(outputs, axis=0)
|
| 264 |
+
|
| 265 |
+
outputs = normalize_model_outputs(outputs)
|
| 266 |
+
|
| 267 |
+
auprc = average_precision_score(y_true=targets, y_score=outputs)
|
| 268 |
+
auroc = roc_auc_score(targets, outputs)
|
| 269 |
+
|
| 270 |
+
# print("This is the top 3 row:", outputs[0:3])
|
| 271 |
+
# outputs = np.array([[(1 if prob > 0 else 0) for prob in probs] for probs in np.array(outputs)])
|
| 272 |
+
# print("This is the top 3 row, after:", outputs[0:3])
|
| 273 |
+
# for i in range(len(outputs)):
|
| 274 |
+
# num_ones = targets[i].count(1)
|
| 275 |
+
# third_largest = sorted(outputs[i], reverse=True)[num_ones-1]
|
| 276 |
+
|
| 277 |
+
# print(f"Length of targets: {len(targets)}")
|
| 278 |
+
# print(f"Length of outputs: {len(outputs)}")
|
| 279 |
+
|
| 280 |
+
outputs_0 = np.array([[(1 if prob > 0.5 else 0) for prob in probs] for probs in np.array(outputs)])
|
| 281 |
+
|
| 282 |
+
f1 = f1_score(targets, outputs_0, average='weighted')
|
| 283 |
+
hamming = hamming_loss(targets, outputs_0)
|
| 284 |
+
subset_accuracy = accuracy_score(targets, outputs_0)
|
| 285 |
+
|
| 286 |
+
scores = []
|
| 287 |
+
for thr in np.arange(0., 1., 0.02):
|
| 288 |
+
outputs_dyn = np.array([[(1 if prob > thr else 0) for prob in probs] for probs in np.array(outputs)])
|
| 289 |
+
challenge_value = compute_challenge_metric(weights, targets, outputs_dyn, classes, normal_class)
|
| 290 |
+
scores.append(challenge_value)
|
| 291 |
+
|
| 292 |
+
scores = np.array(scores)
|
| 293 |
+
print("This is the challenge score list from testing set:\n", scores)
|
| 294 |
+
|
| 295 |
+
# Best thrs and preds
|
| 296 |
+
idxs = np.argmax(scores, axis=0)
|
| 297 |
+
thrs = np.array([idxs*0.02])
|
| 298 |
+
|
| 299 |
+
print("This is the best threshold for the challenge score from testing test", thrs)
|
| 300 |
+
outputs_best = np.array([[(1 if prob > thrs else 0) for prob in probs] for probs in np.array(outputs)])
|
| 301 |
+
|
| 302 |
+
challenge_value = compute_challenge_metric(weights, targets, outputs_best, classes, normal_class)
|
| 303 |
+
print("This is the challenge score:", challenge_value)
|
| 304 |
+
|
| 305 |
+
# gather the stats from all processes
|
| 306 |
+
metric_logger.synchronize_between_processes()
|
| 307 |
+
loss_value_after_each_epcoh /= batch_num
|
| 308 |
+
|
| 309 |
+
return auprc, auroc, f1, hamming, subset_accuracy, loss_value_after_each_epcoh
|
engine_ecg_2021.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
"""
|
| 4 |
+
Train and eval functions used in main.py
|
| 5 |
+
"""
|
| 6 |
+
import math
|
| 7 |
+
import sys
|
| 8 |
+
from typing import Iterable, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import timm
|
| 13 |
+
from timm.data import Mixup
|
| 14 |
+
from timm.utils import accuracy, ModelEma
|
| 15 |
+
|
| 16 |
+
from losses import DistillationLoss
|
| 17 |
+
import utils
|
| 18 |
+
|
| 19 |
+
from sklearn.metrics import average_precision_score, roc_auc_score, f1_score
|
| 20 |
+
from sklearn.metrics import hamming_loss
|
| 21 |
+
from sklearn.metrics import accuracy_score
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
from evaluate_model import load_weights, compute_challenge_metric
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def normalize_model_outputs(model_outputs):
|
| 29 |
+
"""
|
| 30 |
+
Normalize model outputs to the range [0, 1].
|
| 31 |
+
|
| 32 |
+
Parameters:
|
| 33 |
+
model_outputs (numpy.ndarray): The raw output of the deep learning model.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
numpy.ndarray: The normalized model outputs.
|
| 37 |
+
"""
|
| 38 |
+
a = model_outputs.min()
|
| 39 |
+
b = model_outputs.max()
|
| 40 |
+
return (model_outputs - a) / (b - a)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
sinus_rhythm_ID = set(['426783006'])
|
| 44 |
+
|
| 45 |
+
weights_file = "./weights.csv"
|
| 46 |
+
|
| 47 |
+
classes, weights = load_weights(weights_file)
|
| 48 |
+
|
| 49 |
+
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
|
| 50 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
| 51 |
+
device: torch.device, epoch: int,
|
| 52 |
+
set_training_mode=True, args = None):
|
| 53 |
+
|
| 54 |
+
model.train(set_training_mode)
|
| 55 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 56 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 57 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 58 |
+
print_freq = 600
|
| 59 |
+
|
| 60 |
+
output_list = []
|
| 61 |
+
target_list = []
|
| 62 |
+
|
| 63 |
+
loss_value_after_each_epcoh = 0
|
| 64 |
+
batch_num = 0
|
| 65 |
+
|
| 66 |
+
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
|
| 67 |
+
batch_num += 1
|
| 68 |
+
|
| 69 |
+
samples = samples.to(device, non_blocking=True)
|
| 70 |
+
targets = targets.to(device, non_blocking=True)
|
| 71 |
+
|
| 72 |
+
outputs = model(samples.float(), if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank)
|
| 73 |
+
loss = criterion(outputs, targets.float())
|
| 74 |
+
|
| 75 |
+
loss_value = loss.item()
|
| 76 |
+
|
| 77 |
+
if args.lrschedule == "Noam":
|
| 78 |
+
optimizer.optimizer.zero_grad()
|
| 79 |
+
elif args.lrschedule == "CosineAnnealing":
|
| 80 |
+
optimizer.zero_grad()
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError(f"No matching condition for value in the file of engine_ecg_2021.py : {args.lrschedule}")
|
| 83 |
+
|
| 84 |
+
loss.backward() # Backward pass: Compute gradient of the loss with respect to model parameters
|
| 85 |
+
|
| 86 |
+
# Update parameters/using the Noam
|
| 87 |
+
optimizer.step()
|
| 88 |
+
|
| 89 |
+
torch.cuda.synchronize()
|
| 90 |
+
|
| 91 |
+
metric_logger.update(loss=loss_value)
|
| 92 |
+
if args.lrschedule == "Noam":
|
| 93 |
+
metric_logger.update(lr=optimizer.optimizer.param_groups[0]["lr"])
|
| 94 |
+
elif args.lrschedule == "CosineAnnealing":
|
| 95 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError(f"No matching condition for value in the file of engine_ecg_2021.py for updating : {args.lrschedule}")
|
| 98 |
+
|
| 99 |
+
target_list.append(targets.data.cpu().numpy())
|
| 100 |
+
output_list.append(outputs.data.cpu().numpy())
|
| 101 |
+
loss_value_after_each_epcoh += loss_value
|
| 102 |
+
|
| 103 |
+
# gather the stats from all processes
|
| 104 |
+
metric_logger.synchronize_between_processes()
|
| 105 |
+
print("Averaged stats:", metric_logger)
|
| 106 |
+
|
| 107 |
+
# below code is for record the AUPRC of training
|
| 108 |
+
targets_all = np.concatenate(target_list, axis=0)
|
| 109 |
+
outputs_all = np.concatenate(output_list, axis=0)
|
| 110 |
+
outputs_all = normalize_model_outputs(outputs_all)
|
| 111 |
+
|
| 112 |
+
threshold = 0.5
|
| 113 |
+
targets_all[targets_all >= threshold] = 1
|
| 114 |
+
targets_all[targets_all < threshold] = 0
|
| 115 |
+
|
| 116 |
+
# When using the function below, y_true must be a binarized value.
|
| 117 |
+
train_auprc = average_precision_score(y_true = targets_all, y_score = outputs_all)
|
| 118 |
+
print("This is the training AUPRC:", train_auprc)
|
| 119 |
+
|
| 120 |
+
### The code below is for obtaining the thresholds
|
| 121 |
+
scores_challengeScore = []
|
| 122 |
+
scores_F1 = []
|
| 123 |
+
scores_SubsetAccuracy = []
|
| 124 |
+
scores_HammingLoss = []
|
| 125 |
+
|
| 126 |
+
for thr in np.arange(0., 1., 0.02):
|
| 127 |
+
outputs_dyn = np.array([[(1 if prob > thr else 0) for prob in probs] for probs in np.array(outputs_all)])
|
| 128 |
+
|
| 129 |
+
challenge_value = compute_challenge_metric(weights, targets_all, outputs_dyn, classes, sinus_rhythm_ID)
|
| 130 |
+
scores_challengeScore.append(challenge_value)
|
| 131 |
+
|
| 132 |
+
f1 = f1_score(targets_all, outputs_dyn, average='weighted')
|
| 133 |
+
scores_F1.append(f1)
|
| 134 |
+
|
| 135 |
+
subset_accuracy = accuracy_score(targets_all, outputs_dyn)
|
| 136 |
+
scores_SubsetAccuracy.append(subset_accuracy)
|
| 137 |
+
|
| 138 |
+
hamming = hamming_loss(targets_all, outputs_dyn)
|
| 139 |
+
scores_HammingLoss.append(hamming)
|
| 140 |
+
|
| 141 |
+
scores_challengeScore = np.array(scores_challengeScore)
|
| 142 |
+
scores_F1 = np.array(scores_F1)
|
| 143 |
+
scores_SubsetAccuracy = np.array(scores_SubsetAccuracy)
|
| 144 |
+
scores_HammingLoss = np.array(scores_HammingLoss)
|
| 145 |
+
|
| 146 |
+
# print("This is the challenge score list from training set:\n", scores)
|
| 147 |
+
|
| 148 |
+
# Best thrs for challenge score
|
| 149 |
+
thrs_CHALL = np.array([np.argmax(scores_challengeScore, axis=0)*0.02])
|
| 150 |
+
print("This is the best threshold for the challenge score from training set", thrs_CHALL)
|
| 151 |
+
outputs_best_CHALL = np.array([[(1 if prob > thrs_CHALL else 0) for prob in probs] for probs in np.array(outputs_all)])
|
| 152 |
+
challenge_value = compute_challenge_metric(weights, targets_all, outputs_best_CHALL, classes, sinus_rhythm_ID)
|
| 153 |
+
print("This is the challenge score from training set:", challenge_value)
|
| 154 |
+
|
| 155 |
+
# Best thrs for F1
|
| 156 |
+
scores_F1 = np.array([np.argmax(scores_F1, axis=0)*0.02])
|
| 157 |
+
print("This is the best threshold for the F1 from training set", scores_F1)
|
| 158 |
+
outputs_best_f1 = np.array([[(1 if prob > scores_F1 else 0) for prob in probs] for probs in np.array(outputs_all)])
|
| 159 |
+
f1 = f1_score(targets_all, outputs_best_f1, average='weighted')
|
| 160 |
+
print("This is the f1 score from training set:", challenge_value)
|
| 161 |
+
|
| 162 |
+
# Best thrs for subset accuracy
|
| 163 |
+
scores_SubsetAccuracy = np.array([np.argmax(scores_SubsetAccuracy, axis=0)*0.02])
|
| 164 |
+
print("This is the best threshold for the Subset Accuracy from training set", scores_SubsetAccuracy)
|
| 165 |
+
outputs_best_SubsetAccuracy = np.array([[(1 if prob > scores_SubsetAccuracy else 0) for prob in probs] for probs in np.array(outputs_all)])
|
| 166 |
+
subset_accuracy = accuracy_score(targets_all, outputs_best_SubsetAccuracy)
|
| 167 |
+
print("This is the subset accuracy from training set:", subset_accuracy)
|
| 168 |
+
|
| 169 |
+
# Best thrs for hamming loss, here is the loss value from output, we need to find the minimum value.
|
| 170 |
+
# Determine the optimal threshold for Hamming loss. The loss values are obtained from the output, and the goal is to find the minimum value.
|
| 171 |
+
scores_HammingLoss = np.array([np.argmin(scores_HammingLoss, axis=0)*0.02])
|
| 172 |
+
print("This is the best threshold for the Subset Accuracy from training set", scores_HammingLoss)
|
| 173 |
+
outputs_best_HammingLoss = np.array([[(1 if prob > scores_HammingLoss else 0) for prob in probs] for probs in np.array(outputs_all)])
|
| 174 |
+
hamming = hamming_loss(targets_all, outputs_best_HammingLoss)
|
| 175 |
+
print("This is the Hamming from training set:", hamming)
|
| 176 |
+
|
| 177 |
+
loss_value_after_each_epcoh /= batch_num
|
| 178 |
+
return train_auprc, loss_value_after_each_epcoh, thrs_CHALL, scores_F1, scores_SubsetAccuracy, scores_HammingLoss
|
| 179 |
+
|
| 180 |
+
@torch.no_grad()
|
| 181 |
+
def evaluate(data_loader, model, thrs_chall, thrs_F1, thrs_accuracy, thrs_hammingLoss, device):
|
| 182 |
+
# criterion = torch.nn.CrossEntropyLoss()
|
| 183 |
+
criterion = torch.nn.BCEWithLogitsLoss()
|
| 184 |
+
|
| 185 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 186 |
+
header = 'Test:'
|
| 187 |
+
|
| 188 |
+
targets = []
|
| 189 |
+
outputs = []
|
| 190 |
+
loss_value_after_each_epcoh = 0
|
| 191 |
+
batch_num = 0
|
| 192 |
+
# switch to evaluation mode
|
| 193 |
+
model.eval()
|
| 194 |
+
|
| 195 |
+
for images, target in metric_logger.log_every(data_loader, 100, header):
|
| 196 |
+
|
| 197 |
+
batch_num += 1
|
| 198 |
+
images = images.to(device, non_blocking=True)
|
| 199 |
+
target = target.to(device, non_blocking=True)
|
| 200 |
+
|
| 201 |
+
output = model(images.float())
|
| 202 |
+
|
| 203 |
+
loss = criterion(output, target.float())
|
| 204 |
+
|
| 205 |
+
metric_logger.update(loss=loss.item())
|
| 206 |
+
|
| 207 |
+
targets.append(target.data.cpu().numpy())
|
| 208 |
+
outputs.append(output.data.cpu().numpy())
|
| 209 |
+
|
| 210 |
+
loss_value_after_each_epcoh += loss.item()
|
| 211 |
+
|
| 212 |
+
# below code is for record the AUPRC of validation
|
| 213 |
+
targets = np.concatenate(targets, axis=0)
|
| 214 |
+
outputs = np.concatenate(outputs, axis=0)
|
| 215 |
+
|
| 216 |
+
outputs = normalize_model_outputs(outputs)
|
| 217 |
+
|
| 218 |
+
auprc = average_precision_score(y_true=targets, y_score=outputs)
|
| 219 |
+
auroc = roc_auc_score(targets, outputs)
|
| 220 |
+
|
| 221 |
+
# print("This is the top 3 row:", outputs[0:3])
|
| 222 |
+
outputs_F1 = np.array([[(1 if prob > thrs_F1 else 0) for prob in probs] for probs in np.array(outputs)])
|
| 223 |
+
f1 = f1_score(targets, outputs_F1, average='weighted')
|
| 224 |
+
|
| 225 |
+
outputs_hammingLoss = np.array([[(1 if prob > thrs_hammingLoss else 0) for prob in probs] for probs in np.array(outputs)])
|
| 226 |
+
hamming = hamming_loss(targets, outputs_hammingLoss)
|
| 227 |
+
|
| 228 |
+
outputs_accuracy = np.array([[(1 if prob > thrs_accuracy else 0) for prob in probs] for probs in np.array(outputs)])
|
| 229 |
+
subset_accuracy = accuracy_score(targets, outputs_accuracy)
|
| 230 |
+
|
| 231 |
+
print("This is the best threshold for the challenge score, obtained from the training test, without any testing data leakage:", thrs_chall)
|
| 232 |
+
outputs_best = np.array([[(1 if prob > thrs_chall else 0) for prob in probs] for probs in np.array(outputs)])
|
| 233 |
+
|
| 234 |
+
challenge_value = compute_challenge_metric(weights, targets, outputs_best, classes, sinus_rhythm_ID)
|
| 235 |
+
print("This is the challenge score:", challenge_value)
|
| 236 |
+
|
| 237 |
+
# gather the stats from all processes
|
| 238 |
+
metric_logger.synchronize_between_processes()
|
| 239 |
+
loss_value_after_each_epcoh /= batch_num
|
| 240 |
+
|
| 241 |
+
return auprc, auroc, f1, hamming, subset_accuracy, challenge_value, loss_value_after_each_epcoh
|
evaluate_12ECG_score.py
ADDED
|
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# This file contains functions for evaluating algorithms for the 2020 PhysioNet/
|
| 4 |
+
# Computing in Cardiology Challenge. You can run it as follows:
|
| 5 |
+
#
|
| 6 |
+
# python evaluate_12ECG_score.py labels outputs scores.csv
|
| 7 |
+
#
|
| 8 |
+
# where 'labels' is a directory containing files with the labels, 'outputs' is a
|
| 9 |
+
# directory containing files with the outputs from your model, and 'scores.csv'
|
| 10 |
+
# (optional) is a collection of scores for the algorithm outputs.
|
| 11 |
+
#
|
| 12 |
+
# Each file of labels or outputs must have the format described on the Challenge
|
| 13 |
+
# webpage. The scores for the algorithm outputs include the area under the
|
| 14 |
+
# receiver-operating characteristic curve (AUROC), the area under the recall-
|
| 15 |
+
# precision curve (AUPRC), accuracy (fraction of correct recordings), macro F-
|
| 16 |
+
# measure, and the Challenge metric, which assigns different weights to
|
| 17 |
+
# different misclassification errors.
|
| 18 |
+
|
| 19 |
+
import numpy as np, os, os.path, sys
|
| 20 |
+
|
| 21 |
+
def evaluate_12ECG_score(label_directory, output_directory):
|
| 22 |
+
# Define the weights, the SNOMED CT code for the normal class, and equivalent SNOMED CT codes.
|
| 23 |
+
weights_file = 'weights.csv'
|
| 24 |
+
normal_class = '426783006'
|
| 25 |
+
equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
|
| 26 |
+
|
| 27 |
+
# Find the label and output files.
|
| 28 |
+
print('Finding label and output files...')
|
| 29 |
+
label_files, output_files = find_challenge_files(label_directory, output_directory)
|
| 30 |
+
|
| 31 |
+
# Load the labels and outputs.
|
| 32 |
+
print('Loading labels and outputs...')
|
| 33 |
+
label_classes, labels = load_labels(label_files, normal_class, equivalent_classes)
|
| 34 |
+
output_classes, binary_outputs, scalar_outputs = load_outputs(output_files, normal_class, equivalent_classes)
|
| 35 |
+
|
| 36 |
+
# Organize/sort the labels and outputs.
|
| 37 |
+
print('Organizing labels and outputs...')
|
| 38 |
+
classes, labels, binary_outputs, scalar_outputs = organize_labels_outputs(label_classes, output_classes, labels, binary_outputs, scalar_outputs)
|
| 39 |
+
|
| 40 |
+
# Load the weights for the Challenge metric.
|
| 41 |
+
print('Loading weights...')
|
| 42 |
+
weights = load_weights(weights_file, classes)
|
| 43 |
+
|
| 44 |
+
# Only consider classes that are scored with the Challenge metric.
|
| 45 |
+
indices = np.any(weights, axis=0) # Find indices of classes in weight matrix.
|
| 46 |
+
classes = [x for i, x in enumerate(classes) if indices[i]]
|
| 47 |
+
labels = labels[:, indices]
|
| 48 |
+
scalar_outputs = scalar_outputs[:, indices]
|
| 49 |
+
binary_outputs = binary_outputs[:, indices]
|
| 50 |
+
weights = weights[np.ix_(indices, indices)]
|
| 51 |
+
|
| 52 |
+
# Evaluate the model by comparing the labels and outputs.
|
| 53 |
+
print('Evaluating model...')
|
| 54 |
+
|
| 55 |
+
print('- AUROC and AUPRC...')
|
| 56 |
+
auroc, auprc = compute_auc(labels, scalar_outputs)
|
| 57 |
+
|
| 58 |
+
print('- Accuracy...')
|
| 59 |
+
accuracy = compute_accuracy(labels, binary_outputs)
|
| 60 |
+
|
| 61 |
+
print('- F-measure...')
|
| 62 |
+
f_measure = compute_f_measure(labels, binary_outputs)
|
| 63 |
+
|
| 64 |
+
print('- F-beta and G-beta measures...')
|
| 65 |
+
f_beta_measure, g_beta_measure = compute_beta_measures(labels, binary_outputs, beta=2)
|
| 66 |
+
|
| 67 |
+
print('- Challenge metric...')
|
| 68 |
+
challenge_metric = compute_challenge_metric(weights, labels, binary_outputs, classes, normal_class)
|
| 69 |
+
|
| 70 |
+
print('Done.')
|
| 71 |
+
|
| 72 |
+
# Return the results.
|
| 73 |
+
return auroc, auprc, accuracy, f_measure, f_beta_measure, g_beta_measure, challenge_metric
|
| 74 |
+
|
| 75 |
+
# Check if the input is a number.
|
| 76 |
+
def is_number(x):
|
| 77 |
+
try:
|
| 78 |
+
float(x)
|
| 79 |
+
return True
|
| 80 |
+
except ValueError:
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
# Find Challenge files.
|
| 84 |
+
def find_challenge_files(label_directory, output_directory):
|
| 85 |
+
label_files = list()
|
| 86 |
+
output_files = list()
|
| 87 |
+
for f in sorted(os.listdir(label_directory)):
|
| 88 |
+
F = os.path.join(label_directory, f) # Full path for label file
|
| 89 |
+
if os.path.isfile(F) and F.lower().endswith('.hea') and not f.lower().startswith('.'):
|
| 90 |
+
root, ext = os.path.splitext(f)
|
| 91 |
+
g = root + '.csv'
|
| 92 |
+
G = os.path.join(output_directory, g) # Full path for corresponding output file
|
| 93 |
+
if os.path.isfile(G):
|
| 94 |
+
label_files.append(F)
|
| 95 |
+
output_files.append(G)
|
| 96 |
+
else:
|
| 97 |
+
raise IOError('Output file {} not found for label file {}.'.format(g, f))
|
| 98 |
+
|
| 99 |
+
if label_files and output_files:
|
| 100 |
+
return label_files, output_files
|
| 101 |
+
else:
|
| 102 |
+
raise IOError('No label or output files found.')
|
| 103 |
+
|
| 104 |
+
# Load labels from header/label files.
|
| 105 |
+
def load_labels(label_files, normal_class, equivalent_classes_collection):
|
| 106 |
+
# The labels should have the following form:
|
| 107 |
+
#
|
| 108 |
+
# Dx: label_1, label_2, label_3
|
| 109 |
+
#
|
| 110 |
+
num_recordings = len(label_files)
|
| 111 |
+
|
| 112 |
+
# Load diagnoses.
|
| 113 |
+
tmp_labels = list()
|
| 114 |
+
for i in range(num_recordings):
|
| 115 |
+
with open(label_files[i], 'r') as f:
|
| 116 |
+
for l in f:
|
| 117 |
+
if l.startswith('#Dx'):
|
| 118 |
+
dxs = set(arr.strip() for arr in l.split(': ')[1].split(','))
|
| 119 |
+
tmp_labels.append(dxs)
|
| 120 |
+
|
| 121 |
+
# Identify classes.
|
| 122 |
+
classes = set.union(*map(set, tmp_labels))
|
| 123 |
+
if normal_class not in classes:
|
| 124 |
+
classes.add(normal_class)
|
| 125 |
+
print('- The normal class {} is not one of the label classes, so it has been automatically added, but please check that you chose the correct normal class.'.format(normal_class))
|
| 126 |
+
classes = sorted(classes)
|
| 127 |
+
num_classes = len(classes)
|
| 128 |
+
|
| 129 |
+
# Use one-hot encoding for labels.
|
| 130 |
+
labels = np.zeros((num_recordings, num_classes), dtype=np.bool_)
|
| 131 |
+
for i in range(num_recordings):
|
| 132 |
+
dxs = tmp_labels[i]
|
| 133 |
+
for dx in dxs:
|
| 134 |
+
j = classes.index(dx)
|
| 135 |
+
labels[i, j] = 1
|
| 136 |
+
|
| 137 |
+
# For each set of equivalent class, use only one class as the representative class for the set and discard the other classes in the set.
|
| 138 |
+
# The label for the representative class is positive if any of the labels in the set is positive.
|
| 139 |
+
remove_classes = list()
|
| 140 |
+
remove_indices = list()
|
| 141 |
+
for equivalent_classes in equivalent_classes_collection:
|
| 142 |
+
equivalent_classes = [x for x in equivalent_classes if x in classes]
|
| 143 |
+
if len(equivalent_classes)>1:
|
| 144 |
+
representative_class = equivalent_classes[0]
|
| 145 |
+
other_classes = equivalent_classes[1:]
|
| 146 |
+
equivalent_indices = [classes.index(x) for x in equivalent_classes]
|
| 147 |
+
representative_index = equivalent_indices[0]
|
| 148 |
+
other_indices = equivalent_indices[1:]
|
| 149 |
+
|
| 150 |
+
labels[:, representative_index] = np.any(labels[:, equivalent_indices], axis=1)
|
| 151 |
+
remove_classes += other_classes
|
| 152 |
+
remove_indices += other_indices
|
| 153 |
+
|
| 154 |
+
for x in remove_classes:
|
| 155 |
+
classes.remove(x)
|
| 156 |
+
labels = np.delete(labels, remove_indices, axis=1)
|
| 157 |
+
|
| 158 |
+
# If the labels are negative for all classes, then change the label for the normal class to positive.
|
| 159 |
+
normal_index = classes.index(normal_class)
|
| 160 |
+
for i in range(num_recordings):
|
| 161 |
+
num_positive_classes = np.sum(labels[i, :])
|
| 162 |
+
if num_positive_classes==0:
|
| 163 |
+
labels[i, normal_index] = 1
|
| 164 |
+
|
| 165 |
+
return classes, labels
|
| 166 |
+
|
| 167 |
+
# Load outputs from output files.
|
| 168 |
+
def load_outputs(output_files, normal_class, equivalent_classes_collection):
|
| 169 |
+
# The outputs should have the following form:
|
| 170 |
+
#
|
| 171 |
+
# diagnosis_1, diagnosis_2, diagnosis_3
|
| 172 |
+
# 0, 1, 1
|
| 173 |
+
# 0.12, 0.34, 0.56
|
| 174 |
+
#
|
| 175 |
+
num_recordings = len(output_files)
|
| 176 |
+
|
| 177 |
+
tmp_labels = list()
|
| 178 |
+
tmp_binary_outputs = list()
|
| 179 |
+
tmp_scalar_outputs = list()
|
| 180 |
+
for i in range(num_recordings):
|
| 181 |
+
with open(output_files[i], 'r') as f:
|
| 182 |
+
for j, l in enumerate(f):
|
| 183 |
+
arrs = [arr.strip() for arr in l.split(',')]
|
| 184 |
+
if j==1:
|
| 185 |
+
row = arrs
|
| 186 |
+
tmp_labels.append(row)
|
| 187 |
+
elif j==2:
|
| 188 |
+
row = list()
|
| 189 |
+
for arr in arrs:
|
| 190 |
+
number = 1 if arr in ('1', 'True', 'true', 'T', 't') else 0
|
| 191 |
+
row.append(number)
|
| 192 |
+
tmp_binary_outputs.append(row)
|
| 193 |
+
elif j==3:
|
| 194 |
+
row = list()
|
| 195 |
+
for arr in arrs:
|
| 196 |
+
number = float(arr) if is_number(arr) else 0
|
| 197 |
+
row.append(number)
|
| 198 |
+
tmp_scalar_outputs.append(row)
|
| 199 |
+
|
| 200 |
+
# Identify classes.
|
| 201 |
+
classes = set.union(*map(set, tmp_labels))
|
| 202 |
+
if normal_class not in classes:
|
| 203 |
+
classes.add(normal_class)
|
| 204 |
+
print('- The normal class {} is not one of the output classes, so it has been automatically added, but please check that you identified the correct normal class.'.format(normal_class))
|
| 205 |
+
classes = sorted(classes)
|
| 206 |
+
num_classes = len(classes)
|
| 207 |
+
|
| 208 |
+
# Use one-hot encoding for binary outputs and the same order for scalar outputs.
|
| 209 |
+
binary_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool_)
|
| 210 |
+
scalar_outputs = np.zeros((num_recordings, num_classes), dtype=np.float64)
|
| 211 |
+
for i in range(num_recordings):
|
| 212 |
+
dxs = tmp_labels[i]
|
| 213 |
+
for k, dx in enumerate(dxs):
|
| 214 |
+
j = classes.index(dx)
|
| 215 |
+
binary_outputs[i, j] = tmp_binary_outputs[i][k]
|
| 216 |
+
scalar_outputs[i, j] = tmp_scalar_outputs[i][k]
|
| 217 |
+
|
| 218 |
+
# For each set of equivalent class, use only one class as the representative class for the set and discard the other classes in the set.
|
| 219 |
+
# The binary output for the representative class is positive if any of the classes in the set is positive.
|
| 220 |
+
# The scalar output is the mean of the scalar outputs for the classes in the set.
|
| 221 |
+
remove_classes = list()
|
| 222 |
+
remove_indices = list()
|
| 223 |
+
for equivalent_classes in equivalent_classes_collection:
|
| 224 |
+
equivalent_classes = [x for x in equivalent_classes if x in classes]
|
| 225 |
+
if len(equivalent_classes)>1:
|
| 226 |
+
representative_class = equivalent_classes[0]
|
| 227 |
+
other_classes = equivalent_classes[1:]
|
| 228 |
+
equivalent_indices = [classes.index(x) for x in equivalent_classes]
|
| 229 |
+
representative_index = equivalent_indices[0]
|
| 230 |
+
other_indices = equivalent_indices[1:]
|
| 231 |
+
|
| 232 |
+
binary_outputs[:, representative_index] = np.any(binary_outputs[:, equivalent_indices], axis=1)
|
| 233 |
+
scalar_outputs[:, representative_index] = np.nanmean(scalar_outputs[:, equivalent_indices], axis=1)
|
| 234 |
+
remove_classes += other_classes
|
| 235 |
+
remove_indices += other_indices
|
| 236 |
+
|
| 237 |
+
for x in remove_classes:
|
| 238 |
+
classes.remove(x)
|
| 239 |
+
binary_outputs = np.delete(binary_outputs, remove_indices, axis=1)
|
| 240 |
+
scalar_outputs = np.delete(scalar_outputs, remove_indices, axis=1)
|
| 241 |
+
|
| 242 |
+
# If any of the outputs is a NaN, then replace it with a zero.
|
| 243 |
+
binary_outputs[np.isnan(binary_outputs)] = 0
|
| 244 |
+
scalar_outputs[np.isnan(scalar_outputs)] = 0
|
| 245 |
+
|
| 246 |
+
# If the binary outputs are negative for all classes, then change the binary output for the normal class to positive.
|
| 247 |
+
normal_index = classes.index(normal_class)
|
| 248 |
+
for i in range(num_recordings):
|
| 249 |
+
num_positive_classes = np.sum(binary_outputs[i, :])
|
| 250 |
+
if num_positive_classes==0:
|
| 251 |
+
binary_outputs[i, normal_index] = 1
|
| 252 |
+
|
| 253 |
+
return classes, binary_outputs, scalar_outputs
|
| 254 |
+
|
| 255 |
+
# Organize labels and outputs.
|
| 256 |
+
def organize_labels_outputs(label_classes, output_classes, tmp_labels, tmp_binary_outputs, tmp_scalar_outputs):
|
| 257 |
+
# Include all classes from either the labels or the outputs.
|
| 258 |
+
classes = sorted(set(label_classes) | set(output_classes))
|
| 259 |
+
num_classes = len(classes)
|
| 260 |
+
|
| 261 |
+
# Check that the labels and outputs have the same numbers of recordings.
|
| 262 |
+
assert(len(tmp_labels)==len(tmp_binary_outputs)==len(tmp_scalar_outputs))
|
| 263 |
+
num_recordings = len(tmp_labels)
|
| 264 |
+
|
| 265 |
+
# Rearrange the columns of the labels and the outputs to be consistent with the order of the classes.
|
| 266 |
+
labels = np.zeros((num_recordings, num_classes), dtype=np.bool_)
|
| 267 |
+
for k, dx in enumerate(label_classes):
|
| 268 |
+
j = classes.index(dx)
|
| 269 |
+
labels[:, j] = tmp_labels[:, k]
|
| 270 |
+
|
| 271 |
+
binary_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool_)
|
| 272 |
+
scalar_outputs = np.zeros((num_recordings, num_classes), dtype=np.float64)
|
| 273 |
+
for k, dx in enumerate(output_classes):
|
| 274 |
+
j = classes.index(dx)
|
| 275 |
+
binary_outputs[:, j] = tmp_binary_outputs[:, k]
|
| 276 |
+
scalar_outputs[:, j] = tmp_scalar_outputs[:, k]
|
| 277 |
+
|
| 278 |
+
return classes, labels, binary_outputs, scalar_outputs
|
| 279 |
+
|
| 280 |
+
# Load a table with row and column names.
|
| 281 |
+
def load_table(table_file):
|
| 282 |
+
# The table should have the following form:
|
| 283 |
+
#
|
| 284 |
+
# , a, b, c
|
| 285 |
+
# a, 1.2, 2.3, 3.4
|
| 286 |
+
# b, 4.5, 5.6, 6.7
|
| 287 |
+
# c, 7.8, 8.9, 9.0
|
| 288 |
+
#
|
| 289 |
+
table = list()
|
| 290 |
+
with open(table_file, 'r') as f:
|
| 291 |
+
for i, l in enumerate(f):
|
| 292 |
+
arrs = [arr.strip() for arr in l.split(',')]
|
| 293 |
+
table.append(arrs)
|
| 294 |
+
|
| 295 |
+
# Define the numbers of rows and columns and check for errors.
|
| 296 |
+
num_rows = len(table)-1
|
| 297 |
+
if num_rows<1:
|
| 298 |
+
raise Exception('The table {} is empty.'.format(table_file))
|
| 299 |
+
|
| 300 |
+
num_cols = set(len(table[i])-1 for i in range(num_rows))
|
| 301 |
+
if len(num_cols)!=1:
|
| 302 |
+
raise Exception('The table {} has rows with different lengths.'.format(table_file))
|
| 303 |
+
num_cols = min(num_cols)
|
| 304 |
+
if num_cols<1:
|
| 305 |
+
raise Exception('The table {} is empty.'.format(table_file))
|
| 306 |
+
|
| 307 |
+
# Find the row and column labels.
|
| 308 |
+
rows = [table[0][j+1] for j in range(num_rows)]
|
| 309 |
+
cols = [table[i+1][0] for i in range(num_cols)]
|
| 310 |
+
|
| 311 |
+
# Find the entries of the table.
|
| 312 |
+
values = np.zeros((num_rows, num_cols))
|
| 313 |
+
for i in range(num_rows):
|
| 314 |
+
for j in range(num_cols):
|
| 315 |
+
value = table[i+1][j+1]
|
| 316 |
+
if is_number(value):
|
| 317 |
+
values[i, j] = float(value)
|
| 318 |
+
else:
|
| 319 |
+
values[i, j] = float('nan')
|
| 320 |
+
|
| 321 |
+
return rows, cols, values
|
| 322 |
+
|
| 323 |
+
# Load weights.
|
| 324 |
+
def load_weights(weight_file, classes):
|
| 325 |
+
# Load the weight matrix.
|
| 326 |
+
rows, cols, values = load_table(weight_file)
|
| 327 |
+
assert(rows == cols)
|
| 328 |
+
num_rows = len(rows)
|
| 329 |
+
|
| 330 |
+
# Assign the entries of the weight matrix with rows and columns corresponding to the classes.
|
| 331 |
+
num_classes = len(classes)
|
| 332 |
+
weights = np.zeros((num_classes, num_classes), dtype=np.float64)
|
| 333 |
+
for i, a in enumerate(rows):
|
| 334 |
+
if a in classes:
|
| 335 |
+
k = classes.index(a)
|
| 336 |
+
for j, b in enumerate(rows):
|
| 337 |
+
if b in classes:
|
| 338 |
+
l = classes.index(b)
|
| 339 |
+
weights[k, l] = values[i, j]
|
| 340 |
+
|
| 341 |
+
return weights
|
| 342 |
+
|
| 343 |
+
# Compute recording-wise accuracy.
|
| 344 |
+
def compute_accuracy(labels, outputs):
|
| 345 |
+
num_recordings, num_classes = np.shape(labels)
|
| 346 |
+
|
| 347 |
+
num_correct_recordings = 0
|
| 348 |
+
for i in range(num_recordings):
|
| 349 |
+
if np.all(labels[i, :]==outputs[i, :]):
|
| 350 |
+
num_correct_recordings += 1
|
| 351 |
+
|
| 352 |
+
return float(num_correct_recordings) / float(num_recordings)
|
| 353 |
+
|
| 354 |
+
# Compute confusion matrices.
|
| 355 |
+
def compute_confusion_matrices(labels, outputs, normalize=False):
|
| 356 |
+
# Compute a binary confusion matrix for each class k:
|
| 357 |
+
#
|
| 358 |
+
# [TN_k FN_k]
|
| 359 |
+
# [FP_k TP_k]
|
| 360 |
+
#
|
| 361 |
+
# If the normalize variable is set to true, then normalize the contributions
|
| 362 |
+
# to the confusion matrix by the number of labels per recording.
|
| 363 |
+
num_recordings, num_classes = np.shape(labels)
|
| 364 |
+
|
| 365 |
+
if not normalize:
|
| 366 |
+
A = np.zeros((num_classes, 2, 2))
|
| 367 |
+
for i in range(num_recordings):
|
| 368 |
+
for j in range(num_classes):
|
| 369 |
+
if labels[i, j]==1 and outputs[i, j]==1: # TP
|
| 370 |
+
A[j, 1, 1] += 1
|
| 371 |
+
elif labels[i, j]==0 and outputs[i, j]==1: # FP
|
| 372 |
+
A[j, 1, 0] += 1
|
| 373 |
+
elif labels[i, j]==1 and outputs[i, j]==0: # FN
|
| 374 |
+
A[j, 0, 1] += 1
|
| 375 |
+
elif labels[i, j]==0 and outputs[i, j]==0: # TN
|
| 376 |
+
A[j, 0, 0] += 1
|
| 377 |
+
else: # This condition should not happen.
|
| 378 |
+
raise ValueError('Error in computing the confusion matrix.')
|
| 379 |
+
else:
|
| 380 |
+
A = np.zeros((num_classes, 2, 2))
|
| 381 |
+
for i in range(num_recordings):
|
| 382 |
+
normalization = float(max(np.sum(labels[i, :]), 1))
|
| 383 |
+
for j in range(num_classes):
|
| 384 |
+
if labels[i, j]==1 and outputs[i, j]==1: # TP
|
| 385 |
+
A[j, 1, 1] += 1.0/normalization
|
| 386 |
+
elif labels[i, j]==0 and outputs[i, j]==1: # FP
|
| 387 |
+
A[j, 1, 0] += 1.0/normalization
|
| 388 |
+
elif labels[i, j]==1 and outputs[i, j]==0: # FN
|
| 389 |
+
A[j, 0, 1] += 1.0/normalization
|
| 390 |
+
elif labels[i, j]==0 and outputs[i, j]==0: # TN
|
| 391 |
+
A[j, 0, 0] += 1.0/normalization
|
| 392 |
+
else: # This condition should not happen.
|
| 393 |
+
raise ValueError('Error in computing the confusion matrix.')
|
| 394 |
+
|
| 395 |
+
return A
|
| 396 |
+
|
| 397 |
+
# Compute macro F-measure.
|
| 398 |
+
def compute_f_measure(labels, outputs):
|
| 399 |
+
num_recordings, num_classes = np.shape(labels)
|
| 400 |
+
|
| 401 |
+
A = compute_confusion_matrices(labels, outputs)
|
| 402 |
+
|
| 403 |
+
f_measure = np.zeros(num_classes)
|
| 404 |
+
for k in range(num_classes):
|
| 405 |
+
tp, fp, fn, tn = A[k, 1, 1], A[k, 1, 0], A[k, 0, 1], A[k, 0, 0]
|
| 406 |
+
if 2 * tp + fp + fn:
|
| 407 |
+
f_measure[k] = float(2 * tp) / float(2 * tp + fp + fn)
|
| 408 |
+
else:
|
| 409 |
+
f_measure[k] = float('nan')
|
| 410 |
+
|
| 411 |
+
macro_f_measure = np.nanmean(f_measure)
|
| 412 |
+
|
| 413 |
+
return macro_f_measure
|
| 414 |
+
|
| 415 |
+
# Compute F-beta and G-beta measures from the unofficial phase of the Challenge.
|
| 416 |
+
def compute_beta_measures(labels, outputs, beta):
|
| 417 |
+
num_recordings, num_classes = np.shape(labels)
|
| 418 |
+
|
| 419 |
+
A = compute_confusion_matrices(labels, outputs, normalize=True)
|
| 420 |
+
|
| 421 |
+
f_beta_measure = np.zeros(num_classes)
|
| 422 |
+
g_beta_measure = np.zeros(num_classes)
|
| 423 |
+
for k in range(num_classes):
|
| 424 |
+
tp, fp, fn, tn = A[k, 1, 1], A[k, 1, 0], A[k, 0, 1], A[k, 0, 0]
|
| 425 |
+
if (1+beta**2)*tp + fp + beta**2*fn:
|
| 426 |
+
f_beta_measure[k] = float((1+beta**2)*tp) / float((1+beta**2)*tp + fp + beta**2*fn)
|
| 427 |
+
else:
|
| 428 |
+
f_beta_measure[k] = float('nan')
|
| 429 |
+
if tp + fp + beta*fn:
|
| 430 |
+
g_beta_measure[k] = float(tp) / float(tp + fp + beta*fn)
|
| 431 |
+
else:
|
| 432 |
+
g_beta_measure[k] = float('nan')
|
| 433 |
+
|
| 434 |
+
macro_f_beta_measure = np.nanmean(f_beta_measure)
|
| 435 |
+
macro_g_beta_measure = np.nanmean(g_beta_measure)
|
| 436 |
+
|
| 437 |
+
return macro_f_beta_measure, macro_g_beta_measure
|
| 438 |
+
|
| 439 |
+
# Compute macro AUROC and macro AUPRC.
|
| 440 |
+
def compute_auc(labels, outputs):
|
| 441 |
+
num_recordings, num_classes = np.shape(labels)
|
| 442 |
+
|
| 443 |
+
# Compute and summarize the confusion matrices for each class across at distinct output values.
|
| 444 |
+
auroc = np.zeros(num_classes)
|
| 445 |
+
auprc = np.zeros(num_classes)
|
| 446 |
+
|
| 447 |
+
for k in range(num_classes):
|
| 448 |
+
# We only need to compute TPs, FPs, FNs, and TNs at distinct output values.
|
| 449 |
+
thresholds = np.unique(outputs[:, k])
|
| 450 |
+
thresholds = np.append(thresholds, thresholds[-1]+1)
|
| 451 |
+
thresholds = thresholds[::-1]
|
| 452 |
+
num_thresholds = len(thresholds)
|
| 453 |
+
|
| 454 |
+
# Initialize the TPs, FPs, FNs, and TNs.
|
| 455 |
+
tp = np.zeros(num_thresholds)
|
| 456 |
+
fp = np.zeros(num_thresholds)
|
| 457 |
+
fn = np.zeros(num_thresholds)
|
| 458 |
+
tn = np.zeros(num_thresholds)
|
| 459 |
+
fn[0] = np.sum(labels[:, k]==1)
|
| 460 |
+
tn[0] = np.sum(labels[:, k]==0)
|
| 461 |
+
|
| 462 |
+
# Find the indices that result in sorted output values.
|
| 463 |
+
idx = np.argsort(outputs[:, k])[::-1]
|
| 464 |
+
|
| 465 |
+
# Compute the TPs, FPs, FNs, and TNs for class k across thresholds.
|
| 466 |
+
i = 0
|
| 467 |
+
for j in range(1, num_thresholds):
|
| 468 |
+
# Initialize TPs, FPs, FNs, and TNs using values at previous threshold.
|
| 469 |
+
tp[j] = tp[j-1]
|
| 470 |
+
fp[j] = fp[j-1]
|
| 471 |
+
fn[j] = fn[j-1]
|
| 472 |
+
tn[j] = tn[j-1]
|
| 473 |
+
|
| 474 |
+
# Update the TPs, FPs, FNs, and TNs at i-th output value.
|
| 475 |
+
while i < num_recordings and outputs[idx[i], k] >= thresholds[j]:
|
| 476 |
+
if labels[idx[i], k]:
|
| 477 |
+
tp[j] += 1
|
| 478 |
+
fn[j] -= 1
|
| 479 |
+
else:
|
| 480 |
+
fp[j] += 1
|
| 481 |
+
tn[j] -= 1
|
| 482 |
+
i += 1
|
| 483 |
+
|
| 484 |
+
# Summarize the TPs, FPs, FNs, and TNs for class k.
|
| 485 |
+
tpr = np.zeros(num_thresholds)
|
| 486 |
+
tnr = np.zeros(num_thresholds)
|
| 487 |
+
ppv = np.zeros(num_thresholds)
|
| 488 |
+
npv = np.zeros(num_thresholds)
|
| 489 |
+
|
| 490 |
+
for j in range(num_thresholds):
|
| 491 |
+
if tp[j] + fn[j]:
|
| 492 |
+
tpr[j] = float(tp[j]) / float(tp[j] + fn[j])
|
| 493 |
+
else:
|
| 494 |
+
tpr[j] = float('nan')
|
| 495 |
+
if fp[j] + tn[j]:
|
| 496 |
+
tnr[j] = float(tn[j]) / float(fp[j] + tn[j])
|
| 497 |
+
else:
|
| 498 |
+
tnr[j] = float('nan')
|
| 499 |
+
if tp[j] + fp[j]:
|
| 500 |
+
ppv[j] = float(tp[j]) / float(tp[j] + fp[j])
|
| 501 |
+
else:
|
| 502 |
+
ppv[j] = float('nan')
|
| 503 |
+
|
| 504 |
+
# Compute AUROC as the area under a piecewise linear function with TPR/
|
| 505 |
+
# sensitivity (x-axis) and TNR/specificity (y-axis) and AUPRC as the area
|
| 506 |
+
# under a piecewise constant with TPR/recall (x-axis) and PPV/precision
|
| 507 |
+
# (y-axis) for class k.
|
| 508 |
+
for j in range(num_thresholds-1):
|
| 509 |
+
auroc[k] += 0.5 * (tpr[j+1] - tpr[j]) * (tnr[j+1] + tnr[j])
|
| 510 |
+
auprc[k] += (tpr[j+1] - tpr[j]) * ppv[j+1]
|
| 511 |
+
|
| 512 |
+
# Compute macro AUROC and macro AUPRC across classes.
|
| 513 |
+
macro_auroc = np.nanmean(auroc)
|
| 514 |
+
macro_auprc = np.nanmean(auprc)
|
| 515 |
+
|
| 516 |
+
return macro_auroc, macro_auprc
|
| 517 |
+
|
| 518 |
+
# Compute modified confusion matrix for multi-class, multi-label tasks.
|
| 519 |
+
def compute_modified_confusion_matrix(labels, outputs):
|
| 520 |
+
# Compute a binary multi-class, multi-label confusion matrix, where the rows
|
| 521 |
+
# are the labels and the columns are the outputs.
|
| 522 |
+
num_recordings, num_classes = np.shape(labels)
|
| 523 |
+
A = np.zeros((num_classes, num_classes))
|
| 524 |
+
|
| 525 |
+
# Iterate over all of the recordings.
|
| 526 |
+
for i in range(num_recordings):
|
| 527 |
+
# Calculate the number of positive labels and/or outputs.
|
| 528 |
+
normalization = float(max(np.sum(np.any((labels[i, :], outputs[i, :]), axis=0)), 1))
|
| 529 |
+
# Iterate over all of the classes.
|
| 530 |
+
for j in range(num_classes):
|
| 531 |
+
# Assign full and/or partial credit for each positive class.
|
| 532 |
+
if labels[i, j]:
|
| 533 |
+
for k in range(num_classes):
|
| 534 |
+
if outputs[i, k]:
|
| 535 |
+
A[j, k] += 1.0/normalization
|
| 536 |
+
|
| 537 |
+
return A
|
| 538 |
+
|
| 539 |
+
# Compute the evaluation metric for the Challenge.
|
| 540 |
+
def compute_challenge_metric(weights, labels, outputs, classes, normal_class):
|
| 541 |
+
num_recordings, num_classes = np.shape(labels)
|
| 542 |
+
normal_index = classes.index(normal_class)
|
| 543 |
+
|
| 544 |
+
# Compute the observed score.
|
| 545 |
+
A = compute_modified_confusion_matrix(labels, outputs)
|
| 546 |
+
observed_score = np.nansum(weights * A)
|
| 547 |
+
|
| 548 |
+
# Compute the score for the model that always chooses the correct label(s).
|
| 549 |
+
correct_outputs = labels
|
| 550 |
+
A = compute_modified_confusion_matrix(labels, correct_outputs)
|
| 551 |
+
correct_score = np.nansum(weights * A)
|
| 552 |
+
|
| 553 |
+
# Compute the score for the model that always chooses the normal class.
|
| 554 |
+
inactive_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool_)
|
| 555 |
+
inactive_outputs[:, normal_index] = 1
|
| 556 |
+
A = compute_modified_confusion_matrix(labels, inactive_outputs)
|
| 557 |
+
inactive_score = np.nansum(weights * A)
|
| 558 |
+
|
| 559 |
+
if correct_score != inactive_score:
|
| 560 |
+
normalized_score = float(observed_score - inactive_score) / float(correct_score - inactive_score)
|
| 561 |
+
else:
|
| 562 |
+
normalized_score = float('nan')
|
| 563 |
+
|
| 564 |
+
return normalized_score
|
| 565 |
+
|
| 566 |
+
if __name__ == '__main__':
|
| 567 |
+
#auroc, auprc, accuracy, f_measure, f_beta_measure, g_beta_measure, challenge_metric = evaluate_12ECG_score(sys.argv[1], sys.argv[2])
|
| 568 |
+
lbl_dir = '/home/p2017-999/acs_data/processed_data/physionet2020/jonathan/in'
|
| 569 |
+
out_dir = '/home/p2017-999/acs_data/processed_data/physionet2020/jonathan/out'
|
| 570 |
+
auroc, auprc, accuracy, f_measure, f_beta_measure, g_beta_measure, challenge_metric = evaluate_12ECG_score(lbl_dir, out_dir)
|
| 571 |
+
|
| 572 |
+
output_string = 'AUROC,AUPRC,Accuracy,F-measure,Fbeta-measure,Gbeta-measure,Challenge metric\n{:.3f},{:.3f},{:.3f},{:.3f},{:.3f},{:.3f},{:.3f}'.format(auroc, auprc, accuracy, f_measure, f_beta_measure, g_beta_measure, challenge_metric)
|
| 573 |
+
if len(sys.argv) > 3:
|
| 574 |
+
with open(sys.argv[3], 'w') as f:
|
| 575 |
+
f.write(output_string)
|
| 576 |
+
else:
|
| 577 |
+
print(output_string)
|
evaluate_model.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# This file contains functions for evaluating algorithms for the 2021 PhysioNet/
|
| 4 |
+
# Computing in Cardiology Challenge. You can run it as follows:
|
| 5 |
+
#
|
| 6 |
+
# python evaluate_model.py labels outputs scores.csv
|
| 7 |
+
#
|
| 8 |
+
# where 'labels' is a directory containing files with the labels, 'outputs' is a
|
| 9 |
+
# directory containing files with the outputs from your model, and 'scores.csv'
|
| 10 |
+
# (optional) is a collection of scores for the algorithm outputs.
|
| 11 |
+
#
|
| 12 |
+
# Each file of labels or outputs must have the format described on the Challenge
|
| 13 |
+
# webpage. The scores for the algorithm outputs include the area under the
|
| 14 |
+
# receiver-operating characteristic curve (AUROC), the area under the recall-
|
| 15 |
+
# precision curve (AUPRC), accuracy (fraction of correct recordings), macro F-
|
| 16 |
+
# measure, and the Challenge metric, which assigns different weights to
|
| 17 |
+
# different misclassification errors.
|
| 18 |
+
|
| 19 |
+
import os, os.path, sys, numpy as np
|
| 20 |
+
from helper_code import get_labels, is_finite_number, load_header, load_outputs
|
| 21 |
+
import pandas as pd
|
| 22 |
+
from tabulate import tabulate
|
| 23 |
+
|
| 24 |
+
def evaluate_model(label_directory, output_directory):
|
| 25 |
+
# Identify the weights and the SNOMED CT code for the sinus rhythm class.
|
| 26 |
+
weights_file = 'weights.csv'
|
| 27 |
+
sinus_rhythm = set(['426783006'])
|
| 28 |
+
|
| 29 |
+
# Load the scored classes and the weights for the Challenge metric.
|
| 30 |
+
print('Loading weights...')
|
| 31 |
+
classes, weights = load_weights(weights_file)
|
| 32 |
+
|
| 33 |
+
# Load the label and output files.
|
| 34 |
+
print('Loading label and output files...')
|
| 35 |
+
label_files, output_files = find_challenge_files(label_directory, output_directory)
|
| 36 |
+
labels = load_labels(label_files, classes)
|
| 37 |
+
binary_outputs, scalar_outputs = load_classifier_outputs(output_files, classes)
|
| 38 |
+
|
| 39 |
+
# Evaluate the model by comparing the labels and outputs.
|
| 40 |
+
print('Evaluating model...')
|
| 41 |
+
|
| 42 |
+
print('- AUROC and AUPRC...')
|
| 43 |
+
auroc, auprc, auroc_classes, auprc_classes = compute_auc(labels, scalar_outputs)
|
| 44 |
+
|
| 45 |
+
print('- Accuracy...')
|
| 46 |
+
accuracy = compute_accuracy(labels, binary_outputs)
|
| 47 |
+
|
| 48 |
+
print('- F-measure...')
|
| 49 |
+
f_measure, f_measure_classes = compute_f_measure(labels, binary_outputs)
|
| 50 |
+
|
| 51 |
+
print('- Challenge metric...')
|
| 52 |
+
challenge_metric = compute_challenge_metric(weights, labels, binary_outputs, classes, sinus_rhythm)
|
| 53 |
+
|
| 54 |
+
print('Done.')
|
| 55 |
+
|
| 56 |
+
# Return the results.
|
| 57 |
+
return classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric
|
| 58 |
+
|
| 59 |
+
# Find Challenge files.
|
| 60 |
+
def find_challenge_files(label_directory, output_directory):
|
| 61 |
+
label_files = list()
|
| 62 |
+
output_files = list()
|
| 63 |
+
for label_file in sorted(os.listdir(label_directory)):
|
| 64 |
+
label_file_path = os.path.join(label_directory, label_file) # Full path for label file
|
| 65 |
+
if os.path.isfile(label_file_path) and label_file.lower().endswith('.hea') and not label_file.lower().startswith('.'):
|
| 66 |
+
root, ext = os.path.splitext(label_file)
|
| 67 |
+
output_file = root + '.csv'
|
| 68 |
+
output_file_path = os.path.join(output_directory, output_file) # Full path for corresponding output file
|
| 69 |
+
if os.path.isfile(output_file_path):
|
| 70 |
+
label_files.append(label_file_path)
|
| 71 |
+
output_files.append(output_file_path)
|
| 72 |
+
else:
|
| 73 |
+
raise IOError('Output file {} not found for label file {}.'.format(output_file, label_file))
|
| 74 |
+
|
| 75 |
+
if label_files and output_files:
|
| 76 |
+
return label_files, output_files
|
| 77 |
+
else:
|
| 78 |
+
raise IOError('No label or output files found.')
|
| 79 |
+
|
| 80 |
+
# Load a table with row and column names.
|
| 81 |
+
def load_table(table_file):
|
| 82 |
+
# The table should have the following form:
|
| 83 |
+
#
|
| 84 |
+
# , a, b, c
|
| 85 |
+
# a, 1.2, 2.3, 3.4
|
| 86 |
+
# b, 4.5, 5.6, 6.7
|
| 87 |
+
# c, 7.8, 8.9, 9.0
|
| 88 |
+
#
|
| 89 |
+
table = list()
|
| 90 |
+
with open(table_file, 'r') as f:
|
| 91 |
+
for i, l in enumerate(f):
|
| 92 |
+
arrs = [arr.strip() for arr in l.split(',')]
|
| 93 |
+
table.append(arrs)
|
| 94 |
+
|
| 95 |
+
# Define the numbers of rows and columns and check for errors.
|
| 96 |
+
num_rows = len(table)-1
|
| 97 |
+
if num_rows<1:
|
| 98 |
+
raise Exception('The table {} is empty.'.format(table_file))
|
| 99 |
+
row_lengths = set(len(table[i])-1 for i in range(num_rows))
|
| 100 |
+
if len(row_lengths)!=1:
|
| 101 |
+
raise Exception('The table {} has rows with different lengths.'.format(table_file))
|
| 102 |
+
num_cols = min(row_lengths)
|
| 103 |
+
if num_cols<1:
|
| 104 |
+
raise Exception('The table {} is empty.'.format(table_file))
|
| 105 |
+
|
| 106 |
+
# Find the row and column labels.
|
| 107 |
+
rows = [table[0][j+1] for j in range(num_rows)]
|
| 108 |
+
cols = [table[i+1][0] for i in range(num_cols)]
|
| 109 |
+
|
| 110 |
+
# Find the entries of the table.
|
| 111 |
+
values = np.zeros((num_rows, num_cols), dtype=np.float64)
|
| 112 |
+
for i in range(num_rows):
|
| 113 |
+
for j in range(num_cols):
|
| 114 |
+
value = table[i+1][j+1]
|
| 115 |
+
if is_finite_number(value):
|
| 116 |
+
values[i, j] = float(value)
|
| 117 |
+
else:
|
| 118 |
+
values[i, j] = float('nan')
|
| 119 |
+
|
| 120 |
+
return rows, cols, values
|
| 121 |
+
|
| 122 |
+
# Load weights.
|
| 123 |
+
def load_weights(weight_file):
|
| 124 |
+
# Load the table with the weight matrix.
|
| 125 |
+
rows, cols, values = load_table(weight_file)
|
| 126 |
+
|
| 127 |
+
# Split the equivalent classes.
|
| 128 |
+
rows = [set(row.split('|')) for row in rows]
|
| 129 |
+
cols = [set(col.split('|')) for col in cols]
|
| 130 |
+
assert(rows == cols)
|
| 131 |
+
|
| 132 |
+
# Identify the classes and the weight matrix.
|
| 133 |
+
classes = rows
|
| 134 |
+
weights = values
|
| 135 |
+
|
| 136 |
+
return classes, weights
|
| 137 |
+
|
| 138 |
+
# Load labels from header/label files.
|
| 139 |
+
def load_labels(label_files, classes):
|
| 140 |
+
# The labels should have the following form:
|
| 141 |
+
#
|
| 142 |
+
# Dx: label_1, label_2, label_3
|
| 143 |
+
#
|
| 144 |
+
num_recordings = len(label_files)
|
| 145 |
+
num_classes = len(classes)
|
| 146 |
+
|
| 147 |
+
# Use one-hot encoding for the labels.
|
| 148 |
+
labels = np.zeros((num_recordings, num_classes), dtype=np.bool)
|
| 149 |
+
|
| 150 |
+
# Iterate over the recordings.
|
| 151 |
+
for i in range(num_recordings):
|
| 152 |
+
header = load_header(label_files[i])
|
| 153 |
+
y = set(get_labels(header))
|
| 154 |
+
for j, x in enumerate(classes):
|
| 155 |
+
if x & y:
|
| 156 |
+
labels[i, j] = 1
|
| 157 |
+
|
| 158 |
+
return labels
|
| 159 |
+
|
| 160 |
+
# Load outputs from output files.
|
| 161 |
+
def load_classifier_outputs(output_files, classes):
|
| 162 |
+
# The outputs should have the following form:
|
| 163 |
+
#
|
| 164 |
+
# #Record ID
|
| 165 |
+
# diagnosis_1, diagnosis_2, diagnosis_3
|
| 166 |
+
# 0, 1, 1
|
| 167 |
+
# 0.12, 0.34, 0.56
|
| 168 |
+
#
|
| 169 |
+
num_recordings = len(output_files)
|
| 170 |
+
num_classes = len(classes)
|
| 171 |
+
|
| 172 |
+
# Use one-hot encoding for the outputs.
|
| 173 |
+
binary_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool)
|
| 174 |
+
scalar_outputs = np.zeros((num_recordings, num_classes), dtype=np.float64)
|
| 175 |
+
|
| 176 |
+
# Iterate over the recordings.
|
| 177 |
+
for i in range(num_recordings):
|
| 178 |
+
recording_id, recording_classes, recording_binary_outputs, recording_scalar_outputs = load_outputs(output_files[i])
|
| 179 |
+
|
| 180 |
+
# Allow for equivalent classes and sanitize classifier outputs.
|
| 181 |
+
recording_classes = [set(entry.split('|')) for entry in recording_classes]
|
| 182 |
+
recording_binary_outputs = [1 if entry in ('1', 'True', 'true', 'T', 't') else 0 for entry in recording_binary_outputs]
|
| 183 |
+
recording_scalar_outputs = [float(entry) if is_finite_number(entry) else 0 for entry in recording_scalar_outputs]
|
| 184 |
+
|
| 185 |
+
# Allow for unordered/reordered and equivalent classes.
|
| 186 |
+
for j, x in enumerate(classes):
|
| 187 |
+
binary_values = list()
|
| 188 |
+
scalar_values = list()
|
| 189 |
+
for k, y in enumerate(recording_classes):
|
| 190 |
+
if x & y:
|
| 191 |
+
binary_values.append(recording_binary_outputs[k])
|
| 192 |
+
scalar_values.append(recording_scalar_outputs[k])
|
| 193 |
+
if binary_values:
|
| 194 |
+
binary_outputs[i, j] = any(binary_values) # Define a class as positive if any of the equivalent classes is positive.
|
| 195 |
+
if scalar_values:
|
| 196 |
+
scalar_outputs[i, j] = np.mean(scalar_values) # Define the scalar value of a class as the mean value of the scalar values across equivalent classes.
|
| 197 |
+
|
| 198 |
+
return binary_outputs, scalar_outputs
|
| 199 |
+
|
| 200 |
+
# Compute recording-wise accuracy.
|
| 201 |
+
def compute_accuracy(labels, outputs):
|
| 202 |
+
num_recordings, num_classes = np.shape(labels)
|
| 203 |
+
|
| 204 |
+
num_correct_recordings = 0
|
| 205 |
+
for i in range(num_recordings):
|
| 206 |
+
if np.all(labels[i, :]==outputs[i, :]):
|
| 207 |
+
num_correct_recordings += 1
|
| 208 |
+
|
| 209 |
+
return float(num_correct_recordings) / float(num_recordings)
|
| 210 |
+
|
| 211 |
+
# Compute confusion matrices.
|
| 212 |
+
def compute_confusion_matrices(labels, outputs, normalize=False):
|
| 213 |
+
# Compute a binary confusion matrix for each class k:
|
| 214 |
+
#
|
| 215 |
+
# [TN_k FN_k]
|
| 216 |
+
# [FP_k TP_k]
|
| 217 |
+
#
|
| 218 |
+
# If the normalize variable is set to true, then normalize the contributions
|
| 219 |
+
# to the confusion matrix by the number of labels per recording.
|
| 220 |
+
num_recordings, num_classes = np.shape(labels)
|
| 221 |
+
|
| 222 |
+
if not normalize:
|
| 223 |
+
A = np.zeros((num_classes, 2, 2))
|
| 224 |
+
for i in range(num_recordings):
|
| 225 |
+
for j in range(num_classes):
|
| 226 |
+
if labels[i, j]==1 and outputs[i, j]==1: # TP
|
| 227 |
+
A[j, 1, 1] += 1
|
| 228 |
+
elif labels[i, j]==0 and outputs[i, j]==1: # FP
|
| 229 |
+
A[j, 1, 0] += 1
|
| 230 |
+
elif labels[i, j]==1 and outputs[i, j]==0: # FN
|
| 231 |
+
A[j, 0, 1] += 1
|
| 232 |
+
elif labels[i, j]==0 and outputs[i, j]==0: # TN
|
| 233 |
+
A[j, 0, 0] += 1
|
| 234 |
+
else: # This condition should not happen.
|
| 235 |
+
raise ValueError('Error in computing the confusion matrix.')
|
| 236 |
+
else:
|
| 237 |
+
A = np.zeros((num_classes, 2, 2))
|
| 238 |
+
for i in range(num_recordings):
|
| 239 |
+
normalization = float(max(np.sum(labels[i, :]), 1))
|
| 240 |
+
for j in range(num_classes):
|
| 241 |
+
if labels[i, j]==1 and outputs[i, j]==1: # TP
|
| 242 |
+
A[j, 1, 1] += 1.0/normalization
|
| 243 |
+
elif labels[i, j]==0 and outputs[i, j]==1: # FP
|
| 244 |
+
A[j, 1, 0] += 1.0/normalization
|
| 245 |
+
elif labels[i, j]==1 and outputs[i, j]==0: # FN
|
| 246 |
+
A[j, 0, 1] += 1.0/normalization
|
| 247 |
+
elif labels[i, j]==0 and outputs[i, j]==0: # TN
|
| 248 |
+
A[j, 0, 0] += 1.0/normalization
|
| 249 |
+
else: # This condition should not happen.
|
| 250 |
+
raise ValueError('Error in computing the confusion matrix.')
|
| 251 |
+
|
| 252 |
+
return A
|
| 253 |
+
|
| 254 |
+
# Compute macro F-measure.
|
| 255 |
+
def compute_f_measure(labels, outputs):
|
| 256 |
+
num_recordings, num_classes = np.shape(labels)
|
| 257 |
+
|
| 258 |
+
A = compute_confusion_matrices(labels, outputs)
|
| 259 |
+
|
| 260 |
+
f_measure = np.zeros(num_classes)
|
| 261 |
+
for k in range(num_classes):
|
| 262 |
+
tp, fp, fn, tn = A[k, 1, 1], A[k, 1, 0], A[k, 0, 1], A[k, 0, 0]
|
| 263 |
+
if 2 * tp + fp + fn:
|
| 264 |
+
f_measure[k] = float(2 * tp) / float(2 * tp + fp + fn)
|
| 265 |
+
else:
|
| 266 |
+
f_measure[k] = float('nan')
|
| 267 |
+
|
| 268 |
+
if np.any(np.isfinite(f_measure)):
|
| 269 |
+
macro_f_measure = np.nanmean(f_measure)
|
| 270 |
+
else:
|
| 271 |
+
macro_f_measure = float('nan')
|
| 272 |
+
|
| 273 |
+
return macro_f_measure, f_measure
|
| 274 |
+
|
| 275 |
+
# Compute macro AUROC and macro AUPRC.
|
| 276 |
+
def compute_auc(labels, outputs):
|
| 277 |
+
num_recordings, num_classes = np.shape(labels)
|
| 278 |
+
|
| 279 |
+
# Compute and summarize the confusion matrices for each class across at distinct output values.
|
| 280 |
+
auroc = np.zeros(num_classes)
|
| 281 |
+
auprc = np.zeros(num_classes)
|
| 282 |
+
|
| 283 |
+
for k in range(num_classes):
|
| 284 |
+
# We only need to compute TPs, FPs, FNs, and TNs at distinct output values.
|
| 285 |
+
thresholds = np.unique(outputs[:, k])
|
| 286 |
+
thresholds = np.append(thresholds, thresholds[-1]+1)
|
| 287 |
+
thresholds = thresholds[::-1]
|
| 288 |
+
num_thresholds = len(thresholds)
|
| 289 |
+
|
| 290 |
+
# Initialize the TPs, FPs, FNs, and TNs.
|
| 291 |
+
tp = np.zeros(num_thresholds)
|
| 292 |
+
fp = np.zeros(num_thresholds)
|
| 293 |
+
fn = np.zeros(num_thresholds)
|
| 294 |
+
tn = np.zeros(num_thresholds)
|
| 295 |
+
fn[0] = np.sum(labels[:, k]==1)
|
| 296 |
+
tn[0] = np.sum(labels[:, k]==0)
|
| 297 |
+
|
| 298 |
+
# Find the indices that result in sorted output values.
|
| 299 |
+
idx = np.argsort(outputs[:, k])[::-1]
|
| 300 |
+
|
| 301 |
+
# Compute the TPs, FPs, FNs, and TNs for class k across thresholds.
|
| 302 |
+
i = 0
|
| 303 |
+
for j in range(1, num_thresholds):
|
| 304 |
+
# Initialize TPs, FPs, FNs, and TNs using values at previous threshold.
|
| 305 |
+
tp[j] = tp[j-1]
|
| 306 |
+
fp[j] = fp[j-1]
|
| 307 |
+
fn[j] = fn[j-1]
|
| 308 |
+
tn[j] = tn[j-1]
|
| 309 |
+
|
| 310 |
+
# Update the TPs, FPs, FNs, and TNs at i-th output value.
|
| 311 |
+
while i < num_recordings and outputs[idx[i], k] >= thresholds[j]:
|
| 312 |
+
if labels[idx[i], k]:
|
| 313 |
+
tp[j] += 1
|
| 314 |
+
fn[j] -= 1
|
| 315 |
+
else:
|
| 316 |
+
fp[j] += 1
|
| 317 |
+
tn[j] -= 1
|
| 318 |
+
i += 1
|
| 319 |
+
|
| 320 |
+
# Summarize the TPs, FPs, FNs, and TNs for class k.
|
| 321 |
+
tpr = np.zeros(num_thresholds)
|
| 322 |
+
tnr = np.zeros(num_thresholds)
|
| 323 |
+
ppv = np.zeros(num_thresholds)
|
| 324 |
+
for j in range(num_thresholds):
|
| 325 |
+
if tp[j] + fn[j]:
|
| 326 |
+
tpr[j] = float(tp[j]) / float(tp[j] + fn[j])
|
| 327 |
+
else:
|
| 328 |
+
tpr[j] = float('nan')
|
| 329 |
+
if fp[j] + tn[j]:
|
| 330 |
+
tnr[j] = float(tn[j]) / float(fp[j] + tn[j])
|
| 331 |
+
else:
|
| 332 |
+
tnr[j] = float('nan')
|
| 333 |
+
if tp[j] + fp[j]:
|
| 334 |
+
ppv[j] = float(tp[j]) / float(tp[j] + fp[j])
|
| 335 |
+
else:
|
| 336 |
+
ppv[j] = float('nan')
|
| 337 |
+
|
| 338 |
+
# Compute AUROC as the area under a piecewise linear function with TPR/
|
| 339 |
+
# sensitivity (x-axis) and TNR/specificity (y-axis) and AUPRC as the area
|
| 340 |
+
# under a piecewise constant with TPR/recall (x-axis) and PPV/precision
|
| 341 |
+
# (y-axis) for class k.
|
| 342 |
+
for j in range(num_thresholds-1):
|
| 343 |
+
auroc[k] += 0.5 * (tpr[j+1] - tpr[j]) * (tnr[j+1] + tnr[j])
|
| 344 |
+
auprc[k] += (tpr[j+1] - tpr[j]) * ppv[j+1]
|
| 345 |
+
|
| 346 |
+
# Compute macro AUROC and macro AUPRC across classes.
|
| 347 |
+
if np.any(np.isfinite(auroc)):
|
| 348 |
+
macro_auroc = np.nanmean(auroc)
|
| 349 |
+
else:
|
| 350 |
+
macro_auroc = float('nan')
|
| 351 |
+
if np.any(np.isfinite(auprc)):
|
| 352 |
+
macro_auprc = np.nanmean(auprc)
|
| 353 |
+
else:
|
| 354 |
+
macro_auprc = float('nan')
|
| 355 |
+
|
| 356 |
+
return macro_auroc, macro_auprc, auroc, auprc
|
| 357 |
+
|
| 358 |
+
# Compute a modified confusion matrix for multi-class, multi-label tasks.
|
| 359 |
+
def compute_modified_confusion_matrix(labels, outputs):
|
| 360 |
+
# Compute a binary multi-class, multi-label confusion matrix, where the rows
|
| 361 |
+
# are the labels and the columns are the outputs.
|
| 362 |
+
num_recordings, num_classes = np.shape(labels)
|
| 363 |
+
A = np.zeros((num_classes, num_classes))
|
| 364 |
+
|
| 365 |
+
# Iterate over all of the recordings.
|
| 366 |
+
for i in range(num_recordings):
|
| 367 |
+
# Calculate the number of positive labels and/or outputs.
|
| 368 |
+
normalization = float(max(np.sum(np.any((labels[i, :], outputs[i, :]), axis=0)), 1))
|
| 369 |
+
# Iterate over all of the classes.
|
| 370 |
+
for j in range(num_classes):
|
| 371 |
+
# Assign full and/or partial credit for each positive class.
|
| 372 |
+
if labels[i, j]:
|
| 373 |
+
for k in range(num_classes):
|
| 374 |
+
if outputs[i, k]:
|
| 375 |
+
A[j, k] += 1.0/normalization
|
| 376 |
+
|
| 377 |
+
return A
|
| 378 |
+
|
| 379 |
+
# Compute the evaluation metric for the Challenge.
|
| 380 |
+
def compute_challenge_metric(weights, labels, outputs, classes, sinus_rhythm):
|
| 381 |
+
num_recordings, num_classes = np.shape(labels)
|
| 382 |
+
if sinus_rhythm in classes:
|
| 383 |
+
sinus_rhythm_index = classes.index(sinus_rhythm)
|
| 384 |
+
else:
|
| 385 |
+
raise ValueError('The sinus rhythm class is not available.')
|
| 386 |
+
|
| 387 |
+
# Compute the observed score.
|
| 388 |
+
A = compute_modified_confusion_matrix(labels, outputs)
|
| 389 |
+
observed_score = np.nansum(weights * A)
|
| 390 |
+
|
| 391 |
+
# Compute the score for the model that always chooses the correct label(s).
|
| 392 |
+
correct_outputs = labels
|
| 393 |
+
A = compute_modified_confusion_matrix(labels, correct_outputs)
|
| 394 |
+
correct_score = np.nansum(weights * A)
|
| 395 |
+
|
| 396 |
+
# Compute the score for the model that always chooses the sinus rhythm class.
|
| 397 |
+
inactive_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool_)
|
| 398 |
+
inactive_outputs[:, sinus_rhythm_index] = 1
|
| 399 |
+
A = compute_modified_confusion_matrix(labels, inactive_outputs)
|
| 400 |
+
inactive_score = np.nansum(weights * A)
|
| 401 |
+
|
| 402 |
+
if correct_score != inactive_score:
|
| 403 |
+
normalized_score = float(observed_score - inactive_score) / float(correct_score - inactive_score)
|
| 404 |
+
else:
|
| 405 |
+
normalized_score = 0.0
|
| 406 |
+
|
| 407 |
+
return normalized_score
|
| 408 |
+
|
| 409 |
+
if __name__ == '__main__':
|
| 410 |
+
classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric = evaluate_model(sys.argv[1], sys.argv[2])
|
| 411 |
+
output_string = 'AUROC,AUPRC,Accuracy,F-measure,Challenge metric\n{:.3f},{:.3f},{:.3f},{:.3f},{:.3f}'.format(auroc, auprc, accuracy, f_measure, challenge_metric)
|
| 412 |
+
class_output_string = 'Classes,{}\nAUROC,{}\nAUPRC,{}\nF-measure,{}'.format(
|
| 413 |
+
','.join('|'.join(sorted(x)) for x in classes),
|
| 414 |
+
','.join('{:.3f}'.format(x) for x in auroc_classes),
|
| 415 |
+
','.join('{:.3f}'.format(x) for x in auprc_classes),
|
| 416 |
+
','.join('{:.3f}'.format(x) for x in f_measure_classes))
|
| 417 |
+
|
| 418 |
+
print(output_string)
|
| 419 |
+
with open('test_outputs/output.txt', 'w') as f:
|
| 420 |
+
f.write(output_string)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
df = pd.DataFrame({'classes':classes,
|
| 424 |
+
'auroc_classes':auroc_classes,
|
| 425 |
+
'auprc_classes':auprc_classes,
|
| 426 |
+
'f_measure_classes':f_measure_classes})
|
| 427 |
+
# dxname = pd.read_csv('dx_mapping_scored.csv')
|
| 428 |
+
# dxname = dxname[['Dx','SNOMED CT Code']]
|
| 429 |
+
# dxname = dxname.rename(columns={'SNOMED CT Code':'classes'})
|
| 430 |
+
# df = df.merge(dxname,on='classes',how='left')
|
| 431 |
+
print(tabulate(df, headers='keys', tablefmt='psql'))
|
| 432 |
+
|
| 433 |
+
with open('test_outputs/table.txt', 'w') as f:
|
| 434 |
+
f.write(tabulate(df, headers='keys', tablefmt='psql'))
|
helper.ipynb
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"#### uniform stochastic depth"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "code",
|
| 12 |
+
"execution_count": 4,
|
| 13 |
+
"metadata": {},
|
| 14 |
+
"outputs": [
|
| 15 |
+
{
|
| 16 |
+
"name": "stdout",
|
| 17 |
+
"output_type": "stream",
|
| 18 |
+
"text": [
|
| 19 |
+
"24\n",
|
| 20 |
+
"25\n",
|
| 21 |
+
"[0.0, 0.004347826354205608, 0.008695652708411217, 0.013043479062616825, 0.017391305416822433, 0.021739132702350616, 0.02608695812523365, 0.030434783548116684, 0.03478261083364487, 0.03913043811917305, 0.04347826540470123, 0.04782608896493912, 0.052173912525177, 0.056521736085414886, 0.06086956337094307, 0.06521739065647125, 0.06956521421670914, 0.07391304522752762, 0.0782608687877655, 0.08260869979858398, 0.08695652335882187, 0.09130434691905975, 0.09565217792987823, 0.10000000149011612]\n"
|
| 22 |
+
]
|
| 23 |
+
}
|
| 24 |
+
],
|
| 25 |
+
"source": [
|
| 26 |
+
"import torch\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"drop_path_rate = 0.1\n",
|
| 29 |
+
"depth = 24\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"print(len(dpr))\n",
|
| 34 |
+
"inter_dpr = [0.0] + dpr\n",
|
| 35 |
+
"print(len(inter_dpr))\n",
|
| 36 |
+
"print(dpr)"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "code",
|
| 41 |
+
"execution_count": 1,
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"outputs": [
|
| 44 |
+
{
|
| 45 |
+
"name": "stdout",
|
| 46 |
+
"output_type": "stream",
|
| 47 |
+
"text": [
|
| 48 |
+
"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
|
| 49 |
+
"24\n"
|
| 50 |
+
]
|
| 51 |
+
}
|
| 52 |
+
],
|
| 53 |
+
"source": [
|
| 54 |
+
"import torch\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"drop_path_rate = 0\n",
|
| 57 |
+
"depth = 24\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"dpr = [x.item() for x in torch.full((depth,), drop_path_rate)]\n",
|
| 60 |
+
"print(dpr)\n",
|
| 61 |
+
"print(len(dpr))"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"execution_count": 3,
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"outputs": [
|
| 69 |
+
{
|
| 70 |
+
"name": "stdout",
|
| 71 |
+
"output_type": "stream",
|
| 72 |
+
"text": [
|
| 73 |
+
"tensor([[[ 0, 1, 2],\n",
|
| 74 |
+
" [ 3, 4, 5],\n",
|
| 75 |
+
" [ 6, 7, 8],\n",
|
| 76 |
+
" [ 9, 10, 11],\n",
|
| 77 |
+
" [12, 13, 14]],\n",
|
| 78 |
+
"\n",
|
| 79 |
+
" [[15, 16, 17],\n",
|
| 80 |
+
" [18, 19, 20],\n",
|
| 81 |
+
" [21, 22, 23],\n",
|
| 82 |
+
" [24, 25, 26],\n",
|
| 83 |
+
" [27, 28, 29]]])\n",
|
| 84 |
+
"-------------------------\n",
|
| 85 |
+
"tensor([[[ 2, 1, 0],\n",
|
| 86 |
+
" [ 5, 4, 3],\n",
|
| 87 |
+
" [ 8, 7, 6],\n",
|
| 88 |
+
" [11, 10, 9],\n",
|
| 89 |
+
" [14, 13, 12]],\n",
|
| 90 |
+
"\n",
|
| 91 |
+
" [[17, 16, 15],\n",
|
| 92 |
+
" [20, 19, 18],\n",
|
| 93 |
+
" [23, 22, 21],\n",
|
| 94 |
+
" [26, 25, 24],\n",
|
| 95 |
+
" [29, 28, 27]]])\n",
|
| 96 |
+
"-------------------------\n",
|
| 97 |
+
"tensor([[[12, 13, 14],\n",
|
| 98 |
+
" [ 9, 10, 11],\n",
|
| 99 |
+
" [ 6, 7, 8],\n",
|
| 100 |
+
" [ 3, 4, 5],\n",
|
| 101 |
+
" [ 0, 1, 2]],\n",
|
| 102 |
+
"\n",
|
| 103 |
+
" [[27, 28, 29],\n",
|
| 104 |
+
" [24, 25, 26],\n",
|
| 105 |
+
" [21, 22, 23],\n",
|
| 106 |
+
" [18, 19, 20],\n",
|
| 107 |
+
" [15, 16, 17]]])\n"
|
| 108 |
+
]
|
| 109 |
+
}
|
| 110 |
+
],
|
| 111 |
+
"source": [
|
| 112 |
+
"import torch\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"x = torch.arange(30).view(2, 5, 3)\n",
|
| 115 |
+
"print(x)\n",
|
| 116 |
+
"print(\"-------------------------\")\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"y = x.flip([-1])\n",
|
| 119 |
+
"print(y)\n",
|
| 120 |
+
"print(\"-------------------------\")\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"z = x.flip([1])\n",
|
| 123 |
+
"print(z)"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"cell_type": "code",
|
| 128 |
+
"execution_count": 2,
|
| 129 |
+
"metadata": {},
|
| 130 |
+
"outputs": [
|
| 131 |
+
{
|
| 132 |
+
"name": "stdout",
|
| 133 |
+
"output_type": "stream",
|
| 134 |
+
"text": [
|
| 135 |
+
"[0, 2, 3, 1]\n",
|
| 136 |
+
"This is the i: 0\n",
|
| 137 |
+
"This is the idx: 0\n",
|
| 138 |
+
"This is the i: 1\n",
|
| 139 |
+
"This is the idx: 2\n",
|
| 140 |
+
"This is the i: 2\n",
|
| 141 |
+
"This is the idx: 3\n",
|
| 142 |
+
"This is the i: 3\n",
|
| 143 |
+
"This is the idx: 1\n"
|
| 144 |
+
]
|
| 145 |
+
}
|
| 146 |
+
],
|
| 147 |
+
"source": [
|
| 148 |
+
"import random\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"indices = [0, 1, 2, 3]\n",
|
| 151 |
+
"random.shuffle(indices)\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"print(indices)\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"for i, idx in enumerate(indices):\n",
|
| 156 |
+
" print(\"This is the i:\", i)\n",
|
| 157 |
+
" print(\"This is the idx:\", idx)"
|
| 158 |
+
]
|
| 159 |
+
}
|
| 160 |
+
],
|
| 161 |
+
"metadata": {
|
| 162 |
+
"kernelspec": {
|
| 163 |
+
"display_name": "mamba",
|
| 164 |
+
"language": "python",
|
| 165 |
+
"name": "python3"
|
| 166 |
+
},
|
| 167 |
+
"language_info": {
|
| 168 |
+
"codemirror_mode": {
|
| 169 |
+
"name": "ipython",
|
| 170 |
+
"version": 3
|
| 171 |
+
},
|
| 172 |
+
"file_extension": ".py",
|
| 173 |
+
"mimetype": "text/x-python",
|
| 174 |
+
"name": "python",
|
| 175 |
+
"nbconvert_exporter": "python",
|
| 176 |
+
"pygments_lexer": "ipython3",
|
| 177 |
+
"version": "3.10.14"
|
| 178 |
+
}
|
| 179 |
+
},
|
| 180 |
+
"nbformat": 4,
|
| 181 |
+
"nbformat_minor": 2
|
| 182 |
+
}
|
helper_code.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Do *not* edit this script.
|
| 4 |
+
# These are helper functions that you can use with your code.
|
| 5 |
+
|
| 6 |
+
import os, numpy as np
|
| 7 |
+
|
| 8 |
+
# Check if a variable is a number or represents a number.
|
| 9 |
+
def is_number(x):
|
| 10 |
+
try:
|
| 11 |
+
float(x)
|
| 12 |
+
return True
|
| 13 |
+
except (ValueError, TypeError):
|
| 14 |
+
return False
|
| 15 |
+
|
| 16 |
+
# Check if a variable is an integer or represents an integer.
|
| 17 |
+
def is_integer(x):
|
| 18 |
+
if is_number(x):
|
| 19 |
+
return float(x).is_integer()
|
| 20 |
+
else:
|
| 21 |
+
return False
|
| 22 |
+
|
| 23 |
+
# Check if a variable is a a finite number or represents a finite number.
|
| 24 |
+
def is_finite_number(x):
|
| 25 |
+
if is_number(x):
|
| 26 |
+
return np.isfinite(float(x))
|
| 27 |
+
else:
|
| 28 |
+
return False
|
| 29 |
+
|
| 30 |
+
# (Re)sort leads using the standard order of leads for the standard twelve-lead ECG.
|
| 31 |
+
def sort_leads(leads):
|
| 32 |
+
x = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
|
| 33 |
+
leads = sorted(leads, key=lambda lead: (x.index(lead) if lead in x else len(x) + leads.index(lead)))
|
| 34 |
+
return tuple(leads)
|
| 35 |
+
|
| 36 |
+
# Find header and recording files.
|
| 37 |
+
def find_challenge_files(data_directory):
|
| 38 |
+
header_files = list()
|
| 39 |
+
recording_files = list()
|
| 40 |
+
for f in os.listdir(data_directory):
|
| 41 |
+
root, extension = os.path.splitext(f)
|
| 42 |
+
if not root.startswith('.') and extension=='.hea':
|
| 43 |
+
header_file = os.path.join(data_directory, root + '.hea')
|
| 44 |
+
recording_file = os.path.join(data_directory, root + '.mat')
|
| 45 |
+
if os.path.isfile(header_file) and os.path.isfile(recording_file):
|
| 46 |
+
header_files.append(header_file)
|
| 47 |
+
recording_files.append(recording_file)
|
| 48 |
+
return header_files, recording_files
|
| 49 |
+
|
| 50 |
+
# Load header file as a string.
|
| 51 |
+
def load_header(header_file):
|
| 52 |
+
with open(header_file, 'r') as f:
|
| 53 |
+
header = f.read()
|
| 54 |
+
return header
|
| 55 |
+
|
| 56 |
+
# Load recording file as an array.
|
| 57 |
+
def load_recording(recording_file, header=None, leads=None, key='val'):
|
| 58 |
+
from scipy.io import loadmat
|
| 59 |
+
recording = loadmat(recording_file)[key]
|
| 60 |
+
if header and leads:
|
| 61 |
+
recording = choose_leads(recording, header, leads)
|
| 62 |
+
return recording
|
| 63 |
+
|
| 64 |
+
# Choose leads from the recording file.
|
| 65 |
+
def choose_leads(recording, header, leads):
|
| 66 |
+
num_leads = len(leads)
|
| 67 |
+
num_samples = np.shape(recording)[1]
|
| 68 |
+
chosen_recording = np.zeros((num_leads, num_samples), recording.dtype)
|
| 69 |
+
available_leads = get_leads(header)
|
| 70 |
+
for i, lead in enumerate(leads):
|
| 71 |
+
if lead in available_leads:
|
| 72 |
+
j = available_leads.index(lead)
|
| 73 |
+
chosen_recording[i, :] = recording[j, :]
|
| 74 |
+
return chosen_recording
|
| 75 |
+
|
| 76 |
+
# Get recording ID.
|
| 77 |
+
def get_recording_id(header):
|
| 78 |
+
recording_id = None
|
| 79 |
+
for i, l in enumerate(header.split('\n')):
|
| 80 |
+
if i==0:
|
| 81 |
+
try:
|
| 82 |
+
recording_id = l.split(' ')[0]
|
| 83 |
+
except:
|
| 84 |
+
pass
|
| 85 |
+
else:
|
| 86 |
+
break
|
| 87 |
+
return recording_id
|
| 88 |
+
|
| 89 |
+
# Get leads from header.
|
| 90 |
+
def get_leads(header):
|
| 91 |
+
leads = list()
|
| 92 |
+
for i, l in enumerate(header.split('\n')):
|
| 93 |
+
entries = l.split(' ')
|
| 94 |
+
if i==0:
|
| 95 |
+
num_leads = int(entries[1])
|
| 96 |
+
elif i<=num_leads:
|
| 97 |
+
leads.append(entries[-1])
|
| 98 |
+
else:
|
| 99 |
+
break
|
| 100 |
+
return tuple(leads)
|
| 101 |
+
|
| 102 |
+
# Get age from header.
|
| 103 |
+
def get_age(header):
|
| 104 |
+
age = None
|
| 105 |
+
for l in header.split('\n'):
|
| 106 |
+
if l.startswith('# Age'):
|
| 107 |
+
try:
|
| 108 |
+
age = float(l.split(': ')[1].strip())
|
| 109 |
+
except:
|
| 110 |
+
age = float('nan')
|
| 111 |
+
return age
|
| 112 |
+
|
| 113 |
+
# Get sex from header.
|
| 114 |
+
def get_sex(header):
|
| 115 |
+
sex = None
|
| 116 |
+
for l in header.split('\n'):
|
| 117 |
+
if l.startswith('# Sex'):
|
| 118 |
+
try:
|
| 119 |
+
sex = l.split(': ')[1].strip()
|
| 120 |
+
except:
|
| 121 |
+
pass
|
| 122 |
+
return sex
|
| 123 |
+
|
| 124 |
+
# Get frequency from header.
|
| 125 |
+
def get_num_leads(header):
|
| 126 |
+
num_leads = None
|
| 127 |
+
for i, l in enumerate(header.split('\n')):
|
| 128 |
+
if i==0:
|
| 129 |
+
try:
|
| 130 |
+
num_samples = float(l.split(' ')[1])
|
| 131 |
+
except:
|
| 132 |
+
pass
|
| 133 |
+
else:
|
| 134 |
+
break
|
| 135 |
+
return num_leads
|
| 136 |
+
|
| 137 |
+
# Get frequency from header.
|
| 138 |
+
def get_frequency(header):
|
| 139 |
+
frequency = None
|
| 140 |
+
for i, l in enumerate(header.split('\n')):
|
| 141 |
+
if i==0:
|
| 142 |
+
try:
|
| 143 |
+
frequency = float(l.split(' ')[2])
|
| 144 |
+
except:
|
| 145 |
+
pass
|
| 146 |
+
else:
|
| 147 |
+
break
|
| 148 |
+
return frequency
|
| 149 |
+
|
| 150 |
+
# Get number of samples from header.
|
| 151 |
+
def get_num_samples(header):
|
| 152 |
+
num_samples = None
|
| 153 |
+
for i, l in enumerate(header.split('\n')):
|
| 154 |
+
if i==0:
|
| 155 |
+
try:
|
| 156 |
+
num_samples = float(l.split(' ')[3])
|
| 157 |
+
except:
|
| 158 |
+
pass
|
| 159 |
+
else:
|
| 160 |
+
break
|
| 161 |
+
return num_samples
|
| 162 |
+
|
| 163 |
+
# Get analog-to-digital converter (ADC) gains from header.
|
| 164 |
+
def get_adc_gains(header, leads):
|
| 165 |
+
adc_gains = np.zeros(len(leads))
|
| 166 |
+
for i, l in enumerate(header.split('\n')):
|
| 167 |
+
entries = l.split(' ')
|
| 168 |
+
if i==0:
|
| 169 |
+
num_leads = int(entries[1])
|
| 170 |
+
elif i<=num_leads:
|
| 171 |
+
current_lead = entries[-1]
|
| 172 |
+
if current_lead in leads:
|
| 173 |
+
j = leads.index(current_lead)
|
| 174 |
+
try:
|
| 175 |
+
adc_gains[j] = float(entries[2].split('/')[0])
|
| 176 |
+
except:
|
| 177 |
+
pass
|
| 178 |
+
else:
|
| 179 |
+
break
|
| 180 |
+
return adc_gains
|
| 181 |
+
|
| 182 |
+
# Get baselines from header.
|
| 183 |
+
def get_baselines(header, leads):
|
| 184 |
+
baselines = np.zeros(len(leads))
|
| 185 |
+
for i, l in enumerate(header.split('\n')):
|
| 186 |
+
entries = l.split(' ')
|
| 187 |
+
if i==0:
|
| 188 |
+
num_leads = int(entries[1])
|
| 189 |
+
elif i<=num_leads:
|
| 190 |
+
current_lead = entries[-1]
|
| 191 |
+
if current_lead in leads:
|
| 192 |
+
j = leads.index(current_lead)
|
| 193 |
+
try:
|
| 194 |
+
baselines[j] = float(entries[4].split('/')[0])
|
| 195 |
+
except:
|
| 196 |
+
pass
|
| 197 |
+
else:
|
| 198 |
+
break
|
| 199 |
+
return baselines
|
| 200 |
+
|
| 201 |
+
# Get labels from header.
|
| 202 |
+
def get_labels(header):
|
| 203 |
+
labels = list()
|
| 204 |
+
for l in header.split('\n'):
|
| 205 |
+
if l.startswith('# Dx'):
|
| 206 |
+
try:
|
| 207 |
+
entries = l.split(': ')[1].split(',')
|
| 208 |
+
for entry in entries:
|
| 209 |
+
labels.append(entry.strip())
|
| 210 |
+
except:
|
| 211 |
+
pass
|
| 212 |
+
return labels
|
| 213 |
+
|
| 214 |
+
# Save outputs from model.
|
| 215 |
+
def save_outputs(output_file, recording_id, classes, labels, probabilities):
|
| 216 |
+
# Format the model outputs.
|
| 217 |
+
recording_string = '#{}'.format(recording_id)
|
| 218 |
+
class_string = ','.join(str(c) for c in classes)
|
| 219 |
+
label_string = ','.join(str(l) for l in labels)
|
| 220 |
+
probabilities_string = ','.join(str(p) for p in probabilities)
|
| 221 |
+
output_string = recording_string + '\n' + class_string + '\n' + label_string + '\n' + probabilities_string + '\n'
|
| 222 |
+
|
| 223 |
+
# Save the model outputs.
|
| 224 |
+
with open(output_file, 'w') as f:
|
| 225 |
+
f.write(output_string)
|
| 226 |
+
|
| 227 |
+
# Load outputs from model.
|
| 228 |
+
def load_outputs(output_file):
|
| 229 |
+
with open(output_file, 'r') as f:
|
| 230 |
+
for i, l in enumerate(f):
|
| 231 |
+
if i==0:
|
| 232 |
+
recording_id = l[1:] if len(l)>1 else None
|
| 233 |
+
elif i==1:
|
| 234 |
+
classes = tuple(entry.strip() for entry in l.split(','))
|
| 235 |
+
elif i==2:
|
| 236 |
+
labels = tuple(entry.strip() for entry in l.split(','))
|
| 237 |
+
elif i==3:
|
| 238 |
+
probabilities = tuple(float(entry) if is_finite_number(entry) else float('nan') for entry in l.split(','))
|
| 239 |
+
else:
|
| 240 |
+
break
|
| 241 |
+
return recording_id, classes, labels, probabilities
|
losses.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
"""
|
| 4 |
+
Implements the knowledge distillation loss
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DistillationLoss(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
This module wraps a standard criterion and adds an extra knowledge distillation loss by
|
| 13 |
+
taking a teacher model prediction and using it as additional supervision.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
|
| 16 |
+
distillation_type: str, alpha: float, tau: float):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.base_criterion = base_criterion
|
| 19 |
+
self.teacher_model = teacher_model
|
| 20 |
+
assert distillation_type in ['none', 'soft', 'hard']
|
| 21 |
+
self.distillation_type = distillation_type
|
| 22 |
+
self.alpha = alpha
|
| 23 |
+
self.tau = tau
|
| 24 |
+
|
| 25 |
+
def forward(self, inputs, outputs, labels):
|
| 26 |
+
"""
|
| 27 |
+
Args:
|
| 28 |
+
inputs: The original inputs that are feed to the teacher model
|
| 29 |
+
outputs: the outputs of the model to be trained. It is expected to be
|
| 30 |
+
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
|
| 31 |
+
in the first position and the distillation predictions as the second output
|
| 32 |
+
labels: the labels for the base criterion
|
| 33 |
+
"""
|
| 34 |
+
outputs_kd = None
|
| 35 |
+
if not isinstance(outputs, torch.Tensor):
|
| 36 |
+
# assume that the model outputs a tuple of [outputs, outputs_kd]
|
| 37 |
+
outputs, outputs_kd = outputs
|
| 38 |
+
base_loss = self.base_criterion(outputs, labels)
|
| 39 |
+
if self.distillation_type == 'none':
|
| 40 |
+
return base_loss
|
| 41 |
+
|
| 42 |
+
if outputs_kd is None:
|
| 43 |
+
raise ValueError("When knowledge distillation is enabled, the model is "
|
| 44 |
+
"expected to return a Tuple[Tensor, Tensor] with the output of the "
|
| 45 |
+
"class_token and the dist_token")
|
| 46 |
+
# don't backprop throught the teacher
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
teacher_outputs = self.teacher_model(inputs)
|
| 49 |
+
|
| 50 |
+
if self.distillation_type == 'soft':
|
| 51 |
+
T = self.tau
|
| 52 |
+
# taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
| 53 |
+
# with slight modifications
|
| 54 |
+
distillation_loss = F.kl_div(
|
| 55 |
+
F.log_softmax(outputs_kd / T, dim=1),
|
| 56 |
+
#We provide the teacher's targets in log probability because we use log_target=True
|
| 57 |
+
#(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
|
| 58 |
+
#but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
|
| 59 |
+
F.log_softmax(teacher_outputs / T, dim=1),
|
| 60 |
+
reduction='sum',
|
| 61 |
+
log_target=True
|
| 62 |
+
) * (T * T) / outputs_kd.numel()
|
| 63 |
+
#We divide by outputs_kd.numel() to have the legacy PyTorch behavior.
|
| 64 |
+
#But we also experiments output_kd.size(0)
|
| 65 |
+
#see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details
|
| 66 |
+
elif self.distillation_type == 'hard':
|
| 67 |
+
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
|
| 68 |
+
|
| 69 |
+
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
|
| 70 |
+
return loss
|
main_ecg.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import utils
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch.backends.cudnn as cudnn
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import time
|
| 9 |
+
import datetime
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# mixup
|
| 13 |
+
from timm.data import Mixup
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# log about
|
| 17 |
+
# import mlflow
|
| 18 |
+
|
| 19 |
+
# for the challenge 2021
|
| 20 |
+
import ecg_dataset_2021
|
| 21 |
+
import engine_ecg_2021
|
| 22 |
+
|
| 23 |
+
# for the challenge 2020
|
| 24 |
+
import ecg_dataset_2020
|
| 25 |
+
import engine_ecg_2020
|
| 26 |
+
|
| 27 |
+
from engine_ecg_2021 import train_one_epoch, evaluate
|
| 28 |
+
|
| 29 |
+
# for the timm
|
| 30 |
+
from timm.models import create_model
|
| 31 |
+
import timm
|
| 32 |
+
from timm.scheduler.cosine_lr import CosineLRScheduler
|
| 33 |
+
|
| 34 |
+
import sys
|
| 35 |
+
|
| 36 |
+
from optimizer import NoamOpt
|
| 37 |
+
|
| 38 |
+
import models_mamba_ecg
|
| 39 |
+
|
| 40 |
+
def get_args_parser():
|
| 41 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
|
| 42 |
+
parser.add_argument('--batch-size', default=64, type=int)
|
| 43 |
+
parser.add_argument('--epochs', default=300, type=int)
|
| 44 |
+
|
| 45 |
+
# Model parameters
|
| 46 |
+
parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', help='Name of model to train')
|
| 47 |
+
|
| 48 |
+
# Parameters regarding the Drop
|
| 49 |
+
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)')
|
| 50 |
+
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', help='Drop path rate (default: 0.1)')
|
| 51 |
+
|
| 52 |
+
# setting the mode
|
| 53 |
+
parser.add_argument('--train-mode', action='store_true')
|
| 54 |
+
parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
|
| 55 |
+
parser.set_defaults(train_mode=True)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# * data augmentation method
|
| 59 |
+
parser.add_argument('--mixup', type=float, default=0,
|
| 60 |
+
help='mixup alpha, mixup enabled if > 0. (default: 0)')
|
| 61 |
+
parser.add_argument('--cutmix', type=float, default=0,
|
| 62 |
+
help='cutmix alpha, cutmix enabled if > 0. (default: 0)')
|
| 63 |
+
parser.add_argument('--mixup_no_label', type=float, default=0,
|
| 64 |
+
help='this is the mixup but without interpolation of label. (default: 0)')
|
| 65 |
+
parser.add_argument('--progressive_switch', type=bool, default=False,
|
| 66 |
+
help='this is the switch of progressive data augmentation method')
|
| 67 |
+
|
| 68 |
+
# the output directory
|
| 69 |
+
parser.add_argument('--output_dir', default='', help='path where to save, empty for no saving')
|
| 70 |
+
|
| 71 |
+
# the device information
|
| 72 |
+
parser.add_argument('--device', default='cuda', help='device to use for training / testing')
|
| 73 |
+
|
| 74 |
+
parser.add_argument('--seed', default=0, type=int)
|
| 75 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
| 79 |
+
|
| 80 |
+
parser.add_argument('--pin-mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
| 81 |
+
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', help='')
|
| 82 |
+
parser.set_defaults(pin_mem=True)
|
| 83 |
+
|
| 84 |
+
# special parameters for this strategy
|
| 85 |
+
parser.add_argument('--depth', default=24, type=int)
|
| 86 |
+
parser.add_argument('--block', default='VisionMamba', choices=['OriginalMamba', 'VisionMamba'], type=str, help='the selection of the block of the paradigm')
|
| 87 |
+
parser.add_argument('--lead', default='12Lead', choices=['12Lead', 'RandomLead'], type=str, help='the number of Lead')
|
| 88 |
+
parser.add_argument('--lrschedule', default='CosineAnnealing', choices=['Noam', 'CosineAnnealing'], type=str, help='the selection of learning rate strategy')
|
| 89 |
+
|
| 90 |
+
# if use random token position
|
| 91 |
+
parser.add_argument('--if_random_cls_token_position', action='store_true')
|
| 92 |
+
parser.add_argument('--no_randargs.if_nan2numom_cls_token_position', action='store_false', dest='if_random_cls_token_position')
|
| 93 |
+
parser.set_defaults(if_random_cls_token_position=False)
|
| 94 |
+
|
| 95 |
+
# if use random token rank
|
| 96 |
+
parser.add_argument('--if_random_token_rank', action='store_true')
|
| 97 |
+
parser.add_argument('--no_random_token_rank', action='store_false', dest='if_random_token_rank')
|
| 98 |
+
parser.set_defaults(if_random_token_rank=False)
|
| 99 |
+
|
| 100 |
+
# using for the journal
|
| 101 |
+
parser.add_argument('--fused_add_norm', type=bool, default=True, help='combines the element-wise addition (from residual connections) and normalization (e.g., RMSNorm)')
|
| 102 |
+
parser.add_argument('--if_divide_out', type=bool, default=True, help='Should the result be divided by two when combining outputs from two directions?')
|
| 103 |
+
parser.add_argument('--use_middle_cls_token', type=bool, default=True, help='Whether the class is inserted in the middle of sequence')
|
| 104 |
+
|
| 105 |
+
# the switch of scenario
|
| 106 |
+
parser.add_argument('--challenge_scenario', default='2021', choices=[2021, 2020], type=int, help='the scenario of Challenge')
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
return parser
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def collate(batch):
|
| 113 |
+
# Left-zero padding
|
| 114 |
+
ch = batch[0][0].shape[0]
|
| 115 |
+
maxL = 8192
|
| 116 |
+
X = np.zeros((len(batch), ch, maxL))
|
| 117 |
+
|
| 118 |
+
for i in range(len(batch)):
|
| 119 |
+
X[i, :, -batch[i][0].shape[-1]:] = batch[i][0]
|
| 120 |
+
|
| 121 |
+
t = np.array([b[1] for b in batch])
|
| 122 |
+
|
| 123 |
+
X = torch.from_numpy(X)
|
| 124 |
+
t = torch.from_numpy(t)
|
| 125 |
+
|
| 126 |
+
return X, t
|
| 127 |
+
|
| 128 |
+
def main(args, data_directory, model_directory, group_number):
|
| 129 |
+
|
| 130 |
+
train_log_fp = open(args.output_dir + '/train_log_group_%d.txt' % group_number, 'a')
|
| 131 |
+
print(args)
|
| 132 |
+
train_log_fp.write("The is the configuration: {}\n".format(args))
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
device = torch.device(args.device)
|
| 136 |
+
|
| 137 |
+
print("This is the running device", device)
|
| 138 |
+
train_log_fp.write("This is the running device:{}\n".format(device))
|
| 139 |
+
|
| 140 |
+
seed = args.seed + utils.get_rank()
|
| 141 |
+
torch.manual_seed(seed)
|
| 142 |
+
np.random.seed(seed)
|
| 143 |
+
|
| 144 |
+
print("This is the random's seed:", seed)
|
| 145 |
+
train_log_fp.write("This is the random's seed:{}\n".format(seed))
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
cudnn.benchmark = True
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
dataset_file_all_address = "../../collection_of_all_datasets/"
|
| 152 |
+
# below is building dataset(train and validation)
|
| 153 |
+
|
| 154 |
+
print("This is building the training set:")
|
| 155 |
+
############ training area #########################
|
| 156 |
+
the_training_address = data_directory + "/training_group" + data_directory[-1] + ".csv"
|
| 157 |
+
df = pd.read_csv(the_training_address)
|
| 158 |
+
print("Total the {} files will be fed into model for training".format(len(df['Name'])))
|
| 159 |
+
train_log_fp.write("Total the {} files will be fed into model for training \n".format(len(df['Name'])))
|
| 160 |
+
|
| 161 |
+
print("This is the first file of training:",df['Name'][0])
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
training_header_files=[]
|
| 165 |
+
|
| 166 |
+
for i in range(len(df['Name'])):
|
| 167 |
+
each_header_file = dataset_file_all_address + df['Name'][i]
|
| 168 |
+
training_header_files.append(each_header_file)
|
| 169 |
+
|
| 170 |
+
collate_training = collate
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if args.mixup > 0:
|
| 174 |
+
print("The mixup is using, and the percentage is:", args.mixup)
|
| 175 |
+
train_log_fp.write("The mixup is using, and the percentage is:{}\n".format(args.mixup))
|
| 176 |
+
# collate_training = collate_mixup
|
| 177 |
+
|
| 178 |
+
if args.cutmix > 0:
|
| 179 |
+
print("The cutmix is using, and the percentage is:", args.cutmix)
|
| 180 |
+
train_log_fp.write("The cutmix is using, and the percentage is:{}\n".format(args.cutmix))
|
| 181 |
+
if args.mixup_no_label > 0:
|
| 182 |
+
print("The mixup is using without label's interpolation , and the percentage is:", args.mixup_no_label)
|
| 183 |
+
train_log_fp.write("The mixup is using without label's interpolation , and the percentage is:{}\n".format(args.mixup_no_label))
|
| 184 |
+
if args.progressive_switch:
|
| 185 |
+
print("The progressive mixup is using.")
|
| 186 |
+
train_log_fp.write("The progressive mixup is using.")
|
| 187 |
+
|
| 188 |
+
if args.challenge_scenario == 2021:
|
| 189 |
+
train_dataset = ecg_dataset_2021.dataset(training_header_files, Mixup = args.mixup, amount = len(df['Name']), cutMix=args.cutmix, Mixup_no_label_interpolate=args.mixup_no_label, progressive_switch=args.progressive_switch)
|
| 190 |
+
elif args.challenge_scenario == 2020:
|
| 191 |
+
train_dataset = ecg_dataset_2020.dataset(training_header_files, Mixup = args.mixup, amount = len(df['Name']), cutMix=args.cutmix, Mixup_no_label_interpolate=args.mixup_no_label, progressive_switch=args.progressive_switch)
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError(f"No matching condition for value: {args.challenge_scenario}")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if args.lead == "12Lead":
|
| 197 |
+
lead_number = 12
|
| 198 |
+
print("This 12 lead is using.")
|
| 199 |
+
train_log_fp.write("This 12 lead is using.")
|
| 200 |
+
elif args.lead == "RandomLead":
|
| 201 |
+
lead_number = None
|
| 202 |
+
print("This random lead is using.")
|
| 203 |
+
train_log_fp.write("This random lead is using.")
|
| 204 |
+
else:
|
| 205 |
+
raise ValueError(f"No matching condition for value: {args.lead}")
|
| 206 |
+
|
| 207 |
+
"""
|
| 208 |
+
we filter out the sample which the length is 8192 via the below code.
|
| 209 |
+
just like the random shift windows.
|
| 210 |
+
"""
|
| 211 |
+
train_dataset.num_leads = lead_number
|
| 212 |
+
train_dataset.sample = True
|
| 213 |
+
###################################################
|
| 214 |
+
print("done")
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
print("This is building the validation set:")
|
| 218 |
+
############ testing area #########################
|
| 219 |
+
the_testing_address = data_directory + "/testing_group" + data_directory[-1]+".csv"
|
| 220 |
+
df = pd.read_csv(the_testing_address)
|
| 221 |
+
print("Total the {} files will be used as testing".format(len(df['Name'])))
|
| 222 |
+
train_log_fp.write("Total the {} files will be used as testing \n".format(len(df['Name'])))
|
| 223 |
+
|
| 224 |
+
print("This is the first file of testing:",df['Name'][0])
|
| 225 |
+
|
| 226 |
+
testing_header_files=[]
|
| 227 |
+
|
| 228 |
+
for i in range(len(df['Name'])):
|
| 229 |
+
each_header_file = dataset_file_all_address + df['Name'][i]
|
| 230 |
+
testing_header_files.append(each_header_file)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
if args.challenge_scenario == 2021:
|
| 234 |
+
test_dataset = ecg_dataset_2021.dataset(testing_header_files, Mixup = 0)
|
| 235 |
+
elif args.challenge_scenario == 2020:
|
| 236 |
+
test_dataset = ecg_dataset_2020.dataset(testing_header_files, Mixup = 0)
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError(f"No matching condition for value: {args.challenge_scenario}")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
test_dataset.num_leads = 12
|
| 242 |
+
test_dataset.sample = True
|
| 243 |
+
###################################################
|
| 244 |
+
print("done")
|
| 245 |
+
|
| 246 |
+
sampler_train = torch.utils.data.RandomSampler(train_dataset)
|
| 247 |
+
sampler_val = torch.utils.data.SequentialSampler(test_dataset)
|
| 248 |
+
|
| 249 |
+
data_loader_train = torch.utils.data.DataLoader(
|
| 250 |
+
train_dataset, sampler=sampler_train,
|
| 251 |
+
batch_size=args.batch_size,
|
| 252 |
+
collate_fn=collate_training,
|
| 253 |
+
num_workers=args.num_workers,
|
| 254 |
+
pin_memory=args.pin_mem,
|
| 255 |
+
drop_last=True,
|
| 256 |
+
)
|
| 257 |
+
# Setting `drop_last=True` means that this incomplete batch will be dropped,
|
| 258 |
+
# ensuring that all batches fed to the model during training have **the same size.**
|
| 259 |
+
|
| 260 |
+
data_loader_val = torch.utils.data.DataLoader(
|
| 261 |
+
test_dataset, sampler=sampler_val,
|
| 262 |
+
batch_size=int(1.5 * args.batch_size),
|
| 263 |
+
collate_fn=collate,
|
| 264 |
+
num_workers=args.num_workers,
|
| 265 |
+
pin_memory=args.pin_mem,
|
| 266 |
+
drop_last=False
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
if args.challenge_scenario == 2021:
|
| 270 |
+
args.nb_classes = 26
|
| 271 |
+
elif args.challenge_scenario == 2020:
|
| 272 |
+
args.nb_classes = 27
|
| 273 |
+
else:
|
| 274 |
+
raise ValueError(f"No matching condition for value: {args.challenge_scenario}")
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
print(f"Creating model: {args.model}")
|
| 278 |
+
train_log_fp.write("Creating model:{}\n".format(args.model))
|
| 279 |
+
model = create_model(
|
| 280 |
+
args.model,
|
| 281 |
+
pretrained=False,
|
| 282 |
+
num_classes=args.nb_classes,
|
| 283 |
+
drop_rate=args.drop,
|
| 284 |
+
drop_path_rate=args.drop_path,
|
| 285 |
+
drop_block_rate=None,
|
| 286 |
+
block = args.block,
|
| 287 |
+
depth = args.depth,
|
| 288 |
+
fused_add_norm = args.fused_add_norm,
|
| 289 |
+
use_middle_cls_token = args.use_middle_cls_token,
|
| 290 |
+
img_size=8192
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
model.to(device)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 298 |
+
|
| 299 |
+
print('number of params:', n_parameters)
|
| 300 |
+
train_log_fp.write("number of params:{}\n".format(n_parameters))
|
| 301 |
+
|
| 302 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.0006, betas=(0.9, 0.98), eps=1e-9)
|
| 303 |
+
|
| 304 |
+
if args.lrschedule == "Noam":
|
| 305 |
+
optimizer = NoamOpt(729, 1, 4000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
|
| 306 |
+
print("The Noam is using as learning rate strategy.")
|
| 307 |
+
train_log_fp.write("The Noam is using as learning rate strategy.")
|
| 308 |
+
|
| 309 |
+
elif args.lrschedule == "CosineAnnealing":
|
| 310 |
+
lr_scheduler = CosineLRScheduler(
|
| 311 |
+
optimizer,
|
| 312 |
+
t_initial=13, # Number of epochs after warmup
|
| 313 |
+
lr_min=1e-6,
|
| 314 |
+
warmup_lr_init = 1e-5,
|
| 315 |
+
warmup_t=5,
|
| 316 |
+
cycle_limit=1,
|
| 317 |
+
t_in_epochs=True,
|
| 318 |
+
warmup_prefix=True # Added to reach initial_lr
|
| 319 |
+
)
|
| 320 |
+
print("The cosing annealing is using as learning rate strategy.")
|
| 321 |
+
train_log_fp.write("The cosing annealing is using as learning rate strategy.")
|
| 322 |
+
else:
|
| 323 |
+
raise ValueError(f"No matching condition for value: {args.lrschedule}")
|
| 324 |
+
|
| 325 |
+
# below is for the cosine annealing schedule
|
| 326 |
+
|
| 327 |
+
# the loss is set as normal BCE loss.
|
| 328 |
+
criterion = torch.nn.BCEWithLogitsLoss()
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
output_dir = Path(args.output_dir)
|
| 332 |
+
|
| 333 |
+
print(f"Start training for {args.epochs} epochs")
|
| 334 |
+
train_log_fp.write(f"Start training for {args.epochs} epochs\n")
|
| 335 |
+
|
| 336 |
+
start_time = time.time()
|
| 337 |
+
|
| 338 |
+
max_AUPRC = 0.0
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
train_auprc_list = []
|
| 342 |
+
testing_AUPRC_list = []
|
| 343 |
+
loss_value_after_each_epcoh_list = []
|
| 344 |
+
loss_value_testing_list = []
|
| 345 |
+
|
| 346 |
+
testing_auroc_list = []
|
| 347 |
+
testing_f1_list = []
|
| 348 |
+
|
| 349 |
+
testing_subset_accuracy_list = []
|
| 350 |
+
hamming_loss_list=[]
|
| 351 |
+
challenge_score_list = []
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
jump_count = 0
|
| 355 |
+
|
| 356 |
+
thrs_list = []
|
| 357 |
+
# loss_value_after_each_epcoh_testing_list = []
|
| 358 |
+
# epoch is started from 0
|
| 359 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 360 |
+
|
| 361 |
+
train_log_fp.write("\n")
|
| 362 |
+
train_log_fp.write("---------------------------------------------------------- \n")
|
| 363 |
+
train_log_fp.write("---------------------------------------------------------- \n")
|
| 364 |
+
train_log_fp.write("---------------------------------------------------------- \n")
|
| 365 |
+
|
| 366 |
+
now = datetime.datetime.now()
|
| 367 |
+
current_time = now.strftime("%H:%M:%S")
|
| 368 |
+
print("Current Time =", current_time)
|
| 369 |
+
train_log_fp.write(f"Current Time = {current_time} \n")
|
| 370 |
+
train_log_fp.write(f"Current epoch: {epoch} \n")
|
| 371 |
+
|
| 372 |
+
train_dataset.set_epoch(epoch) # Update the current epoch in the dataset
|
| 373 |
+
|
| 374 |
+
if args.lrschedule == "CosineAnnealing":
|
| 375 |
+
lr_scheduler.step(epoch)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
###### below is for the training ###########
|
| 379 |
+
if args.challenge_scenario == 2021:
|
| 380 |
+
train_auprc, loss_value_after_each_epcoh, thrs, scores_F1, scores_SubsetAccuracy, scores_HammingLoss = engine_ecg_2021.train_one_epoch(
|
| 381 |
+
model, criterion, data_loader_train,
|
| 382 |
+
optimizer, device, epoch,
|
| 383 |
+
set_training_mode=args.train_mode, # keep in eval mode for deit finetuning / train mode for training and deit III finetuning
|
| 384 |
+
args=args,
|
| 385 |
+
)
|
| 386 |
+
elif args.challenge_scenario == 2020:
|
| 387 |
+
train_auprc, loss_value_after_each_epcoh, thrs, scores_F1, scores_SubsetAccuracy, scores_HammingLoss = engine_ecg_2020.train_one_epoch(
|
| 388 |
+
model, criterion, data_loader_train,
|
| 389 |
+
optimizer, device, epoch,
|
| 390 |
+
set_training_mode=args.train_mode, # keep in eval mode for deit finetuning / train mode for training and deit III finetuning
|
| 391 |
+
args=args,
|
| 392 |
+
)
|
| 393 |
+
else:
|
| 394 |
+
raise ValueError(f"No matching condition for value: {args.challenge_scenario}")
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
train_auprc_list.append(round(train_auprc, 4))
|
| 398 |
+
loss_value_after_each_epcoh_list.append(round(loss_value_after_each_epcoh, 4))
|
| 399 |
+
print("**************************************************************")
|
| 400 |
+
print("This is the list of Training of AUPRC:", train_auprc_list)
|
| 401 |
+
train_log_fp.write(f"This is the list of Training of AUPRC: {train_auprc_list} \n")
|
| 402 |
+
print("This is the list of Training of loss:", loss_value_after_each_epcoh_list)
|
| 403 |
+
train_log_fp.write(f"This is the list of Training of loss: {loss_value_after_each_epcoh_list} \n")
|
| 404 |
+
print("**************************************************************")
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
###### below is for the evaluation ###########
|
| 408 |
+
|
| 409 |
+
if args.challenge_scenario == 2021:
|
| 410 |
+
AUPRC, auroc, f1, hamming, subset_accuracy, challenge_score, loss_value_after_each_epoch_testing = engine_ecg_2021.evaluate(data_loader_val, model, thrs, scores_F1, scores_SubsetAccuracy, scores_HammingLoss, device)
|
| 411 |
+
elif args.challenge_scenario == 2020:
|
| 412 |
+
AUPRC, auroc, f1, hamming, subset_accuracy, challenge_score, loss_value_after_each_epoch_testing = engine_ecg_2020.evaluate(data_loader_val, model, thrs, scores_F1, scores_SubsetAccuracy, scores_HammingLoss, device)
|
| 413 |
+
else:
|
| 414 |
+
raise ValueError(f"No matching condition for value: {args.challenge_scenario}")
|
| 415 |
+
|
| 416 |
+
# print(f"Accuracy of the network on the {len(test_dataset)} test images: {AUPRC:.4f}")
|
| 417 |
+
|
| 418 |
+
testing_AUPRC_list.append(round(AUPRC, 4))
|
| 419 |
+
loss_value_testing_list.append(round(loss_value_after_each_epoch_testing, 4))
|
| 420 |
+
testing_auroc_list.append(round(auroc, 4))
|
| 421 |
+
testing_f1_list.append(round(f1, 4))
|
| 422 |
+
testing_subset_accuracy_list.append((round(subset_accuracy, 4)))
|
| 423 |
+
hamming_loss_list.append((round(hamming, 4)))
|
| 424 |
+
challenge_score_list.append((round(challenge_score, 4)))
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
# loss_value_after_each_epcoh_testing_list.append(loss)
|
| 428 |
+
print("**********************************************************")
|
| 429 |
+
print("This is the list of Testing AUPRC:", testing_AUPRC_list)
|
| 430 |
+
print("This is the list of Testing loss:", loss_value_testing_list)
|
| 431 |
+
print("This is the list of Testing auroc:", testing_auroc_list)
|
| 432 |
+
print("This is the list of Testing f1:", testing_f1_list)
|
| 433 |
+
print("This is the list of subset accuracy:", testing_subset_accuracy_list)
|
| 434 |
+
print("This is the list of hamming_loss:", hamming_loss_list)
|
| 435 |
+
print("This is the list of challenge_score:", challenge_score_list)
|
| 436 |
+
|
| 437 |
+
train_log_fp.write("---------------------------------------------------------- \n")
|
| 438 |
+
train_log_fp.write(f"This is the list of Testing AUPRC: {testing_AUPRC_list} \n")
|
| 439 |
+
train_log_fp.write(f"This is the list of Testing loss: {loss_value_testing_list} \n")
|
| 440 |
+
train_log_fp.write(f"This is the list of Testing auroc: {testing_auroc_list} \n")
|
| 441 |
+
train_log_fp.write(f"This is the list of Testing f1: {testing_f1_list} \n")
|
| 442 |
+
train_log_fp.write(f"This is the list of Testing subset accuracy: {testing_subset_accuracy_list} \n")
|
| 443 |
+
train_log_fp.write(f"This is the list of Testing hamming_loss: {hamming_loss_list} \n")
|
| 444 |
+
train_log_fp.write(f"This is the list of challenge_score: {challenge_score_list} \n")
|
| 445 |
+
print("**********************************************************")
|
| 446 |
+
|
| 447 |
+
# here I have to put code regarding the auprc
|
| 448 |
+
if max_AUPRC < AUPRC:
|
| 449 |
+
max_AUPRC = AUPRC
|
| 450 |
+
jump_count = 0
|
| 451 |
+
if args.output_dir:
|
| 452 |
+
checkpoint_paths = [output_dir / 'best_auprc_checkpoint.pth']
|
| 453 |
+
for checkpoint_path in checkpoint_paths:
|
| 454 |
+
utils.save_on_master({
|
| 455 |
+
'model': model.state_dict(),
|
| 456 |
+
'optimizer': optimizer.optimizer.state_dict(),
|
| 457 |
+
'epoch': epoch,
|
| 458 |
+
'args': args,
|
| 459 |
+
}, checkpoint_path)
|
| 460 |
+
else:
|
| 461 |
+
jump_count = jump_count + 1
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
print(f'This is the Max AUPRC: {max_AUPRC:.4f}')
|
| 465 |
+
train_log_fp.write(f"This is the Max AUPRC: {max_AUPRC:.4f} \n")
|
| 466 |
+
|
| 467 |
+
if jump_count > 4:
|
| 468 |
+
print("This experimental will be finished at epoch:", epoch)
|
| 469 |
+
train_log_fp.write(f"This experimental will be finished at epoch: {epoch} \n")
|
| 470 |
+
break
|
| 471 |
+
|
| 472 |
+
total_time = time.time() - start_time
|
| 473 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 474 |
+
print('Training time {}'.format(total_time_str))
|
| 475 |
+
|
| 476 |
+
train_log_fp.write('Training time {}\n'.format(total_time_str))
|
| 477 |
+
train_log_fp.close()
|
| 478 |
+
old_txt_file_name = args.output_dir + '/train_log_group_%d.txt' % group_number
|
| 479 |
+
new_txt_file_name = args.output_dir + '/train_log_group_%d_MAX_AUPRC_%.4f_.txt' % (group_number, max_AUPRC)
|
| 480 |
+
os.rename(old_txt_file_name, new_txt_file_name)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
if __name__ == '__main__':
|
| 484 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script, but this is for the multiple classification ECG.', parents=[get_args_parser()])
|
| 485 |
+
args = parser.parse_args()
|
| 486 |
+
|
| 487 |
+
if args.challenge_scenario == 2021:
|
| 488 |
+
data_directory = "./csv-file_2021_challenge/training_validation_testing/group"
|
| 489 |
+
elif args.challenge_scenario == 2020:
|
| 490 |
+
data_directory = "./csv-file_2020_challenge/training_validation_testing/group"
|
| 491 |
+
else:
|
| 492 |
+
raise ValueError(f"No matching condition for value: {args.challenge_scenario}")
|
| 493 |
+
|
| 494 |
+
print("This is the scenario:", args.challenge_scenario)
|
| 495 |
+
print("This is the scenario:", args.challenge_scenario)
|
| 496 |
+
|
| 497 |
+
model_directory = "./model/model_group"
|
| 498 |
+
|
| 499 |
+
for i in range(1,6):
|
| 500 |
+
data_directory_x = data_directory + str(i)
|
| 501 |
+
model_directory_x = model_directory +str(i)
|
| 502 |
+
|
| 503 |
+
args.output_dir = f"./output/Scenario_{args.challenge_scenario}_{args.block}_depth_{args.depth}_{args.lead}_batchSize_{args.batch_size}_CutMix(random)_{args.cutmix}_MixUp_{args.mixup}_JTEHM_group{i}"
|
| 504 |
+
|
| 505 |
+
if args.output_dir:
|
| 506 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 507 |
+
|
| 508 |
+
print("This is the group", i)
|
| 509 |
+
main(args, data_directory_x, model_directory_x, i)
|
models_mamba_ecg.py
ADDED
|
@@ -0,0 +1,1013 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as functional
|
| 6 |
+
from functools import partial
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
from timm.models.vision_transformer import VisionTransformer, _cfg
|
| 11 |
+
# from timm.models.registry import register_model
|
| 12 |
+
# from timm.models.layers import trunc_normal_, lecun_normal_
|
| 13 |
+
|
| 14 |
+
from timm.models import register_model
|
| 15 |
+
from timm.layers import trunc_normal_, lecun_normal_
|
| 16 |
+
|
| 17 |
+
from timm.layers import DropPath, to_2tuple
|
| 18 |
+
|
| 19 |
+
# from timm.models.layers import DropPath, to_2tuple
|
| 20 |
+
from timm.models.vision_transformer import _load_weights
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
|
| 24 |
+
from collections import namedtuple
|
| 25 |
+
|
| 26 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
| 27 |
+
from mamba_ssm.utils.generation import GenerationMixin
|
| 28 |
+
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
| 29 |
+
|
| 30 |
+
from rope import *
|
| 31 |
+
import random
|
| 32 |
+
import sys
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 36 |
+
except ImportError:
|
| 37 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# layer_norm_fn and rms_norm_fn both are normalization method
|
| 41 |
+
|
| 42 |
+
__all__ = [
|
| 43 |
+
'vim_tiny_patch16_224', 'vim_small_patch16_224', 'vim_base_patch16_224',
|
| 44 |
+
'vim_tiny_patch16_384', 'vim_small_patch16_384', 'vim_base_patch16_384',
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
'''
|
| 49 |
+
in the original script(ft-vim-s.sh)
|
| 50 |
+
img_size = 224,
|
| 51 |
+
patch_size = 16,
|
| 52 |
+
stride = 8,
|
| 53 |
+
in_chans = 3,
|
| 54 |
+
embed_dim = 768
|
| 55 |
+
------------------------------------
|
| 56 |
+
self.img_size: (224, 224)
|
| 57 |
+
self.patch_size: (16, 16)
|
| 58 |
+
self.grid_size: (27, 27)
|
| 59 |
+
self.num_patches: 729
|
| 60 |
+
self.flatten: True
|
| 61 |
+
self.norm: nn.Identity()
|
| 62 |
+
'''
|
| 63 |
+
class PatchEmbed(nn.Module):
|
| 64 |
+
""" 2D Image to Patch Embedding
|
| 65 |
+
"""
|
| 66 |
+
def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
|
| 67 |
+
super().__init__()
|
| 68 |
+
img_size = to_2tuple(img_size)
|
| 69 |
+
patch_size = to_2tuple(patch_size)
|
| 70 |
+
self.img_size = img_size
|
| 71 |
+
self.patch_size = patch_size
|
| 72 |
+
self.grid_size = ((img_size[0] - patch_size[0]) // stride + 1, (img_size[1] - patch_size[1]) // stride + 1)
|
| 73 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 74 |
+
self.flatten = flatten
|
| 75 |
+
|
| 76 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)
|
| 77 |
+
# if the norm_layer is not none or null, the self.norm = norm_layer(embed_dim)
|
| 78 |
+
# otherwise, self.norm = nn.Identity()
|
| 79 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
B, C, H, W = x.shape
|
| 84 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
| 85 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 86 |
+
|
| 87 |
+
x = self.proj(x)
|
| 88 |
+
print("This is the shape after the CNN", x.shape)
|
| 89 |
+
|
| 90 |
+
if self.flatten:
|
| 91 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
| 92 |
+
print("This is the shape after the flatten:", x.shape)
|
| 93 |
+
|
| 94 |
+
x = self.norm(x)
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
class PatchEmbed_spectrogram(nn.Module):
|
| 98 |
+
""" 2D spectrogram to Patch Embedding
|
| 99 |
+
"""
|
| 100 |
+
def __init__(self, img_size_f = 128, img_size_t = 64, patch_size=6, stride=3, in_chans=12, embed_dim=432, flatten=True):
|
| 101 |
+
super().__init__()
|
| 102 |
+
# img_size = to_2tuple(img_size)
|
| 103 |
+
patch_size = to_2tuple(patch_size)
|
| 104 |
+
|
| 105 |
+
self.img_size_f = img_size_f
|
| 106 |
+
self.img_size_t = img_size_t
|
| 107 |
+
|
| 108 |
+
self.patch_size = patch_size
|
| 109 |
+
self.grid_size = ((img_size_f - patch_size[0]) // stride + 1, (img_size_t - patch_size[1]) // stride + 1)
|
| 110 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 111 |
+
self.flatten = flatten
|
| 112 |
+
|
| 113 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)
|
| 114 |
+
# if the norm_layer is not none or null, the self.norm = norm_layer(embed_dim)
|
| 115 |
+
# otherwise, self.norm = nn.Identity()
|
| 116 |
+
# self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
B, C, H, W = x.shape
|
| 121 |
+
assert H == self.img_size_f and W == self.img_size_t, \
|
| 122 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size_f}*{self.img_size_t})."
|
| 123 |
+
|
| 124 |
+
x = self.proj(x)
|
| 125 |
+
|
| 126 |
+
# This is the shape after the CNN torch.Size([1, 432, 41, 20])
|
| 127 |
+
# print("This is the shape after the CNN", x.shape)
|
| 128 |
+
if self.flatten:
|
| 129 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
| 130 |
+
# This is the shape after the flatten: torch.Size([1, 820, 432])
|
| 131 |
+
# print("This is the shape after the flatten:", x.shape)
|
| 132 |
+
return x
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class CNN_layers(nn.Module):
|
| 136 |
+
|
| 137 |
+
def __init__(self, embed_size = 384):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.multiple_cnn = nn.Sequential(
|
| 140 |
+
nn.Conv1d(12, 128, kernel_size=14, stride=3, padding=2, bias=False),
|
| 141 |
+
nn.BatchNorm1d(128),
|
| 142 |
+
nn.ReLU(inplace=True),
|
| 143 |
+
|
| 144 |
+
nn.Conv1d(128, embed_size, kernel_size=15, stride=4, padding=100, bias=False),
|
| 145 |
+
nn.BatchNorm1d(embed_size),
|
| 146 |
+
nn.ReLU(inplace=True))
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
# print("This is the shape of enconder:(before)", x.shape)
|
| 150 |
+
# This is the shape of enconder:(before) torch.Size([44, 12, 8192])
|
| 151 |
+
x = self.multiple_cnn(x)
|
| 152 |
+
x = x.transpose(1, 2)
|
| 153 |
+
# print("This is the shape of enconder:(after)", x.shape)
|
| 154 |
+
# This is the shape of enconder:(after) torch.Size([44, 729, 384])
|
| 155 |
+
|
| 156 |
+
# sys.exit()
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
class CNN_layers_shortcut(nn.Module):
|
| 160 |
+
|
| 161 |
+
def __init__(self, embed_size = 384):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.multiple_cnn1 = nn.Sequential(
|
| 164 |
+
nn.Conv1d(12, 128, kernel_size=14, stride=3, padding=2, bias=False),
|
| 165 |
+
nn.BatchNorm1d(128),
|
| 166 |
+
nn.ReLU(inplace=True),
|
| 167 |
+
|
| 168 |
+
nn.Conv1d(128, embed_size, kernel_size=15, stride=4, padding=100, bias=False),
|
| 169 |
+
nn.BatchNorm1d(embed_size),
|
| 170 |
+
nn.ReLU(inplace=True))
|
| 171 |
+
|
| 172 |
+
self.multiple_cnn2 = nn.Sequential(
|
| 173 |
+
nn.Conv1d(in_channels=embed_size, out_channels=embed_size, kernel_size=3, stride=1, padding=1),
|
| 174 |
+
nn.BatchNorm1d(embed_size),
|
| 175 |
+
nn.ReLU(inplace=True),
|
| 176 |
+
|
| 177 |
+
nn.Conv1d(in_channels=embed_size, out_channels=embed_size, kernel_size=3, stride=1, padding=1),
|
| 178 |
+
nn.BatchNorm1d(embed_size),
|
| 179 |
+
nn.ReLU(inplace=True)
|
| 180 |
+
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def forward(self, x):
|
| 184 |
+
# print("This is the shape of enconder:(before)", x.shape)
|
| 185 |
+
# This is the shape of enconder:(before) torch.Size([44, 12, 8192])
|
| 186 |
+
x = self.multiple_cnn1(x)
|
| 187 |
+
|
| 188 |
+
shortcut = x
|
| 189 |
+
|
| 190 |
+
x = self.multiple_cnn2(x)
|
| 191 |
+
x = x + shortcut
|
| 192 |
+
|
| 193 |
+
x = x.transpose(1, 2)
|
| 194 |
+
# print("This is the shape of enconder:(after)", x.shape)
|
| 195 |
+
# This is the shape of enconder:(after) torch.Size([44, 729, 384])
|
| 196 |
+
# sys.exit()
|
| 197 |
+
return x
|
| 198 |
+
|
| 199 |
+
# class ECG_Patch_embedding(nn.Module):
|
| 200 |
+
|
| 201 |
+
# def __init__(self, embed_size = 384):
|
| 202 |
+
# super().__init__()
|
| 203 |
+
# self.multiple_cnn = nn.Sequential(
|
| 204 |
+
# nn.Conv1d(12, embed_size, kernel_size=16, stride=8),
|
| 205 |
+
# )
|
| 206 |
+
|
| 207 |
+
# def forward(self, x):
|
| 208 |
+
# # print("This is the shape of enconder:(before)", x.shape)
|
| 209 |
+
# # This is the shape of enconder:(before) torch.Size([44, 12, 8192])
|
| 210 |
+
# x = self.multiple_cnn(x)
|
| 211 |
+
# x = x.transpose(1, 2)
|
| 212 |
+
# # print("This is the shape of enconder:(after)", x.shape)
|
| 213 |
+
# # This is the shape of enconder:(after) torch.Size([44, 1023, 384])
|
| 214 |
+
|
| 215 |
+
# # sys.exit()
|
| 216 |
+
# return x
|
| 217 |
+
|
| 218 |
+
# class LeadCombiner(nn.Module):
|
| 219 |
+
# def __init__(self, lead, out_ch):
|
| 220 |
+
# super(LeadCombiner, self).__init__()
|
| 221 |
+
# self.conv2_1 = nn.Conv1d(in_channels=lead * out_ch,
|
| 222 |
+
# out_channels=out_ch,
|
| 223 |
+
# kernel_size=1,
|
| 224 |
+
# bias=False)
|
| 225 |
+
# self.bn2_1 = nn.BatchNorm1d(out_ch)
|
| 226 |
+
|
| 227 |
+
# self.conv2_2 = nn.Conv1d(in_channels=lead * out_ch,
|
| 228 |
+
# out_channels=out_ch,
|
| 229 |
+
# kernel_size=1,
|
| 230 |
+
# bias=False)
|
| 231 |
+
# self.bn2_2 = nn.BatchNorm1d(out_ch)
|
| 232 |
+
|
| 233 |
+
# self.pool1 = nn.AdaptiveMaxPool1d(output_size=1)
|
| 234 |
+
# self.pool2 = nn.AdaptiveMaxPool1d(output_size=1)
|
| 235 |
+
|
| 236 |
+
# def forward(self, x):
|
| 237 |
+
# # this is the shape of x: torch.Size([78, 128, 12, 128]
|
| 238 |
+
# x1 = rearrange(x, 'b c l t -> b (c l) t')
|
| 239 |
+
# x1 = functional.leaky_relu(self.bn2_1(self.conv2_1(x1)))
|
| 240 |
+
|
| 241 |
+
# x2 = rearrange(x, 'b c l t -> b (t l) c')
|
| 242 |
+
# x2 = functional.leaky_relu(self.bn2_2(self.conv2_2(x2)))
|
| 243 |
+
|
| 244 |
+
# x1 = functional.dropout(x1, p=0.5, training=self.training)
|
| 245 |
+
# x2 = functional.dropout(x2, p=0.5, training=self.training)
|
| 246 |
+
|
| 247 |
+
# x1 = self.pool1(x1).squeeze(2)
|
| 248 |
+
# x2 = self.pool2(x2).squeeze(2)
|
| 249 |
+
|
| 250 |
+
# x = torch.cat([x1, x2], dim=1)
|
| 251 |
+
# return x
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# '''
|
| 255 |
+
# changed the stride and added one more layer
|
| 256 |
+
# '''
|
| 257 |
+
# class CNN_layers(nn.Module):
|
| 258 |
+
|
| 259 |
+
# def __init__(self, embed_size = 384):
|
| 260 |
+
# super().__init__()
|
| 261 |
+
# self.multiple_cnn = nn.Sequential(
|
| 262 |
+
# nn.Conv1d(12, 128, kernel_size=15, stride=2, padding=2, bias=False),
|
| 263 |
+
# nn.BatchNorm1d(128),
|
| 264 |
+
# nn.ReLU(inplace=True),
|
| 265 |
+
|
| 266 |
+
# nn.Conv1d(128, embed_size, kernel_size=15, stride=2, padding=100, bias=False),
|
| 267 |
+
# nn.BatchNorm1d(embed_size),
|
| 268 |
+
# nn.ReLU(inplace=True),
|
| 269 |
+
|
| 270 |
+
# nn.Conv1d(384, embed_size, kernel_size=15, stride=3, padding=100, bias=False),
|
| 271 |
+
# nn.BatchNorm1d(embed_size),
|
| 272 |
+
# nn.ReLU(inplace=True)
|
| 273 |
+
# )
|
| 274 |
+
|
| 275 |
+
# def forward(self, x):
|
| 276 |
+
# # print("This is the shape of enconder:(before)", x.shape)
|
| 277 |
+
# # This is the shape of enconder:(before) torch.Size([44, 12, 8192])
|
| 278 |
+
# x = self.multiple_cnn(x)
|
| 279 |
+
# x = x.transpose(1, 2)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# return x
|
| 283 |
+
|
| 284 |
+
# class D2_CNN_layers(nn.Module):
|
| 285 |
+
# def __init__(self, embed_size = 384):
|
| 286 |
+
# super().__init__()
|
| 287 |
+
# self.multiple_cnn = nn.Sequential(
|
| 288 |
+
# nn.Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 15), padding=(1, 7), stride=(1, 2), bias=False),
|
| 289 |
+
# nn.BatchNorm2d(128),
|
| 290 |
+
# nn.LeakyReLU(inplace=True),
|
| 291 |
+
|
| 292 |
+
# nn.Conv2d(in_channels=128, out_channels=384, kernel_size=(3, 15), padding=(1, 7), stride=(1, 2), bias=False),
|
| 293 |
+
# nn.BatchNorm2d(embed_size),
|
| 294 |
+
# nn.LeakyReLU(inplace=True),
|
| 295 |
+
|
| 296 |
+
# nn.Conv2d(in_channels=384, out_channels=384, kernel_size=(3, 15), padding=(1, 7), stride=(1, 30), bias=False),
|
| 297 |
+
# nn.BatchNorm2d(embed_size),
|
| 298 |
+
# nn.LeakyReLU(inplace=True),
|
| 299 |
+
|
| 300 |
+
# )
|
| 301 |
+
|
| 302 |
+
# def forward(self, x):
|
| 303 |
+
# x = x.unsqueeze(1)
|
| 304 |
+
# # print("This is the input shape:", x.shape)
|
| 305 |
+
# x = self.multiple_cnn(x)
|
| 306 |
+
# x = torch.flatten(x, start_dim=-2)
|
| 307 |
+
# x = x.transpose(1, 2)
|
| 308 |
+
# return x
|
| 309 |
+
|
| 310 |
+
def broadcat(tensors, dim=-1):
|
| 311 |
+
num_tensors = len(tensors)
|
| 312 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
| 313 |
+
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
|
| 314 |
+
shape_len = list(shape_lens)[0]
|
| 315 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 316 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
| 317 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 318 |
+
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatenation'
|
| 319 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
| 320 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
| 321 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 322 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
| 323 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
| 324 |
+
return torch.cat(tensors, dim=dim)
|
| 325 |
+
|
| 326 |
+
def rotate_half(x):
|
| 327 |
+
x = rearrange(x, '... (d r) -> ... d r', r=2)
|
| 328 |
+
x1, x2 = x.unbind(dim=-1)
|
| 329 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 330 |
+
return rearrange(x, '... d r -> ... (d r)')
|
| 331 |
+
|
| 332 |
+
# Adapted Rotary Embedding for 1D time-series (ECG) with 1024 tokens
|
| 333 |
+
class TimeSeriesRotaryEmbeddingFast(nn.Module):
|
| 334 |
+
def __init__(
|
| 335 |
+
self,
|
| 336 |
+
dim,
|
| 337 |
+
seq_len=1024, # Updated to 1024 tokens
|
| 338 |
+
custom_freqs=None,
|
| 339 |
+
freqs_for='lang', # Suitable for 1D sequential data
|
| 340 |
+
theta=10000,
|
| 341 |
+
max_freq=10,
|
| 342 |
+
num_freqs=1,
|
| 343 |
+
):
|
| 344 |
+
super().__init__()
|
| 345 |
+
if custom_freqs:
|
| 346 |
+
freqs = custom_freqs
|
| 347 |
+
elif freqs_for == 'lang':
|
| 348 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 349 |
+
elif freqs_for == 'pixel':
|
| 350 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 351 |
+
elif freqs_for == 'constant':
|
| 352 |
+
freqs = torch.ones(num_freqs).float()
|
| 353 |
+
else:
|
| 354 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
| 355 |
+
|
| 356 |
+
# 1D sequence for 1024 tokens
|
| 357 |
+
t = torch.arange(seq_len).float() # [0, 1, ..., 1023]
|
| 358 |
+
|
| 359 |
+
freqs = torch.einsum('..., f -> ... f', t, freqs)
|
| 360 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r=2) # Doubles the dimension for rotation
|
| 361 |
+
|
| 362 |
+
# Shape: (1024, dim)
|
| 363 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
| 364 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
| 365 |
+
|
| 366 |
+
self.register_buffer("freqs_cos", freqs_cos) # Shape: (1024, dim)
|
| 367 |
+
self.register_buffer("freqs_sin", freqs_sin) # Shape: (1024, dim)
|
| 368 |
+
|
| 369 |
+
print('======== shape of rope freq', self.freqs_cos.shape, '========')
|
| 370 |
+
|
| 371 |
+
def forward(self, t):
|
| 372 |
+
# t: (batch_size, seq_len, embed_dim)
|
| 373 |
+
# Apply rotation to the entire sequence
|
| 374 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
'''
|
| 379 |
+
dim: 384
|
| 380 |
+
mixer_cls: mixer_cla is an instance of mamba.
|
| 381 |
+
drop_path = 0.
|
| 382 |
+
norm_cls = nn.LayerNorm
|
| 383 |
+
fused_add_norm = True,
|
| 384 |
+
residual_in_fp32 = True
|
| 385 |
+
'''
|
| 386 |
+
class Block(nn.Module):
|
| 387 |
+
def __init__(
|
| 388 |
+
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False, drop_path=0.,
|
| 389 |
+
):
|
| 390 |
+
"""
|
| 391 |
+
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
| 392 |
+
|
| 393 |
+
This Block has a slightly different structure compared to a regular
|
| 394 |
+
prenorm Transformer block.
|
| 395 |
+
The standard block is: LN -> MHA/MLP -> Add.
|
| 396 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
| 397 |
+
Here we have: Add -> LN -> Mixer, returning both
|
| 398 |
+
the hidden_states (output of the mixer) and the residual.
|
| 399 |
+
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
| 400 |
+
The residual needs to be provided (except for the very first block).
|
| 401 |
+
"""
|
| 402 |
+
super().__init__()
|
| 403 |
+
|
| 404 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 405 |
+
self.fused_add_norm = fused_add_norm
|
| 406 |
+
self.mixer = mixer_cls(dim)
|
| 407 |
+
self.norm = norm_cls(dim)
|
| 408 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 409 |
+
|
| 410 |
+
# fused_add_norm true
|
| 411 |
+
if self.fused_add_norm:
|
| 412 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
| 413 |
+
assert isinstance(self.norm, (nn.LayerNorm, RMSNorm)), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
| 414 |
+
|
| 415 |
+
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):
|
| 416 |
+
|
| 417 |
+
r"""Pass the input through the encoder layer.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
hidden_states: the sequence to the encoder layer (required).
|
| 421 |
+
residual: hidden_states = Mixer(LN(residual))
|
| 422 |
+
"""
|
| 423 |
+
|
| 424 |
+
if not self.fused_add_norm:
|
| 425 |
+
if residual is None:
|
| 426 |
+
residual = hidden_states
|
| 427 |
+
else:
|
| 428 |
+
residual = residual + self.drop_path(hidden_states)
|
| 429 |
+
|
| 430 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
| 431 |
+
if self.residual_in_fp32:
|
| 432 |
+
residual = residual.to(torch.float32)
|
| 433 |
+
|
| 434 |
+
# since the self.fused_add_norm is true, the code is going below
|
| 435 |
+
# fused_add_norm_fn = layer_norm_fn
|
| 436 |
+
###########
|
| 437 |
+
# hidden_states: Tensor
|
| 438 |
+
# self.norm.weight = torch.nn.LayerNorm.weight
|
| 439 |
+
# self.norm.bias = torch.nn.LayerNorm.bias
|
| 440 |
+
# residual: Optional[Tensor] = None
|
| 441 |
+
# prenorm=True
|
| 442 |
+
# residual_in_fp32 = True
|
| 443 |
+
# eps = torch.nn.LayerNorm.eps
|
| 444 |
+
|
| 445 |
+
else:
|
| 446 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
| 447 |
+
if residual is None:
|
| 448 |
+
hidden_states, residual = fused_add_norm_fn(
|
| 449 |
+
hidden_states,
|
| 450 |
+
self.norm.weight,
|
| 451 |
+
self.norm.bias,
|
| 452 |
+
residual=residual,
|
| 453 |
+
prenorm=True,
|
| 454 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 455 |
+
eps=self.norm.eps,
|
| 456 |
+
)
|
| 457 |
+
else:
|
| 458 |
+
hidden_states, residual = fused_add_norm_fn(
|
| 459 |
+
self.drop_path(hidden_states),
|
| 460 |
+
self.norm.weight,
|
| 461 |
+
self.norm.bias,
|
| 462 |
+
residual=residual,
|
| 463 |
+
prenorm=True,
|
| 464 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 465 |
+
eps=self.norm.eps,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
# inference_params=None
|
| 469 |
+
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
| 470 |
+
return hidden_states, residual
|
| 471 |
+
|
| 472 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 473 |
+
print("the code is going through allocate_inference_cache in the block")
|
| 474 |
+
|
| 475 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 476 |
+
|
| 477 |
+
'''
|
| 478 |
+
from torch.nn.ModuleList()
|
| 479 |
+
|
| 480 |
+
embed_dim=384, (embedding dimension)
|
| 481 |
+
device: None
|
| 482 |
+
dtype: None
|
| 483 |
+
ssm_cfg: None
|
| 484 |
+
norm_epsilon: float = 1e-5
|
| 485 |
+
rms_norm: bool = False
|
| 486 |
+
residual_in_fp32 = True
|
| 487 |
+
fused_add_norm = True
|
| 488 |
+
if_bimamba = False
|
| 489 |
+
bimamba_type = "v2"
|
| 490 |
+
inter_dpr: [0.0, 0.0, 0.004347826354205608, ...,0.09565217792987823, 0.10000000149011612]
|
| 491 |
+
if_devide_out = True
|
| 492 |
+
init_layer_scale = None
|
| 493 |
+
layer_idx = i
|
| 494 |
+
'''
|
| 495 |
+
|
| 496 |
+
def create_block(
|
| 497 |
+
d_model,
|
| 498 |
+
ssm_cfg=None,
|
| 499 |
+
norm_epsilon=1e-5,
|
| 500 |
+
drop_path=0.,
|
| 501 |
+
rms_norm=False,
|
| 502 |
+
residual_in_fp32=False,
|
| 503 |
+
fused_add_norm=False,
|
| 504 |
+
layer_idx=None,
|
| 505 |
+
device=None,
|
| 506 |
+
dtype=None,
|
| 507 |
+
if_bimamba=False,
|
| 508 |
+
bimamba_type="none",
|
| 509 |
+
if_devide_out=False,
|
| 510 |
+
init_layer_scale=None,
|
| 511 |
+
block_name = "default_value"
|
| 512 |
+
):
|
| 513 |
+
if if_bimamba:
|
| 514 |
+
bimamba_type = "v1"
|
| 515 |
+
|
| 516 |
+
if ssm_cfg is None:
|
| 517 |
+
ssm_cfg = {}
|
| 518 |
+
|
| 519 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 520 |
+
|
| 521 |
+
if block_name == "VisionMamba":
|
| 522 |
+
mixer_cls = partial(Mamba, layer_idx=layer_idx, bimamba_type=bimamba_type, if_devide_out=if_devide_out, init_layer_scale=init_layer_scale, **ssm_cfg, **factory_kwargs)
|
| 523 |
+
elif block_name == "OriginalMamba":
|
| 524 |
+
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
|
| 525 |
+
else:
|
| 526 |
+
raise ValueError(f"No matching condition for value: {block_name}")
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
# rms_norm = False
|
| 531 |
+
norm_cls = partial(nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs)
|
| 532 |
+
|
| 533 |
+
block = Block(
|
| 534 |
+
d_model,
|
| 535 |
+
mixer_cls,
|
| 536 |
+
norm_cls=norm_cls,
|
| 537 |
+
drop_path=drop_path,
|
| 538 |
+
fused_add_norm=fused_add_norm,
|
| 539 |
+
residual_in_fp32=residual_in_fp32,
|
| 540 |
+
)
|
| 541 |
+
block.layer_idx = layer_idx
|
| 542 |
+
return block
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
| 546 |
+
def _init_weights(
|
| 547 |
+
module,
|
| 548 |
+
n_layer,
|
| 549 |
+
initializer_range=0.02, # Now only used for embedding layer.
|
| 550 |
+
rescale_prenorm_residual=True,
|
| 551 |
+
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
| 552 |
+
):
|
| 553 |
+
if isinstance(module, nn.Linear):
|
| 554 |
+
if module.bias is not None:
|
| 555 |
+
if not getattr(module.bias, "_no_reinit", False):
|
| 556 |
+
nn.init.zeros_(module.bias)
|
| 557 |
+
elif isinstance(module, nn.Embedding):
|
| 558 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
| 559 |
+
|
| 560 |
+
if rescale_prenorm_residual:
|
| 561 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
| 562 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
| 563 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
| 564 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
| 565 |
+
#
|
| 566 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
| 567 |
+
for name, p in module.named_parameters():
|
| 568 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
| 569 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
| 570 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
| 571 |
+
# We need to reinit p since this code could be called multiple times
|
| 572 |
+
# Having just p *= scale would repeatedly scale it down
|
| 573 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
| 574 |
+
with torch.no_grad():
|
| 575 |
+
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def segm_init_weights(m):
|
| 579 |
+
if isinstance(m, nn.Linear):
|
| 580 |
+
trunc_normal_(m.weight, std=0.02)
|
| 581 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 582 |
+
nn.init.constant_(m.bias, 0)
|
| 583 |
+
elif isinstance(m, nn.Conv2d):
|
| 584 |
+
# NOTE conv was left to pytorch default in my original init
|
| 585 |
+
lecun_normal_(m.weight)
|
| 586 |
+
if m.bias is not None:
|
| 587 |
+
nn.init.zeros_(m.bias)
|
| 588 |
+
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
| 589 |
+
nn.init.zeros_(m.bias)
|
| 590 |
+
nn.init.ones_(m.weight)
|
| 591 |
+
|
| 592 |
+
'''
|
| 593 |
+
below is the 'ft-vim-s.sh':
|
| 594 |
+
|
| 595 |
+
patch_size=16, (just patch size)
|
| 596 |
+
stride=8, (just stride)
|
| 597 |
+
embed_dim=384, (embedding dimension)
|
| 598 |
+
depth=24, (?) the number of block
|
| 599 |
+
rms_norm = True, (?)
|
| 600 |
+
residual_in_fp32 = True, (?)
|
| 601 |
+
fused_add_norm = True, (?)
|
| 602 |
+
final_pool_type = 'mean', (?)
|
| 603 |
+
if_abs_pos_embed = True, (?)
|
| 604 |
+
if_rope = False, (?)
|
| 605 |
+
if_rope_residual = False, (?)
|
| 606 |
+
bimamba_type = "v2", (?)
|
| 607 |
+
if_cls_token = True, (?)
|
| 608 |
+
if_devide_out = True, (?)
|
| 609 |
+
use_middle_cls_token = True,
|
| 610 |
+
**kwargs
|
| 611 |
+
'''
|
| 612 |
+
|
| 613 |
+
class VisionMamba(nn.Module):
|
| 614 |
+
def __init__(self,
|
| 615 |
+
img_size=224,
|
| 616 |
+
patch_size=16,
|
| 617 |
+
stride=16,
|
| 618 |
+
depth=24,
|
| 619 |
+
embed_dim=192,
|
| 620 |
+
channels=3,
|
| 621 |
+
num_classes=26,
|
| 622 |
+
ssm_cfg=None,
|
| 623 |
+
drop_rate=0.,
|
| 624 |
+
drop_path_rate=0,
|
| 625 |
+
norm_epsilon: float = 1e-5,
|
| 626 |
+
rms_norm: bool = False,
|
| 627 |
+
initializer_cfg=None,
|
| 628 |
+
fused_add_norm=False,
|
| 629 |
+
residual_in_fp32=False,
|
| 630 |
+
device=None,
|
| 631 |
+
dtype=None,
|
| 632 |
+
ft_seq_len=None,
|
| 633 |
+
pt_hw_seq_len=14,
|
| 634 |
+
if_bidirectional=False,
|
| 635 |
+
final_pool_type='none',
|
| 636 |
+
if_abs_pos_embed=False,
|
| 637 |
+
if_rope=False,
|
| 638 |
+
if_rope_residual=False,
|
| 639 |
+
flip_img_sequences_ratio=-1.,
|
| 640 |
+
if_bimamba=False,
|
| 641 |
+
bimamba_type="none",
|
| 642 |
+
if_cls_token=False,
|
| 643 |
+
if_devide_out=False,
|
| 644 |
+
init_layer_scale=None,
|
| 645 |
+
use_double_cls_token=False,
|
| 646 |
+
use_middle_cls_token=False,
|
| 647 |
+
**kwargs):
|
| 648 |
+
# print("The program is coming the init")
|
| 649 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 650 |
+
# factory_kwargs: {'device': None, 'dtype': None}
|
| 651 |
+
# add factory_kwargs into kwargs
|
| 652 |
+
block_name = kwargs.get('block', 'default_value')
|
| 653 |
+
|
| 654 |
+
kwargs.update(factory_kwargs)
|
| 655 |
+
|
| 656 |
+
super().__init__()
|
| 657 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 658 |
+
self.fused_add_norm = fused_add_norm
|
| 659 |
+
self.if_bidirectional = if_bidirectional
|
| 660 |
+
self.final_pool_type = final_pool_type
|
| 661 |
+
self.if_abs_pos_embed = if_abs_pos_embed
|
| 662 |
+
self.if_rope = if_rope
|
| 663 |
+
self.if_rope_residual = if_rope_residual
|
| 664 |
+
self.flip_img_sequences_ratio = flip_img_sequences_ratio
|
| 665 |
+
self.if_cls_token = if_cls_token
|
| 666 |
+
self.use_double_cls_token = use_double_cls_token
|
| 667 |
+
self.use_middle_cls_token = use_middle_cls_token
|
| 668 |
+
self.num_tokens = 1 if if_cls_token else 0
|
| 669 |
+
|
| 670 |
+
# pretrain parameters
|
| 671 |
+
self.num_classes = num_classes
|
| 672 |
+
self.d_model = self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 673 |
+
|
| 674 |
+
# self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, stride=stride, in_chans=channels, embed_dim=embed_dim)
|
| 675 |
+
# num_patches = self.patch_embed.num_patches
|
| 676 |
+
|
| 677 |
+
self.CNN_layers = CNN_layers()
|
| 678 |
+
num_patches = 729
|
| 679 |
+
|
| 680 |
+
# self.ECG_patch_embedding = ECG_Patch_embedding()
|
| 681 |
+
# num_patches = 1023
|
| 682 |
+
# self.LC = LeadCombiner(lead=6, out_ch=8)
|
| 683 |
+
|
| 684 |
+
# self.CNN_layers = D2_CNN_layers()
|
| 685 |
+
# num_patches = 828
|
| 686 |
+
|
| 687 |
+
# self.CNN_layers = CNN_layers()
|
| 688 |
+
# num_patches = 775
|
| 689 |
+
|
| 690 |
+
# if_cls_token: True (in the original script)
|
| 691 |
+
if if_cls_token:
|
| 692 |
+
# use_double_cls_token: False (in the original script)
|
| 693 |
+
if use_double_cls_token:
|
| 694 |
+
self.cls_token_head = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
| 695 |
+
self.cls_token_tail = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
| 696 |
+
self.num_tokens = 2
|
| 697 |
+
|
| 698 |
+
else:
|
| 699 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
| 700 |
+
# self.num_tokens = 1
|
| 701 |
+
|
| 702 |
+
# if_abs_pos_embed: True (in the original script)
|
| 703 |
+
if if_abs_pos_embed:
|
| 704 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, self.embed_dim))
|
| 705 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 706 |
+
|
| 707 |
+
# if_rope: False (in the original script)
|
| 708 |
+
|
| 709 |
+
if if_rope:
|
| 710 |
+
half_head_dim = embed_dim // 2
|
| 711 |
+
self.rope = TimeSeriesRotaryEmbeddingFast(dim=embed_dim, seq_len= num_patches + self.num_tokens)
|
| 712 |
+
|
| 713 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 714 |
+
# self.head_LC = nn.Linear(16, num_classes)
|
| 715 |
+
# depth: 24; drop_path_rate: 0.1
|
| 716 |
+
# TODO: release this comment
|
| 717 |
+
# dpr: [0.0, 0.004347826354205608, 0.008695652708411217, ..., 0.09130434691905975, 0.09565217792987823, 0.10000000149011612]
|
| 718 |
+
|
| 719 |
+
if drop_path_rate == 0:
|
| 720 |
+
print("This is the drop_path_rate:", drop_path_rate)
|
| 721 |
+
dpr = [x.item() for x in torch.full((depth,), drop_path_rate)]
|
| 722 |
+
else:
|
| 723 |
+
print("This is the drop_path_rate:", drop_path_rate)
|
| 724 |
+
print("follow the stochastic depth decay rule")
|
| 725 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
# import ipdb;ipdb.set_trace()
|
| 729 |
+
# inter_dpr: [0.0, 0.0, 0.004347826354205608, ...,0.09565217792987823, 0.10000000149011612]
|
| 730 |
+
inter_dpr = [0.0] + dpr
|
| 731 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 732 |
+
|
| 733 |
+
# transformer blocks
|
| 734 |
+
# depth:
|
| 735 |
+
self.layers = nn.ModuleList(
|
| 736 |
+
[
|
| 737 |
+
create_block(
|
| 738 |
+
embed_dim,
|
| 739 |
+
ssm_cfg=ssm_cfg,
|
| 740 |
+
norm_epsilon=norm_epsilon,
|
| 741 |
+
rms_norm=rms_norm,
|
| 742 |
+
residual_in_fp32=residual_in_fp32,
|
| 743 |
+
fused_add_norm=fused_add_norm,
|
| 744 |
+
layer_idx=i,
|
| 745 |
+
if_bimamba=if_bimamba,
|
| 746 |
+
bimamba_type=bimamba_type,
|
| 747 |
+
drop_path=inter_dpr[i],
|
| 748 |
+
if_devide_out=if_devide_out,
|
| 749 |
+
init_layer_scale=init_layer_scale,
|
| 750 |
+
block_name = block_name,
|
| 751 |
+
**factory_kwargs,
|
| 752 |
+
)
|
| 753 |
+
for i in range(depth)
|
| 754 |
+
]
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
# output head
|
| 758 |
+
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(embed_dim, eps=norm_epsilon, **factory_kwargs)
|
| 759 |
+
|
| 760 |
+
# self.pre_logits = nn.Identity()
|
| 761 |
+
|
| 762 |
+
# original init
|
| 763 |
+
# self.patch_embed.apply(segm_init_weights)
|
| 764 |
+
|
| 765 |
+
self.CNN_layers.apply(segm_init_weights)
|
| 766 |
+
|
| 767 |
+
# self.ECG_patch_embedding.apply(segm_init_weights)
|
| 768 |
+
|
| 769 |
+
self.head.apply(segm_init_weights)
|
| 770 |
+
|
| 771 |
+
# self.head_LC.apply(segm_init_weights)
|
| 772 |
+
# self.LC.apply(segm_init_weights)
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
# if_abs_pos_embed: True (in the original script)
|
| 776 |
+
if if_abs_pos_embed:
|
| 777 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 778 |
+
# if_cls_token: True (in the original script)
|
| 779 |
+
if if_cls_token:
|
| 780 |
+
if use_double_cls_token:
|
| 781 |
+
trunc_normal_(self.cls_token_head, std=.02)
|
| 782 |
+
trunc_normal_(self.cls_token_tail, std=.02)
|
| 783 |
+
# the code is coming here
|
| 784 |
+
else:
|
| 785 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 786 |
+
|
| 787 |
+
# mamba init
|
| 788 |
+
self.apply(partial(_init_weights, n_layer=depth, **(initializer_cfg if initializer_cfg is not None else {}),))
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 792 |
+
print("the code is going through allocate_inference_cache in the vision mamba")
|
| 793 |
+
return {
|
| 794 |
+
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 795 |
+
for i, layer in enumerate(self.layers)
|
| 796 |
+
}
|
| 797 |
+
|
| 798 |
+
@torch.jit.ignore
|
| 799 |
+
def no_weight_decay(self):
|
| 800 |
+
return {"pos_embed", "cls_token", "dist_token", "cls_token_head", "cls_token_tail"}
|
| 801 |
+
|
| 802 |
+
@torch.jit.ignore()
|
| 803 |
+
def load_pretrained(self, checkpoint_path, prefix=""):
|
| 804 |
+
_load_weights(self, checkpoint_path, prefix)
|
| 805 |
+
|
| 806 |
+
# x: this is the input 224*224(tensor)
|
| 807 |
+
# inference_params: None
|
| 808 |
+
# if_random_cls_token_position:False
|
| 809 |
+
# if_random_token_rank: False
|
| 810 |
+
|
| 811 |
+
def forward_features(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
|
| 812 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 813 |
+
# with slight modifications to add the dist_token
|
| 814 |
+
|
| 815 |
+
# x = self.patch_embed(x)
|
| 816 |
+
|
| 817 |
+
# x = self.CNN_layers(x)
|
| 818 |
+
|
| 819 |
+
x = self.CNN_layers(x)
|
| 820 |
+
|
| 821 |
+
# B: batch size 16
|
| 822 |
+
# M: 729 the number of patch,(27*27)
|
| 823 |
+
# D: the hidden state dimension, 384 (small-size variant), this is set by author of Vim, it is 768 in the convential Vit
|
| 824 |
+
# N: SSM dimension, SSM dimension N to 16.
|
| 825 |
+
# L: the number of blocks, we set the number of blocks L to 24
|
| 826 |
+
|
| 827 |
+
B, M, _ = x.shape
|
| 828 |
+
|
| 829 |
+
# if_cls_token: True (in the original script)
|
| 830 |
+
if self.if_cls_token:
|
| 831 |
+
|
| 832 |
+
# self.use_double_cls_token: False (in the original script)
|
| 833 |
+
if self.use_double_cls_token:
|
| 834 |
+
cls_token_head = self.cls_token_head.expand(B, -1, -1)
|
| 835 |
+
cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
|
| 836 |
+
token_position = [0, M + 1]
|
| 837 |
+
x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
|
| 838 |
+
M = x.shape[1]
|
| 839 |
+
else:
|
| 840 |
+
# self.use_middle_cls_token: True(in the original script)
|
| 841 |
+
if self.use_middle_cls_token:
|
| 842 |
+
cls_token = self.cls_token.expand(B, -1, -1)
|
| 843 |
+
token_position = M // 2
|
| 844 |
+
# add cls token in the middle
|
| 845 |
+
x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
|
| 846 |
+
elif if_random_cls_token_position:
|
| 847 |
+
cls_token = self.cls_token.expand(B, -1, -1)
|
| 848 |
+
token_position = random.randint(0, M)
|
| 849 |
+
x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
|
| 850 |
+
print("token_position: ", token_position)
|
| 851 |
+
else:
|
| 852 |
+
cls_token = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 853 |
+
token_position = 0
|
| 854 |
+
x = torch.cat((cls_token, x), dim=1)
|
| 855 |
+
M = x.shape[1]
|
| 856 |
+
|
| 857 |
+
# # if_abs_pos_embed: True (in the original script)
|
| 858 |
+
if self.if_abs_pos_embed:
|
| 859 |
+
# if new_grid_size[0] == self.patch_embed.grid_size[0] and new_grid_size[1] == self.patch_embed.grid_size[1]:
|
| 860 |
+
# x = x + self.pos_embed
|
| 861 |
+
# else:
|
| 862 |
+
# pos_embed = interpolate_pos_embed_online(
|
| 863 |
+
# self.pos_embed, self.patch_embed.grid_size, new_grid_size,0
|
| 864 |
+
# )
|
| 865 |
+
x = x + self.pos_embed
|
| 866 |
+
x = self.pos_drop(x)
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
if_flip_img_sequences = False
|
| 871 |
+
if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:
|
| 872 |
+
x = x.flip([1])
|
| 873 |
+
if_flip_img_sequences = True
|
| 874 |
+
|
| 875 |
+
# mamba impl
|
| 876 |
+
# if_bidirectional: false
|
| 877 |
+
# inference_params: None
|
| 878 |
+
residual = None
|
| 879 |
+
hidden_states = x
|
| 880 |
+
if not self.if_bidirectional:
|
| 881 |
+
for layer in self.layers:
|
| 882 |
+
|
| 883 |
+
# here is false in the original script
|
| 884 |
+
if if_flip_img_sequences and self.if_rope:
|
| 885 |
+
hidden_states = hidden_states.flip([1])
|
| 886 |
+
if residual is not None:
|
| 887 |
+
residual = residual.flip([1])
|
| 888 |
+
|
| 889 |
+
# rope about, defaule is false
|
| 890 |
+
if self.if_rope:
|
| 891 |
+
hidden_states = self.rope(hidden_states)
|
| 892 |
+
if residual is not None and self.if_rope_residual:
|
| 893 |
+
residual = self.rope(residual)
|
| 894 |
+
|
| 895 |
+
# here is false in the original script
|
| 896 |
+
if if_flip_img_sequences and self.if_rope:
|
| 897 |
+
hidden_states = hidden_states.flip([1])
|
| 898 |
+
if residual is not None:
|
| 899 |
+
residual = residual.flip([1])
|
| 900 |
+
|
| 901 |
+
hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)
|
| 902 |
+
# sys.exit()
|
| 903 |
+
|
| 904 |
+
else:
|
| 905 |
+
# get two layers in a single for-loop
|
| 906 |
+
for i in range(len(self.layers) // 2):
|
| 907 |
+
if self.if_rope:
|
| 908 |
+
hidden_states = self.rope(hidden_states)
|
| 909 |
+
if residual is not None and self.if_rope_residual:
|
| 910 |
+
residual = self.rope(residual)
|
| 911 |
+
|
| 912 |
+
hidden_states_f, residual_f = self.layers[i * 2](
|
| 913 |
+
hidden_states, residual, inference_params=inference_params
|
| 914 |
+
)
|
| 915 |
+
hidden_states_b, residual_b = self.layers[i * 2 + 1](
|
| 916 |
+
hidden_states.flip([1]), None if residual == None else residual.flip([1]), inference_params=inference_params
|
| 917 |
+
)
|
| 918 |
+
hidden_states = hidden_states_f + hidden_states_b.flip([1])
|
| 919 |
+
residual = residual_f + residual_b.flip([1])
|
| 920 |
+
|
| 921 |
+
# fused_add_norm: True
|
| 922 |
+
|
| 923 |
+
if not self.fused_add_norm:
|
| 924 |
+
if residual is None:
|
| 925 |
+
residual = hidden_states
|
| 926 |
+
else:
|
| 927 |
+
residual = residual + self.drop_path(hidden_states)
|
| 928 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
| 929 |
+
else:
|
| 930 |
+
# Set prenorm = False here since we don't need the residual
|
| 931 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
|
| 932 |
+
|
| 933 |
+
hidden_states = fused_add_norm_fn(
|
| 934 |
+
self.drop_path(hidden_states),
|
| 935 |
+
self.norm_f.weight,
|
| 936 |
+
self.norm_f.bias,
|
| 937 |
+
eps=self.norm_f.eps,
|
| 938 |
+
residual=residual,
|
| 939 |
+
prenorm=False,
|
| 940 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
# return only cls token if it exists
|
| 944 |
+
# if_cls_token: True (in the original script)
|
| 945 |
+
# self.use_middle_cls_token: True
|
| 946 |
+
if self.if_cls_token:
|
| 947 |
+
if self.use_double_cls_token:
|
| 948 |
+
return (hidden_states[:, token_position[0], :] + hidden_states[:, token_position[1], :]) / 2
|
| 949 |
+
else:
|
| 950 |
+
if self.use_middle_cls_token:
|
| 951 |
+
return hidden_states[:, token_position, :]
|
| 952 |
+
elif if_random_cls_token_position:
|
| 953 |
+
return hidden_states[:, token_position, :]
|
| 954 |
+
else:
|
| 955 |
+
return hidden_states[:, token_position, :]
|
| 956 |
+
|
| 957 |
+
# self.final_pol_type = 'mean'
|
| 958 |
+
if self.final_pool_type == 'none':
|
| 959 |
+
return hidden_states[:, -1, :]
|
| 960 |
+
elif self.final_pool_type == 'mean':
|
| 961 |
+
return hidden_states.mean(dim=1)
|
| 962 |
+
elif self.final_pool_type == 'max':
|
| 963 |
+
return hidden_states
|
| 964 |
+
elif self.final_pool_type == 'all':
|
| 965 |
+
return hidden_states
|
| 966 |
+
else:
|
| 967 |
+
raise NotImplementedError
|
| 968 |
+
|
| 969 |
+
def forward(self, x, return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
|
| 970 |
+
x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
|
| 971 |
+
# batch_number = x.shape[0]
|
| 972 |
+
# x = x.view(batch_number, 8, 6, 8)
|
| 973 |
+
# x = self.LC(x)
|
| 974 |
+
# x = self.head_LC(x)
|
| 975 |
+
# print("This is the shape of X (after the fully connected layer):", x.shape)
|
| 976 |
+
# return_features = False
|
| 977 |
+
# print("This is the return feature:", return_features)
|
| 978 |
+
# if return_features:
|
| 979 |
+
# return x
|
| 980 |
+
|
| 981 |
+
# print("This is the shape of X (Before the fully connected layer):", x.shape)
|
| 982 |
+
x = self.head(x)
|
| 983 |
+
|
| 984 |
+
# final_pool_type = 'mean' in original script
|
| 985 |
+
if self.final_pool_type == 'max':
|
| 986 |
+
x = x.max(dim=1)[0]
|
| 987 |
+
# sys.exit()
|
| 988 |
+
return x
|
| 989 |
+
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
# below is for the vision in mamba
|
| 993 |
+
@register_model
|
| 994 |
+
def ecg_vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, depth=5, fused_add_norm = True, drop_path_rate = 0.1, if_divide_out = True, use_middle_cls_token = True, **kwargs):
|
| 995 |
+
model = VisionMamba(patch_size=16, stride=8, embed_dim=384, depth=depth, rms_norm=True, residual_in_fp32=True, drop_path_rate = drop_path_rate, fused_add_norm = fused_add_norm, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=if_divide_out, use_middle_cls_token=use_middle_cls_token, **kwargs)
|
| 996 |
+
|
| 997 |
+
# As a reminder:
|
| 998 |
+
print("This is whether the fused_add_norm:", fused_add_norm)
|
| 999 |
+
print("This is whether the if_divide_out:", if_divide_out)
|
| 1000 |
+
print("This is whether the use_middle_cls_token:", use_middle_cls_token)
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
model.default_cfg = _cfg()
|
| 1004 |
+
|
| 1005 |
+
return model
|
| 1006 |
+
|
| 1007 |
+
# below is for the original mamba and 24 blocks
|
| 1008 |
+
# @register_model
|
| 1009 |
+
# def ecg_vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
|
| 1010 |
+
# model = VisionMamba(patch_size=16, stride=8, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
|
| 1011 |
+
# model.default_cfg = _cfg()
|
| 1012 |
+
|
| 1013 |
+
# return model
|
optimizer.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
# from utils import d_model
|
| 7 |
+
|
| 8 |
+
'''
|
| 9 |
+
optimizer = NoamOpt(d_model, 1, 4000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
|
| 10 |
+
model_size: 256
|
| 11 |
+
factor: 1
|
| 12 |
+
warmup: 4000
|
| 13 |
+
optimizer: torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
|
| 14 |
+
'''
|
| 15 |
+
class NoamOpt:
|
| 16 |
+
"Optim wrapper that implements rate."
|
| 17 |
+
def __init__(self, model_size, factor, warmup, optimizer):
|
| 18 |
+
self.optimizer = optimizer
|
| 19 |
+
self._step = 0
|
| 20 |
+
self.warmup = warmup
|
| 21 |
+
self.factor = factor
|
| 22 |
+
self.model_size = model_size
|
| 23 |
+
self._rate = 0
|
| 24 |
+
|
| 25 |
+
def step(self):
|
| 26 |
+
"Update parameters and rate"
|
| 27 |
+
self._step += 1
|
| 28 |
+
rate = self.rate()
|
| 29 |
+
# print("This is the rate:", rate)
|
| 30 |
+
for p in self.optimizer.param_groups:
|
| 31 |
+
p['lr'] = rate
|
| 32 |
+
self._rate = rate
|
| 33 |
+
self.optimizer.step()
|
| 34 |
+
|
| 35 |
+
def rate(self, step = None):
|
| 36 |
+
"Implement `lrate` above"
|
| 37 |
+
if step is None:
|
| 38 |
+
step = self._step
|
| 39 |
+
return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))
|
| 40 |
+
|
| 41 |
+
def get_std_opt(model):
|
| 42 |
+
return NoamOpt(729, 2, 4000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
|
rope.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# EVA-02: A Visual Representation for Neon Genesis
|
| 3 |
+
# Github source: https://github.com/baaivision/EVA/EVA02
|
| 4 |
+
# Copyright (c) 2023 Beijing Academy of Artificial Intelligence (BAAI)
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# By Yuxin Fang
|
| 7 |
+
#
|
| 8 |
+
# Based on https://github.com/lucidrains/rotary-embedding-torch
|
| 9 |
+
# --------------------------------------------------------'
|
| 10 |
+
|
| 11 |
+
from math import pi
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
from einops import rearrange, repeat
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def broadcat(tensors, dim = -1):
|
| 21 |
+
num_tensors = len(tensors)
|
| 22 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
| 23 |
+
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
|
| 24 |
+
shape_len = list(shape_lens)[0]
|
| 25 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 26 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
| 27 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 28 |
+
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
|
| 29 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
| 30 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
| 31 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 32 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
| 33 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
| 34 |
+
return torch.cat(tensors, dim = dim)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def rotate_half(x):
|
| 39 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
| 40 |
+
x1, x2 = x.unbind(dim = -1)
|
| 41 |
+
x = torch.stack((-x2, x1), dim = -1)
|
| 42 |
+
return rearrange(x, '... d r -> ... (d r)')
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class VisionRotaryEmbedding(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
dim,
|
| 50 |
+
pt_seq_len,
|
| 51 |
+
ft_seq_len=None,
|
| 52 |
+
custom_freqs = None,
|
| 53 |
+
freqs_for = 'lang',
|
| 54 |
+
theta = 10000,
|
| 55 |
+
max_freq = 10,
|
| 56 |
+
num_freqs = 1,
|
| 57 |
+
):
|
| 58 |
+
super().__init__()
|
| 59 |
+
if custom_freqs:
|
| 60 |
+
freqs = custom_freqs
|
| 61 |
+
elif freqs_for == 'lang':
|
| 62 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 63 |
+
elif freqs_for == 'pixel':
|
| 64 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 65 |
+
elif freqs_for == 'constant':
|
| 66 |
+
freqs = torch.ones(num_freqs).float()
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
| 69 |
+
|
| 70 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
| 71 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 72 |
+
|
| 73 |
+
freqs_h = torch.einsum('..., f -> ... f', t, freqs)
|
| 74 |
+
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
|
| 75 |
+
|
| 76 |
+
freqs_w = torch.einsum('..., f -> ... f', t, freqs)
|
| 77 |
+
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
|
| 78 |
+
|
| 79 |
+
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
|
| 80 |
+
|
| 81 |
+
self.register_buffer("freqs_cos", freqs.cos())
|
| 82 |
+
self.register_buffer("freqs_sin", freqs.sin())
|
| 83 |
+
|
| 84 |
+
print('======== shape of rope freq', self.freqs_cos.shape, '========')
|
| 85 |
+
|
| 86 |
+
def forward(self, t, start_index = 0):
|
| 87 |
+
rot_dim = self.freqs_cos.shape[-1]
|
| 88 |
+
end_index = start_index + rot_dim
|
| 89 |
+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
| 90 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
| 91 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
| 92 |
+
return torch.cat((t_left, t, t_right), dim = -1)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
dim,
|
| 100 |
+
pt_seq_len=16,
|
| 101 |
+
ft_seq_len=None,
|
| 102 |
+
custom_freqs = None,
|
| 103 |
+
freqs_for = 'lang',
|
| 104 |
+
theta = 10000,
|
| 105 |
+
max_freq = 10,
|
| 106 |
+
num_freqs = 1,
|
| 107 |
+
):
|
| 108 |
+
super().__init__()
|
| 109 |
+
if custom_freqs:
|
| 110 |
+
freqs = custom_freqs
|
| 111 |
+
elif freqs_for == 'lang':
|
| 112 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 113 |
+
elif freqs_for == 'pixel':
|
| 114 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 115 |
+
elif freqs_for == 'constant':
|
| 116 |
+
freqs = torch.ones(num_freqs).float()
|
| 117 |
+
else:
|
| 118 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
| 119 |
+
|
| 120 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
| 121 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 122 |
+
|
| 123 |
+
freqs = torch.einsum('..., f -> ... f', t, freqs)
|
| 124 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
| 125 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
|
| 126 |
+
|
| 127 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
| 128 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
| 129 |
+
|
| 130 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
| 131 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
| 132 |
+
|
| 133 |
+
print('======== shape of rope freq', self.freqs_cos.shape, '========')
|
| 134 |
+
|
| 135 |
+
def forward(self, t):
|
| 136 |
+
if t.shape[1] % 2 != 0:
|
| 137 |
+
t_spatial = t[:, 1:, :]
|
| 138 |
+
t_spatial = t_spatial * self.freqs_cos + rotate_half(t_spatial) * self.freqs_sin
|
| 139 |
+
return torch.cat((t[:, :1, :], t_spatial), dim=1)
|
| 140 |
+
else:
|
| 141 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|