poult commited on
Commit
9e220d3
·
verified ·
1 Parent(s): fda718d

Upload 54 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. __pycache__/ecg_dataset_2020.cpython-310.pyc +0 -0
  3. __pycache__/ecg_dataset_2021.cpython-310.pyc +0 -0
  4. __pycache__/engine_ecg_2020.cpython-310.pyc +0 -0
  5. __pycache__/engine_ecg_2021.cpython-310.pyc +0 -0
  6. __pycache__/evaluate_12ECG_score.cpython-310.pyc +0 -0
  7. __pycache__/evaluate_model.cpython-310.pyc +0 -0
  8. __pycache__/helper_code.cpython-310.pyc +0 -0
  9. __pycache__/losses.cpython-310.pyc +0 -0
  10. __pycache__/models_mamba_ecg.cpython-310.pyc +0 -0
  11. __pycache__/optimizer.cpython-310.pyc +0 -0
  12. __pycache__/rope.cpython-310.pyc +0 -0
  13. __pycache__/utils.cpython-310.pyc +0 -0
  14. csv-file_2020_challenge/training_validation_testing/group1/testing_group1.csv +0 -0
  15. csv-file_2020_challenge/training_validation_testing/group1/train_validation_testing_fold5.csv +0 -0
  16. csv-file_2020_challenge/training_validation_testing/group1/training_group1.csv +0 -0
  17. csv-file_2020_challenge/training_validation_testing/group2/testing_group2.csv +0 -0
  18. csv-file_2020_challenge/training_validation_testing/group2/training_group2.csv +0 -0
  19. csv-file_2020_challenge/training_validation_testing/group3/testing_group3.csv +0 -0
  20. csv-file_2020_challenge/training_validation_testing/group3/training_group3.csv +0 -0
  21. csv-file_2020_challenge/training_validation_testing/group4/testing_group4.csv +0 -0
  22. csv-file_2020_challenge/training_validation_testing/group4/training_group4.csv +0 -0
  23. csv-file_2020_challenge/training_validation_testing/group5/testing_group5.csv +0 -0
  24. csv-file_2020_challenge/training_validation_testing/group5/training_group5.csv +0 -0
  25. csv-file_2021_challenge/collection_of_all_datasets.csv +3 -0
  26. csv-file_2021_challenge/name.csv +11 -0
  27. csv-file_2021_challenge/training_validation_testing/group1/testing_group1.csv +0 -0
  28. csv-file_2021_challenge/training_validation_testing/group1/training_group1.csv +3 -0
  29. csv-file_2021_challenge/training_validation_testing/group2/testing_group2.csv +0 -0
  30. csv-file_2021_challenge/training_validation_testing/group2/training_group2.csv +3 -0
  31. csv-file_2021_challenge/training_validation_testing/group3/testing_group3.csv +0 -0
  32. csv-file_2021_challenge/training_validation_testing/group3/training_group3.csv +3 -0
  33. csv-file_2021_challenge/training_validation_testing/group4/testing_group4.csv +0 -0
  34. csv-file_2021_challenge/training_validation_testing/group4/training_group4.csv +3 -0
  35. csv-file_2021_challenge/training_validation_testing/group5/testing_group5.csv +0 -0
  36. csv-file_2021_challenge/training_validation_testing/group5/training_group5.csv +3 -0
  37. ecg_dataset_2020.py +397 -0
  38. ecg_dataset_2021.py +422 -0
  39. ecg_dataset_2021_DAFirst.py +418 -0
  40. engine_ecg_2020.py +309 -0
  41. engine_ecg_2021.py +241 -0
  42. evaluate_12ECG_score.py +577 -0
  43. evaluate_model.py +434 -0
  44. helper.ipynb +182 -0
  45. helper_code.py +241 -0
  46. losses.py +70 -0
  47. main_ecg.py +509 -0
  48. models_mamba_ecg.py +1013 -0
  49. optimizer.py +42 -0
  50. rope.py +141 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ csv-file_2021_challenge/collection_of_all_datasets.csv filter=lfs diff=lfs merge=lfs -text
37
+ csv-file_2021_challenge/training_validation_testing/group1/training_group1.csv filter=lfs diff=lfs merge=lfs -text
38
+ csv-file_2021_challenge/training_validation_testing/group2/training_group2.csv filter=lfs diff=lfs merge=lfs -text
39
+ csv-file_2021_challenge/training_validation_testing/group3/training_group3.csv filter=lfs diff=lfs merge=lfs -text
40
+ csv-file_2021_challenge/training_validation_testing/group4/training_group4.csv filter=lfs diff=lfs merge=lfs -text
41
+ csv-file_2021_challenge/training_validation_testing/group5/training_group5.csv filter=lfs diff=lfs merge=lfs -text
__pycache__/ecg_dataset_2020.cpython-310.pyc ADDED
Binary file (7.46 kB). View file
 
__pycache__/ecg_dataset_2021.cpython-310.pyc ADDED
Binary file (9.87 kB). View file
 
__pycache__/engine_ecg_2020.cpython-310.pyc ADDED
Binary file (7.86 kB). View file
 
__pycache__/engine_ecg_2021.cpython-310.pyc ADDED
Binary file (7.44 kB). View file
 
__pycache__/evaluate_12ECG_score.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
__pycache__/evaluate_model.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
__pycache__/helper_code.cpython-310.pyc ADDED
Binary file (6.37 kB). View file
 
__pycache__/losses.cpython-310.pyc ADDED
Binary file (2.34 kB). View file
 
__pycache__/models_mamba_ecg.cpython-310.pyc ADDED
Binary file (19.1 kB). View file
 
__pycache__/optimizer.cpython-310.pyc ADDED
Binary file (1.47 kB). View file
 
__pycache__/rope.cpython-310.pyc ADDED
Binary file (5 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (8.37 kB). View file
 
csv-file_2020_challenge/training_validation_testing/group1/testing_group1.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2020_challenge/training_validation_testing/group1/train_validation_testing_fold5.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2020_challenge/training_validation_testing/group1/training_group1.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2020_challenge/training_validation_testing/group2/testing_group2.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2020_challenge/training_validation_testing/group2/training_group2.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2020_challenge/training_validation_testing/group3/testing_group3.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2020_challenge/training_validation_testing/group3/training_group3.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2020_challenge/training_validation_testing/group4/testing_group4.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2020_challenge/training_validation_testing/group4/training_group4.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2020_challenge/training_validation_testing/group5/testing_group5.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2020_challenge/training_validation_testing/group5/training_group5.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2021_challenge/collection_of_all_datasets.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:455156a539d4dde56aa9c2d1e5e460a72df64bda32b8ba36a6692837520292ac
3
+ size 12635911
csv-file_2021_challenge/name.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Name,Target,Label_code,Frequency,Duration,Age,Sex
2
+ JS10647.hea,"[0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
3
+ 1. 0.]","['426177001', '713427006', '164934002', '39732003']",500.0,10.0,82.0,Male
4
+ JS10648.hea,"[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
5
+ 0. 0.]","['426177001', '713426002']",500.0,10.0,63.0,Male
6
+ JS10649.hea,"[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
7
+ 0. 0.]","['111975006', '713426002', '426761007', '55930002']",500.0,10.0,83.0,Male
8
+ JS10650.hea,"[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
9
+ 0. 0.]",['29320008'],500.0,10.0,81.0,Female
10
+ JS10651.hea,"[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
11
+ 0. 0.]","['233897008', '10370003']",500.0,10.0,63.0,Female
csv-file_2021_challenge/training_validation_testing/group1/testing_group1.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2021_challenge/training_validation_testing/group1/training_group1.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd7406d2f748a0a5edd033b389fd967fc9a232d2421ed9be69adcccf964db7a1
3
+ size 10522513
csv-file_2021_challenge/training_validation_testing/group2/testing_group2.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2021_challenge/training_validation_testing/group2/training_group2.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db819a51f741bbb0bc88878f4adbec1a16065eaf1cb991758dab8a42aef34fce
3
+ size 10523820
csv-file_2021_challenge/training_validation_testing/group3/testing_group3.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2021_challenge/training_validation_testing/group3/training_group3.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03f7959deaa6cbe865b911c103a1fae71210cce4f5ac99c6920dc2efb4beb18a
3
+ size 10522844
csv-file_2021_challenge/training_validation_testing/group4/testing_group4.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2021_challenge/training_validation_testing/group4/training_group4.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:125d50af59c4b1862aefcf389b5a7eda5819bdc2e66ce52f43097f9f60661b25
3
+ size 10525392
csv-file_2021_challenge/training_validation_testing/group5/testing_group5.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv-file_2021_challenge/training_validation_testing/group5/training_group5.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c6b92de27c915999cf59125328b7f2d73a9b33ef8939f04d942b17d1cb7630a
3
+ size 10522780
ecg_dataset_2020.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ from helper_code import *
4
+ from scipy import signal
5
+ import pandas as pd
6
+ from skmultilearn.model_selection import iterative_train_test_split
7
+ import random
8
+ import warnings
9
+ import sys
10
+
11
+
12
+ def get_nsamp(header):
13
+ return int(header.split('\n')[0].split(' ')[3])
14
+
15
+ # Adapted from original scoring function code
16
+ # For each set of equivalent classes, replace each class with the representative class for the set.
17
+
18
+
19
+ '''
20
+ the 'output' will be using the zero-padding if the leads are less than 12.(8 leads)
21
+ the 'output_leads' will be [1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
22
+
23
+ the output will not be changed if the lead is 12
24
+ the 'output_leads' will be [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
25
+ '''
26
+ def mixup(data, mix_data):
27
+ mix_lambda = np.random.beta(10, 10)# 这里利用beta分布得到一个0到1的值,这里如果可以,尝试别的分布函数
28
+ if data.shape[1] > mix_data.shape[1]:
29
+ data = mix_lambda * data[:, :mix_data.shape[1]] + (1 - mix_lambda) * mix_data
30
+ else:
31
+ data = mix_lambda * data + (1 - mix_lambda) * mix_data[:, :data.shape[1]]
32
+ return data, mix_lambda
33
+
34
+ def cutmix(data, mix_data):
35
+ cutmix_lambda = np.random.beta(10, 10)
36
+
37
+ if data.shape[1] < mix_data.shape[1]:
38
+ seq_length = data.shape[1]
39
+ win_len = int(np.ceil(seq_length * cutmix_lambda))
40
+ start = np.random.randint(0, seq_length - win_len + 1)
41
+ end = start + win_len
42
+ data[:,start:end] = mix_data[:,start:end]
43
+ else:
44
+ seq_length = mix_data.shape[1]
45
+ win_len = int(np.ceil(seq_length * cutmix_lambda))
46
+ start = np.random.randint(0, seq_length - win_len + 1)
47
+ end = start + win_len
48
+ data[:,start:end] = mix_data[:,start:end]
49
+ return data, cutmix_lambda
50
+
51
+ def cutmix_fix_length(data, mix_data):
52
+ # cutmix_lambda = np.random.beta(0.2, 0.2)
53
+ cutmix_lambda = 0.2
54
+
55
+ if data.shape[1] < mix_data.shape[1]:
56
+ seq_length = data.shape[1]
57
+ win_len = int(np.ceil(seq_length * cutmix_lambda))
58
+ start = np.random.randint(0, seq_length - win_len + 1)
59
+ end = start + win_len
60
+ data[:,start:end] = mix_data[:,start:end]
61
+ else:
62
+ seq_length = mix_data.shape[1]
63
+ win_len = int(np.ceil(seq_length * cutmix_lambda))
64
+ start = np.random.randint(0, seq_length - win_len + 1)
65
+ end = start + win_len
66
+ data[:,start:end] = mix_data[:,start:end]
67
+ return data, cutmix_lambda
68
+
69
+ def same_shape_mixup(data, mix_data):
70
+ if data.shape[1] == mix_data.shape[1]:
71
+ mix_lambda = np.random.beta(10, 10)# 这里利用beta分布得到一个0到1的值,这里如果可以,尝试别的分布函数
72
+ data = mix_lambda * data + (1 - mix_lambda) * mix_data
73
+ return data
74
+
75
+ def expand_leads(recording, input_leads):
76
+ output = np.zeros((12, recording.shape[1]))
77
+ # recording.shape[1]: 5000
78
+ twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
79
+ twelve_leads = [k.lower() for k in twelve_leads]
80
+ # ['i', 'ii', 'iii', 'avr', 'avl', 'avf', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6']
81
+
82
+ input_leads = [k.lower() for k in input_leads]
83
+ # Here we can assume:
84
+ # input_leads:I, II, V1, V2, V3, V4, V5, V6,
85
+ # so the new input_leads: ['i', 'ii', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6']
86
+
87
+ output_leads = np.zeros((12,))
88
+ # [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
89
+
90
+ # idx: [0, 1, 6, 7, 8, 9, 10, 11]
91
+ for i,k in enumerate(input_leads):
92
+ idx = twelve_leads.index(k)
93
+ output[idx,:] = recording[i,:]
94
+ output_leads[idx] = 1
95
+
96
+ return output, output_leads
97
+
98
+ '''
99
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
100
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
101
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
102
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
103
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
104
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
105
+
106
+ '''
107
+
108
+ class lead_exctractor:
109
+ """
110
+ used to select specific leads or random choice of configurations
111
+
112
+ Twelve leads: I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6
113
+ Eight leads: I, II, V1, V2, V3, V4, V5, V6
114
+ Six leads: I, II, III, aVR, aVL, aVF
115
+ Four leads: I, II, III, V2
116
+ Three leads: I, II, V2
117
+ Two leads: I, II
118
+
119
+ """
120
+ L2 = np.array([1,1,0,0,0,0,0,0,0,0,0,0])
121
+ L3 = np.array([1,1,0,0,0,0,0,1,0,0,0,0])
122
+ L4 = np.array([1,1,1,0,0,0,0,1,0,0,0,0])
123
+ L6 = np.array([1,1,1,1,1,1,0,0,0,0,0,0])
124
+ L8 = np.array([1,1,0,0,0,0,1,1,1,1,1,1])
125
+ L12 = np.array([1,1,1,1,1,1,1,1,1,1,1,1])
126
+
127
+ @staticmethod
128
+ def get (x, num_leads, lead_indicator):
129
+ if num_leads==None:
130
+ # random choice output
131
+ num_leads = random.choice([12,8,6,4,3,2])
132
+
133
+ if num_leads==12:
134
+ # Twelve leads: I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6
135
+ return x, lead_indicator * lead_exctractor.L12
136
+
137
+ if num_leads==8:
138
+ # Six leads: I, II, V1, V2, V3, V4, V5, V6
139
+ x = x * lead_exctractor.L8.reshape(12,1)
140
+ return x,lead_indicator * lead_exctractor.L8
141
+
142
+ if num_leads==6:
143
+ # Six leads: I, II, III, aVL, aVR, aVF
144
+ x = x * lead_exctractor.L6.reshape(12,1)
145
+ return x,lead_indicator * lead_exctractor.L6
146
+
147
+ if num_leads==4:
148
+ # Six leads: I, II, III, V2
149
+ x = x * lead_exctractor.L4.reshape(12,1)
150
+ return x,lead_indicator * lead_exctractor.L4
151
+
152
+ if num_leads==3:
153
+ # Three leads: I, II, V2
154
+ x = x * lead_exctractor.L3.reshape(12,1)
155
+ return x,lead_indicator * lead_exctractor.L3
156
+
157
+ if num_leads==2:
158
+ # Two leads: II, V5
159
+ x = x * lead_exctractor.L2.reshape(12,1)
160
+ return x,lead_indicator * lead_exctractor.L2
161
+ raise Exception("invalid-leads-number")
162
+
163
+ class dataset:
164
+ # #ECG 2021
165
+ # classes = ['164889003','164890007','6374002','426627000','733534002',
166
+ # '713427006','270492004','713426002','39732003','445118002',
167
+ # '164947007','251146004','111975006','698252002','426783006',
168
+ # '284470004','10370003','365413008','427172004','164917005',
169
+ # '47665007','427393009','426177001','427084000','164934002',
170
+ # '59931005']
171
+
172
+ # ECG 2020
173
+
174
+
175
+
176
+ normal_class = '426783006'
177
+
178
+ def __init__(self, header_files, Mixup = 0, amount = 0, cutMix = 0, Mixup_no_label_interpolate =0, progressive_switch = False):
179
+ self.files = []
180
+ self.sample = True
181
+ self.num_leads = None
182
+
183
+ for h in tqdm(header_files):
184
+ tmp = dict()
185
+ tmp['header'] = h
186
+ tmp['record'] = h.replace('.hea','.mat')
187
+ hdr = load_header(h)
188
+ tmp['nsamp'] = get_nsamp(hdr)
189
+ tmp['leads'] = get_leads(hdr)
190
+ tmp['age'] = get_age(hdr)
191
+ tmp['sex'] = get_sex(hdr)
192
+ tmp['dx'] = get_labels(hdr)
193
+ tmp['fs'] = get_frequency(hdr)
194
+ # tmp['target'] = np.zeros((26,))
195
+
196
+
197
+ self.files.append(tmp)
198
+
199
+ # print("This is the target:", tmp['target'])
200
+ # set filter parameters
201
+ self.b, self.a = signal.butter(3, [1 / 250, 47 / 250], 'bandpass')
202
+
203
+ self.files = pd.DataFrame(self.files)
204
+
205
+ self.Mixup = Mixup
206
+ self.cutMix = cutMix
207
+ self.Mixup_no_label_interplate = Mixup_no_label_interpolate
208
+ self.amount = amount
209
+ self.progressive = progressive_switch
210
+
211
+ self.current_epoch = 0 # Initialize the current epoch
212
+
213
+ address_2020 = "./csv-file_2020_challenge/training_validation_testing/group1/train_validation_testing_fold5.csv"
214
+ self.data_df_2020 = pd.read_csv(address_2020, index_col=0)
215
+ self.data_df_2020.set_index('Patient', inplace=True)
216
+ self.classes_2020 = ['270492004', '164889003', '164890007', '426627000', '713427006',
217
+ '713426002', '445118002', '39732003', '164909002', '251146004',
218
+ '698252002', '10370003', '284470004', '427172004', '164947007',
219
+ '111975006', '164917005', '47665007', '59118001', '427393009',
220
+ '426177001', '426783006', '427084000', '63593006', '164934002',
221
+ '59931005', '17338001']
222
+
223
+
224
+ def set_epoch(self, epoch):
225
+ self.current_epoch = epoch
226
+
227
+ def summary(self, output):
228
+ if output=='pandas':
229
+ # print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
230
+ # print("This is the target:",self.files['target'],type(self.files['target']))
231
+ return pd.Series(np.stack(self.files['target'].to_list(),axis=0).sum(axis=0),index=dataset.classes)
232
+
233
+ if output=='numpy':
234
+ return np.stack(self.files['target'].to_list(),axis=0).sum(axis=0)
235
+
236
+ def __len__(self):
237
+ return len(self.files)
238
+
239
+ '''
240
+ fs: 500.0 sampling rate
241
+ target: [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
242
+ leads: ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
243
+ self.files.iloc[item]['record']: ../../python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00518.mat
244
+ data: shape->(12 X 5000)
245
+ '''
246
+
247
+ def __getitem__(self, item):
248
+
249
+ fs = self.files.iloc[item]['fs']
250
+ # target = self.files.iloc[item]['target']
251
+ header = self.files.iloc[item]['header']
252
+ name = header.split('/')[-1]
253
+ row = self.data_df_2020.loc[name[0:-4]]
254
+ target = row[self.classes_2020].values.astype(np.int_)
255
+
256
+ leads = self.files.iloc[item]['leads']
257
+ data = load_recording(self.files.iloc[item]['record'])
258
+
259
+ # print("This is the target:", target)
260
+ # Set your threshold
261
+
262
+ threshold = 0.5
263
+ # print("This is the original target:", target)
264
+ # print("This is the type of original target:", type(target))
265
+
266
+ # if random.random() < self.Mixup :
267
+ # mix_sample_idx = random.randint(0, self.amount-1)
268
+ # mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
269
+ # mix_target = self.files.iloc[mix_sample_idx]['target']
270
+ # data, alpha_value = mixup(data, mix_datum)
271
+ # target = alpha_value * target + (1-alpha_value) * mix_target
272
+ # # Binarize the labels based on the threshold
273
+ # target[target >= threshold] = 1
274
+ # target[target < threshold] = 0
275
+
276
+
277
+ # if random.random() < self.cutMix:
278
+ # mix_sample_idx = random.randint(0, self.amount-1)
279
+ # mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
280
+ # mix_target = self.files.iloc[mix_sample_idx]['target']
281
+ # data, alpha_value = cutmix_fix_length(data, mix_datum)
282
+ # target = alpha_value * target + (1-alpha_value) * mix_target
283
+ # target[target >= threshold] = 1
284
+ # target[target < threshold] = 0
285
+
286
+ # if random.random()< self.Mixup_no_label_interplate:
287
+ # mix_sample_idx = random.randint(0, self.amount-1)
288
+ # mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
289
+ # data= same_shape_mixup(data, mix_datum)
290
+
291
+ # proressive_index = 0
292
+
293
+ # # # below is the progessive data augmentation method
294
+ # # if self.progressive:
295
+ # # if self.current_epoch < 20:
296
+ # # if self.current_epoch % 4 == 0 or 3:
297
+ # # pass
298
+ # # elif self.current_epoch % 4 == 1:
299
+ # # proressive_index = 0.2
300
+ # # else:
301
+ # # proressive_index = 0.5
302
+
303
+ # # elif self.current_epoch < 40:
304
+ # # if self.current_epoch % 4 == 0 or 3:
305
+ # # pass
306
+ # # elif self.current_epoch % 4 == 1:
307
+ # # proressive_index = 0.7
308
+ # # else:
309
+ # # proressive_index = 0.8
310
+ # # else:
311
+ # # if self.current_epoch % 4 == 0 or 3:
312
+ # # pass
313
+ # # elif self.current_epoch % 4 == 1 or 2:
314
+ # # proressive_index = 1
315
+
316
+ # # below is the wave-mix data augmentation method
317
+ # if self.progressive:
318
+ # if self.current_epoch < 5:
319
+ # match self.current_epoch:
320
+ # case 0:
321
+ # self.progressive=0
322
+ # case 1:
323
+ # self.progressive=0.2
324
+ # case 2:
325
+ # self.progressive=0.4
326
+ # case 3:
327
+ # self.progressive=0.6
328
+ # case 4:
329
+ # self.progressive=0.8
330
+
331
+ # else :
332
+ # proressive_index = 0.8
333
+
334
+ # if random.random() < proressive_index :
335
+ # mix_sample_idx = random.randint(0, self.amount-1)
336
+ # mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
337
+ # mix_target = self.files.iloc[mix_sample_idx]['target']
338
+ # data, alpha_value = mixup(data, mix_datum)
339
+ # target = alpha_value * target + (1-alpha_value) * mix_target
340
+ # # Binarize the labels based on the threshold
341
+ # target[target >= threshold] = 1
342
+ # target[target < threshold] = 0
343
+
344
+
345
+ # expand to 12 lead setup if original signal has less channels
346
+ data, lead_indicator = expand_leads(data, input_leads=leads)
347
+ data = np.nan_to_num(data)
348
+
349
+
350
+ # resample to 500hz
351
+ if fs == float(1000):
352
+ data = signal.resample_poly(data, up=1, down=2, axis=-1) # to 500Hz
353
+ fs = 500
354
+ elif fs == float(500):
355
+ pass
356
+ else:
357
+ data = signal.resample(data, int(data.shape[1] * 500 / fs), axis=1)
358
+ fs = 500
359
+
360
+ # below is the Butterworth digital and analog filter design.
361
+ data = signal.filtfilt(self.b, self.a, data)
362
+
363
+ '''
364
+ we filter out the sample which the length is 8192 via the below code.
365
+ just like the random shift windows.
366
+ '''
367
+ if self.sample:
368
+ fs = int(fs)
369
+ # random sample signal if len > 8192 samples
370
+ if data.shape[-1] >= 8192:
371
+ idx = data.shape[-1] - 8192-1
372
+ idx = np.random.randint(idx)
373
+ data = data[:, idx:idx + 8192]
374
+
375
+ '''
376
+ we can obtain the average and mean for each lead via below code, 90% conclusion.
377
+ '''
378
+ mu = np.nanmean(data, axis=-1, keepdims=True)
379
+ std = np.nanstd(data, axis=-1, keepdims=True)
380
+
381
+ '''
382
+ Z-Score Normalization
383
+ '''
384
+ #std = np.nanstd(data.flatten())
385
+ with warnings.catch_warnings():
386
+ warnings.simplefilter("ignore")
387
+ data = (data - mu) / std
388
+
389
+ data = np.nan_to_num(data)
390
+
391
+ # random choose number of leads to keep
392
+ data, lead_indicator = lead_exctractor.get(data, self.num_leads, lead_indicator)
393
+
394
+ # print("This is the lead_indicator:(l)", type(lead_indicator), lead_indicator)
395
+
396
+
397
+ return data, target
ecg_dataset_2021.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ from helper_code import *
4
+ from scipy import signal
5
+ import pandas as pd
6
+ from skmultilearn.model_selection import iterative_train_test_split
7
+ import random
8
+ import warnings
9
+ import sys
10
+
11
+
12
+ def get_nsamp(header):
13
+ return int(header.split('\n')[0].split(' ')[3])
14
+
15
+ # Adapted from original scoring function code
16
+ # For each set of equivalent classes, replace each class with the representative class for the set.
17
+ def replace_equivalent_classes(classes, equivalent_classes):
18
+ for j, x in enumerate(classes):
19
+ for multiple_classes in equivalent_classes:
20
+ if x in multiple_classes:
21
+ classes[j] = multiple_classes[0] # Use the first class as the representative class.
22
+ return classes
23
+
24
+ '''
25
+ the 'output' will be using the zero-padding if the leads are less than 12.(8 leads)
26
+ the 'output_leads' will be [1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
27
+
28
+ the output will not be changed if the lead is 12
29
+ the 'output_leads' will be [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
30
+ '''
31
+ def mixup(data, mix_data):
32
+ mix_lambda = np.random.beta(10, 10)# 这里利用beta分布得到一个0到1的值,这里如果可以,尝试别的分布函数
33
+ if data.shape[1] > mix_data.shape[1]:
34
+ data = mix_lambda * data[:, :mix_data.shape[1]] + (1 - mix_lambda) * mix_data
35
+ else:
36
+ data = mix_lambda * data + (1 - mix_lambda) * mix_data[:, :data.shape[1]]
37
+ return data, mix_lambda
38
+
39
+ def cutmix(data, mix_data):
40
+ cutmix_lambda = np.random.beta(10, 10)
41
+
42
+ if data.shape[1] < mix_data.shape[1]:
43
+ seq_length = data.shape[1]
44
+ win_len = int(np.ceil(seq_length * cutmix_lambda))
45
+ start = np.random.randint(0, seq_length - win_len + 1)
46
+ end = start + win_len
47
+ data[:,start:end] = mix_data[:,start:end]
48
+ else:
49
+ seq_length = mix_data.shape[1]
50
+ win_len = int(np.ceil(seq_length * cutmix_lambda))
51
+ start = np.random.randint(0, seq_length - win_len + 1)
52
+ end = start + win_len
53
+ data[:,start:end] = mix_data[:,start:end]
54
+ return data, cutmix_lambda
55
+
56
+ def cutmix_fix_length(data, mix_data):
57
+ # cutmix_lambda = np.random.beta(0.2, 0.2)
58
+ cutmix_lambda = 0.2
59
+
60
+ if data.shape[1] < mix_data.shape[1]:
61
+ seq_length = data.shape[1]
62
+ win_len = int(np.ceil(seq_length * cutmix_lambda))
63
+ start = np.random.randint(0, seq_length - win_len + 1)
64
+ end = start + win_len
65
+ data[:,start:end] = mix_data[:,start:end]
66
+ else:
67
+ seq_length = mix_data.shape[1]
68
+ win_len = int(np.ceil(seq_length * cutmix_lambda))
69
+ start = np.random.randint(0, seq_length - win_len + 1)
70
+ end = start + win_len
71
+ data[:,start:end] = mix_data[:,start:end]
72
+ return data, cutmix_lambda
73
+
74
+ def same_shape_mixup(data, mix_data):
75
+ if data.shape[1] == mix_data.shape[1]:
76
+ mix_lambda = np.random.beta(10, 10)# 这里利用beta分布得到一个0到1的值,这里如果可以,尝试别的分布函数
77
+ data = mix_lambda * data + (1 - mix_lambda) * mix_data
78
+ return data
79
+
80
+ def expand_leads(recording, input_leads):
81
+ output = np.zeros((12, recording.shape[1]))
82
+ # recording.shape[1]: 5000
83
+ twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
84
+ twelve_leads = [k.lower() for k in twelve_leads]
85
+ # ['i', 'ii', 'iii', 'avr', 'avl', 'avf', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6']
86
+
87
+ input_leads = [k.lower() for k in input_leads]
88
+ # Here we can assume:
89
+ # input_leads:I, II, V1, V2, V3, V4, V5, V6,
90
+ # so the new input_leads: ['i', 'ii', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6']
91
+
92
+ output_leads = np.zeros((12,))
93
+ # [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
94
+
95
+ # idx: [0, 1, 6, 7, 8, 9, 10, 11]
96
+ for i,k in enumerate(input_leads):
97
+ idx = twelve_leads.index(k)
98
+ output[idx,:] = recording[i,:]
99
+ output_leads[idx] = 1
100
+
101
+ return output, output_leads
102
+
103
+ '''
104
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
105
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
106
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
107
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
108
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
109
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
110
+
111
+ '''
112
+
113
+ class lead_exctractor:
114
+ """
115
+ used to select specific leads or random choice of configurations
116
+
117
+ Twelve leads: I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6
118
+ Eight leads: I, II, V1, V2, V3, V4, V5, V6
119
+ Six leads: I, II, III, aVR, aVL, aVF
120
+ Four leads: I, II, III, V2
121
+ Three leads: I, II, V2
122
+ Two leads: I, II
123
+
124
+ """
125
+ L2 = np.array([1,1,0,0,0,0,0,0,0,0,0,0])
126
+ L3 = np.array([1,1,0,0,0,0,0,1,0,0,0,0])
127
+ L4 = np.array([1,1,1,0,0,0,0,1,0,0,0,0])
128
+ L6 = np.array([1,1,1,1,1,1,0,0,0,0,0,0])
129
+ L8 = np.array([1,1,0,0,0,0,1,1,1,1,1,1])
130
+ L12 = np.array([1,1,1,1,1,1,1,1,1,1,1,1])
131
+
132
+ @staticmethod
133
+ def get (x, num_leads, lead_indicator):
134
+ if num_leads==None:
135
+ # random choice output
136
+ num_leads = random.choice([12,8,6,4,3,2])
137
+
138
+ if num_leads==12:
139
+ # Twelve leads: I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6
140
+ return x, lead_indicator * lead_exctractor.L12
141
+
142
+ if num_leads==8:
143
+ # Six leads: I, II, V1, V2, V3, V4, V5, V6
144
+ x = x * lead_exctractor.L8.reshape(12,1)
145
+ return x,lead_indicator * lead_exctractor.L8
146
+
147
+ if num_leads==6:
148
+ # Six leads: I, II, III, aVL, aVR, aVF
149
+ x = x * lead_exctractor.L6.reshape(12,1)
150
+ return x,lead_indicator * lead_exctractor.L6
151
+
152
+ if num_leads==4:
153
+ # Six leads: I, II, III, V2
154
+ x = x * lead_exctractor.L4.reshape(12,1)
155
+ return x,lead_indicator * lead_exctractor.L4
156
+
157
+ if num_leads==3:
158
+ # Three leads: I, II, V2
159
+ x = x * lead_exctractor.L3.reshape(12,1)
160
+ return x,lead_indicator * lead_exctractor.L3
161
+
162
+ if num_leads==2:
163
+ # Two leads: II, V5
164
+ x = x * lead_exctractor.L2.reshape(12,1)
165
+ return x,lead_indicator * lead_exctractor.L2
166
+ raise Exception("invalid-leads-number")
167
+
168
+ class dataset:
169
+ classes = ['164889003','164890007','6374002','426627000','733534002',
170
+ '713427006','270492004','713426002','39732003','445118002',
171
+ '164947007','251146004','111975006','698252002','426783006',
172
+ '284470004','10370003','365413008','427172004','164917005',
173
+ '47665007','427393009','426177001','427084000','164934002',
174
+ '59931005']
175
+ normal_class = '426783006'
176
+ equivalent_classes = [['713427006', '59118001'],
177
+ ['284470004', '63593006'],
178
+ ['427172004', '17338001'],
179
+ ['733534002', '164909002']]
180
+ def __init__(self, header_files, Mixup = 0, amount = 0, cutMix = 0, Mixup_no_label_interpolate =0, progressive_switch = False):
181
+ self.files = []
182
+ self.sample = True
183
+ self.num_leads = None
184
+
185
+ for h in tqdm(header_files):
186
+ tmp = dict()
187
+ tmp['header'] = h
188
+ tmp['record'] = h.replace('.hea','.mat')
189
+ hdr = load_header(h)
190
+ tmp['nsamp'] = get_nsamp(hdr)
191
+ tmp['leads'] = get_leads(hdr)
192
+ tmp['age'] = get_age(hdr)
193
+ tmp['sex'] = get_sex(hdr)
194
+ tmp['dx'] = get_labels(hdr)
195
+ tmp['fs'] = get_frequency(hdr)
196
+ tmp['target'] = np.zeros((26,))
197
+ tmp['dx'] = replace_equivalent_classes(tmp['dx'], dataset.equivalent_classes)
198
+
199
+ for dx in tmp['dx']:
200
+ # in SNOMED code is in scored classes
201
+ if dx in dataset.classes:
202
+ idx = dataset.classes.index(dx)
203
+ tmp['target'][idx] = 1
204
+
205
+ self.files.append(tmp)
206
+
207
+ # print("This is the target:", tmp['target'])
208
+ # set filter parameters
209
+
210
+ # Filtering: Data are filtered using a zero-phase method with 3rd order Butterworth bandpass filter
211
+ # with frequency band from 1 Hz to 47 Hz.
212
+ self.b, self.a = signal.butter(3, [1 / 250, 47 / 250], 'bandpass')
213
+
214
+ self.files = pd.DataFrame(self.files)
215
+
216
+ self.Mixup = Mixup
217
+ self.cutMix = cutMix
218
+ self.Mixup_no_label_interplate = Mixup_no_label_interpolate
219
+ self.amount = amount
220
+ self.progressive = progressive_switch
221
+
222
+ self.current_epoch = 0 # Initialize the current epoch
223
+
224
+ def set_epoch(self, epoch):
225
+ self.current_epoch = epoch
226
+
227
+ def train_valid_split(self, test_size):
228
+ '''
229
+ test_size: 0.2
230
+ below is the value of each variable:
231
+ files[0] <- "/media/jiang/ECG/physionet.org/files/challenge-2021/1.0.3/training/python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00001.hea" shape -> (999, 1)
232
+ targets[1] <-[1. 0. 0. ... 0. 1. 0.] 26, shape->(999, 26)
233
+ x_train[0] <-"/media/jiang/ECG/physionet.org/files/challenge-2021/1.0.3/training/python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00001.hea" shape -> (799,1)
234
+ x_valid[0] <-"/media/jiang/ECG/physionet.org/files/challenge-2021/1.0.3/training/python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00651.hea" shape -> (200,1)
235
+ '''
236
+ files = self.files['header'].to_numpy().reshape(-1,1)
237
+ # print(files[0])
238
+ targets = np.stack(self.files['target'].to_list(), axis=0)
239
+ print("This is the targets:", targets)
240
+
241
+ x_train, y_train, x_valid, y_valid = iterative_train_test_split(files, targets, test_size=test_size)
242
+
243
+ train = dataset(header_files=x_train[:,0].tolist())
244
+
245
+ train.num_leads=None
246
+ train.sample=True
247
+
248
+ valid = dataset(header_files=x_valid[:,0].tolist())
249
+
250
+ valid.num_leads=12
251
+ valid.sample=False
252
+
253
+ return train, valid
254
+
255
+ def summary(self, output):
256
+ if output=='pandas':
257
+ # print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
258
+ # print("This is the target:",self.files['target'],type(self.files['target']))
259
+ return pd.Series(np.stack(self.files['target'].to_list(),axis=0).sum(axis=0),index=dataset.classes)
260
+
261
+ if output=='numpy':
262
+ return np.stack(self.files['target'].to_list(),axis=0).sum(axis=0)
263
+
264
+ def Pre_processing(self, fs, leads, data):
265
+
266
+ # 1. Expand to a 12-lead setup if the original signal has fewer channels
267
+ data, lead_indicator = expand_leads(data, input_leads=leads)
268
+ data = np.nan_to_num(data)
269
+
270
+
271
+ # 2. Resample from 1000 Hz and other rates to 500 Hz
272
+ if fs == float(1000):
273
+ data = signal.resample_poly(data, up=1, down=2, axis=-1) # to 500Hz
274
+ fs = 500
275
+ elif fs == float(500):
276
+ pass
277
+ else:
278
+ data = signal.resample(data, int(data.shape[1] * 500 / fs), axis=1)
279
+ fs = 500
280
+
281
+ # 3. Below is the design of the Butterworth digital and analog filter
282
+ data = signal.filtfilt(self.b, self.a, data)
283
+
284
+
285
+ # 4. We compute the average and mean for each lead using the code below; this is Z-score normalization
286
+ mu = np.nanmean(data, axis=-1, keepdims=True)
287
+ std = np.nanstd(data, axis=-1, keepdims=True)
288
+
289
+ #std = np.nanstd(data.flatten())
290
+ with warnings.catch_warnings():
291
+ warnings.simplefilter("ignore")
292
+ data = (data - mu) / std
293
+
294
+ data = np.nan_to_num(data)
295
+
296
+ # 5. Selection of leads for 12-lead ECGs; the default is 12 leads
297
+ data, lead_indicator = lead_exctractor.get(data, self.num_leads, lead_indicator)
298
+
299
+
300
+ # 6. We filter out samples with a length of 8192 using the code below, similar to random shift windows, and zero-padding as well
301
+ if self.sample:
302
+ fs = int(fs)
303
+ # random sample signal if len > 8192 samples
304
+ if data.shape[-1] >= 8192:
305
+ idx = data.shape[-1] - 8192-1
306
+ idx = np.random.randint(idx)
307
+ data = data[:, idx:idx + 8192]
308
+ else:
309
+ # Apply right zero-padding along the second dimension
310
+ padding_size = 8192 - data.shape[-1]
311
+ data = np.pad(data, ((0, 0), (0, padding_size)), mode='constant', constant_values=0)
312
+
313
+ return data
314
+
315
+
316
+ def __len__(self):
317
+ return len(self.files)
318
+
319
+ '''
320
+ fs: 500.0 sampling rate
321
+ target: [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
322
+ leads: ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
323
+ self.files.iloc[item]['record']: ../../python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00518.mat
324
+ data: shape->(12 X 5000)
325
+ '''
326
+
327
+ def __getitem__(self, item):
328
+
329
+ fs = self.files.iloc[item]['fs']
330
+ target = self.files.iloc[item]['target']
331
+ leads = self.files.iloc[item]['leads']
332
+ data = load_recording(self.files.iloc[item]['record'])
333
+
334
+ # the phase of pre-processing:
335
+ data = self.Pre_processing(fs, leads, data)
336
+
337
+ '''
338
+ below are the data augmentation methods
339
+ '''
340
+ if random.random() < self.Mixup :
341
+ mix_sample_idx = random.randint(0, self.amount-1)
342
+ mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
343
+ mix_target = self.files.iloc[mix_sample_idx]['target']
344
+
345
+ mix_fs = self.files.iloc[mix_sample_idx]['fs']
346
+ mix_leads = self.files.iloc[mix_sample_idx]['leads']
347
+ mix_datum = self.Pre_processing(mix_fs, mix_leads, mix_datum)
348
+
349
+ data, alpha_value = mixup(data, mix_datum)
350
+ target = alpha_value * target + (1-alpha_value) * mix_target
351
+
352
+
353
+ if random.random() < self.cutMix:
354
+ mix_sample_idx = random.randint(0, self.amount-1)
355
+ mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
356
+ mix_target = self.files.iloc[mix_sample_idx]['target']
357
+
358
+ mix_fs = self.files.iloc[mix_sample_idx]['fs']
359
+ mix_leads = self.files.iloc[mix_sample_idx]['leads']
360
+ mix_datum = self.Pre_processing(mix_fs, mix_leads, mix_datum)
361
+
362
+ data, alpha_value = cutmix(data, mix_datum)
363
+ target = alpha_value * target + (1-alpha_value) * mix_target
364
+
365
+
366
+ progressive_index = 0
367
+
368
+ # # below is the progessive data augmentation method
369
+ # if self.progressive:
370
+ # if self.current_epoch < 20:
371
+ # if self.current_epoch % 4 == 0 or 3:
372
+ # pass
373
+ # elif self.current_epoch % 4 == 1:
374
+ # proressive_index = 0.2
375
+ # else:
376
+ # proressive_index = 0.5
377
+
378
+ # elif self.current_epoch < 40:
379
+ # if self.current_epoch % 4 == 0 or 3:
380
+ # pass
381
+ # elif self.current_epoch % 4 == 1:
382
+ # proressive_index = 0.7
383
+ # else:
384
+ # proressive_index = 0.8
385
+ # else:
386
+ # if self.current_epoch % 4 == 0 or 3:
387
+ # pass
388
+ # elif self.current_epoch % 4 == 1 or 2:
389
+ # proressive_index = 1
390
+
391
+ # below is the wave-mix data augmentation method
392
+ if self.progressive:
393
+ if self.current_epoch < 5:
394
+ match self.current_epoch:
395
+ case 0:
396
+ progressive_index=0.8
397
+ case 1:
398
+ progressive_index=0.8
399
+ case 2:
400
+ progressive_index=0.8
401
+ case 3:
402
+ progressive_index=0.8
403
+ case 4:
404
+ progressive_index=0.8
405
+
406
+ else :
407
+ progressive_index = 0.8
408
+
409
+ if random.random() < progressive_index :
410
+ mix_sample_idx = random.randint(0, self.amount-1)
411
+ mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
412
+ mix_target = self.files.iloc[mix_sample_idx]['target']
413
+
414
+ mix_fs = self.files.iloc[mix_sample_idx]['fs']
415
+ mix_leads = self.files.iloc[mix_sample_idx]['leads']
416
+ mix_datum = self.Pre_processing(mix_fs, mix_leads, mix_datum)
417
+
418
+ data, alpha_value = mixup(data, mix_datum)
419
+ target = alpha_value * target + (1-alpha_value) * mix_target
420
+
421
+
422
+ return data, target
ecg_dataset_2021_DAFirst.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ from helper_code import *
4
+ from scipy import signal
5
+ import pandas as pd
6
+ from skmultilearn.model_selection import iterative_train_test_split
7
+ import random
8
+ import warnings
9
+ import sys
10
+
11
+
12
+ def get_nsamp(header):
13
+ return int(header.split('\n')[0].split(' ')[3])
14
+
15
+ # Adapted from original scoring function code
16
+ # For each set of equivalent classes, replace each class with the representative class for the set.
17
+ def replace_equivalent_classes(classes, equivalent_classes):
18
+ for j, x in enumerate(classes):
19
+ for multiple_classes in equivalent_classes:
20
+ if x in multiple_classes:
21
+ classes[j] = multiple_classes[0] # Use the first class as the representative class.
22
+ return classes
23
+
24
+ '''
25
+ the 'output' will be using the zero-padding if the leads are less than 12.(8 leads)
26
+ the 'output_leads' will be [1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
27
+
28
+ the output will not be changed if the lead is 12
29
+ the 'output_leads' will be [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
30
+ '''
31
+ def mixup(data, mix_data):
32
+ mix_lambda = np.random.beta(10, 10)# 这里利用beta分布得到一个0到1的值,这里如果可以,尝试别的分布函数
33
+ if data.shape[1] > mix_data.shape[1]:
34
+ data = mix_lambda * data[:, :mix_data.shape[1]] + (1 - mix_lambda) * mix_data
35
+ else:
36
+ data = mix_lambda * data + (1 - mix_lambda) * mix_data[:, :data.shape[1]]
37
+ return data, mix_lambda
38
+
39
+ def cutmix(data, mix_data):
40
+ cutmix_lambda = np.random.beta(10, 10)
41
+
42
+ if data.shape[1] < mix_data.shape[1]:
43
+ seq_length = data.shape[1]
44
+ win_len = int(np.ceil(seq_length * cutmix_lambda))
45
+ start = np.random.randint(0, seq_length - win_len + 1)
46
+ end = start + win_len
47
+ data[:,start:end] = mix_data[:,start:end]
48
+ else:
49
+ seq_length = mix_data.shape[1]
50
+ win_len = int(np.ceil(seq_length * cutmix_lambda))
51
+ start = np.random.randint(0, seq_length - win_len + 1)
52
+ end = start + win_len
53
+ data[:,start:end] = mix_data[:,start:end]
54
+ return data, cutmix_lambda
55
+
56
+ def cutmix_fix_length(data, mix_data):
57
+ # cutmix_lambda = np.random.beta(0.2, 0.2)
58
+ cutmix_lambda = 0.2
59
+
60
+ if data.shape[1] < mix_data.shape[1]:
61
+ seq_length = data.shape[1]
62
+ win_len = int(np.ceil(seq_length * cutmix_lambda))
63
+ start = np.random.randint(0, seq_length - win_len + 1)
64
+ end = start + win_len
65
+ data[:,start:end] = mix_data[:,start:end]
66
+ else:
67
+ seq_length = mix_data.shape[1]
68
+ win_len = int(np.ceil(seq_length * cutmix_lambda))
69
+ start = np.random.randint(0, seq_length - win_len + 1)
70
+ end = start + win_len
71
+ data[:,start:end] = mix_data[:,start:end]
72
+ return data, cutmix_lambda
73
+
74
+ def same_shape_mixup(data, mix_data):
75
+ if data.shape[1] == mix_data.shape[1]:
76
+ mix_lambda = np.random.beta(10, 10)# 这里利用beta分布得到一个0到1的值,这里如果可以,尝试别的分布函数
77
+ data = mix_lambda * data + (1 - mix_lambda) * mix_data
78
+ return data
79
+
80
+ def expand_leads(recording, input_leads):
81
+ output = np.zeros((12, recording.shape[1]))
82
+ # recording.shape[1]: 5000
83
+ twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
84
+ twelve_leads = [k.lower() for k in twelve_leads]
85
+ # ['i', 'ii', 'iii', 'avr', 'avl', 'avf', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6']
86
+
87
+ input_leads = [k.lower() for k in input_leads]
88
+ # Here we can assume:
89
+ # input_leads:I, II, V1, V2, V3, V4, V5, V6,
90
+ # so the new input_leads: ['i', 'ii', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6']
91
+
92
+ output_leads = np.zeros((12,))
93
+ # [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
94
+
95
+ # idx: [0, 1, 6, 7, 8, 9, 10, 11]
96
+ for i,k in enumerate(input_leads):
97
+ idx = twelve_leads.index(k)
98
+ output[idx,:] = recording[i,:]
99
+ output_leads[idx] = 1
100
+
101
+ return output, output_leads
102
+
103
+ '''
104
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
105
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
106
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
107
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
108
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
109
+ This is the lead_indicator:(l) <class 'numpy.ndarray'> [1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
110
+
111
+ '''
112
+
113
+ class lead_exctractor:
114
+ """
115
+ used to select specific leads or random choice of configurations
116
+
117
+ Twelve leads: I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6
118
+ Eight leads: I, II, V1, V2, V3, V4, V5, V6
119
+ Six leads: I, II, III, aVR, aVL, aVF
120
+ Four leads: I, II, III, V2
121
+ Three leads: I, II, V2
122
+ Two leads: I, II
123
+
124
+ """
125
+ L2 = np.array([1,1,0,0,0,0,0,0,0,0,0,0])
126
+ L3 = np.array([1,1,0,0,0,0,0,1,0,0,0,0])
127
+ L4 = np.array([1,1,1,0,0,0,0,1,0,0,0,0])
128
+ L6 = np.array([1,1,1,1,1,1,0,0,0,0,0,0])
129
+ L8 = np.array([1,1,0,0,0,0,1,1,1,1,1,1])
130
+ L12 = np.array([1,1,1,1,1,1,1,1,1,1,1,1])
131
+
132
+ @staticmethod
133
+ def get (x, num_leads, lead_indicator):
134
+ if num_leads==None:
135
+ # random choice output
136
+ num_leads = random.choice([12,8,6,4,3,2])
137
+
138
+ if num_leads==12:
139
+ # Twelve leads: I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6
140
+ return x, lead_indicator * lead_exctractor.L12
141
+
142
+ if num_leads==8:
143
+ # Six leads: I, II, V1, V2, V3, V4, V5, V6
144
+ x = x * lead_exctractor.L8.reshape(12,1)
145
+ return x,lead_indicator * lead_exctractor.L8
146
+
147
+ if num_leads==6:
148
+ # Six leads: I, II, III, aVL, aVR, aVF
149
+ x = x * lead_exctractor.L6.reshape(12,1)
150
+ return x,lead_indicator * lead_exctractor.L6
151
+
152
+ if num_leads==4:
153
+ # Six leads: I, II, III, V2
154
+ x = x * lead_exctractor.L4.reshape(12,1)
155
+ return x,lead_indicator * lead_exctractor.L4
156
+
157
+ if num_leads==3:
158
+ # Three leads: I, II, V2
159
+ x = x * lead_exctractor.L3.reshape(12,1)
160
+ return x,lead_indicator * lead_exctractor.L3
161
+
162
+ if num_leads==2:
163
+ # Two leads: II, V5
164
+ x = x * lead_exctractor.L2.reshape(12,1)
165
+ return x,lead_indicator * lead_exctractor.L2
166
+ raise Exception("invalid-leads-number")
167
+
168
+ class dataset:
169
+ classes = ['164889003','164890007','6374002','426627000','733534002',
170
+ '713427006','270492004','713426002','39732003','445118002',
171
+ '164947007','251146004','111975006','698252002','426783006',
172
+ '284470004','10370003','365413008','427172004','164917005',
173
+ '47665007','427393009','426177001','427084000','164934002',
174
+ '59931005']
175
+ normal_class = '426783006'
176
+ equivalent_classes = [['713427006', '59118001'],
177
+ ['284470004', '63593006'],
178
+ ['427172004', '17338001'],
179
+ ['733534002', '164909002']]
180
+ def __init__(self, header_files, Mixup = 0, amount = 0, cutMix = 0, Mixup_no_label_interpolate =0, progressive_switch = False):
181
+ self.files = []
182
+ self.sample = True
183
+ self.num_leads = None
184
+
185
+ for h in tqdm(header_files):
186
+ tmp = dict()
187
+ tmp['header'] = h
188
+ tmp['record'] = h.replace('.hea','.mat')
189
+ hdr = load_header(h)
190
+ tmp['nsamp'] = get_nsamp(hdr)
191
+ tmp['leads'] = get_leads(hdr)
192
+ tmp['age'] = get_age(hdr)
193
+ tmp['sex'] = get_sex(hdr)
194
+ tmp['dx'] = get_labels(hdr)
195
+ tmp['fs'] = get_frequency(hdr)
196
+ tmp['target'] = np.zeros((26,))
197
+ tmp['dx'] = replace_equivalent_classes(tmp['dx'], dataset.equivalent_classes)
198
+
199
+ for dx in tmp['dx']:
200
+ # in SNOMED code is in scored classes
201
+ if dx in dataset.classes:
202
+ idx = dataset.classes.index(dx)
203
+ tmp['target'][idx] = 1
204
+
205
+ self.files.append(tmp)
206
+
207
+ # print("This is the target:", tmp['target'])
208
+ # set filter parameters
209
+
210
+ # Filtering: Data are filtered using a zero-phase method with 3rd order Butterworth bandpass filter
211
+ # with frequency band from 1 Hz to 47 Hz.
212
+ self.b, self.a = signal.butter(3, [1 / 250, 47 / 250], 'bandpass')
213
+
214
+ self.files = pd.DataFrame(self.files)
215
+
216
+ self.Mixup = Mixup
217
+ self.cutMix = cutMix
218
+ self.Mixup_no_label_interplate = Mixup_no_label_interpolate
219
+ self.amount = amount
220
+ self.progressive = progressive_switch
221
+
222
+ self.current_epoch = 0 # Initialize the current epoch
223
+
224
+ def set_epoch(self, epoch):
225
+ self.current_epoch = epoch
226
+
227
+ def train_valid_split(self, test_size):
228
+ '''
229
+ test_size: 0.2
230
+ below is the value of each variable:
231
+ files[0] <- "/media/jiang/ECG/physionet.org/files/challenge-2021/1.0.3/training/python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00001.hea" shape -> (999, 1)
232
+ targets[1] <-[1. 0. 0. ... 0. 1. 0.] 26, shape->(999, 26)
233
+ x_train[0] <-"/media/jiang/ECG/physionet.org/files/challenge-2021/1.0.3/training/python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00001.hea" shape -> (799,1)
234
+ x_valid[0] <-"/media/jiang/ECG/physionet.org/files/challenge-2021/1.0.3/training/python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00651.hea" shape -> (200,1)
235
+ '''
236
+ files = self.files['header'].to_numpy().reshape(-1,1)
237
+ # print(files[0])
238
+ targets = np.stack(self.files['target'].to_list(), axis=0)
239
+ print("This is the targets:", targets)
240
+
241
+ x_train, y_train, x_valid, y_valid = iterative_train_test_split(files, targets, test_size=test_size)
242
+
243
+ train = dataset(header_files=x_train[:,0].tolist())
244
+
245
+ train.num_leads=None
246
+ train.sample=True
247
+
248
+ valid = dataset(header_files=x_valid[:,0].tolist())
249
+
250
+ valid.num_leads=12
251
+ valid.sample=False
252
+
253
+ return train, valid
254
+
255
+ def summary(self, output):
256
+ if output=='pandas':
257
+ # print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
258
+ # print("This is the target:",self.files['target'],type(self.files['target']))
259
+ return pd.Series(np.stack(self.files['target'].to_list(),axis=0).sum(axis=0),index=dataset.classes)
260
+
261
+ if output=='numpy':
262
+ return np.stack(self.files['target'].to_list(),axis=0).sum(axis=0)
263
+
264
+ def __len__(self):
265
+ return len(self.files)
266
+
267
+ '''
268
+ fs: 500.0 sampling rate
269
+ target: [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
270
+ leads: ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
271
+ self.files.iloc[item]['record']: ../../python-classifier-2021-main/training_data/chapman_shaoxing/g1/JS00518.mat
272
+ data: shape->(12 X 5000)
273
+ '''
274
+
275
+ def __getitem__(self, item):
276
+
277
+ fs = self.files.iloc[item]['fs']
278
+ target = self.files.iloc[item]['target']
279
+ leads = self.files.iloc[item]['leads']
280
+ data = load_recording(self.files.iloc[item]['record'])
281
+
282
+ # print("This is the target:", target)
283
+ # Set your threshold
284
+
285
+ # threshold = 0.5
286
+ # print("This is the original target:", target)
287
+ # print("This is the type of original target:", type(target))
288
+
289
+ if random.random() < self.Mixup :
290
+ mix_sample_idx = random.randint(0, self.amount-1)
291
+ mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
292
+ mix_target = self.files.iloc[mix_sample_idx]['target']
293
+ data, alpha_value = mixup(data, mix_datum)
294
+ target = alpha_value * target + (1-alpha_value) * mix_target
295
+ # Binarize the labels based on the threshold
296
+ # target[target >= threshold] = 1
297
+ # target[target < threshold] = 0
298
+
299
+
300
+ if random.random() < self.cutMix:
301
+ mix_sample_idx = random.randint(0, self.amount-1)
302
+ mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
303
+ mix_target = self.files.iloc[mix_sample_idx]['target']
304
+ data, alpha_value = cutmix_fix_length(data, mix_datum)
305
+ target = alpha_value * target + (1-alpha_value) * mix_target
306
+ # target[target >= threshold] = 1
307
+ # target[target < threshold] = 0
308
+
309
+ if random.random()< self.Mixup_no_label_interplate:
310
+ mix_sample_idx = random.randint(0, self.amount-1)
311
+ mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
312
+ data= same_shape_mixup(data, mix_datum)
313
+
314
+ progressive_index = 0
315
+
316
+ # # below is the progessive data augmentation method
317
+ # if self.progressive:
318
+ # if self.current_epoch < 20:
319
+ # if self.current_epoch % 4 == 0 or 3:
320
+ # pass
321
+ # elif self.current_epoch % 4 == 1:
322
+ # proressive_index = 0.2
323
+ # else:
324
+ # proressive_index = 0.5
325
+
326
+ # elif self.current_epoch < 40:
327
+ # if self.current_epoch % 4 == 0 or 3:
328
+ # pass
329
+ # elif self.current_epoch % 4 == 1:
330
+ # proressive_index = 0.7
331
+ # else:
332
+ # proressive_index = 0.8
333
+ # else:
334
+ # if self.current_epoch % 4 == 0 or 3:
335
+ # pass
336
+ # elif self.current_epoch % 4 == 1 or 2:
337
+ # proressive_index = 1
338
+
339
+ # below is the wave-mix data augmentation method
340
+ if self.progressive:
341
+ if self.current_epoch < 5:
342
+ match self.current_epoch:
343
+ case 0:
344
+ progressive_index=0.2
345
+ case 1:
346
+ progressive_index=0.4
347
+ case 2:
348
+ progressive_index=0.6
349
+ case 3:
350
+ progressive_index=0.8
351
+ case 4:
352
+ progressive_index=0.8
353
+
354
+ else :
355
+ progressive_index = 0.8
356
+
357
+ if random.random() < progressive_index :
358
+ mix_sample_idx = random.randint(0, self.amount-1)
359
+ mix_datum = load_recording(self.files.iloc[mix_sample_idx]['record'])
360
+ mix_target = self.files.iloc[mix_sample_idx]['target']
361
+ data, alpha_value = mixup(data, mix_datum)
362
+ target = alpha_value * target + (1-alpha_value) * mix_target
363
+ # Binarize the labels based on the threshold
364
+ # target[target >= threshold] = 1
365
+ # target[target < threshold] = 0
366
+
367
+
368
+ # expand to 12 lead setup if original signal has less channels
369
+ data, lead_indicator = expand_leads(data, input_leads=leads)
370
+ data = np.nan_to_num(data)
371
+
372
+
373
+ # resample to 500hz
374
+ if fs == float(1000):
375
+ data = signal.resample_poly(data, up=1, down=2, axis=-1) # to 500Hz
376
+ fs = 500
377
+ elif fs == float(500):
378
+ pass
379
+ else:
380
+ data = signal.resample(data, int(data.shape[1] * 500 / fs), axis=1)
381
+ fs = 500
382
+
383
+ # below is the Butterworth digital and analog filter design.
384
+ data = signal.filtfilt(self.b, self.a, data)
385
+
386
+ '''
387
+ we filter out the sample which the length is 8192 via the below code.
388
+ just like the random shift windows.
389
+ '''
390
+ if self.sample:
391
+ fs = int(fs)
392
+ # random sample signal if len > 8192 samples
393
+ if data.shape[-1] >= 8192:
394
+ idx = data.shape[-1] - 8192-1
395
+ idx = np.random.randint(idx)
396
+ data = data[:, idx:idx + 8192]
397
+
398
+ '''
399
+ Z-Score Normalization
400
+ '''
401
+ mu = np.nanmean(data, axis=-1, keepdims=True)
402
+ std = np.nanstd(data, axis=-1, keepdims=True)
403
+
404
+
405
+ #std = np.nanstd(data.flatten())
406
+ with warnings.catch_warnings():
407
+ warnings.simplefilter("ignore")
408
+ data = (data - mu) / std
409
+
410
+ data = np.nan_to_num(data)
411
+
412
+ # random choose number of leads to keep
413
+ data, lead_indicator = lead_exctractor.get(data, self.num_leads, lead_indicator)
414
+
415
+ # print("This is the lead_indicator:(l)", type(lead_indicator), lead_indicator)
416
+
417
+
418
+ return data, target
engine_ecg_2020.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ """
4
+ Train and eval functions used in main.py
5
+ """
6
+ import math
7
+ import sys
8
+ from typing import Iterable, Optional
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+
14
+ import timm
15
+ from timm.data import Mixup
16
+ from timm.utils import accuracy, ModelEma
17
+
18
+ from losses import DistillationLoss
19
+ import utils
20
+
21
+ from sklearn.metrics import average_precision_score, roc_auc_score, f1_score
22
+ from sklearn.metrics import hamming_loss
23
+ from sklearn.metrics import accuracy_score
24
+
25
+ import numpy as np
26
+
27
+ from evaluate_12ECG_score import load_weights, compute_challenge_metric
28
+
29
+ classes = sorted(['270492004', '164889003', '164890007', '426627000', '713427006',
30
+ '713426002', '445118002', '39732003', '164909002', '251146004',
31
+ '698252002', '10370003', '284470004', '427172004', '164947007',
32
+ '111975006', '164917005', '47665007', '59118001', '427393009',
33
+ '426177001', '426783006', '427084000', '63593006', '164934002',
34
+ '59931005', '17338001'])
35
+
36
+ def normalize_model_outputs(model_outputs):
37
+ """
38
+ Normalize model outputs to the range [0, 1].
39
+
40
+ Parameters:
41
+ model_outputs (numpy.ndarray): The raw output of the deep learning model.
42
+
43
+ Returns:
44
+ numpy.ndarray: The normalized model outputs.
45
+ """
46
+ a = model_outputs.min()
47
+ b = model_outputs.max()
48
+ return (model_outputs - a) / (b - a)
49
+
50
+ weights_file = "./weights_2020.csv"
51
+ normal_class = '426783006'
52
+ weights = load_weights(weights_file, classes)
53
+
54
+ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
55
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
56
+ device: torch.device, epoch: int, max_norm: float = 0,
57
+ model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
58
+ set_training_mode=True, args = None):
59
+
60
+ model.train(set_training_mode)
61
+ metric_logger = utils.MetricLogger(delimiter=" ")
62
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
63
+ header = 'Epoch: [{}]'.format(epoch)
64
+ print_freq = 600
65
+
66
+
67
+
68
+ output_list = []
69
+ target_list = []
70
+
71
+ train_auprc_list = []
72
+
73
+ loss_value_after_each_epcoh = 0
74
+ batch_num = 0
75
+
76
+ # debug
77
+ # count = 0
78
+ # for samples, targets, mix_target, alpha_value in metric_logger.log_every(data_loader, print_freq, header):
79
+
80
+ # third_party = 0
81
+
82
+
83
+ for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
84
+ # count += 1
85
+ # if count > 20:
86
+ # break
87
+ batch_num += 1
88
+ # print("code goes here......................*******************************************************")
89
+
90
+ samples = samples.to(device, non_blocking=True)
91
+ targets = targets.to(device, non_blocking=True)
92
+ # mix_target = mix_target.to(device, non_blocking=True)
93
+ # alpha_value = alpha_value.to(device, non_blocking=True)
94
+
95
+
96
+ outputs = model(samples.float(), if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank)
97
+
98
+ # below is original loss
99
+ loss = criterion(outputs, targets.float())
100
+ # # below is the asymmetric loss training
101
+ # outputs = torch.sigmoid(outputs)
102
+ # loss = -torch.mean(targets * F.logsigmoid(outputs) + (1 - targets) * F.logsigmoid(-outputs) * 0.1)
103
+
104
+ loss_value = loss.item()
105
+
106
+ if args.lrschedule == "Noam":
107
+ optimizer.optimizer.zero_grad()
108
+ elif args.lrschedule == "CosineAnnealing":
109
+ optimizer.zero_grad()
110
+ else:
111
+ raise ValueError(f"No matching condition for value in the file of engine_ecg_2021.py : {args.lrschedule}")
112
+
113
+ loss.backward() # Backward pass: Compute gradient of the loss with respect to model parameters
114
+
115
+ # Update parameters/using the Noam
116
+ optimizer.step()
117
+
118
+ torch.cuda.synchronize()
119
+
120
+ metric_logger.update(loss=loss_value)
121
+
122
+ if args.lrschedule == "Noam":
123
+ metric_logger.update(lr=optimizer.optimizer.param_groups[0]["lr"])
124
+ elif args.lrschedule == "CosineAnnealing":
125
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
126
+ else:
127
+ raise ValueError(f"No matching condition for value in the file of engine_ecg_2021.py for updating : {args.lrschedule}")
128
+
129
+ target_list.append(targets.data.cpu().numpy())
130
+ output_list.append(outputs.data.cpu().numpy())
131
+ loss_value_after_each_epcoh += loss_value
132
+
133
+
134
+ # gather the stats from all processes
135
+ metric_logger.synchronize_between_processes()
136
+ print("Averaged stats:", metric_logger)
137
+
138
+ # below code is for record the AUPRC of training
139
+ targets_all = np.concatenate(target_list, axis=0)
140
+ outputs_all = np.concatenate(output_list, axis=0)
141
+ outputs_all = normalize_model_outputs(outputs_all)
142
+
143
+ threshold = 0.5
144
+ targets_all[targets_all >= threshold] = 1
145
+ targets_all[targets_all < threshold] = 0
146
+
147
+ # When using the function below, y_true must be a binarized value.
148
+ train_auprc = average_precision_score(y_true = targets_all, y_score = outputs_all)
149
+ print("This is the training AUPRC:", train_auprc)
150
+
151
+ ### The code below is for obtaining the thresholds
152
+ scores_challengeScore = []
153
+ scores_F1 = []
154
+ scores_SubsetAccuracy = []
155
+ scores_HammingLoss = []
156
+
157
+ for thr in np.arange(0., 1., 0.02):
158
+ outputs_dyn = np.array([[(1 if prob > thr else 0) for prob in probs] for probs in np.array(outputs_all)])
159
+
160
+ challenge_value = compute_challenge_metric(weights, targets_all, outputs_dyn, classes, normal_class)
161
+ scores_challengeScore.append(challenge_value)
162
+
163
+ f1 = f1_score(targets_all, outputs_dyn, average='weighted')
164
+ scores_F1.append(f1)
165
+
166
+ subset_accuracy = accuracy_score(targets_all, outputs_dyn)
167
+ scores_SubsetAccuracy.append(subset_accuracy)
168
+
169
+ hamming = hamming_loss(targets_all, outputs_dyn)
170
+ scores_HammingLoss.append(hamming)
171
+
172
+ scores_challengeScore = np.array(scores_challengeScore)
173
+ scores_F1 = np.array(scores_F1)
174
+ scores_SubsetAccuracy = np.array(scores_SubsetAccuracy)
175
+ scores_HammingLoss = np.array(scores_HammingLoss)
176
+
177
+ # print("This is the challenge score list from training set:\n", scores)
178
+
179
+ # Best thrs for challenge score
180
+ thrs_CHALL = np.array([np.argmax(scores_challengeScore, axis=0)*0.02])
181
+ print("This is the best threshold for the challenge score from training set", thrs_CHALL)
182
+ outputs_best_CHALL = np.array([[(1 if prob > thrs_CHALL else 0) for prob in probs] for probs in np.array(outputs_all)])
183
+ challenge_value = compute_challenge_metric(weights, targets_all, outputs_best_CHALL, classes, normal_class)
184
+ print("This is the challenge score from training set:", challenge_value)
185
+
186
+ # Best thrs for F1
187
+ scores_F1 = np.array([np.argmax(scores_F1, axis=0)*0.02])
188
+ print("This is the best threshold for the F1 from training set", scores_F1)
189
+ outputs_best_f1 = np.array([[(1 if prob > scores_F1 else 0) for prob in probs] for probs in np.array(outputs_all)])
190
+ f1 = f1_score(targets_all, outputs_best_f1, average='weighted')
191
+ print("This is the f1 score from training set:", challenge_value)
192
+
193
+ # Best thrs for subset accuracy
194
+ scores_SubsetAccuracy = np.array([np.argmax(scores_SubsetAccuracy, axis=0)*0.02])
195
+ print("This is the best threshold for the Subset Accuracy from training set", scores_SubsetAccuracy)
196
+ outputs_best_SubsetAccuracy = np.array([[(1 if prob > scores_SubsetAccuracy else 0) for prob in probs] for probs in np.array(outputs_all)])
197
+ subset_accuracy = accuracy_score(targets_all, outputs_best_SubsetAccuracy)
198
+ print("This is the subset accuracy from training set:", subset_accuracy)
199
+
200
+ # Best thrs for hamming loss, here is the loss value from output, we need to find the minimum value.
201
+ # Determine the optimal threshold for Hamming loss. The loss values are obtained from the output, and the goal is to find the minimum value.
202
+ scores_HammingLoss = np.array([np.argmin(scores_HammingLoss, axis=0)*0.02])
203
+ print("This is the best threshold for the Subset Accuracy from training set", scores_HammingLoss)
204
+ outputs_best_HammingLoss = np.array([[(1 if prob > scores_HammingLoss else 0) for prob in probs] for probs in np.array(outputs_all)])
205
+ hamming = hamming_loss(targets_all, outputs_best_HammingLoss)
206
+ print("This is the Hamming from training set:", hamming)
207
+
208
+ loss_value_after_each_epcoh /= batch_num
209
+
210
+ return train_auprc, loss_value_after_each_epcoh, thrs_CHALL, scores_F1, scores_SubsetAccuracy, scores_HammingLoss
211
+
212
+
213
+
214
+ @torch.no_grad()
215
+ def evaluate(data_loader, model, device):
216
+ # criterion = torch.nn.CrossEntropyLoss()
217
+ criterion = torch.nn.BCEWithLogitsLoss()
218
+
219
+ metric_logger = utils.MetricLogger(delimiter=" ")
220
+ header = 'Test:'
221
+
222
+ targets = []
223
+ outputs = []
224
+ loss_value_after_each_epcoh = 0
225
+ batch_num = 0
226
+ # switch to evaluation mode
227
+ model.eval()
228
+
229
+ for images, target in metric_logger.log_every(data_loader, 100, header):
230
+
231
+ batch_num += 1
232
+ images = images.to(device, non_blocking=True)
233
+ target = target.to(device, non_blocking=True)
234
+ # print("NaN in images(input):", torch.isnan(images).any())
235
+
236
+ # compute output
237
+
238
+ output = model(images.float())
239
+ # below is original loss
240
+ loss = criterion(output, target.float())
241
+ # below is the logit to proba
242
+
243
+ # below is the asymmetric loss testing
244
+ # output = torch.sigmoid(output)
245
+ # loss = -torch.mean(target * F.logsigmoid(output) + (1 - target) * F.logsigmoid(-output) * 0.1)
246
+
247
+ # print("This is the output:", output.shape,output)
248
+ # print("This is the target.float():",target.float().shape ,target.float())
249
+ # print("NaN in output:", torch.isnan(output).any())
250
+ # print("NaN in target:", torch.isnan(target).any())
251
+ # print("This is the loss.item():", loss.item())
252
+ # sys.exit()
253
+
254
+ metric_logger.update(loss=loss.item())
255
+
256
+ targets.append(target.data.cpu().numpy())
257
+ outputs.append(output.data.cpu().numpy())
258
+
259
+ loss_value_after_each_epcoh += loss.item()
260
+
261
+ # below code is for record the AUPRC of validation
262
+ targets = np.concatenate(targets, axis=0)
263
+ outputs = np.concatenate(outputs, axis=0)
264
+
265
+ outputs = normalize_model_outputs(outputs)
266
+
267
+ auprc = average_precision_score(y_true=targets, y_score=outputs)
268
+ auroc = roc_auc_score(targets, outputs)
269
+
270
+ # print("This is the top 3 row:", outputs[0:3])
271
+ # outputs = np.array([[(1 if prob > 0 else 0) for prob in probs] for probs in np.array(outputs)])
272
+ # print("This is the top 3 row, after:", outputs[0:3])
273
+ # for i in range(len(outputs)):
274
+ # num_ones = targets[i].count(1)
275
+ # third_largest = sorted(outputs[i], reverse=True)[num_ones-1]
276
+
277
+ # print(f"Length of targets: {len(targets)}")
278
+ # print(f"Length of outputs: {len(outputs)}")
279
+
280
+ outputs_0 = np.array([[(1 if prob > 0.5 else 0) for prob in probs] for probs in np.array(outputs)])
281
+
282
+ f1 = f1_score(targets, outputs_0, average='weighted')
283
+ hamming = hamming_loss(targets, outputs_0)
284
+ subset_accuracy = accuracy_score(targets, outputs_0)
285
+
286
+ scores = []
287
+ for thr in np.arange(0., 1., 0.02):
288
+ outputs_dyn = np.array([[(1 if prob > thr else 0) for prob in probs] for probs in np.array(outputs)])
289
+ challenge_value = compute_challenge_metric(weights, targets, outputs_dyn, classes, normal_class)
290
+ scores.append(challenge_value)
291
+
292
+ scores = np.array(scores)
293
+ print("This is the challenge score list from testing set:\n", scores)
294
+
295
+ # Best thrs and preds
296
+ idxs = np.argmax(scores, axis=0)
297
+ thrs = np.array([idxs*0.02])
298
+
299
+ print("This is the best threshold for the challenge score from testing test", thrs)
300
+ outputs_best = np.array([[(1 if prob > thrs else 0) for prob in probs] for probs in np.array(outputs)])
301
+
302
+ challenge_value = compute_challenge_metric(weights, targets, outputs_best, classes, normal_class)
303
+ print("This is the challenge score:", challenge_value)
304
+
305
+ # gather the stats from all processes
306
+ metric_logger.synchronize_between_processes()
307
+ loss_value_after_each_epcoh /= batch_num
308
+
309
+ return auprc, auroc, f1, hamming, subset_accuracy, loss_value_after_each_epcoh
engine_ecg_2021.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ """
4
+ Train and eval functions used in main.py
5
+ """
6
+ import math
7
+ import sys
8
+ from typing import Iterable, Optional
9
+
10
+ import torch
11
+
12
+ import timm
13
+ from timm.data import Mixup
14
+ from timm.utils import accuracy, ModelEma
15
+
16
+ from losses import DistillationLoss
17
+ import utils
18
+
19
+ from sklearn.metrics import average_precision_score, roc_auc_score, f1_score
20
+ from sklearn.metrics import hamming_loss
21
+ from sklearn.metrics import accuracy_score
22
+
23
+ import numpy as np
24
+
25
+ from evaluate_model import load_weights, compute_challenge_metric
26
+
27
+
28
+ def normalize_model_outputs(model_outputs):
29
+ """
30
+ Normalize model outputs to the range [0, 1].
31
+
32
+ Parameters:
33
+ model_outputs (numpy.ndarray): The raw output of the deep learning model.
34
+
35
+ Returns:
36
+ numpy.ndarray: The normalized model outputs.
37
+ """
38
+ a = model_outputs.min()
39
+ b = model_outputs.max()
40
+ return (model_outputs - a) / (b - a)
41
+
42
+
43
+ sinus_rhythm_ID = set(['426783006'])
44
+
45
+ weights_file = "./weights.csv"
46
+
47
+ classes, weights = load_weights(weights_file)
48
+
49
+ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
50
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
51
+ device: torch.device, epoch: int,
52
+ set_training_mode=True, args = None):
53
+
54
+ model.train(set_training_mode)
55
+ metric_logger = utils.MetricLogger(delimiter=" ")
56
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
57
+ header = 'Epoch: [{}]'.format(epoch)
58
+ print_freq = 600
59
+
60
+ output_list = []
61
+ target_list = []
62
+
63
+ loss_value_after_each_epcoh = 0
64
+ batch_num = 0
65
+
66
+ for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
67
+ batch_num += 1
68
+
69
+ samples = samples.to(device, non_blocking=True)
70
+ targets = targets.to(device, non_blocking=True)
71
+
72
+ outputs = model(samples.float(), if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank)
73
+ loss = criterion(outputs, targets.float())
74
+
75
+ loss_value = loss.item()
76
+
77
+ if args.lrschedule == "Noam":
78
+ optimizer.optimizer.zero_grad()
79
+ elif args.lrschedule == "CosineAnnealing":
80
+ optimizer.zero_grad()
81
+ else:
82
+ raise ValueError(f"No matching condition for value in the file of engine_ecg_2021.py : {args.lrschedule}")
83
+
84
+ loss.backward() # Backward pass: Compute gradient of the loss with respect to model parameters
85
+
86
+ # Update parameters/using the Noam
87
+ optimizer.step()
88
+
89
+ torch.cuda.synchronize()
90
+
91
+ metric_logger.update(loss=loss_value)
92
+ if args.lrschedule == "Noam":
93
+ metric_logger.update(lr=optimizer.optimizer.param_groups[0]["lr"])
94
+ elif args.lrschedule == "CosineAnnealing":
95
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
96
+ else:
97
+ raise ValueError(f"No matching condition for value in the file of engine_ecg_2021.py for updating : {args.lrschedule}")
98
+
99
+ target_list.append(targets.data.cpu().numpy())
100
+ output_list.append(outputs.data.cpu().numpy())
101
+ loss_value_after_each_epcoh += loss_value
102
+
103
+ # gather the stats from all processes
104
+ metric_logger.synchronize_between_processes()
105
+ print("Averaged stats:", metric_logger)
106
+
107
+ # below code is for record the AUPRC of training
108
+ targets_all = np.concatenate(target_list, axis=0)
109
+ outputs_all = np.concatenate(output_list, axis=0)
110
+ outputs_all = normalize_model_outputs(outputs_all)
111
+
112
+ threshold = 0.5
113
+ targets_all[targets_all >= threshold] = 1
114
+ targets_all[targets_all < threshold] = 0
115
+
116
+ # When using the function below, y_true must be a binarized value.
117
+ train_auprc = average_precision_score(y_true = targets_all, y_score = outputs_all)
118
+ print("This is the training AUPRC:", train_auprc)
119
+
120
+ ### The code below is for obtaining the thresholds
121
+ scores_challengeScore = []
122
+ scores_F1 = []
123
+ scores_SubsetAccuracy = []
124
+ scores_HammingLoss = []
125
+
126
+ for thr in np.arange(0., 1., 0.02):
127
+ outputs_dyn = np.array([[(1 if prob > thr else 0) for prob in probs] for probs in np.array(outputs_all)])
128
+
129
+ challenge_value = compute_challenge_metric(weights, targets_all, outputs_dyn, classes, sinus_rhythm_ID)
130
+ scores_challengeScore.append(challenge_value)
131
+
132
+ f1 = f1_score(targets_all, outputs_dyn, average='weighted')
133
+ scores_F1.append(f1)
134
+
135
+ subset_accuracy = accuracy_score(targets_all, outputs_dyn)
136
+ scores_SubsetAccuracy.append(subset_accuracy)
137
+
138
+ hamming = hamming_loss(targets_all, outputs_dyn)
139
+ scores_HammingLoss.append(hamming)
140
+
141
+ scores_challengeScore = np.array(scores_challengeScore)
142
+ scores_F1 = np.array(scores_F1)
143
+ scores_SubsetAccuracy = np.array(scores_SubsetAccuracy)
144
+ scores_HammingLoss = np.array(scores_HammingLoss)
145
+
146
+ # print("This is the challenge score list from training set:\n", scores)
147
+
148
+ # Best thrs for challenge score
149
+ thrs_CHALL = np.array([np.argmax(scores_challengeScore, axis=0)*0.02])
150
+ print("This is the best threshold for the challenge score from training set", thrs_CHALL)
151
+ outputs_best_CHALL = np.array([[(1 if prob > thrs_CHALL else 0) for prob in probs] for probs in np.array(outputs_all)])
152
+ challenge_value = compute_challenge_metric(weights, targets_all, outputs_best_CHALL, classes, sinus_rhythm_ID)
153
+ print("This is the challenge score from training set:", challenge_value)
154
+
155
+ # Best thrs for F1
156
+ scores_F1 = np.array([np.argmax(scores_F1, axis=0)*0.02])
157
+ print("This is the best threshold for the F1 from training set", scores_F1)
158
+ outputs_best_f1 = np.array([[(1 if prob > scores_F1 else 0) for prob in probs] for probs in np.array(outputs_all)])
159
+ f1 = f1_score(targets_all, outputs_best_f1, average='weighted')
160
+ print("This is the f1 score from training set:", challenge_value)
161
+
162
+ # Best thrs for subset accuracy
163
+ scores_SubsetAccuracy = np.array([np.argmax(scores_SubsetAccuracy, axis=0)*0.02])
164
+ print("This is the best threshold for the Subset Accuracy from training set", scores_SubsetAccuracy)
165
+ outputs_best_SubsetAccuracy = np.array([[(1 if prob > scores_SubsetAccuracy else 0) for prob in probs] for probs in np.array(outputs_all)])
166
+ subset_accuracy = accuracy_score(targets_all, outputs_best_SubsetAccuracy)
167
+ print("This is the subset accuracy from training set:", subset_accuracy)
168
+
169
+ # Best thrs for hamming loss, here is the loss value from output, we need to find the minimum value.
170
+ # Determine the optimal threshold for Hamming loss. The loss values are obtained from the output, and the goal is to find the minimum value.
171
+ scores_HammingLoss = np.array([np.argmin(scores_HammingLoss, axis=0)*0.02])
172
+ print("This is the best threshold for the Subset Accuracy from training set", scores_HammingLoss)
173
+ outputs_best_HammingLoss = np.array([[(1 if prob > scores_HammingLoss else 0) for prob in probs] for probs in np.array(outputs_all)])
174
+ hamming = hamming_loss(targets_all, outputs_best_HammingLoss)
175
+ print("This is the Hamming from training set:", hamming)
176
+
177
+ loss_value_after_each_epcoh /= batch_num
178
+ return train_auprc, loss_value_after_each_epcoh, thrs_CHALL, scores_F1, scores_SubsetAccuracy, scores_HammingLoss
179
+
180
+ @torch.no_grad()
181
+ def evaluate(data_loader, model, thrs_chall, thrs_F1, thrs_accuracy, thrs_hammingLoss, device):
182
+ # criterion = torch.nn.CrossEntropyLoss()
183
+ criterion = torch.nn.BCEWithLogitsLoss()
184
+
185
+ metric_logger = utils.MetricLogger(delimiter=" ")
186
+ header = 'Test:'
187
+
188
+ targets = []
189
+ outputs = []
190
+ loss_value_after_each_epcoh = 0
191
+ batch_num = 0
192
+ # switch to evaluation mode
193
+ model.eval()
194
+
195
+ for images, target in metric_logger.log_every(data_loader, 100, header):
196
+
197
+ batch_num += 1
198
+ images = images.to(device, non_blocking=True)
199
+ target = target.to(device, non_blocking=True)
200
+
201
+ output = model(images.float())
202
+
203
+ loss = criterion(output, target.float())
204
+
205
+ metric_logger.update(loss=loss.item())
206
+
207
+ targets.append(target.data.cpu().numpy())
208
+ outputs.append(output.data.cpu().numpy())
209
+
210
+ loss_value_after_each_epcoh += loss.item()
211
+
212
+ # below code is for record the AUPRC of validation
213
+ targets = np.concatenate(targets, axis=0)
214
+ outputs = np.concatenate(outputs, axis=0)
215
+
216
+ outputs = normalize_model_outputs(outputs)
217
+
218
+ auprc = average_precision_score(y_true=targets, y_score=outputs)
219
+ auroc = roc_auc_score(targets, outputs)
220
+
221
+ # print("This is the top 3 row:", outputs[0:3])
222
+ outputs_F1 = np.array([[(1 if prob > thrs_F1 else 0) for prob in probs] for probs in np.array(outputs)])
223
+ f1 = f1_score(targets, outputs_F1, average='weighted')
224
+
225
+ outputs_hammingLoss = np.array([[(1 if prob > thrs_hammingLoss else 0) for prob in probs] for probs in np.array(outputs)])
226
+ hamming = hamming_loss(targets, outputs_hammingLoss)
227
+
228
+ outputs_accuracy = np.array([[(1 if prob > thrs_accuracy else 0) for prob in probs] for probs in np.array(outputs)])
229
+ subset_accuracy = accuracy_score(targets, outputs_accuracy)
230
+
231
+ print("This is the best threshold for the challenge score, obtained from the training test, without any testing data leakage:", thrs_chall)
232
+ outputs_best = np.array([[(1 if prob > thrs_chall else 0) for prob in probs] for probs in np.array(outputs)])
233
+
234
+ challenge_value = compute_challenge_metric(weights, targets, outputs_best, classes, sinus_rhythm_ID)
235
+ print("This is the challenge score:", challenge_value)
236
+
237
+ # gather the stats from all processes
238
+ metric_logger.synchronize_between_processes()
239
+ loss_value_after_each_epcoh /= batch_num
240
+
241
+ return auprc, auroc, f1, hamming, subset_accuracy, challenge_value, loss_value_after_each_epcoh
evaluate_12ECG_score.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # This file contains functions for evaluating algorithms for the 2020 PhysioNet/
4
+ # Computing in Cardiology Challenge. You can run it as follows:
5
+ #
6
+ # python evaluate_12ECG_score.py labels outputs scores.csv
7
+ #
8
+ # where 'labels' is a directory containing files with the labels, 'outputs' is a
9
+ # directory containing files with the outputs from your model, and 'scores.csv'
10
+ # (optional) is a collection of scores for the algorithm outputs.
11
+ #
12
+ # Each file of labels or outputs must have the format described on the Challenge
13
+ # webpage. The scores for the algorithm outputs include the area under the
14
+ # receiver-operating characteristic curve (AUROC), the area under the recall-
15
+ # precision curve (AUPRC), accuracy (fraction of correct recordings), macro F-
16
+ # measure, and the Challenge metric, which assigns different weights to
17
+ # different misclassification errors.
18
+
19
+ import numpy as np, os, os.path, sys
20
+
21
+ def evaluate_12ECG_score(label_directory, output_directory):
22
+ # Define the weights, the SNOMED CT code for the normal class, and equivalent SNOMED CT codes.
23
+ weights_file = 'weights.csv'
24
+ normal_class = '426783006'
25
+ equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
26
+
27
+ # Find the label and output files.
28
+ print('Finding label and output files...')
29
+ label_files, output_files = find_challenge_files(label_directory, output_directory)
30
+
31
+ # Load the labels and outputs.
32
+ print('Loading labels and outputs...')
33
+ label_classes, labels = load_labels(label_files, normal_class, equivalent_classes)
34
+ output_classes, binary_outputs, scalar_outputs = load_outputs(output_files, normal_class, equivalent_classes)
35
+
36
+ # Organize/sort the labels and outputs.
37
+ print('Organizing labels and outputs...')
38
+ classes, labels, binary_outputs, scalar_outputs = organize_labels_outputs(label_classes, output_classes, labels, binary_outputs, scalar_outputs)
39
+
40
+ # Load the weights for the Challenge metric.
41
+ print('Loading weights...')
42
+ weights = load_weights(weights_file, classes)
43
+
44
+ # Only consider classes that are scored with the Challenge metric.
45
+ indices = np.any(weights, axis=0) # Find indices of classes in weight matrix.
46
+ classes = [x for i, x in enumerate(classes) if indices[i]]
47
+ labels = labels[:, indices]
48
+ scalar_outputs = scalar_outputs[:, indices]
49
+ binary_outputs = binary_outputs[:, indices]
50
+ weights = weights[np.ix_(indices, indices)]
51
+
52
+ # Evaluate the model by comparing the labels and outputs.
53
+ print('Evaluating model...')
54
+
55
+ print('- AUROC and AUPRC...')
56
+ auroc, auprc = compute_auc(labels, scalar_outputs)
57
+
58
+ print('- Accuracy...')
59
+ accuracy = compute_accuracy(labels, binary_outputs)
60
+
61
+ print('- F-measure...')
62
+ f_measure = compute_f_measure(labels, binary_outputs)
63
+
64
+ print('- F-beta and G-beta measures...')
65
+ f_beta_measure, g_beta_measure = compute_beta_measures(labels, binary_outputs, beta=2)
66
+
67
+ print('- Challenge metric...')
68
+ challenge_metric = compute_challenge_metric(weights, labels, binary_outputs, classes, normal_class)
69
+
70
+ print('Done.')
71
+
72
+ # Return the results.
73
+ return auroc, auprc, accuracy, f_measure, f_beta_measure, g_beta_measure, challenge_metric
74
+
75
+ # Check if the input is a number.
76
+ def is_number(x):
77
+ try:
78
+ float(x)
79
+ return True
80
+ except ValueError:
81
+ return False
82
+
83
+ # Find Challenge files.
84
+ def find_challenge_files(label_directory, output_directory):
85
+ label_files = list()
86
+ output_files = list()
87
+ for f in sorted(os.listdir(label_directory)):
88
+ F = os.path.join(label_directory, f) # Full path for label file
89
+ if os.path.isfile(F) and F.lower().endswith('.hea') and not f.lower().startswith('.'):
90
+ root, ext = os.path.splitext(f)
91
+ g = root + '.csv'
92
+ G = os.path.join(output_directory, g) # Full path for corresponding output file
93
+ if os.path.isfile(G):
94
+ label_files.append(F)
95
+ output_files.append(G)
96
+ else:
97
+ raise IOError('Output file {} not found for label file {}.'.format(g, f))
98
+
99
+ if label_files and output_files:
100
+ return label_files, output_files
101
+ else:
102
+ raise IOError('No label or output files found.')
103
+
104
+ # Load labels from header/label files.
105
+ def load_labels(label_files, normal_class, equivalent_classes_collection):
106
+ # The labels should have the following form:
107
+ #
108
+ # Dx: label_1, label_2, label_3
109
+ #
110
+ num_recordings = len(label_files)
111
+
112
+ # Load diagnoses.
113
+ tmp_labels = list()
114
+ for i in range(num_recordings):
115
+ with open(label_files[i], 'r') as f:
116
+ for l in f:
117
+ if l.startswith('#Dx'):
118
+ dxs = set(arr.strip() for arr in l.split(': ')[1].split(','))
119
+ tmp_labels.append(dxs)
120
+
121
+ # Identify classes.
122
+ classes = set.union(*map(set, tmp_labels))
123
+ if normal_class not in classes:
124
+ classes.add(normal_class)
125
+ print('- The normal class {} is not one of the label classes, so it has been automatically added, but please check that you chose the correct normal class.'.format(normal_class))
126
+ classes = sorted(classes)
127
+ num_classes = len(classes)
128
+
129
+ # Use one-hot encoding for labels.
130
+ labels = np.zeros((num_recordings, num_classes), dtype=np.bool_)
131
+ for i in range(num_recordings):
132
+ dxs = tmp_labels[i]
133
+ for dx in dxs:
134
+ j = classes.index(dx)
135
+ labels[i, j] = 1
136
+
137
+ # For each set of equivalent class, use only one class as the representative class for the set and discard the other classes in the set.
138
+ # The label for the representative class is positive if any of the labels in the set is positive.
139
+ remove_classes = list()
140
+ remove_indices = list()
141
+ for equivalent_classes in equivalent_classes_collection:
142
+ equivalent_classes = [x for x in equivalent_classes if x in classes]
143
+ if len(equivalent_classes)>1:
144
+ representative_class = equivalent_classes[0]
145
+ other_classes = equivalent_classes[1:]
146
+ equivalent_indices = [classes.index(x) for x in equivalent_classes]
147
+ representative_index = equivalent_indices[0]
148
+ other_indices = equivalent_indices[1:]
149
+
150
+ labels[:, representative_index] = np.any(labels[:, equivalent_indices], axis=1)
151
+ remove_classes += other_classes
152
+ remove_indices += other_indices
153
+
154
+ for x in remove_classes:
155
+ classes.remove(x)
156
+ labels = np.delete(labels, remove_indices, axis=1)
157
+
158
+ # If the labels are negative for all classes, then change the label for the normal class to positive.
159
+ normal_index = classes.index(normal_class)
160
+ for i in range(num_recordings):
161
+ num_positive_classes = np.sum(labels[i, :])
162
+ if num_positive_classes==0:
163
+ labels[i, normal_index] = 1
164
+
165
+ return classes, labels
166
+
167
+ # Load outputs from output files.
168
+ def load_outputs(output_files, normal_class, equivalent_classes_collection):
169
+ # The outputs should have the following form:
170
+ #
171
+ # diagnosis_1, diagnosis_2, diagnosis_3
172
+ # 0, 1, 1
173
+ # 0.12, 0.34, 0.56
174
+ #
175
+ num_recordings = len(output_files)
176
+
177
+ tmp_labels = list()
178
+ tmp_binary_outputs = list()
179
+ tmp_scalar_outputs = list()
180
+ for i in range(num_recordings):
181
+ with open(output_files[i], 'r') as f:
182
+ for j, l in enumerate(f):
183
+ arrs = [arr.strip() for arr in l.split(',')]
184
+ if j==1:
185
+ row = arrs
186
+ tmp_labels.append(row)
187
+ elif j==2:
188
+ row = list()
189
+ for arr in arrs:
190
+ number = 1 if arr in ('1', 'True', 'true', 'T', 't') else 0
191
+ row.append(number)
192
+ tmp_binary_outputs.append(row)
193
+ elif j==3:
194
+ row = list()
195
+ for arr in arrs:
196
+ number = float(arr) if is_number(arr) else 0
197
+ row.append(number)
198
+ tmp_scalar_outputs.append(row)
199
+
200
+ # Identify classes.
201
+ classes = set.union(*map(set, tmp_labels))
202
+ if normal_class not in classes:
203
+ classes.add(normal_class)
204
+ print('- The normal class {} is not one of the output classes, so it has been automatically added, but please check that you identified the correct normal class.'.format(normal_class))
205
+ classes = sorted(classes)
206
+ num_classes = len(classes)
207
+
208
+ # Use one-hot encoding for binary outputs and the same order for scalar outputs.
209
+ binary_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool_)
210
+ scalar_outputs = np.zeros((num_recordings, num_classes), dtype=np.float64)
211
+ for i in range(num_recordings):
212
+ dxs = tmp_labels[i]
213
+ for k, dx in enumerate(dxs):
214
+ j = classes.index(dx)
215
+ binary_outputs[i, j] = tmp_binary_outputs[i][k]
216
+ scalar_outputs[i, j] = tmp_scalar_outputs[i][k]
217
+
218
+ # For each set of equivalent class, use only one class as the representative class for the set and discard the other classes in the set.
219
+ # The binary output for the representative class is positive if any of the classes in the set is positive.
220
+ # The scalar output is the mean of the scalar outputs for the classes in the set.
221
+ remove_classes = list()
222
+ remove_indices = list()
223
+ for equivalent_classes in equivalent_classes_collection:
224
+ equivalent_classes = [x for x in equivalent_classes if x in classes]
225
+ if len(equivalent_classes)>1:
226
+ representative_class = equivalent_classes[0]
227
+ other_classes = equivalent_classes[1:]
228
+ equivalent_indices = [classes.index(x) for x in equivalent_classes]
229
+ representative_index = equivalent_indices[0]
230
+ other_indices = equivalent_indices[1:]
231
+
232
+ binary_outputs[:, representative_index] = np.any(binary_outputs[:, equivalent_indices], axis=1)
233
+ scalar_outputs[:, representative_index] = np.nanmean(scalar_outputs[:, equivalent_indices], axis=1)
234
+ remove_classes += other_classes
235
+ remove_indices += other_indices
236
+
237
+ for x in remove_classes:
238
+ classes.remove(x)
239
+ binary_outputs = np.delete(binary_outputs, remove_indices, axis=1)
240
+ scalar_outputs = np.delete(scalar_outputs, remove_indices, axis=1)
241
+
242
+ # If any of the outputs is a NaN, then replace it with a zero.
243
+ binary_outputs[np.isnan(binary_outputs)] = 0
244
+ scalar_outputs[np.isnan(scalar_outputs)] = 0
245
+
246
+ # If the binary outputs are negative for all classes, then change the binary output for the normal class to positive.
247
+ normal_index = classes.index(normal_class)
248
+ for i in range(num_recordings):
249
+ num_positive_classes = np.sum(binary_outputs[i, :])
250
+ if num_positive_classes==0:
251
+ binary_outputs[i, normal_index] = 1
252
+
253
+ return classes, binary_outputs, scalar_outputs
254
+
255
+ # Organize labels and outputs.
256
+ def organize_labels_outputs(label_classes, output_classes, tmp_labels, tmp_binary_outputs, tmp_scalar_outputs):
257
+ # Include all classes from either the labels or the outputs.
258
+ classes = sorted(set(label_classes) | set(output_classes))
259
+ num_classes = len(classes)
260
+
261
+ # Check that the labels and outputs have the same numbers of recordings.
262
+ assert(len(tmp_labels)==len(tmp_binary_outputs)==len(tmp_scalar_outputs))
263
+ num_recordings = len(tmp_labels)
264
+
265
+ # Rearrange the columns of the labels and the outputs to be consistent with the order of the classes.
266
+ labels = np.zeros((num_recordings, num_classes), dtype=np.bool_)
267
+ for k, dx in enumerate(label_classes):
268
+ j = classes.index(dx)
269
+ labels[:, j] = tmp_labels[:, k]
270
+
271
+ binary_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool_)
272
+ scalar_outputs = np.zeros((num_recordings, num_classes), dtype=np.float64)
273
+ for k, dx in enumerate(output_classes):
274
+ j = classes.index(dx)
275
+ binary_outputs[:, j] = tmp_binary_outputs[:, k]
276
+ scalar_outputs[:, j] = tmp_scalar_outputs[:, k]
277
+
278
+ return classes, labels, binary_outputs, scalar_outputs
279
+
280
+ # Load a table with row and column names.
281
+ def load_table(table_file):
282
+ # The table should have the following form:
283
+ #
284
+ # , a, b, c
285
+ # a, 1.2, 2.3, 3.4
286
+ # b, 4.5, 5.6, 6.7
287
+ # c, 7.8, 8.9, 9.0
288
+ #
289
+ table = list()
290
+ with open(table_file, 'r') as f:
291
+ for i, l in enumerate(f):
292
+ arrs = [arr.strip() for arr in l.split(',')]
293
+ table.append(arrs)
294
+
295
+ # Define the numbers of rows and columns and check for errors.
296
+ num_rows = len(table)-1
297
+ if num_rows<1:
298
+ raise Exception('The table {} is empty.'.format(table_file))
299
+
300
+ num_cols = set(len(table[i])-1 for i in range(num_rows))
301
+ if len(num_cols)!=1:
302
+ raise Exception('The table {} has rows with different lengths.'.format(table_file))
303
+ num_cols = min(num_cols)
304
+ if num_cols<1:
305
+ raise Exception('The table {} is empty.'.format(table_file))
306
+
307
+ # Find the row and column labels.
308
+ rows = [table[0][j+1] for j in range(num_rows)]
309
+ cols = [table[i+1][0] for i in range(num_cols)]
310
+
311
+ # Find the entries of the table.
312
+ values = np.zeros((num_rows, num_cols))
313
+ for i in range(num_rows):
314
+ for j in range(num_cols):
315
+ value = table[i+1][j+1]
316
+ if is_number(value):
317
+ values[i, j] = float(value)
318
+ else:
319
+ values[i, j] = float('nan')
320
+
321
+ return rows, cols, values
322
+
323
+ # Load weights.
324
+ def load_weights(weight_file, classes):
325
+ # Load the weight matrix.
326
+ rows, cols, values = load_table(weight_file)
327
+ assert(rows == cols)
328
+ num_rows = len(rows)
329
+
330
+ # Assign the entries of the weight matrix with rows and columns corresponding to the classes.
331
+ num_classes = len(classes)
332
+ weights = np.zeros((num_classes, num_classes), dtype=np.float64)
333
+ for i, a in enumerate(rows):
334
+ if a in classes:
335
+ k = classes.index(a)
336
+ for j, b in enumerate(rows):
337
+ if b in classes:
338
+ l = classes.index(b)
339
+ weights[k, l] = values[i, j]
340
+
341
+ return weights
342
+
343
+ # Compute recording-wise accuracy.
344
+ def compute_accuracy(labels, outputs):
345
+ num_recordings, num_classes = np.shape(labels)
346
+
347
+ num_correct_recordings = 0
348
+ for i in range(num_recordings):
349
+ if np.all(labels[i, :]==outputs[i, :]):
350
+ num_correct_recordings += 1
351
+
352
+ return float(num_correct_recordings) / float(num_recordings)
353
+
354
+ # Compute confusion matrices.
355
+ def compute_confusion_matrices(labels, outputs, normalize=False):
356
+ # Compute a binary confusion matrix for each class k:
357
+ #
358
+ # [TN_k FN_k]
359
+ # [FP_k TP_k]
360
+ #
361
+ # If the normalize variable is set to true, then normalize the contributions
362
+ # to the confusion matrix by the number of labels per recording.
363
+ num_recordings, num_classes = np.shape(labels)
364
+
365
+ if not normalize:
366
+ A = np.zeros((num_classes, 2, 2))
367
+ for i in range(num_recordings):
368
+ for j in range(num_classes):
369
+ if labels[i, j]==1 and outputs[i, j]==1: # TP
370
+ A[j, 1, 1] += 1
371
+ elif labels[i, j]==0 and outputs[i, j]==1: # FP
372
+ A[j, 1, 0] += 1
373
+ elif labels[i, j]==1 and outputs[i, j]==0: # FN
374
+ A[j, 0, 1] += 1
375
+ elif labels[i, j]==0 and outputs[i, j]==0: # TN
376
+ A[j, 0, 0] += 1
377
+ else: # This condition should not happen.
378
+ raise ValueError('Error in computing the confusion matrix.')
379
+ else:
380
+ A = np.zeros((num_classes, 2, 2))
381
+ for i in range(num_recordings):
382
+ normalization = float(max(np.sum(labels[i, :]), 1))
383
+ for j in range(num_classes):
384
+ if labels[i, j]==1 and outputs[i, j]==1: # TP
385
+ A[j, 1, 1] += 1.0/normalization
386
+ elif labels[i, j]==0 and outputs[i, j]==1: # FP
387
+ A[j, 1, 0] += 1.0/normalization
388
+ elif labels[i, j]==1 and outputs[i, j]==0: # FN
389
+ A[j, 0, 1] += 1.0/normalization
390
+ elif labels[i, j]==0 and outputs[i, j]==0: # TN
391
+ A[j, 0, 0] += 1.0/normalization
392
+ else: # This condition should not happen.
393
+ raise ValueError('Error in computing the confusion matrix.')
394
+
395
+ return A
396
+
397
+ # Compute macro F-measure.
398
+ def compute_f_measure(labels, outputs):
399
+ num_recordings, num_classes = np.shape(labels)
400
+
401
+ A = compute_confusion_matrices(labels, outputs)
402
+
403
+ f_measure = np.zeros(num_classes)
404
+ for k in range(num_classes):
405
+ tp, fp, fn, tn = A[k, 1, 1], A[k, 1, 0], A[k, 0, 1], A[k, 0, 0]
406
+ if 2 * tp + fp + fn:
407
+ f_measure[k] = float(2 * tp) / float(2 * tp + fp + fn)
408
+ else:
409
+ f_measure[k] = float('nan')
410
+
411
+ macro_f_measure = np.nanmean(f_measure)
412
+
413
+ return macro_f_measure
414
+
415
+ # Compute F-beta and G-beta measures from the unofficial phase of the Challenge.
416
+ def compute_beta_measures(labels, outputs, beta):
417
+ num_recordings, num_classes = np.shape(labels)
418
+
419
+ A = compute_confusion_matrices(labels, outputs, normalize=True)
420
+
421
+ f_beta_measure = np.zeros(num_classes)
422
+ g_beta_measure = np.zeros(num_classes)
423
+ for k in range(num_classes):
424
+ tp, fp, fn, tn = A[k, 1, 1], A[k, 1, 0], A[k, 0, 1], A[k, 0, 0]
425
+ if (1+beta**2)*tp + fp + beta**2*fn:
426
+ f_beta_measure[k] = float((1+beta**2)*tp) / float((1+beta**2)*tp + fp + beta**2*fn)
427
+ else:
428
+ f_beta_measure[k] = float('nan')
429
+ if tp + fp + beta*fn:
430
+ g_beta_measure[k] = float(tp) / float(tp + fp + beta*fn)
431
+ else:
432
+ g_beta_measure[k] = float('nan')
433
+
434
+ macro_f_beta_measure = np.nanmean(f_beta_measure)
435
+ macro_g_beta_measure = np.nanmean(g_beta_measure)
436
+
437
+ return macro_f_beta_measure, macro_g_beta_measure
438
+
439
+ # Compute macro AUROC and macro AUPRC.
440
+ def compute_auc(labels, outputs):
441
+ num_recordings, num_classes = np.shape(labels)
442
+
443
+ # Compute and summarize the confusion matrices for each class across at distinct output values.
444
+ auroc = np.zeros(num_classes)
445
+ auprc = np.zeros(num_classes)
446
+
447
+ for k in range(num_classes):
448
+ # We only need to compute TPs, FPs, FNs, and TNs at distinct output values.
449
+ thresholds = np.unique(outputs[:, k])
450
+ thresholds = np.append(thresholds, thresholds[-1]+1)
451
+ thresholds = thresholds[::-1]
452
+ num_thresholds = len(thresholds)
453
+
454
+ # Initialize the TPs, FPs, FNs, and TNs.
455
+ tp = np.zeros(num_thresholds)
456
+ fp = np.zeros(num_thresholds)
457
+ fn = np.zeros(num_thresholds)
458
+ tn = np.zeros(num_thresholds)
459
+ fn[0] = np.sum(labels[:, k]==1)
460
+ tn[0] = np.sum(labels[:, k]==0)
461
+
462
+ # Find the indices that result in sorted output values.
463
+ idx = np.argsort(outputs[:, k])[::-1]
464
+
465
+ # Compute the TPs, FPs, FNs, and TNs for class k across thresholds.
466
+ i = 0
467
+ for j in range(1, num_thresholds):
468
+ # Initialize TPs, FPs, FNs, and TNs using values at previous threshold.
469
+ tp[j] = tp[j-1]
470
+ fp[j] = fp[j-1]
471
+ fn[j] = fn[j-1]
472
+ tn[j] = tn[j-1]
473
+
474
+ # Update the TPs, FPs, FNs, and TNs at i-th output value.
475
+ while i < num_recordings and outputs[idx[i], k] >= thresholds[j]:
476
+ if labels[idx[i], k]:
477
+ tp[j] += 1
478
+ fn[j] -= 1
479
+ else:
480
+ fp[j] += 1
481
+ tn[j] -= 1
482
+ i += 1
483
+
484
+ # Summarize the TPs, FPs, FNs, and TNs for class k.
485
+ tpr = np.zeros(num_thresholds)
486
+ tnr = np.zeros(num_thresholds)
487
+ ppv = np.zeros(num_thresholds)
488
+ npv = np.zeros(num_thresholds)
489
+
490
+ for j in range(num_thresholds):
491
+ if tp[j] + fn[j]:
492
+ tpr[j] = float(tp[j]) / float(tp[j] + fn[j])
493
+ else:
494
+ tpr[j] = float('nan')
495
+ if fp[j] + tn[j]:
496
+ tnr[j] = float(tn[j]) / float(fp[j] + tn[j])
497
+ else:
498
+ tnr[j] = float('nan')
499
+ if tp[j] + fp[j]:
500
+ ppv[j] = float(tp[j]) / float(tp[j] + fp[j])
501
+ else:
502
+ ppv[j] = float('nan')
503
+
504
+ # Compute AUROC as the area under a piecewise linear function with TPR/
505
+ # sensitivity (x-axis) and TNR/specificity (y-axis) and AUPRC as the area
506
+ # under a piecewise constant with TPR/recall (x-axis) and PPV/precision
507
+ # (y-axis) for class k.
508
+ for j in range(num_thresholds-1):
509
+ auroc[k] += 0.5 * (tpr[j+1] - tpr[j]) * (tnr[j+1] + tnr[j])
510
+ auprc[k] += (tpr[j+1] - tpr[j]) * ppv[j+1]
511
+
512
+ # Compute macro AUROC and macro AUPRC across classes.
513
+ macro_auroc = np.nanmean(auroc)
514
+ macro_auprc = np.nanmean(auprc)
515
+
516
+ return macro_auroc, macro_auprc
517
+
518
+ # Compute modified confusion matrix for multi-class, multi-label tasks.
519
+ def compute_modified_confusion_matrix(labels, outputs):
520
+ # Compute a binary multi-class, multi-label confusion matrix, where the rows
521
+ # are the labels and the columns are the outputs.
522
+ num_recordings, num_classes = np.shape(labels)
523
+ A = np.zeros((num_classes, num_classes))
524
+
525
+ # Iterate over all of the recordings.
526
+ for i in range(num_recordings):
527
+ # Calculate the number of positive labels and/or outputs.
528
+ normalization = float(max(np.sum(np.any((labels[i, :], outputs[i, :]), axis=0)), 1))
529
+ # Iterate over all of the classes.
530
+ for j in range(num_classes):
531
+ # Assign full and/or partial credit for each positive class.
532
+ if labels[i, j]:
533
+ for k in range(num_classes):
534
+ if outputs[i, k]:
535
+ A[j, k] += 1.0/normalization
536
+
537
+ return A
538
+
539
+ # Compute the evaluation metric for the Challenge.
540
+ def compute_challenge_metric(weights, labels, outputs, classes, normal_class):
541
+ num_recordings, num_classes = np.shape(labels)
542
+ normal_index = classes.index(normal_class)
543
+
544
+ # Compute the observed score.
545
+ A = compute_modified_confusion_matrix(labels, outputs)
546
+ observed_score = np.nansum(weights * A)
547
+
548
+ # Compute the score for the model that always chooses the correct label(s).
549
+ correct_outputs = labels
550
+ A = compute_modified_confusion_matrix(labels, correct_outputs)
551
+ correct_score = np.nansum(weights * A)
552
+
553
+ # Compute the score for the model that always chooses the normal class.
554
+ inactive_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool_)
555
+ inactive_outputs[:, normal_index] = 1
556
+ A = compute_modified_confusion_matrix(labels, inactive_outputs)
557
+ inactive_score = np.nansum(weights * A)
558
+
559
+ if correct_score != inactive_score:
560
+ normalized_score = float(observed_score - inactive_score) / float(correct_score - inactive_score)
561
+ else:
562
+ normalized_score = float('nan')
563
+
564
+ return normalized_score
565
+
566
+ if __name__ == '__main__':
567
+ #auroc, auprc, accuracy, f_measure, f_beta_measure, g_beta_measure, challenge_metric = evaluate_12ECG_score(sys.argv[1], sys.argv[2])
568
+ lbl_dir = '/home/p2017-999/acs_data/processed_data/physionet2020/jonathan/in'
569
+ out_dir = '/home/p2017-999/acs_data/processed_data/physionet2020/jonathan/out'
570
+ auroc, auprc, accuracy, f_measure, f_beta_measure, g_beta_measure, challenge_metric = evaluate_12ECG_score(lbl_dir, out_dir)
571
+
572
+ output_string = 'AUROC,AUPRC,Accuracy,F-measure,Fbeta-measure,Gbeta-measure,Challenge metric\n{:.3f},{:.3f},{:.3f},{:.3f},{:.3f},{:.3f},{:.3f}'.format(auroc, auprc, accuracy, f_measure, f_beta_measure, g_beta_measure, challenge_metric)
573
+ if len(sys.argv) > 3:
574
+ with open(sys.argv[3], 'w') as f:
575
+ f.write(output_string)
576
+ else:
577
+ print(output_string)
evaluate_model.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # This file contains functions for evaluating algorithms for the 2021 PhysioNet/
4
+ # Computing in Cardiology Challenge. You can run it as follows:
5
+ #
6
+ # python evaluate_model.py labels outputs scores.csv
7
+ #
8
+ # where 'labels' is a directory containing files with the labels, 'outputs' is a
9
+ # directory containing files with the outputs from your model, and 'scores.csv'
10
+ # (optional) is a collection of scores for the algorithm outputs.
11
+ #
12
+ # Each file of labels or outputs must have the format described on the Challenge
13
+ # webpage. The scores for the algorithm outputs include the area under the
14
+ # receiver-operating characteristic curve (AUROC), the area under the recall-
15
+ # precision curve (AUPRC), accuracy (fraction of correct recordings), macro F-
16
+ # measure, and the Challenge metric, which assigns different weights to
17
+ # different misclassification errors.
18
+
19
+ import os, os.path, sys, numpy as np
20
+ from helper_code import get_labels, is_finite_number, load_header, load_outputs
21
+ import pandas as pd
22
+ from tabulate import tabulate
23
+
24
+ def evaluate_model(label_directory, output_directory):
25
+ # Identify the weights and the SNOMED CT code for the sinus rhythm class.
26
+ weights_file = 'weights.csv'
27
+ sinus_rhythm = set(['426783006'])
28
+
29
+ # Load the scored classes and the weights for the Challenge metric.
30
+ print('Loading weights...')
31
+ classes, weights = load_weights(weights_file)
32
+
33
+ # Load the label and output files.
34
+ print('Loading label and output files...')
35
+ label_files, output_files = find_challenge_files(label_directory, output_directory)
36
+ labels = load_labels(label_files, classes)
37
+ binary_outputs, scalar_outputs = load_classifier_outputs(output_files, classes)
38
+
39
+ # Evaluate the model by comparing the labels and outputs.
40
+ print('Evaluating model...')
41
+
42
+ print('- AUROC and AUPRC...')
43
+ auroc, auprc, auroc_classes, auprc_classes = compute_auc(labels, scalar_outputs)
44
+
45
+ print('- Accuracy...')
46
+ accuracy = compute_accuracy(labels, binary_outputs)
47
+
48
+ print('- F-measure...')
49
+ f_measure, f_measure_classes = compute_f_measure(labels, binary_outputs)
50
+
51
+ print('- Challenge metric...')
52
+ challenge_metric = compute_challenge_metric(weights, labels, binary_outputs, classes, sinus_rhythm)
53
+
54
+ print('Done.')
55
+
56
+ # Return the results.
57
+ return classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric
58
+
59
+ # Find Challenge files.
60
+ def find_challenge_files(label_directory, output_directory):
61
+ label_files = list()
62
+ output_files = list()
63
+ for label_file in sorted(os.listdir(label_directory)):
64
+ label_file_path = os.path.join(label_directory, label_file) # Full path for label file
65
+ if os.path.isfile(label_file_path) and label_file.lower().endswith('.hea') and not label_file.lower().startswith('.'):
66
+ root, ext = os.path.splitext(label_file)
67
+ output_file = root + '.csv'
68
+ output_file_path = os.path.join(output_directory, output_file) # Full path for corresponding output file
69
+ if os.path.isfile(output_file_path):
70
+ label_files.append(label_file_path)
71
+ output_files.append(output_file_path)
72
+ else:
73
+ raise IOError('Output file {} not found for label file {}.'.format(output_file, label_file))
74
+
75
+ if label_files and output_files:
76
+ return label_files, output_files
77
+ else:
78
+ raise IOError('No label or output files found.')
79
+
80
+ # Load a table with row and column names.
81
+ def load_table(table_file):
82
+ # The table should have the following form:
83
+ #
84
+ # , a, b, c
85
+ # a, 1.2, 2.3, 3.4
86
+ # b, 4.5, 5.6, 6.7
87
+ # c, 7.8, 8.9, 9.0
88
+ #
89
+ table = list()
90
+ with open(table_file, 'r') as f:
91
+ for i, l in enumerate(f):
92
+ arrs = [arr.strip() for arr in l.split(',')]
93
+ table.append(arrs)
94
+
95
+ # Define the numbers of rows and columns and check for errors.
96
+ num_rows = len(table)-1
97
+ if num_rows<1:
98
+ raise Exception('The table {} is empty.'.format(table_file))
99
+ row_lengths = set(len(table[i])-1 for i in range(num_rows))
100
+ if len(row_lengths)!=1:
101
+ raise Exception('The table {} has rows with different lengths.'.format(table_file))
102
+ num_cols = min(row_lengths)
103
+ if num_cols<1:
104
+ raise Exception('The table {} is empty.'.format(table_file))
105
+
106
+ # Find the row and column labels.
107
+ rows = [table[0][j+1] for j in range(num_rows)]
108
+ cols = [table[i+1][0] for i in range(num_cols)]
109
+
110
+ # Find the entries of the table.
111
+ values = np.zeros((num_rows, num_cols), dtype=np.float64)
112
+ for i in range(num_rows):
113
+ for j in range(num_cols):
114
+ value = table[i+1][j+1]
115
+ if is_finite_number(value):
116
+ values[i, j] = float(value)
117
+ else:
118
+ values[i, j] = float('nan')
119
+
120
+ return rows, cols, values
121
+
122
+ # Load weights.
123
+ def load_weights(weight_file):
124
+ # Load the table with the weight matrix.
125
+ rows, cols, values = load_table(weight_file)
126
+
127
+ # Split the equivalent classes.
128
+ rows = [set(row.split('|')) for row in rows]
129
+ cols = [set(col.split('|')) for col in cols]
130
+ assert(rows == cols)
131
+
132
+ # Identify the classes and the weight matrix.
133
+ classes = rows
134
+ weights = values
135
+
136
+ return classes, weights
137
+
138
+ # Load labels from header/label files.
139
+ def load_labels(label_files, classes):
140
+ # The labels should have the following form:
141
+ #
142
+ # Dx: label_1, label_2, label_3
143
+ #
144
+ num_recordings = len(label_files)
145
+ num_classes = len(classes)
146
+
147
+ # Use one-hot encoding for the labels.
148
+ labels = np.zeros((num_recordings, num_classes), dtype=np.bool)
149
+
150
+ # Iterate over the recordings.
151
+ for i in range(num_recordings):
152
+ header = load_header(label_files[i])
153
+ y = set(get_labels(header))
154
+ for j, x in enumerate(classes):
155
+ if x & y:
156
+ labels[i, j] = 1
157
+
158
+ return labels
159
+
160
+ # Load outputs from output files.
161
+ def load_classifier_outputs(output_files, classes):
162
+ # The outputs should have the following form:
163
+ #
164
+ # #Record ID
165
+ # diagnosis_1, diagnosis_2, diagnosis_3
166
+ # 0, 1, 1
167
+ # 0.12, 0.34, 0.56
168
+ #
169
+ num_recordings = len(output_files)
170
+ num_classes = len(classes)
171
+
172
+ # Use one-hot encoding for the outputs.
173
+ binary_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool)
174
+ scalar_outputs = np.zeros((num_recordings, num_classes), dtype=np.float64)
175
+
176
+ # Iterate over the recordings.
177
+ for i in range(num_recordings):
178
+ recording_id, recording_classes, recording_binary_outputs, recording_scalar_outputs = load_outputs(output_files[i])
179
+
180
+ # Allow for equivalent classes and sanitize classifier outputs.
181
+ recording_classes = [set(entry.split('|')) for entry in recording_classes]
182
+ recording_binary_outputs = [1 if entry in ('1', 'True', 'true', 'T', 't') else 0 for entry in recording_binary_outputs]
183
+ recording_scalar_outputs = [float(entry) if is_finite_number(entry) else 0 for entry in recording_scalar_outputs]
184
+
185
+ # Allow for unordered/reordered and equivalent classes.
186
+ for j, x in enumerate(classes):
187
+ binary_values = list()
188
+ scalar_values = list()
189
+ for k, y in enumerate(recording_classes):
190
+ if x & y:
191
+ binary_values.append(recording_binary_outputs[k])
192
+ scalar_values.append(recording_scalar_outputs[k])
193
+ if binary_values:
194
+ binary_outputs[i, j] = any(binary_values) # Define a class as positive if any of the equivalent classes is positive.
195
+ if scalar_values:
196
+ scalar_outputs[i, j] = np.mean(scalar_values) # Define the scalar value of a class as the mean value of the scalar values across equivalent classes.
197
+
198
+ return binary_outputs, scalar_outputs
199
+
200
+ # Compute recording-wise accuracy.
201
+ def compute_accuracy(labels, outputs):
202
+ num_recordings, num_classes = np.shape(labels)
203
+
204
+ num_correct_recordings = 0
205
+ for i in range(num_recordings):
206
+ if np.all(labels[i, :]==outputs[i, :]):
207
+ num_correct_recordings += 1
208
+
209
+ return float(num_correct_recordings) / float(num_recordings)
210
+
211
+ # Compute confusion matrices.
212
+ def compute_confusion_matrices(labels, outputs, normalize=False):
213
+ # Compute a binary confusion matrix for each class k:
214
+ #
215
+ # [TN_k FN_k]
216
+ # [FP_k TP_k]
217
+ #
218
+ # If the normalize variable is set to true, then normalize the contributions
219
+ # to the confusion matrix by the number of labels per recording.
220
+ num_recordings, num_classes = np.shape(labels)
221
+
222
+ if not normalize:
223
+ A = np.zeros((num_classes, 2, 2))
224
+ for i in range(num_recordings):
225
+ for j in range(num_classes):
226
+ if labels[i, j]==1 and outputs[i, j]==1: # TP
227
+ A[j, 1, 1] += 1
228
+ elif labels[i, j]==0 and outputs[i, j]==1: # FP
229
+ A[j, 1, 0] += 1
230
+ elif labels[i, j]==1 and outputs[i, j]==0: # FN
231
+ A[j, 0, 1] += 1
232
+ elif labels[i, j]==0 and outputs[i, j]==0: # TN
233
+ A[j, 0, 0] += 1
234
+ else: # This condition should not happen.
235
+ raise ValueError('Error in computing the confusion matrix.')
236
+ else:
237
+ A = np.zeros((num_classes, 2, 2))
238
+ for i in range(num_recordings):
239
+ normalization = float(max(np.sum(labels[i, :]), 1))
240
+ for j in range(num_classes):
241
+ if labels[i, j]==1 and outputs[i, j]==1: # TP
242
+ A[j, 1, 1] += 1.0/normalization
243
+ elif labels[i, j]==0 and outputs[i, j]==1: # FP
244
+ A[j, 1, 0] += 1.0/normalization
245
+ elif labels[i, j]==1 and outputs[i, j]==0: # FN
246
+ A[j, 0, 1] += 1.0/normalization
247
+ elif labels[i, j]==0 and outputs[i, j]==0: # TN
248
+ A[j, 0, 0] += 1.0/normalization
249
+ else: # This condition should not happen.
250
+ raise ValueError('Error in computing the confusion matrix.')
251
+
252
+ return A
253
+
254
+ # Compute macro F-measure.
255
+ def compute_f_measure(labels, outputs):
256
+ num_recordings, num_classes = np.shape(labels)
257
+
258
+ A = compute_confusion_matrices(labels, outputs)
259
+
260
+ f_measure = np.zeros(num_classes)
261
+ for k in range(num_classes):
262
+ tp, fp, fn, tn = A[k, 1, 1], A[k, 1, 0], A[k, 0, 1], A[k, 0, 0]
263
+ if 2 * tp + fp + fn:
264
+ f_measure[k] = float(2 * tp) / float(2 * tp + fp + fn)
265
+ else:
266
+ f_measure[k] = float('nan')
267
+
268
+ if np.any(np.isfinite(f_measure)):
269
+ macro_f_measure = np.nanmean(f_measure)
270
+ else:
271
+ macro_f_measure = float('nan')
272
+
273
+ return macro_f_measure, f_measure
274
+
275
+ # Compute macro AUROC and macro AUPRC.
276
+ def compute_auc(labels, outputs):
277
+ num_recordings, num_classes = np.shape(labels)
278
+
279
+ # Compute and summarize the confusion matrices for each class across at distinct output values.
280
+ auroc = np.zeros(num_classes)
281
+ auprc = np.zeros(num_classes)
282
+
283
+ for k in range(num_classes):
284
+ # We only need to compute TPs, FPs, FNs, and TNs at distinct output values.
285
+ thresholds = np.unique(outputs[:, k])
286
+ thresholds = np.append(thresholds, thresholds[-1]+1)
287
+ thresholds = thresholds[::-1]
288
+ num_thresholds = len(thresholds)
289
+
290
+ # Initialize the TPs, FPs, FNs, and TNs.
291
+ tp = np.zeros(num_thresholds)
292
+ fp = np.zeros(num_thresholds)
293
+ fn = np.zeros(num_thresholds)
294
+ tn = np.zeros(num_thresholds)
295
+ fn[0] = np.sum(labels[:, k]==1)
296
+ tn[0] = np.sum(labels[:, k]==0)
297
+
298
+ # Find the indices that result in sorted output values.
299
+ idx = np.argsort(outputs[:, k])[::-1]
300
+
301
+ # Compute the TPs, FPs, FNs, and TNs for class k across thresholds.
302
+ i = 0
303
+ for j in range(1, num_thresholds):
304
+ # Initialize TPs, FPs, FNs, and TNs using values at previous threshold.
305
+ tp[j] = tp[j-1]
306
+ fp[j] = fp[j-1]
307
+ fn[j] = fn[j-1]
308
+ tn[j] = tn[j-1]
309
+
310
+ # Update the TPs, FPs, FNs, and TNs at i-th output value.
311
+ while i < num_recordings and outputs[idx[i], k] >= thresholds[j]:
312
+ if labels[idx[i], k]:
313
+ tp[j] += 1
314
+ fn[j] -= 1
315
+ else:
316
+ fp[j] += 1
317
+ tn[j] -= 1
318
+ i += 1
319
+
320
+ # Summarize the TPs, FPs, FNs, and TNs for class k.
321
+ tpr = np.zeros(num_thresholds)
322
+ tnr = np.zeros(num_thresholds)
323
+ ppv = np.zeros(num_thresholds)
324
+ for j in range(num_thresholds):
325
+ if tp[j] + fn[j]:
326
+ tpr[j] = float(tp[j]) / float(tp[j] + fn[j])
327
+ else:
328
+ tpr[j] = float('nan')
329
+ if fp[j] + tn[j]:
330
+ tnr[j] = float(tn[j]) / float(fp[j] + tn[j])
331
+ else:
332
+ tnr[j] = float('nan')
333
+ if tp[j] + fp[j]:
334
+ ppv[j] = float(tp[j]) / float(tp[j] + fp[j])
335
+ else:
336
+ ppv[j] = float('nan')
337
+
338
+ # Compute AUROC as the area under a piecewise linear function with TPR/
339
+ # sensitivity (x-axis) and TNR/specificity (y-axis) and AUPRC as the area
340
+ # under a piecewise constant with TPR/recall (x-axis) and PPV/precision
341
+ # (y-axis) for class k.
342
+ for j in range(num_thresholds-1):
343
+ auroc[k] += 0.5 * (tpr[j+1] - tpr[j]) * (tnr[j+1] + tnr[j])
344
+ auprc[k] += (tpr[j+1] - tpr[j]) * ppv[j+1]
345
+
346
+ # Compute macro AUROC and macro AUPRC across classes.
347
+ if np.any(np.isfinite(auroc)):
348
+ macro_auroc = np.nanmean(auroc)
349
+ else:
350
+ macro_auroc = float('nan')
351
+ if np.any(np.isfinite(auprc)):
352
+ macro_auprc = np.nanmean(auprc)
353
+ else:
354
+ macro_auprc = float('nan')
355
+
356
+ return macro_auroc, macro_auprc, auroc, auprc
357
+
358
+ # Compute a modified confusion matrix for multi-class, multi-label tasks.
359
+ def compute_modified_confusion_matrix(labels, outputs):
360
+ # Compute a binary multi-class, multi-label confusion matrix, where the rows
361
+ # are the labels and the columns are the outputs.
362
+ num_recordings, num_classes = np.shape(labels)
363
+ A = np.zeros((num_classes, num_classes))
364
+
365
+ # Iterate over all of the recordings.
366
+ for i in range(num_recordings):
367
+ # Calculate the number of positive labels and/or outputs.
368
+ normalization = float(max(np.sum(np.any((labels[i, :], outputs[i, :]), axis=0)), 1))
369
+ # Iterate over all of the classes.
370
+ for j in range(num_classes):
371
+ # Assign full and/or partial credit for each positive class.
372
+ if labels[i, j]:
373
+ for k in range(num_classes):
374
+ if outputs[i, k]:
375
+ A[j, k] += 1.0/normalization
376
+
377
+ return A
378
+
379
+ # Compute the evaluation metric for the Challenge.
380
+ def compute_challenge_metric(weights, labels, outputs, classes, sinus_rhythm):
381
+ num_recordings, num_classes = np.shape(labels)
382
+ if sinus_rhythm in classes:
383
+ sinus_rhythm_index = classes.index(sinus_rhythm)
384
+ else:
385
+ raise ValueError('The sinus rhythm class is not available.')
386
+
387
+ # Compute the observed score.
388
+ A = compute_modified_confusion_matrix(labels, outputs)
389
+ observed_score = np.nansum(weights * A)
390
+
391
+ # Compute the score for the model that always chooses the correct label(s).
392
+ correct_outputs = labels
393
+ A = compute_modified_confusion_matrix(labels, correct_outputs)
394
+ correct_score = np.nansum(weights * A)
395
+
396
+ # Compute the score for the model that always chooses the sinus rhythm class.
397
+ inactive_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool_)
398
+ inactive_outputs[:, sinus_rhythm_index] = 1
399
+ A = compute_modified_confusion_matrix(labels, inactive_outputs)
400
+ inactive_score = np.nansum(weights * A)
401
+
402
+ if correct_score != inactive_score:
403
+ normalized_score = float(observed_score - inactive_score) / float(correct_score - inactive_score)
404
+ else:
405
+ normalized_score = 0.0
406
+
407
+ return normalized_score
408
+
409
+ if __name__ == '__main__':
410
+ classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric = evaluate_model(sys.argv[1], sys.argv[2])
411
+ output_string = 'AUROC,AUPRC,Accuracy,F-measure,Challenge metric\n{:.3f},{:.3f},{:.3f},{:.3f},{:.3f}'.format(auroc, auprc, accuracy, f_measure, challenge_metric)
412
+ class_output_string = 'Classes,{}\nAUROC,{}\nAUPRC,{}\nF-measure,{}'.format(
413
+ ','.join('|'.join(sorted(x)) for x in classes),
414
+ ','.join('{:.3f}'.format(x) for x in auroc_classes),
415
+ ','.join('{:.3f}'.format(x) for x in auprc_classes),
416
+ ','.join('{:.3f}'.format(x) for x in f_measure_classes))
417
+
418
+ print(output_string)
419
+ with open('test_outputs/output.txt', 'w') as f:
420
+ f.write(output_string)
421
+
422
+
423
+ df = pd.DataFrame({'classes':classes,
424
+ 'auroc_classes':auroc_classes,
425
+ 'auprc_classes':auprc_classes,
426
+ 'f_measure_classes':f_measure_classes})
427
+ # dxname = pd.read_csv('dx_mapping_scored.csv')
428
+ # dxname = dxname[['Dx','SNOMED CT Code']]
429
+ # dxname = dxname.rename(columns={'SNOMED CT Code':'classes'})
430
+ # df = df.merge(dxname,on='classes',how='left')
431
+ print(tabulate(df, headers='keys', tablefmt='psql'))
432
+
433
+ with open('test_outputs/table.txt', 'w') as f:
434
+ f.write(tabulate(df, headers='keys', tablefmt='psql'))
helper.ipynb ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "#### uniform stochastic depth"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 4,
13
+ "metadata": {},
14
+ "outputs": [
15
+ {
16
+ "name": "stdout",
17
+ "output_type": "stream",
18
+ "text": [
19
+ "24\n",
20
+ "25\n",
21
+ "[0.0, 0.004347826354205608, 0.008695652708411217, 0.013043479062616825, 0.017391305416822433, 0.021739132702350616, 0.02608695812523365, 0.030434783548116684, 0.03478261083364487, 0.03913043811917305, 0.04347826540470123, 0.04782608896493912, 0.052173912525177, 0.056521736085414886, 0.06086956337094307, 0.06521739065647125, 0.06956521421670914, 0.07391304522752762, 0.0782608687877655, 0.08260869979858398, 0.08695652335882187, 0.09130434691905975, 0.09565217792987823, 0.10000000149011612]\n"
22
+ ]
23
+ }
24
+ ],
25
+ "source": [
26
+ "import torch\n",
27
+ "\n",
28
+ "drop_path_rate = 0.1\n",
29
+ "depth = 24\n",
30
+ "\n",
31
+ "dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]\n",
32
+ "\n",
33
+ "print(len(dpr))\n",
34
+ "inter_dpr = [0.0] + dpr\n",
35
+ "print(len(inter_dpr))\n",
36
+ "print(dpr)"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 1,
42
+ "metadata": {},
43
+ "outputs": [
44
+ {
45
+ "name": "stdout",
46
+ "output_type": "stream",
47
+ "text": [
48
+ "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
49
+ "24\n"
50
+ ]
51
+ }
52
+ ],
53
+ "source": [
54
+ "import torch\n",
55
+ "\n",
56
+ "drop_path_rate = 0\n",
57
+ "depth = 24\n",
58
+ "\n",
59
+ "dpr = [x.item() for x in torch.full((depth,), drop_path_rate)]\n",
60
+ "print(dpr)\n",
61
+ "print(len(dpr))"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 3,
67
+ "metadata": {},
68
+ "outputs": [
69
+ {
70
+ "name": "stdout",
71
+ "output_type": "stream",
72
+ "text": [
73
+ "tensor([[[ 0, 1, 2],\n",
74
+ " [ 3, 4, 5],\n",
75
+ " [ 6, 7, 8],\n",
76
+ " [ 9, 10, 11],\n",
77
+ " [12, 13, 14]],\n",
78
+ "\n",
79
+ " [[15, 16, 17],\n",
80
+ " [18, 19, 20],\n",
81
+ " [21, 22, 23],\n",
82
+ " [24, 25, 26],\n",
83
+ " [27, 28, 29]]])\n",
84
+ "-------------------------\n",
85
+ "tensor([[[ 2, 1, 0],\n",
86
+ " [ 5, 4, 3],\n",
87
+ " [ 8, 7, 6],\n",
88
+ " [11, 10, 9],\n",
89
+ " [14, 13, 12]],\n",
90
+ "\n",
91
+ " [[17, 16, 15],\n",
92
+ " [20, 19, 18],\n",
93
+ " [23, 22, 21],\n",
94
+ " [26, 25, 24],\n",
95
+ " [29, 28, 27]]])\n",
96
+ "-------------------------\n",
97
+ "tensor([[[12, 13, 14],\n",
98
+ " [ 9, 10, 11],\n",
99
+ " [ 6, 7, 8],\n",
100
+ " [ 3, 4, 5],\n",
101
+ " [ 0, 1, 2]],\n",
102
+ "\n",
103
+ " [[27, 28, 29],\n",
104
+ " [24, 25, 26],\n",
105
+ " [21, 22, 23],\n",
106
+ " [18, 19, 20],\n",
107
+ " [15, 16, 17]]])\n"
108
+ ]
109
+ }
110
+ ],
111
+ "source": [
112
+ "import torch\n",
113
+ "\n",
114
+ "x = torch.arange(30).view(2, 5, 3)\n",
115
+ "print(x)\n",
116
+ "print(\"-------------------------\")\n",
117
+ "\n",
118
+ "y = x.flip([-1])\n",
119
+ "print(y)\n",
120
+ "print(\"-------------------------\")\n",
121
+ "\n",
122
+ "z = x.flip([1])\n",
123
+ "print(z)"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 2,
129
+ "metadata": {},
130
+ "outputs": [
131
+ {
132
+ "name": "stdout",
133
+ "output_type": "stream",
134
+ "text": [
135
+ "[0, 2, 3, 1]\n",
136
+ "This is the i: 0\n",
137
+ "This is the idx: 0\n",
138
+ "This is the i: 1\n",
139
+ "This is the idx: 2\n",
140
+ "This is the i: 2\n",
141
+ "This is the idx: 3\n",
142
+ "This is the i: 3\n",
143
+ "This is the idx: 1\n"
144
+ ]
145
+ }
146
+ ],
147
+ "source": [
148
+ "import random\n",
149
+ "\n",
150
+ "indices = [0, 1, 2, 3]\n",
151
+ "random.shuffle(indices)\n",
152
+ "\n",
153
+ "print(indices)\n",
154
+ "\n",
155
+ "for i, idx in enumerate(indices):\n",
156
+ " print(\"This is the i:\", i)\n",
157
+ " print(\"This is the idx:\", idx)"
158
+ ]
159
+ }
160
+ ],
161
+ "metadata": {
162
+ "kernelspec": {
163
+ "display_name": "mamba",
164
+ "language": "python",
165
+ "name": "python3"
166
+ },
167
+ "language_info": {
168
+ "codemirror_mode": {
169
+ "name": "ipython",
170
+ "version": 3
171
+ },
172
+ "file_extension": ".py",
173
+ "mimetype": "text/x-python",
174
+ "name": "python",
175
+ "nbconvert_exporter": "python",
176
+ "pygments_lexer": "ipython3",
177
+ "version": "3.10.14"
178
+ }
179
+ },
180
+ "nbformat": 4,
181
+ "nbformat_minor": 2
182
+ }
helper_code.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Do *not* edit this script.
4
+ # These are helper functions that you can use with your code.
5
+
6
+ import os, numpy as np
7
+
8
+ # Check if a variable is a number or represents a number.
9
+ def is_number(x):
10
+ try:
11
+ float(x)
12
+ return True
13
+ except (ValueError, TypeError):
14
+ return False
15
+
16
+ # Check if a variable is an integer or represents an integer.
17
+ def is_integer(x):
18
+ if is_number(x):
19
+ return float(x).is_integer()
20
+ else:
21
+ return False
22
+
23
+ # Check if a variable is a a finite number or represents a finite number.
24
+ def is_finite_number(x):
25
+ if is_number(x):
26
+ return np.isfinite(float(x))
27
+ else:
28
+ return False
29
+
30
+ # (Re)sort leads using the standard order of leads for the standard twelve-lead ECG.
31
+ def sort_leads(leads):
32
+ x = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
33
+ leads = sorted(leads, key=lambda lead: (x.index(lead) if lead in x else len(x) + leads.index(lead)))
34
+ return tuple(leads)
35
+
36
+ # Find header and recording files.
37
+ def find_challenge_files(data_directory):
38
+ header_files = list()
39
+ recording_files = list()
40
+ for f in os.listdir(data_directory):
41
+ root, extension = os.path.splitext(f)
42
+ if not root.startswith('.') and extension=='.hea':
43
+ header_file = os.path.join(data_directory, root + '.hea')
44
+ recording_file = os.path.join(data_directory, root + '.mat')
45
+ if os.path.isfile(header_file) and os.path.isfile(recording_file):
46
+ header_files.append(header_file)
47
+ recording_files.append(recording_file)
48
+ return header_files, recording_files
49
+
50
+ # Load header file as a string.
51
+ def load_header(header_file):
52
+ with open(header_file, 'r') as f:
53
+ header = f.read()
54
+ return header
55
+
56
+ # Load recording file as an array.
57
+ def load_recording(recording_file, header=None, leads=None, key='val'):
58
+ from scipy.io import loadmat
59
+ recording = loadmat(recording_file)[key]
60
+ if header and leads:
61
+ recording = choose_leads(recording, header, leads)
62
+ return recording
63
+
64
+ # Choose leads from the recording file.
65
+ def choose_leads(recording, header, leads):
66
+ num_leads = len(leads)
67
+ num_samples = np.shape(recording)[1]
68
+ chosen_recording = np.zeros((num_leads, num_samples), recording.dtype)
69
+ available_leads = get_leads(header)
70
+ for i, lead in enumerate(leads):
71
+ if lead in available_leads:
72
+ j = available_leads.index(lead)
73
+ chosen_recording[i, :] = recording[j, :]
74
+ return chosen_recording
75
+
76
+ # Get recording ID.
77
+ def get_recording_id(header):
78
+ recording_id = None
79
+ for i, l in enumerate(header.split('\n')):
80
+ if i==0:
81
+ try:
82
+ recording_id = l.split(' ')[0]
83
+ except:
84
+ pass
85
+ else:
86
+ break
87
+ return recording_id
88
+
89
+ # Get leads from header.
90
+ def get_leads(header):
91
+ leads = list()
92
+ for i, l in enumerate(header.split('\n')):
93
+ entries = l.split(' ')
94
+ if i==0:
95
+ num_leads = int(entries[1])
96
+ elif i<=num_leads:
97
+ leads.append(entries[-1])
98
+ else:
99
+ break
100
+ return tuple(leads)
101
+
102
+ # Get age from header.
103
+ def get_age(header):
104
+ age = None
105
+ for l in header.split('\n'):
106
+ if l.startswith('# Age'):
107
+ try:
108
+ age = float(l.split(': ')[1].strip())
109
+ except:
110
+ age = float('nan')
111
+ return age
112
+
113
+ # Get sex from header.
114
+ def get_sex(header):
115
+ sex = None
116
+ for l in header.split('\n'):
117
+ if l.startswith('# Sex'):
118
+ try:
119
+ sex = l.split(': ')[1].strip()
120
+ except:
121
+ pass
122
+ return sex
123
+
124
+ # Get frequency from header.
125
+ def get_num_leads(header):
126
+ num_leads = None
127
+ for i, l in enumerate(header.split('\n')):
128
+ if i==0:
129
+ try:
130
+ num_samples = float(l.split(' ')[1])
131
+ except:
132
+ pass
133
+ else:
134
+ break
135
+ return num_leads
136
+
137
+ # Get frequency from header.
138
+ def get_frequency(header):
139
+ frequency = None
140
+ for i, l in enumerate(header.split('\n')):
141
+ if i==0:
142
+ try:
143
+ frequency = float(l.split(' ')[2])
144
+ except:
145
+ pass
146
+ else:
147
+ break
148
+ return frequency
149
+
150
+ # Get number of samples from header.
151
+ def get_num_samples(header):
152
+ num_samples = None
153
+ for i, l in enumerate(header.split('\n')):
154
+ if i==0:
155
+ try:
156
+ num_samples = float(l.split(' ')[3])
157
+ except:
158
+ pass
159
+ else:
160
+ break
161
+ return num_samples
162
+
163
+ # Get analog-to-digital converter (ADC) gains from header.
164
+ def get_adc_gains(header, leads):
165
+ adc_gains = np.zeros(len(leads))
166
+ for i, l in enumerate(header.split('\n')):
167
+ entries = l.split(' ')
168
+ if i==0:
169
+ num_leads = int(entries[1])
170
+ elif i<=num_leads:
171
+ current_lead = entries[-1]
172
+ if current_lead in leads:
173
+ j = leads.index(current_lead)
174
+ try:
175
+ adc_gains[j] = float(entries[2].split('/')[0])
176
+ except:
177
+ pass
178
+ else:
179
+ break
180
+ return adc_gains
181
+
182
+ # Get baselines from header.
183
+ def get_baselines(header, leads):
184
+ baselines = np.zeros(len(leads))
185
+ for i, l in enumerate(header.split('\n')):
186
+ entries = l.split(' ')
187
+ if i==0:
188
+ num_leads = int(entries[1])
189
+ elif i<=num_leads:
190
+ current_lead = entries[-1]
191
+ if current_lead in leads:
192
+ j = leads.index(current_lead)
193
+ try:
194
+ baselines[j] = float(entries[4].split('/')[0])
195
+ except:
196
+ pass
197
+ else:
198
+ break
199
+ return baselines
200
+
201
+ # Get labels from header.
202
+ def get_labels(header):
203
+ labels = list()
204
+ for l in header.split('\n'):
205
+ if l.startswith('# Dx'):
206
+ try:
207
+ entries = l.split(': ')[1].split(',')
208
+ for entry in entries:
209
+ labels.append(entry.strip())
210
+ except:
211
+ pass
212
+ return labels
213
+
214
+ # Save outputs from model.
215
+ def save_outputs(output_file, recording_id, classes, labels, probabilities):
216
+ # Format the model outputs.
217
+ recording_string = '#{}'.format(recording_id)
218
+ class_string = ','.join(str(c) for c in classes)
219
+ label_string = ','.join(str(l) for l in labels)
220
+ probabilities_string = ','.join(str(p) for p in probabilities)
221
+ output_string = recording_string + '\n' + class_string + '\n' + label_string + '\n' + probabilities_string + '\n'
222
+
223
+ # Save the model outputs.
224
+ with open(output_file, 'w') as f:
225
+ f.write(output_string)
226
+
227
+ # Load outputs from model.
228
+ def load_outputs(output_file):
229
+ with open(output_file, 'r') as f:
230
+ for i, l in enumerate(f):
231
+ if i==0:
232
+ recording_id = l[1:] if len(l)>1 else None
233
+ elif i==1:
234
+ classes = tuple(entry.strip() for entry in l.split(','))
235
+ elif i==2:
236
+ labels = tuple(entry.strip() for entry in l.split(','))
237
+ elif i==3:
238
+ probabilities = tuple(float(entry) if is_finite_number(entry) else float('nan') for entry in l.split(','))
239
+ else:
240
+ break
241
+ return recording_id, classes, labels, probabilities
losses.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ """
4
+ Implements the knowledge distillation loss
5
+ """
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+
10
+ class DistillationLoss(torch.nn.Module):
11
+ """
12
+ This module wraps a standard criterion and adds an extra knowledge distillation loss by
13
+ taking a teacher model prediction and using it as additional supervision.
14
+ """
15
+ def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
16
+ distillation_type: str, alpha: float, tau: float):
17
+ super().__init__()
18
+ self.base_criterion = base_criterion
19
+ self.teacher_model = teacher_model
20
+ assert distillation_type in ['none', 'soft', 'hard']
21
+ self.distillation_type = distillation_type
22
+ self.alpha = alpha
23
+ self.tau = tau
24
+
25
+ def forward(self, inputs, outputs, labels):
26
+ """
27
+ Args:
28
+ inputs: The original inputs that are feed to the teacher model
29
+ outputs: the outputs of the model to be trained. It is expected to be
30
+ either a Tensor, or a Tuple[Tensor, Tensor], with the original output
31
+ in the first position and the distillation predictions as the second output
32
+ labels: the labels for the base criterion
33
+ """
34
+ outputs_kd = None
35
+ if not isinstance(outputs, torch.Tensor):
36
+ # assume that the model outputs a tuple of [outputs, outputs_kd]
37
+ outputs, outputs_kd = outputs
38
+ base_loss = self.base_criterion(outputs, labels)
39
+ if self.distillation_type == 'none':
40
+ return base_loss
41
+
42
+ if outputs_kd is None:
43
+ raise ValueError("When knowledge distillation is enabled, the model is "
44
+ "expected to return a Tuple[Tensor, Tensor] with the output of the "
45
+ "class_token and the dist_token")
46
+ # don't backprop throught the teacher
47
+ with torch.no_grad():
48
+ teacher_outputs = self.teacher_model(inputs)
49
+
50
+ if self.distillation_type == 'soft':
51
+ T = self.tau
52
+ # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
53
+ # with slight modifications
54
+ distillation_loss = F.kl_div(
55
+ F.log_softmax(outputs_kd / T, dim=1),
56
+ #We provide the teacher's targets in log probability because we use log_target=True
57
+ #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
58
+ #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
59
+ F.log_softmax(teacher_outputs / T, dim=1),
60
+ reduction='sum',
61
+ log_target=True
62
+ ) * (T * T) / outputs_kd.numel()
63
+ #We divide by outputs_kd.numel() to have the legacy PyTorch behavior.
64
+ #But we also experiments output_kd.size(0)
65
+ #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details
66
+ elif self.distillation_type == 'hard':
67
+ distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
68
+
69
+ loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
70
+ return loss
main_ecg.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import utils
4
+ import torch
5
+ import numpy as np
6
+ import torch.backends.cudnn as cudnn
7
+ import pandas as pd
8
+ import time
9
+ import datetime
10
+ import os
11
+
12
+ # mixup
13
+ from timm.data import Mixup
14
+
15
+
16
+ # log about
17
+ # import mlflow
18
+
19
+ # for the challenge 2021
20
+ import ecg_dataset_2021
21
+ import engine_ecg_2021
22
+
23
+ # for the challenge 2020
24
+ import ecg_dataset_2020
25
+ import engine_ecg_2020
26
+
27
+ from engine_ecg_2021 import train_one_epoch, evaluate
28
+
29
+ # for the timm
30
+ from timm.models import create_model
31
+ import timm
32
+ from timm.scheduler.cosine_lr import CosineLRScheduler
33
+
34
+ import sys
35
+
36
+ from optimizer import NoamOpt
37
+
38
+ import models_mamba_ecg
39
+
40
+ def get_args_parser():
41
+ parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
42
+ parser.add_argument('--batch-size', default=64, type=int)
43
+ parser.add_argument('--epochs', default=300, type=int)
44
+
45
+ # Model parameters
46
+ parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', help='Name of model to train')
47
+
48
+ # Parameters regarding the Drop
49
+ parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)')
50
+ parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', help='Drop path rate (default: 0.1)')
51
+
52
+ # setting the mode
53
+ parser.add_argument('--train-mode', action='store_true')
54
+ parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
55
+ parser.set_defaults(train_mode=True)
56
+
57
+
58
+ # * data augmentation method
59
+ parser.add_argument('--mixup', type=float, default=0,
60
+ help='mixup alpha, mixup enabled if > 0. (default: 0)')
61
+ parser.add_argument('--cutmix', type=float, default=0,
62
+ help='cutmix alpha, cutmix enabled if > 0. (default: 0)')
63
+ parser.add_argument('--mixup_no_label', type=float, default=0,
64
+ help='this is the mixup but without interpolation of label. (default: 0)')
65
+ parser.add_argument('--progressive_switch', type=bool, default=False,
66
+ help='this is the switch of progressive data augmentation method')
67
+
68
+ # the output directory
69
+ parser.add_argument('--output_dir', default='', help='path where to save, empty for no saving')
70
+
71
+ # the device information
72
+ parser.add_argument('--device', default='cuda', help='device to use for training / testing')
73
+
74
+ parser.add_argument('--seed', default=0, type=int)
75
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
76
+
77
+
78
+ parser.add_argument('--num_workers', default=10, type=int)
79
+
80
+ parser.add_argument('--pin-mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
81
+ parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', help='')
82
+ parser.set_defaults(pin_mem=True)
83
+
84
+ # special parameters for this strategy
85
+ parser.add_argument('--depth', default=24, type=int)
86
+ parser.add_argument('--block', default='VisionMamba', choices=['OriginalMamba', 'VisionMamba'], type=str, help='the selection of the block of the paradigm')
87
+ parser.add_argument('--lead', default='12Lead', choices=['12Lead', 'RandomLead'], type=str, help='the number of Lead')
88
+ parser.add_argument('--lrschedule', default='CosineAnnealing', choices=['Noam', 'CosineAnnealing'], type=str, help='the selection of learning rate strategy')
89
+
90
+ # if use random token position
91
+ parser.add_argument('--if_random_cls_token_position', action='store_true')
92
+ parser.add_argument('--no_randargs.if_nan2numom_cls_token_position', action='store_false', dest='if_random_cls_token_position')
93
+ parser.set_defaults(if_random_cls_token_position=False)
94
+
95
+ # if use random token rank
96
+ parser.add_argument('--if_random_token_rank', action='store_true')
97
+ parser.add_argument('--no_random_token_rank', action='store_false', dest='if_random_token_rank')
98
+ parser.set_defaults(if_random_token_rank=False)
99
+
100
+ # using for the journal
101
+ parser.add_argument('--fused_add_norm', type=bool, default=True, help='combines the element-wise addition (from residual connections) and normalization (e.g., RMSNorm)')
102
+ parser.add_argument('--if_divide_out', type=bool, default=True, help='Should the result be divided by two when combining outputs from two directions?')
103
+ parser.add_argument('--use_middle_cls_token', type=bool, default=True, help='Whether the class is inserted in the middle of sequence')
104
+
105
+ # the switch of scenario
106
+ parser.add_argument('--challenge_scenario', default='2021', choices=[2021, 2020], type=int, help='the scenario of Challenge')
107
+
108
+
109
+ return parser
110
+
111
+
112
+ def collate(batch):
113
+ # Left-zero padding
114
+ ch = batch[0][0].shape[0]
115
+ maxL = 8192
116
+ X = np.zeros((len(batch), ch, maxL))
117
+
118
+ for i in range(len(batch)):
119
+ X[i, :, -batch[i][0].shape[-1]:] = batch[i][0]
120
+
121
+ t = np.array([b[1] for b in batch])
122
+
123
+ X = torch.from_numpy(X)
124
+ t = torch.from_numpy(t)
125
+
126
+ return X, t
127
+
128
+ def main(args, data_directory, model_directory, group_number):
129
+
130
+ train_log_fp = open(args.output_dir + '/train_log_group_%d.txt' % group_number, 'a')
131
+ print(args)
132
+ train_log_fp.write("The is the configuration: {}\n".format(args))
133
+
134
+
135
+ device = torch.device(args.device)
136
+
137
+ print("This is the running device", device)
138
+ train_log_fp.write("This is the running device:{}\n".format(device))
139
+
140
+ seed = args.seed + utils.get_rank()
141
+ torch.manual_seed(seed)
142
+ np.random.seed(seed)
143
+
144
+ print("This is the random's seed:", seed)
145
+ train_log_fp.write("This is the random's seed:{}\n".format(seed))
146
+
147
+
148
+ cudnn.benchmark = True
149
+
150
+
151
+ dataset_file_all_address = "../../collection_of_all_datasets/"
152
+ # below is building dataset(train and validation)
153
+
154
+ print("This is building the training set:")
155
+ ############ training area #########################
156
+ the_training_address = data_directory + "/training_group" + data_directory[-1] + ".csv"
157
+ df = pd.read_csv(the_training_address)
158
+ print("Total the {} files will be fed into model for training".format(len(df['Name'])))
159
+ train_log_fp.write("Total the {} files will be fed into model for training \n".format(len(df['Name'])))
160
+
161
+ print("This is the first file of training:",df['Name'][0])
162
+
163
+
164
+ training_header_files=[]
165
+
166
+ for i in range(len(df['Name'])):
167
+ each_header_file = dataset_file_all_address + df['Name'][i]
168
+ training_header_files.append(each_header_file)
169
+
170
+ collate_training = collate
171
+
172
+
173
+ if args.mixup > 0:
174
+ print("The mixup is using, and the percentage is:", args.mixup)
175
+ train_log_fp.write("The mixup is using, and the percentage is:{}\n".format(args.mixup))
176
+ # collate_training = collate_mixup
177
+
178
+ if args.cutmix > 0:
179
+ print("The cutmix is using, and the percentage is:", args.cutmix)
180
+ train_log_fp.write("The cutmix is using, and the percentage is:{}\n".format(args.cutmix))
181
+ if args.mixup_no_label > 0:
182
+ print("The mixup is using without label's interpolation , and the percentage is:", args.mixup_no_label)
183
+ train_log_fp.write("The mixup is using without label's interpolation , and the percentage is:{}\n".format(args.mixup_no_label))
184
+ if args.progressive_switch:
185
+ print("The progressive mixup is using.")
186
+ train_log_fp.write("The progressive mixup is using.")
187
+
188
+ if args.challenge_scenario == 2021:
189
+ train_dataset = ecg_dataset_2021.dataset(training_header_files, Mixup = args.mixup, amount = len(df['Name']), cutMix=args.cutmix, Mixup_no_label_interpolate=args.mixup_no_label, progressive_switch=args.progressive_switch)
190
+ elif args.challenge_scenario == 2020:
191
+ train_dataset = ecg_dataset_2020.dataset(training_header_files, Mixup = args.mixup, amount = len(df['Name']), cutMix=args.cutmix, Mixup_no_label_interpolate=args.mixup_no_label, progressive_switch=args.progressive_switch)
192
+ else:
193
+ raise ValueError(f"No matching condition for value: {args.challenge_scenario}")
194
+
195
+
196
+ if args.lead == "12Lead":
197
+ lead_number = 12
198
+ print("This 12 lead is using.")
199
+ train_log_fp.write("This 12 lead is using.")
200
+ elif args.lead == "RandomLead":
201
+ lead_number = None
202
+ print("This random lead is using.")
203
+ train_log_fp.write("This random lead is using.")
204
+ else:
205
+ raise ValueError(f"No matching condition for value: {args.lead}")
206
+
207
+ """
208
+ we filter out the sample which the length is 8192 via the below code.
209
+ just like the random shift windows.
210
+ """
211
+ train_dataset.num_leads = lead_number
212
+ train_dataset.sample = True
213
+ ###################################################
214
+ print("done")
215
+
216
+
217
+ print("This is building the validation set:")
218
+ ############ testing area #########################
219
+ the_testing_address = data_directory + "/testing_group" + data_directory[-1]+".csv"
220
+ df = pd.read_csv(the_testing_address)
221
+ print("Total the {} files will be used as testing".format(len(df['Name'])))
222
+ train_log_fp.write("Total the {} files will be used as testing \n".format(len(df['Name'])))
223
+
224
+ print("This is the first file of testing:",df['Name'][0])
225
+
226
+ testing_header_files=[]
227
+
228
+ for i in range(len(df['Name'])):
229
+ each_header_file = dataset_file_all_address + df['Name'][i]
230
+ testing_header_files.append(each_header_file)
231
+
232
+
233
+ if args.challenge_scenario == 2021:
234
+ test_dataset = ecg_dataset_2021.dataset(testing_header_files, Mixup = 0)
235
+ elif args.challenge_scenario == 2020:
236
+ test_dataset = ecg_dataset_2020.dataset(testing_header_files, Mixup = 0)
237
+ else:
238
+ raise ValueError(f"No matching condition for value: {args.challenge_scenario}")
239
+
240
+
241
+ test_dataset.num_leads = 12
242
+ test_dataset.sample = True
243
+ ###################################################
244
+ print("done")
245
+
246
+ sampler_train = torch.utils.data.RandomSampler(train_dataset)
247
+ sampler_val = torch.utils.data.SequentialSampler(test_dataset)
248
+
249
+ data_loader_train = torch.utils.data.DataLoader(
250
+ train_dataset, sampler=sampler_train,
251
+ batch_size=args.batch_size,
252
+ collate_fn=collate_training,
253
+ num_workers=args.num_workers,
254
+ pin_memory=args.pin_mem,
255
+ drop_last=True,
256
+ )
257
+ # Setting `drop_last=True` means that this incomplete batch will be dropped,
258
+ # ensuring that all batches fed to the model during training have **the same size.**
259
+
260
+ data_loader_val = torch.utils.data.DataLoader(
261
+ test_dataset, sampler=sampler_val,
262
+ batch_size=int(1.5 * args.batch_size),
263
+ collate_fn=collate,
264
+ num_workers=args.num_workers,
265
+ pin_memory=args.pin_mem,
266
+ drop_last=False
267
+ )
268
+
269
+ if args.challenge_scenario == 2021:
270
+ args.nb_classes = 26
271
+ elif args.challenge_scenario == 2020:
272
+ args.nb_classes = 27
273
+ else:
274
+ raise ValueError(f"No matching condition for value: {args.challenge_scenario}")
275
+
276
+
277
+ print(f"Creating model: {args.model}")
278
+ train_log_fp.write("Creating model:{}\n".format(args.model))
279
+ model = create_model(
280
+ args.model,
281
+ pretrained=False,
282
+ num_classes=args.nb_classes,
283
+ drop_rate=args.drop,
284
+ drop_path_rate=args.drop_path,
285
+ drop_block_rate=None,
286
+ block = args.block,
287
+ depth = args.depth,
288
+ fused_add_norm = args.fused_add_norm,
289
+ use_middle_cls_token = args.use_middle_cls_token,
290
+ img_size=8192
291
+ )
292
+
293
+
294
+ model.to(device)
295
+
296
+
297
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
298
+
299
+ print('number of params:', n_parameters)
300
+ train_log_fp.write("number of params:{}\n".format(n_parameters))
301
+
302
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.0006, betas=(0.9, 0.98), eps=1e-9)
303
+
304
+ if args.lrschedule == "Noam":
305
+ optimizer = NoamOpt(729, 1, 4000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
306
+ print("The Noam is using as learning rate strategy.")
307
+ train_log_fp.write("The Noam is using as learning rate strategy.")
308
+
309
+ elif args.lrschedule == "CosineAnnealing":
310
+ lr_scheduler = CosineLRScheduler(
311
+ optimizer,
312
+ t_initial=13, # Number of epochs after warmup
313
+ lr_min=1e-6,
314
+ warmup_lr_init = 1e-5,
315
+ warmup_t=5,
316
+ cycle_limit=1,
317
+ t_in_epochs=True,
318
+ warmup_prefix=True # Added to reach initial_lr
319
+ )
320
+ print("The cosing annealing is using as learning rate strategy.")
321
+ train_log_fp.write("The cosing annealing is using as learning rate strategy.")
322
+ else:
323
+ raise ValueError(f"No matching condition for value: {args.lrschedule}")
324
+
325
+ # below is for the cosine annealing schedule
326
+
327
+ # the loss is set as normal BCE loss.
328
+ criterion = torch.nn.BCEWithLogitsLoss()
329
+
330
+
331
+ output_dir = Path(args.output_dir)
332
+
333
+ print(f"Start training for {args.epochs} epochs")
334
+ train_log_fp.write(f"Start training for {args.epochs} epochs\n")
335
+
336
+ start_time = time.time()
337
+
338
+ max_AUPRC = 0.0
339
+
340
+
341
+ train_auprc_list = []
342
+ testing_AUPRC_list = []
343
+ loss_value_after_each_epcoh_list = []
344
+ loss_value_testing_list = []
345
+
346
+ testing_auroc_list = []
347
+ testing_f1_list = []
348
+
349
+ testing_subset_accuracy_list = []
350
+ hamming_loss_list=[]
351
+ challenge_score_list = []
352
+
353
+
354
+ jump_count = 0
355
+
356
+ thrs_list = []
357
+ # loss_value_after_each_epcoh_testing_list = []
358
+ # epoch is started from 0
359
+ for epoch in range(args.start_epoch, args.epochs):
360
+
361
+ train_log_fp.write("\n")
362
+ train_log_fp.write("---------------------------------------------------------- \n")
363
+ train_log_fp.write("---------------------------------------------------------- \n")
364
+ train_log_fp.write("---------------------------------------------------------- \n")
365
+
366
+ now = datetime.datetime.now()
367
+ current_time = now.strftime("%H:%M:%S")
368
+ print("Current Time =", current_time)
369
+ train_log_fp.write(f"Current Time = {current_time} \n")
370
+ train_log_fp.write(f"Current epoch: {epoch} \n")
371
+
372
+ train_dataset.set_epoch(epoch) # Update the current epoch in the dataset
373
+
374
+ if args.lrschedule == "CosineAnnealing":
375
+ lr_scheduler.step(epoch)
376
+
377
+
378
+ ###### below is for the training ###########
379
+ if args.challenge_scenario == 2021:
380
+ train_auprc, loss_value_after_each_epcoh, thrs, scores_F1, scores_SubsetAccuracy, scores_HammingLoss = engine_ecg_2021.train_one_epoch(
381
+ model, criterion, data_loader_train,
382
+ optimizer, device, epoch,
383
+ set_training_mode=args.train_mode, # keep in eval mode for deit finetuning / train mode for training and deit III finetuning
384
+ args=args,
385
+ )
386
+ elif args.challenge_scenario == 2020:
387
+ train_auprc, loss_value_after_each_epcoh, thrs, scores_F1, scores_SubsetAccuracy, scores_HammingLoss = engine_ecg_2020.train_one_epoch(
388
+ model, criterion, data_loader_train,
389
+ optimizer, device, epoch,
390
+ set_training_mode=args.train_mode, # keep in eval mode for deit finetuning / train mode for training and deit III finetuning
391
+ args=args,
392
+ )
393
+ else:
394
+ raise ValueError(f"No matching condition for value: {args.challenge_scenario}")
395
+
396
+
397
+ train_auprc_list.append(round(train_auprc, 4))
398
+ loss_value_after_each_epcoh_list.append(round(loss_value_after_each_epcoh, 4))
399
+ print("**************************************************************")
400
+ print("This is the list of Training of AUPRC:", train_auprc_list)
401
+ train_log_fp.write(f"This is the list of Training of AUPRC: {train_auprc_list} \n")
402
+ print("This is the list of Training of loss:", loss_value_after_each_epcoh_list)
403
+ train_log_fp.write(f"This is the list of Training of loss: {loss_value_after_each_epcoh_list} \n")
404
+ print("**************************************************************")
405
+
406
+
407
+ ###### below is for the evaluation ###########
408
+
409
+ if args.challenge_scenario == 2021:
410
+ AUPRC, auroc, f1, hamming, subset_accuracy, challenge_score, loss_value_after_each_epoch_testing = engine_ecg_2021.evaluate(data_loader_val, model, thrs, scores_F1, scores_SubsetAccuracy, scores_HammingLoss, device)
411
+ elif args.challenge_scenario == 2020:
412
+ AUPRC, auroc, f1, hamming, subset_accuracy, challenge_score, loss_value_after_each_epoch_testing = engine_ecg_2020.evaluate(data_loader_val, model, thrs, scores_F1, scores_SubsetAccuracy, scores_HammingLoss, device)
413
+ else:
414
+ raise ValueError(f"No matching condition for value: {args.challenge_scenario}")
415
+
416
+ # print(f"Accuracy of the network on the {len(test_dataset)} test images: {AUPRC:.4f}")
417
+
418
+ testing_AUPRC_list.append(round(AUPRC, 4))
419
+ loss_value_testing_list.append(round(loss_value_after_each_epoch_testing, 4))
420
+ testing_auroc_list.append(round(auroc, 4))
421
+ testing_f1_list.append(round(f1, 4))
422
+ testing_subset_accuracy_list.append((round(subset_accuracy, 4)))
423
+ hamming_loss_list.append((round(hamming, 4)))
424
+ challenge_score_list.append((round(challenge_score, 4)))
425
+
426
+
427
+ # loss_value_after_each_epcoh_testing_list.append(loss)
428
+ print("**********************************************************")
429
+ print("This is the list of Testing AUPRC:", testing_AUPRC_list)
430
+ print("This is the list of Testing loss:", loss_value_testing_list)
431
+ print("This is the list of Testing auroc:", testing_auroc_list)
432
+ print("This is the list of Testing f1:", testing_f1_list)
433
+ print("This is the list of subset accuracy:", testing_subset_accuracy_list)
434
+ print("This is the list of hamming_loss:", hamming_loss_list)
435
+ print("This is the list of challenge_score:", challenge_score_list)
436
+
437
+ train_log_fp.write("---------------------------------------------------------- \n")
438
+ train_log_fp.write(f"This is the list of Testing AUPRC: {testing_AUPRC_list} \n")
439
+ train_log_fp.write(f"This is the list of Testing loss: {loss_value_testing_list} \n")
440
+ train_log_fp.write(f"This is the list of Testing auroc: {testing_auroc_list} \n")
441
+ train_log_fp.write(f"This is the list of Testing f1: {testing_f1_list} \n")
442
+ train_log_fp.write(f"This is the list of Testing subset accuracy: {testing_subset_accuracy_list} \n")
443
+ train_log_fp.write(f"This is the list of Testing hamming_loss: {hamming_loss_list} \n")
444
+ train_log_fp.write(f"This is the list of challenge_score: {challenge_score_list} \n")
445
+ print("**********************************************************")
446
+
447
+ # here I have to put code regarding the auprc
448
+ if max_AUPRC < AUPRC:
449
+ max_AUPRC = AUPRC
450
+ jump_count = 0
451
+ if args.output_dir:
452
+ checkpoint_paths = [output_dir / 'best_auprc_checkpoint.pth']
453
+ for checkpoint_path in checkpoint_paths:
454
+ utils.save_on_master({
455
+ 'model': model.state_dict(),
456
+ 'optimizer': optimizer.optimizer.state_dict(),
457
+ 'epoch': epoch,
458
+ 'args': args,
459
+ }, checkpoint_path)
460
+ else:
461
+ jump_count = jump_count + 1
462
+
463
+
464
+ print(f'This is the Max AUPRC: {max_AUPRC:.4f}')
465
+ train_log_fp.write(f"This is the Max AUPRC: {max_AUPRC:.4f} \n")
466
+
467
+ if jump_count > 4:
468
+ print("This experimental will be finished at epoch:", epoch)
469
+ train_log_fp.write(f"This experimental will be finished at epoch: {epoch} \n")
470
+ break
471
+
472
+ total_time = time.time() - start_time
473
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
474
+ print('Training time {}'.format(total_time_str))
475
+
476
+ train_log_fp.write('Training time {}\n'.format(total_time_str))
477
+ train_log_fp.close()
478
+ old_txt_file_name = args.output_dir + '/train_log_group_%d.txt' % group_number
479
+ new_txt_file_name = args.output_dir + '/train_log_group_%d_MAX_AUPRC_%.4f_.txt' % (group_number, max_AUPRC)
480
+ os.rename(old_txt_file_name, new_txt_file_name)
481
+
482
+
483
+ if __name__ == '__main__':
484
+ parser = argparse.ArgumentParser('DeiT training and evaluation script, but this is for the multiple classification ECG.', parents=[get_args_parser()])
485
+ args = parser.parse_args()
486
+
487
+ if args.challenge_scenario == 2021:
488
+ data_directory = "./csv-file_2021_challenge/training_validation_testing/group"
489
+ elif args.challenge_scenario == 2020:
490
+ data_directory = "./csv-file_2020_challenge/training_validation_testing/group"
491
+ else:
492
+ raise ValueError(f"No matching condition for value: {args.challenge_scenario}")
493
+
494
+ print("This is the scenario:", args.challenge_scenario)
495
+ print("This is the scenario:", args.challenge_scenario)
496
+
497
+ model_directory = "./model/model_group"
498
+
499
+ for i in range(1,6):
500
+ data_directory_x = data_directory + str(i)
501
+ model_directory_x = model_directory +str(i)
502
+
503
+ args.output_dir = f"./output/Scenario_{args.challenge_scenario}_{args.block}_depth_{args.depth}_{args.lead}_batchSize_{args.batch_size}_CutMix(random)_{args.cutmix}_MixUp_{args.mixup}_JTEHM_group{i}"
504
+
505
+ if args.output_dir:
506
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
507
+
508
+ print("This is the group", i)
509
+ main(args, data_directory_x, model_directory_x, i)
models_mamba_ecg.py ADDED
@@ -0,0 +1,1013 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as functional
6
+ from functools import partial
7
+ from torch import Tensor
8
+ from typing import Optional
9
+
10
+ from timm.models.vision_transformer import VisionTransformer, _cfg
11
+ # from timm.models.registry import register_model
12
+ # from timm.models.layers import trunc_normal_, lecun_normal_
13
+
14
+ from timm.models import register_model
15
+ from timm.layers import trunc_normal_, lecun_normal_
16
+
17
+ from timm.layers import DropPath, to_2tuple
18
+
19
+ # from timm.models.layers import DropPath, to_2tuple
20
+ from timm.models.vision_transformer import _load_weights
21
+
22
+ import math
23
+
24
+ from collections import namedtuple
25
+
26
+ from mamba_ssm.modules.mamba_simple import Mamba
27
+ from mamba_ssm.utils.generation import GenerationMixin
28
+ from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
29
+
30
+ from rope import *
31
+ import random
32
+ import sys
33
+
34
+ try:
35
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
36
+ except ImportError:
37
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
38
+
39
+
40
+ # layer_norm_fn and rms_norm_fn both are normalization method
41
+
42
+ __all__ = [
43
+ 'vim_tiny_patch16_224', 'vim_small_patch16_224', 'vim_base_patch16_224',
44
+ 'vim_tiny_patch16_384', 'vim_small_patch16_384', 'vim_base_patch16_384',
45
+ ]
46
+
47
+
48
+ '''
49
+ in the original script(ft-vim-s.sh)
50
+ img_size = 224,
51
+ patch_size = 16,
52
+ stride = 8,
53
+ in_chans = 3,
54
+ embed_dim = 768
55
+ ------------------------------------
56
+ self.img_size: (224, 224)
57
+ self.patch_size: (16, 16)
58
+ self.grid_size: (27, 27)
59
+ self.num_patches: 729
60
+ self.flatten: True
61
+ self.norm: nn.Identity()
62
+ '''
63
+ class PatchEmbed(nn.Module):
64
+ """ 2D Image to Patch Embedding
65
+ """
66
+ def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
67
+ super().__init__()
68
+ img_size = to_2tuple(img_size)
69
+ patch_size = to_2tuple(patch_size)
70
+ self.img_size = img_size
71
+ self.patch_size = patch_size
72
+ self.grid_size = ((img_size[0] - patch_size[0]) // stride + 1, (img_size[1] - patch_size[1]) // stride + 1)
73
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
74
+ self.flatten = flatten
75
+
76
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)
77
+ # if the norm_layer is not none or null, the self.norm = norm_layer(embed_dim)
78
+ # otherwise, self.norm = nn.Identity()
79
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
80
+
81
+
82
+ def forward(self, x):
83
+ B, C, H, W = x.shape
84
+ assert H == self.img_size[0] and W == self.img_size[1], \
85
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
86
+
87
+ x = self.proj(x)
88
+ print("This is the shape after the CNN", x.shape)
89
+
90
+ if self.flatten:
91
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
92
+ print("This is the shape after the flatten:", x.shape)
93
+
94
+ x = self.norm(x)
95
+ return x
96
+
97
+ class PatchEmbed_spectrogram(nn.Module):
98
+ """ 2D spectrogram to Patch Embedding
99
+ """
100
+ def __init__(self, img_size_f = 128, img_size_t = 64, patch_size=6, stride=3, in_chans=12, embed_dim=432, flatten=True):
101
+ super().__init__()
102
+ # img_size = to_2tuple(img_size)
103
+ patch_size = to_2tuple(patch_size)
104
+
105
+ self.img_size_f = img_size_f
106
+ self.img_size_t = img_size_t
107
+
108
+ self.patch_size = patch_size
109
+ self.grid_size = ((img_size_f - patch_size[0]) // stride + 1, (img_size_t - patch_size[1]) // stride + 1)
110
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
111
+ self.flatten = flatten
112
+
113
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)
114
+ # if the norm_layer is not none or null, the self.norm = norm_layer(embed_dim)
115
+ # otherwise, self.norm = nn.Identity()
116
+ # self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
117
+
118
+
119
+ def forward(self, x):
120
+ B, C, H, W = x.shape
121
+ assert H == self.img_size_f and W == self.img_size_t, \
122
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size_f}*{self.img_size_t})."
123
+
124
+ x = self.proj(x)
125
+
126
+ # This is the shape after the CNN torch.Size([1, 432, 41, 20])
127
+ # print("This is the shape after the CNN", x.shape)
128
+ if self.flatten:
129
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
130
+ # This is the shape after the flatten: torch.Size([1, 820, 432])
131
+ # print("This is the shape after the flatten:", x.shape)
132
+ return x
133
+
134
+
135
+ class CNN_layers(nn.Module):
136
+
137
+ def __init__(self, embed_size = 384):
138
+ super().__init__()
139
+ self.multiple_cnn = nn.Sequential(
140
+ nn.Conv1d(12, 128, kernel_size=14, stride=3, padding=2, bias=False),
141
+ nn.BatchNorm1d(128),
142
+ nn.ReLU(inplace=True),
143
+
144
+ nn.Conv1d(128, embed_size, kernel_size=15, stride=4, padding=100, bias=False),
145
+ nn.BatchNorm1d(embed_size),
146
+ nn.ReLU(inplace=True))
147
+
148
+ def forward(self, x):
149
+ # print("This is the shape of enconder:(before)", x.shape)
150
+ # This is the shape of enconder:(before) torch.Size([44, 12, 8192])
151
+ x = self.multiple_cnn(x)
152
+ x = x.transpose(1, 2)
153
+ # print("This is the shape of enconder:(after)", x.shape)
154
+ # This is the shape of enconder:(after) torch.Size([44, 729, 384])
155
+
156
+ # sys.exit()
157
+ return x
158
+
159
+ class CNN_layers_shortcut(nn.Module):
160
+
161
+ def __init__(self, embed_size = 384):
162
+ super().__init__()
163
+ self.multiple_cnn1 = nn.Sequential(
164
+ nn.Conv1d(12, 128, kernel_size=14, stride=3, padding=2, bias=False),
165
+ nn.BatchNorm1d(128),
166
+ nn.ReLU(inplace=True),
167
+
168
+ nn.Conv1d(128, embed_size, kernel_size=15, stride=4, padding=100, bias=False),
169
+ nn.BatchNorm1d(embed_size),
170
+ nn.ReLU(inplace=True))
171
+
172
+ self.multiple_cnn2 = nn.Sequential(
173
+ nn.Conv1d(in_channels=embed_size, out_channels=embed_size, kernel_size=3, stride=1, padding=1),
174
+ nn.BatchNorm1d(embed_size),
175
+ nn.ReLU(inplace=True),
176
+
177
+ nn.Conv1d(in_channels=embed_size, out_channels=embed_size, kernel_size=3, stride=1, padding=1),
178
+ nn.BatchNorm1d(embed_size),
179
+ nn.ReLU(inplace=True)
180
+
181
+ )
182
+
183
+ def forward(self, x):
184
+ # print("This is the shape of enconder:(before)", x.shape)
185
+ # This is the shape of enconder:(before) torch.Size([44, 12, 8192])
186
+ x = self.multiple_cnn1(x)
187
+
188
+ shortcut = x
189
+
190
+ x = self.multiple_cnn2(x)
191
+ x = x + shortcut
192
+
193
+ x = x.transpose(1, 2)
194
+ # print("This is the shape of enconder:(after)", x.shape)
195
+ # This is the shape of enconder:(after) torch.Size([44, 729, 384])
196
+ # sys.exit()
197
+ return x
198
+
199
+ # class ECG_Patch_embedding(nn.Module):
200
+
201
+ # def __init__(self, embed_size = 384):
202
+ # super().__init__()
203
+ # self.multiple_cnn = nn.Sequential(
204
+ # nn.Conv1d(12, embed_size, kernel_size=16, stride=8),
205
+ # )
206
+
207
+ # def forward(self, x):
208
+ # # print("This is the shape of enconder:(before)", x.shape)
209
+ # # This is the shape of enconder:(before) torch.Size([44, 12, 8192])
210
+ # x = self.multiple_cnn(x)
211
+ # x = x.transpose(1, 2)
212
+ # # print("This is the shape of enconder:(after)", x.shape)
213
+ # # This is the shape of enconder:(after) torch.Size([44, 1023, 384])
214
+
215
+ # # sys.exit()
216
+ # return x
217
+
218
+ # class LeadCombiner(nn.Module):
219
+ # def __init__(self, lead, out_ch):
220
+ # super(LeadCombiner, self).__init__()
221
+ # self.conv2_1 = nn.Conv1d(in_channels=lead * out_ch,
222
+ # out_channels=out_ch,
223
+ # kernel_size=1,
224
+ # bias=False)
225
+ # self.bn2_1 = nn.BatchNorm1d(out_ch)
226
+
227
+ # self.conv2_2 = nn.Conv1d(in_channels=lead * out_ch,
228
+ # out_channels=out_ch,
229
+ # kernel_size=1,
230
+ # bias=False)
231
+ # self.bn2_2 = nn.BatchNorm1d(out_ch)
232
+
233
+ # self.pool1 = nn.AdaptiveMaxPool1d(output_size=1)
234
+ # self.pool2 = nn.AdaptiveMaxPool1d(output_size=1)
235
+
236
+ # def forward(self, x):
237
+ # # this is the shape of x: torch.Size([78, 128, 12, 128]
238
+ # x1 = rearrange(x, 'b c l t -> b (c l) t')
239
+ # x1 = functional.leaky_relu(self.bn2_1(self.conv2_1(x1)))
240
+
241
+ # x2 = rearrange(x, 'b c l t -> b (t l) c')
242
+ # x2 = functional.leaky_relu(self.bn2_2(self.conv2_2(x2)))
243
+
244
+ # x1 = functional.dropout(x1, p=0.5, training=self.training)
245
+ # x2 = functional.dropout(x2, p=0.5, training=self.training)
246
+
247
+ # x1 = self.pool1(x1).squeeze(2)
248
+ # x2 = self.pool2(x2).squeeze(2)
249
+
250
+ # x = torch.cat([x1, x2], dim=1)
251
+ # return x
252
+
253
+
254
+ # '''
255
+ # changed the stride and added one more layer
256
+ # '''
257
+ # class CNN_layers(nn.Module):
258
+
259
+ # def __init__(self, embed_size = 384):
260
+ # super().__init__()
261
+ # self.multiple_cnn = nn.Sequential(
262
+ # nn.Conv1d(12, 128, kernel_size=15, stride=2, padding=2, bias=False),
263
+ # nn.BatchNorm1d(128),
264
+ # nn.ReLU(inplace=True),
265
+
266
+ # nn.Conv1d(128, embed_size, kernel_size=15, stride=2, padding=100, bias=False),
267
+ # nn.BatchNorm1d(embed_size),
268
+ # nn.ReLU(inplace=True),
269
+
270
+ # nn.Conv1d(384, embed_size, kernel_size=15, stride=3, padding=100, bias=False),
271
+ # nn.BatchNorm1d(embed_size),
272
+ # nn.ReLU(inplace=True)
273
+ # )
274
+
275
+ # def forward(self, x):
276
+ # # print("This is the shape of enconder:(before)", x.shape)
277
+ # # This is the shape of enconder:(before) torch.Size([44, 12, 8192])
278
+ # x = self.multiple_cnn(x)
279
+ # x = x.transpose(1, 2)
280
+
281
+
282
+ # return x
283
+
284
+ # class D2_CNN_layers(nn.Module):
285
+ # def __init__(self, embed_size = 384):
286
+ # super().__init__()
287
+ # self.multiple_cnn = nn.Sequential(
288
+ # nn.Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 15), padding=(1, 7), stride=(1, 2), bias=False),
289
+ # nn.BatchNorm2d(128),
290
+ # nn.LeakyReLU(inplace=True),
291
+
292
+ # nn.Conv2d(in_channels=128, out_channels=384, kernel_size=(3, 15), padding=(1, 7), stride=(1, 2), bias=False),
293
+ # nn.BatchNorm2d(embed_size),
294
+ # nn.LeakyReLU(inplace=True),
295
+
296
+ # nn.Conv2d(in_channels=384, out_channels=384, kernel_size=(3, 15), padding=(1, 7), stride=(1, 30), bias=False),
297
+ # nn.BatchNorm2d(embed_size),
298
+ # nn.LeakyReLU(inplace=True),
299
+
300
+ # )
301
+
302
+ # def forward(self, x):
303
+ # x = x.unsqueeze(1)
304
+ # # print("This is the input shape:", x.shape)
305
+ # x = self.multiple_cnn(x)
306
+ # x = torch.flatten(x, start_dim=-2)
307
+ # x = x.transpose(1, 2)
308
+ # return x
309
+
310
+ def broadcat(tensors, dim=-1):
311
+ num_tensors = len(tensors)
312
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
313
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
314
+ shape_len = list(shape_lens)[0]
315
+ dim = (dim + shape_len) if dim < 0 else dim
316
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
317
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
318
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatenation'
319
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
320
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
321
+ expanded_dims.insert(dim, (dim, dims[dim]))
322
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
323
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
324
+ return torch.cat(tensors, dim=dim)
325
+
326
+ def rotate_half(x):
327
+ x = rearrange(x, '... (d r) -> ... d r', r=2)
328
+ x1, x2 = x.unbind(dim=-1)
329
+ x = torch.stack((-x2, x1), dim=-1)
330
+ return rearrange(x, '... d r -> ... (d r)')
331
+
332
+ # Adapted Rotary Embedding for 1D time-series (ECG) with 1024 tokens
333
+ class TimeSeriesRotaryEmbeddingFast(nn.Module):
334
+ def __init__(
335
+ self,
336
+ dim,
337
+ seq_len=1024, # Updated to 1024 tokens
338
+ custom_freqs=None,
339
+ freqs_for='lang', # Suitable for 1D sequential data
340
+ theta=10000,
341
+ max_freq=10,
342
+ num_freqs=1,
343
+ ):
344
+ super().__init__()
345
+ if custom_freqs:
346
+ freqs = custom_freqs
347
+ elif freqs_for == 'lang':
348
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
349
+ elif freqs_for == 'pixel':
350
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
351
+ elif freqs_for == 'constant':
352
+ freqs = torch.ones(num_freqs).float()
353
+ else:
354
+ raise ValueError(f'unknown modality {freqs_for}')
355
+
356
+ # 1D sequence for 1024 tokens
357
+ t = torch.arange(seq_len).float() # [0, 1, ..., 1023]
358
+
359
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
360
+ freqs = repeat(freqs, '... n -> ... (n r)', r=2) # Doubles the dimension for rotation
361
+
362
+ # Shape: (1024, dim)
363
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
364
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
365
+
366
+ self.register_buffer("freqs_cos", freqs_cos) # Shape: (1024, dim)
367
+ self.register_buffer("freqs_sin", freqs_sin) # Shape: (1024, dim)
368
+
369
+ print('======== shape of rope freq', self.freqs_cos.shape, '========')
370
+
371
+ def forward(self, t):
372
+ # t: (batch_size, seq_len, embed_dim)
373
+ # Apply rotation to the entire sequence
374
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
375
+
376
+
377
+
378
+ '''
379
+ dim: 384
380
+ mixer_cls: mixer_cla is an instance of mamba.
381
+ drop_path = 0.
382
+ norm_cls = nn.LayerNorm
383
+ fused_add_norm = True,
384
+ residual_in_fp32 = True
385
+ '''
386
+ class Block(nn.Module):
387
+ def __init__(
388
+ self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False, drop_path=0.,
389
+ ):
390
+ """
391
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
392
+
393
+ This Block has a slightly different structure compared to a regular
394
+ prenorm Transformer block.
395
+ The standard block is: LN -> MHA/MLP -> Add.
396
+ [Ref: https://arxiv.org/abs/2002.04745]
397
+ Here we have: Add -> LN -> Mixer, returning both
398
+ the hidden_states (output of the mixer) and the residual.
399
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
400
+ The residual needs to be provided (except for the very first block).
401
+ """
402
+ super().__init__()
403
+
404
+ self.residual_in_fp32 = residual_in_fp32
405
+ self.fused_add_norm = fused_add_norm
406
+ self.mixer = mixer_cls(dim)
407
+ self.norm = norm_cls(dim)
408
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
409
+
410
+ # fused_add_norm true
411
+ if self.fused_add_norm:
412
+ assert RMSNorm is not None, "RMSNorm import fails"
413
+ assert isinstance(self.norm, (nn.LayerNorm, RMSNorm)), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
414
+
415
+ def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):
416
+
417
+ r"""Pass the input through the encoder layer.
418
+
419
+ Args:
420
+ hidden_states: the sequence to the encoder layer (required).
421
+ residual: hidden_states = Mixer(LN(residual))
422
+ """
423
+
424
+ if not self.fused_add_norm:
425
+ if residual is None:
426
+ residual = hidden_states
427
+ else:
428
+ residual = residual + self.drop_path(hidden_states)
429
+
430
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
431
+ if self.residual_in_fp32:
432
+ residual = residual.to(torch.float32)
433
+
434
+ # since the self.fused_add_norm is true, the code is going below
435
+ # fused_add_norm_fn = layer_norm_fn
436
+ ###########
437
+ # hidden_states: Tensor
438
+ # self.norm.weight = torch.nn.LayerNorm.weight
439
+ # self.norm.bias = torch.nn.LayerNorm.bias
440
+ # residual: Optional[Tensor] = None
441
+ # prenorm=True
442
+ # residual_in_fp32 = True
443
+ # eps = torch.nn.LayerNorm.eps
444
+
445
+ else:
446
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
447
+ if residual is None:
448
+ hidden_states, residual = fused_add_norm_fn(
449
+ hidden_states,
450
+ self.norm.weight,
451
+ self.norm.bias,
452
+ residual=residual,
453
+ prenorm=True,
454
+ residual_in_fp32=self.residual_in_fp32,
455
+ eps=self.norm.eps,
456
+ )
457
+ else:
458
+ hidden_states, residual = fused_add_norm_fn(
459
+ self.drop_path(hidden_states),
460
+ self.norm.weight,
461
+ self.norm.bias,
462
+ residual=residual,
463
+ prenorm=True,
464
+ residual_in_fp32=self.residual_in_fp32,
465
+ eps=self.norm.eps,
466
+ )
467
+
468
+ # inference_params=None
469
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
470
+ return hidden_states, residual
471
+
472
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
473
+ print("the code is going through allocate_inference_cache in the block")
474
+
475
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
476
+
477
+ '''
478
+ from torch.nn.ModuleList()
479
+
480
+ embed_dim=384, (embedding dimension)
481
+ device: None
482
+ dtype: None
483
+ ssm_cfg: None
484
+ norm_epsilon: float = 1e-5
485
+ rms_norm: bool = False
486
+ residual_in_fp32 = True
487
+ fused_add_norm = True
488
+ if_bimamba = False
489
+ bimamba_type = "v2"
490
+ inter_dpr: [0.0, 0.0, 0.004347826354205608, ...,0.09565217792987823, 0.10000000149011612]
491
+ if_devide_out = True
492
+ init_layer_scale = None
493
+ layer_idx = i
494
+ '''
495
+
496
+ def create_block(
497
+ d_model,
498
+ ssm_cfg=None,
499
+ norm_epsilon=1e-5,
500
+ drop_path=0.,
501
+ rms_norm=False,
502
+ residual_in_fp32=False,
503
+ fused_add_norm=False,
504
+ layer_idx=None,
505
+ device=None,
506
+ dtype=None,
507
+ if_bimamba=False,
508
+ bimamba_type="none",
509
+ if_devide_out=False,
510
+ init_layer_scale=None,
511
+ block_name = "default_value"
512
+ ):
513
+ if if_bimamba:
514
+ bimamba_type = "v1"
515
+
516
+ if ssm_cfg is None:
517
+ ssm_cfg = {}
518
+
519
+ factory_kwargs = {"device": device, "dtype": dtype}
520
+
521
+ if block_name == "VisionMamba":
522
+ mixer_cls = partial(Mamba, layer_idx=layer_idx, bimamba_type=bimamba_type, if_devide_out=if_devide_out, init_layer_scale=init_layer_scale, **ssm_cfg, **factory_kwargs)
523
+ elif block_name == "OriginalMamba":
524
+ mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
525
+ else:
526
+ raise ValueError(f"No matching condition for value: {block_name}")
527
+
528
+
529
+
530
+ # rms_norm = False
531
+ norm_cls = partial(nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs)
532
+
533
+ block = Block(
534
+ d_model,
535
+ mixer_cls,
536
+ norm_cls=norm_cls,
537
+ drop_path=drop_path,
538
+ fused_add_norm=fused_add_norm,
539
+ residual_in_fp32=residual_in_fp32,
540
+ )
541
+ block.layer_idx = layer_idx
542
+ return block
543
+
544
+
545
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
546
+ def _init_weights(
547
+ module,
548
+ n_layer,
549
+ initializer_range=0.02, # Now only used for embedding layer.
550
+ rescale_prenorm_residual=True,
551
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
552
+ ):
553
+ if isinstance(module, nn.Linear):
554
+ if module.bias is not None:
555
+ if not getattr(module.bias, "_no_reinit", False):
556
+ nn.init.zeros_(module.bias)
557
+ elif isinstance(module, nn.Embedding):
558
+ nn.init.normal_(module.weight, std=initializer_range)
559
+
560
+ if rescale_prenorm_residual:
561
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
562
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
563
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
564
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
565
+ #
566
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
567
+ for name, p in module.named_parameters():
568
+ if name in ["out_proj.weight", "fc2.weight"]:
569
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
570
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
571
+ # We need to reinit p since this code could be called multiple times
572
+ # Having just p *= scale would repeatedly scale it down
573
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
574
+ with torch.no_grad():
575
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
576
+
577
+
578
+ def segm_init_weights(m):
579
+ if isinstance(m, nn.Linear):
580
+ trunc_normal_(m.weight, std=0.02)
581
+ if isinstance(m, nn.Linear) and m.bias is not None:
582
+ nn.init.constant_(m.bias, 0)
583
+ elif isinstance(m, nn.Conv2d):
584
+ # NOTE conv was left to pytorch default in my original init
585
+ lecun_normal_(m.weight)
586
+ if m.bias is not None:
587
+ nn.init.zeros_(m.bias)
588
+ elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
589
+ nn.init.zeros_(m.bias)
590
+ nn.init.ones_(m.weight)
591
+
592
+ '''
593
+ below is the 'ft-vim-s.sh':
594
+
595
+ patch_size=16, (just patch size)
596
+ stride=8, (just stride)
597
+ embed_dim=384, (embedding dimension)
598
+ depth=24, (?) the number of block
599
+ rms_norm = True, (?)
600
+ residual_in_fp32 = True, (?)
601
+ fused_add_norm = True, (?)
602
+ final_pool_type = 'mean', (?)
603
+ if_abs_pos_embed = True, (?)
604
+ if_rope = False, (?)
605
+ if_rope_residual = False, (?)
606
+ bimamba_type = "v2", (?)
607
+ if_cls_token = True, (?)
608
+ if_devide_out = True, (?)
609
+ use_middle_cls_token = True,
610
+ **kwargs
611
+ '''
612
+
613
+ class VisionMamba(nn.Module):
614
+ def __init__(self,
615
+ img_size=224,
616
+ patch_size=16,
617
+ stride=16,
618
+ depth=24,
619
+ embed_dim=192,
620
+ channels=3,
621
+ num_classes=26,
622
+ ssm_cfg=None,
623
+ drop_rate=0.,
624
+ drop_path_rate=0,
625
+ norm_epsilon: float = 1e-5,
626
+ rms_norm: bool = False,
627
+ initializer_cfg=None,
628
+ fused_add_norm=False,
629
+ residual_in_fp32=False,
630
+ device=None,
631
+ dtype=None,
632
+ ft_seq_len=None,
633
+ pt_hw_seq_len=14,
634
+ if_bidirectional=False,
635
+ final_pool_type='none',
636
+ if_abs_pos_embed=False,
637
+ if_rope=False,
638
+ if_rope_residual=False,
639
+ flip_img_sequences_ratio=-1.,
640
+ if_bimamba=False,
641
+ bimamba_type="none",
642
+ if_cls_token=False,
643
+ if_devide_out=False,
644
+ init_layer_scale=None,
645
+ use_double_cls_token=False,
646
+ use_middle_cls_token=False,
647
+ **kwargs):
648
+ # print("The program is coming the init")
649
+ factory_kwargs = {"device": device, "dtype": dtype}
650
+ # factory_kwargs: {'device': None, 'dtype': None}
651
+ # add factory_kwargs into kwargs
652
+ block_name = kwargs.get('block', 'default_value')
653
+
654
+ kwargs.update(factory_kwargs)
655
+
656
+ super().__init__()
657
+ self.residual_in_fp32 = residual_in_fp32
658
+ self.fused_add_norm = fused_add_norm
659
+ self.if_bidirectional = if_bidirectional
660
+ self.final_pool_type = final_pool_type
661
+ self.if_abs_pos_embed = if_abs_pos_embed
662
+ self.if_rope = if_rope
663
+ self.if_rope_residual = if_rope_residual
664
+ self.flip_img_sequences_ratio = flip_img_sequences_ratio
665
+ self.if_cls_token = if_cls_token
666
+ self.use_double_cls_token = use_double_cls_token
667
+ self.use_middle_cls_token = use_middle_cls_token
668
+ self.num_tokens = 1 if if_cls_token else 0
669
+
670
+ # pretrain parameters
671
+ self.num_classes = num_classes
672
+ self.d_model = self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
673
+
674
+ # self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, stride=stride, in_chans=channels, embed_dim=embed_dim)
675
+ # num_patches = self.patch_embed.num_patches
676
+
677
+ self.CNN_layers = CNN_layers()
678
+ num_patches = 729
679
+
680
+ # self.ECG_patch_embedding = ECG_Patch_embedding()
681
+ # num_patches = 1023
682
+ # self.LC = LeadCombiner(lead=6, out_ch=8)
683
+
684
+ # self.CNN_layers = D2_CNN_layers()
685
+ # num_patches = 828
686
+
687
+ # self.CNN_layers = CNN_layers()
688
+ # num_patches = 775
689
+
690
+ # if_cls_token: True (in the original script)
691
+ if if_cls_token:
692
+ # use_double_cls_token: False (in the original script)
693
+ if use_double_cls_token:
694
+ self.cls_token_head = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
695
+ self.cls_token_tail = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
696
+ self.num_tokens = 2
697
+
698
+ else:
699
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
700
+ # self.num_tokens = 1
701
+
702
+ # if_abs_pos_embed: True (in the original script)
703
+ if if_abs_pos_embed:
704
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, self.embed_dim))
705
+ self.pos_drop = nn.Dropout(p=drop_rate)
706
+
707
+ # if_rope: False (in the original script)
708
+
709
+ if if_rope:
710
+ half_head_dim = embed_dim // 2
711
+ self.rope = TimeSeriesRotaryEmbeddingFast(dim=embed_dim, seq_len= num_patches + self.num_tokens)
712
+
713
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
714
+ # self.head_LC = nn.Linear(16, num_classes)
715
+ # depth: 24; drop_path_rate: 0.1
716
+ # TODO: release this comment
717
+ # dpr: [0.0, 0.004347826354205608, 0.008695652708411217, ..., 0.09130434691905975, 0.09565217792987823, 0.10000000149011612]
718
+
719
+ if drop_path_rate == 0:
720
+ print("This is the drop_path_rate:", drop_path_rate)
721
+ dpr = [x.item() for x in torch.full((depth,), drop_path_rate)]
722
+ else:
723
+ print("This is the drop_path_rate:", drop_path_rate)
724
+ print("follow the stochastic depth decay rule")
725
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
726
+
727
+
728
+ # import ipdb;ipdb.set_trace()
729
+ # inter_dpr: [0.0, 0.0, 0.004347826354205608, ...,0.09565217792987823, 0.10000000149011612]
730
+ inter_dpr = [0.0] + dpr
731
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
732
+
733
+ # transformer blocks
734
+ # depth:
735
+ self.layers = nn.ModuleList(
736
+ [
737
+ create_block(
738
+ embed_dim,
739
+ ssm_cfg=ssm_cfg,
740
+ norm_epsilon=norm_epsilon,
741
+ rms_norm=rms_norm,
742
+ residual_in_fp32=residual_in_fp32,
743
+ fused_add_norm=fused_add_norm,
744
+ layer_idx=i,
745
+ if_bimamba=if_bimamba,
746
+ bimamba_type=bimamba_type,
747
+ drop_path=inter_dpr[i],
748
+ if_devide_out=if_devide_out,
749
+ init_layer_scale=init_layer_scale,
750
+ block_name = block_name,
751
+ **factory_kwargs,
752
+ )
753
+ for i in range(depth)
754
+ ]
755
+ )
756
+
757
+ # output head
758
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(embed_dim, eps=norm_epsilon, **factory_kwargs)
759
+
760
+ # self.pre_logits = nn.Identity()
761
+
762
+ # original init
763
+ # self.patch_embed.apply(segm_init_weights)
764
+
765
+ self.CNN_layers.apply(segm_init_weights)
766
+
767
+ # self.ECG_patch_embedding.apply(segm_init_weights)
768
+
769
+ self.head.apply(segm_init_weights)
770
+
771
+ # self.head_LC.apply(segm_init_weights)
772
+ # self.LC.apply(segm_init_weights)
773
+
774
+
775
+ # if_abs_pos_embed: True (in the original script)
776
+ if if_abs_pos_embed:
777
+ trunc_normal_(self.pos_embed, std=.02)
778
+ # if_cls_token: True (in the original script)
779
+ if if_cls_token:
780
+ if use_double_cls_token:
781
+ trunc_normal_(self.cls_token_head, std=.02)
782
+ trunc_normal_(self.cls_token_tail, std=.02)
783
+ # the code is coming here
784
+ else:
785
+ trunc_normal_(self.cls_token, std=.02)
786
+
787
+ # mamba init
788
+ self.apply(partial(_init_weights, n_layer=depth, **(initializer_cfg if initializer_cfg is not None else {}),))
789
+
790
+
791
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
792
+ print("the code is going through allocate_inference_cache in the vision mamba")
793
+ return {
794
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
795
+ for i, layer in enumerate(self.layers)
796
+ }
797
+
798
+ @torch.jit.ignore
799
+ def no_weight_decay(self):
800
+ return {"pos_embed", "cls_token", "dist_token", "cls_token_head", "cls_token_tail"}
801
+
802
+ @torch.jit.ignore()
803
+ def load_pretrained(self, checkpoint_path, prefix=""):
804
+ _load_weights(self, checkpoint_path, prefix)
805
+
806
+ # x: this is the input 224*224(tensor)
807
+ # inference_params: None
808
+ # if_random_cls_token_position:False
809
+ # if_random_token_rank: False
810
+
811
+ def forward_features(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
812
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
813
+ # with slight modifications to add the dist_token
814
+
815
+ # x = self.patch_embed(x)
816
+
817
+ # x = self.CNN_layers(x)
818
+
819
+ x = self.CNN_layers(x)
820
+
821
+ # B: batch size 16
822
+ # M: 729 the number of patch,(27*27)
823
+ # D: the hidden state dimension, 384 (small-size variant), this is set by author of Vim, it is 768 in the convential Vit
824
+ # N: SSM dimension, SSM dimension N to 16.
825
+ # L: the number of blocks, we set the number of blocks L to 24
826
+
827
+ B, M, _ = x.shape
828
+
829
+ # if_cls_token: True (in the original script)
830
+ if self.if_cls_token:
831
+
832
+ # self.use_double_cls_token: False (in the original script)
833
+ if self.use_double_cls_token:
834
+ cls_token_head = self.cls_token_head.expand(B, -1, -1)
835
+ cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
836
+ token_position = [0, M + 1]
837
+ x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
838
+ M = x.shape[1]
839
+ else:
840
+ # self.use_middle_cls_token: True(in the original script)
841
+ if self.use_middle_cls_token:
842
+ cls_token = self.cls_token.expand(B, -1, -1)
843
+ token_position = M // 2
844
+ # add cls token in the middle
845
+ x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
846
+ elif if_random_cls_token_position:
847
+ cls_token = self.cls_token.expand(B, -1, -1)
848
+ token_position = random.randint(0, M)
849
+ x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
850
+ print("token_position: ", token_position)
851
+ else:
852
+ cls_token = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
853
+ token_position = 0
854
+ x = torch.cat((cls_token, x), dim=1)
855
+ M = x.shape[1]
856
+
857
+ # # if_abs_pos_embed: True (in the original script)
858
+ if self.if_abs_pos_embed:
859
+ # if new_grid_size[0] == self.patch_embed.grid_size[0] and new_grid_size[1] == self.patch_embed.grid_size[1]:
860
+ # x = x + self.pos_embed
861
+ # else:
862
+ # pos_embed = interpolate_pos_embed_online(
863
+ # self.pos_embed, self.patch_embed.grid_size, new_grid_size,0
864
+ # )
865
+ x = x + self.pos_embed
866
+ x = self.pos_drop(x)
867
+
868
+
869
+
870
+ if_flip_img_sequences = False
871
+ if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:
872
+ x = x.flip([1])
873
+ if_flip_img_sequences = True
874
+
875
+ # mamba impl
876
+ # if_bidirectional: false
877
+ # inference_params: None
878
+ residual = None
879
+ hidden_states = x
880
+ if not self.if_bidirectional:
881
+ for layer in self.layers:
882
+
883
+ # here is false in the original script
884
+ if if_flip_img_sequences and self.if_rope:
885
+ hidden_states = hidden_states.flip([1])
886
+ if residual is not None:
887
+ residual = residual.flip([1])
888
+
889
+ # rope about, defaule is false
890
+ if self.if_rope:
891
+ hidden_states = self.rope(hidden_states)
892
+ if residual is not None and self.if_rope_residual:
893
+ residual = self.rope(residual)
894
+
895
+ # here is false in the original script
896
+ if if_flip_img_sequences and self.if_rope:
897
+ hidden_states = hidden_states.flip([1])
898
+ if residual is not None:
899
+ residual = residual.flip([1])
900
+
901
+ hidden_states, residual = layer(hidden_states, residual, inference_params=inference_params)
902
+ # sys.exit()
903
+
904
+ else:
905
+ # get two layers in a single for-loop
906
+ for i in range(len(self.layers) // 2):
907
+ if self.if_rope:
908
+ hidden_states = self.rope(hidden_states)
909
+ if residual is not None and self.if_rope_residual:
910
+ residual = self.rope(residual)
911
+
912
+ hidden_states_f, residual_f = self.layers[i * 2](
913
+ hidden_states, residual, inference_params=inference_params
914
+ )
915
+ hidden_states_b, residual_b = self.layers[i * 2 + 1](
916
+ hidden_states.flip([1]), None if residual == None else residual.flip([1]), inference_params=inference_params
917
+ )
918
+ hidden_states = hidden_states_f + hidden_states_b.flip([1])
919
+ residual = residual_f + residual_b.flip([1])
920
+
921
+ # fused_add_norm: True
922
+
923
+ if not self.fused_add_norm:
924
+ if residual is None:
925
+ residual = hidden_states
926
+ else:
927
+ residual = residual + self.drop_path(hidden_states)
928
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
929
+ else:
930
+ # Set prenorm = False here since we don't need the residual
931
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
932
+
933
+ hidden_states = fused_add_norm_fn(
934
+ self.drop_path(hidden_states),
935
+ self.norm_f.weight,
936
+ self.norm_f.bias,
937
+ eps=self.norm_f.eps,
938
+ residual=residual,
939
+ prenorm=False,
940
+ residual_in_fp32=self.residual_in_fp32,
941
+ )
942
+
943
+ # return only cls token if it exists
944
+ # if_cls_token: True (in the original script)
945
+ # self.use_middle_cls_token: True
946
+ if self.if_cls_token:
947
+ if self.use_double_cls_token:
948
+ return (hidden_states[:, token_position[0], :] + hidden_states[:, token_position[1], :]) / 2
949
+ else:
950
+ if self.use_middle_cls_token:
951
+ return hidden_states[:, token_position, :]
952
+ elif if_random_cls_token_position:
953
+ return hidden_states[:, token_position, :]
954
+ else:
955
+ return hidden_states[:, token_position, :]
956
+
957
+ # self.final_pol_type = 'mean'
958
+ if self.final_pool_type == 'none':
959
+ return hidden_states[:, -1, :]
960
+ elif self.final_pool_type == 'mean':
961
+ return hidden_states.mean(dim=1)
962
+ elif self.final_pool_type == 'max':
963
+ return hidden_states
964
+ elif self.final_pool_type == 'all':
965
+ return hidden_states
966
+ else:
967
+ raise NotImplementedError
968
+
969
+ def forward(self, x, return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
970
+ x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
971
+ # batch_number = x.shape[0]
972
+ # x = x.view(batch_number, 8, 6, 8)
973
+ # x = self.LC(x)
974
+ # x = self.head_LC(x)
975
+ # print("This is the shape of X (after the fully connected layer):", x.shape)
976
+ # return_features = False
977
+ # print("This is the return feature:", return_features)
978
+ # if return_features:
979
+ # return x
980
+
981
+ # print("This is the shape of X (Before the fully connected layer):", x.shape)
982
+ x = self.head(x)
983
+
984
+ # final_pool_type = 'mean' in original script
985
+ if self.final_pool_type == 'max':
986
+ x = x.max(dim=1)[0]
987
+ # sys.exit()
988
+ return x
989
+
990
+
991
+
992
+ # below is for the vision in mamba
993
+ @register_model
994
+ def ecg_vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, depth=5, fused_add_norm = True, drop_path_rate = 0.1, if_divide_out = True, use_middle_cls_token = True, **kwargs):
995
+ model = VisionMamba(patch_size=16, stride=8, embed_dim=384, depth=depth, rms_norm=True, residual_in_fp32=True, drop_path_rate = drop_path_rate, fused_add_norm = fused_add_norm, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=if_divide_out, use_middle_cls_token=use_middle_cls_token, **kwargs)
996
+
997
+ # As a reminder:
998
+ print("This is whether the fused_add_norm:", fused_add_norm)
999
+ print("This is whether the if_divide_out:", if_divide_out)
1000
+ print("This is whether the use_middle_cls_token:", use_middle_cls_token)
1001
+
1002
+
1003
+ model.default_cfg = _cfg()
1004
+
1005
+ return model
1006
+
1007
+ # below is for the original mamba and 24 blocks
1008
+ # @register_model
1009
+ # def ecg_vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
1010
+ # model = VisionMamba(patch_size=16, stride=8, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
1011
+ # model.default_cfg = _cfg()
1012
+
1013
+ # return model
optimizer.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+
6
+ # from utils import d_model
7
+
8
+ '''
9
+ optimizer = NoamOpt(d_model, 1, 4000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
10
+ model_size: 256
11
+ factor: 1
12
+ warmup: 4000
13
+ optimizer: torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
14
+ '''
15
+ class NoamOpt:
16
+ "Optim wrapper that implements rate."
17
+ def __init__(self, model_size, factor, warmup, optimizer):
18
+ self.optimizer = optimizer
19
+ self._step = 0
20
+ self.warmup = warmup
21
+ self.factor = factor
22
+ self.model_size = model_size
23
+ self._rate = 0
24
+
25
+ def step(self):
26
+ "Update parameters and rate"
27
+ self._step += 1
28
+ rate = self.rate()
29
+ # print("This is the rate:", rate)
30
+ for p in self.optimizer.param_groups:
31
+ p['lr'] = rate
32
+ self._rate = rate
33
+ self.optimizer.step()
34
+
35
+ def rate(self, step = None):
36
+ "Implement `lrate` above"
37
+ if step is None:
38
+ step = self._step
39
+ return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))
40
+
41
+ def get_std_opt(model):
42
+ return NoamOpt(729, 2, 4000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
rope.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # EVA-02: A Visual Representation for Neon Genesis
3
+ # Github source: https://github.com/baaivision/EVA/EVA02
4
+ # Copyright (c) 2023 Beijing Academy of Artificial Intelligence (BAAI)
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # By Yuxin Fang
7
+ #
8
+ # Based on https://github.com/lucidrains/rotary-embedding-torch
9
+ # --------------------------------------------------------'
10
+
11
+ from math import pi
12
+
13
+ import torch
14
+ from torch import nn
15
+
16
+ from einops import rearrange, repeat
17
+
18
+
19
+
20
+ def broadcat(tensors, dim = -1):
21
+ num_tensors = len(tensors)
22
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
23
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
24
+ shape_len = list(shape_lens)[0]
25
+ dim = (dim + shape_len) if dim < 0 else dim
26
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
27
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
28
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
29
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
30
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
31
+ expanded_dims.insert(dim, (dim, dims[dim]))
32
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
33
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
34
+ return torch.cat(tensors, dim = dim)
35
+
36
+
37
+
38
+ def rotate_half(x):
39
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
40
+ x1, x2 = x.unbind(dim = -1)
41
+ x = torch.stack((-x2, x1), dim = -1)
42
+ return rearrange(x, '... d r -> ... (d r)')
43
+
44
+
45
+
46
+ class VisionRotaryEmbedding(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim,
50
+ pt_seq_len,
51
+ ft_seq_len=None,
52
+ custom_freqs = None,
53
+ freqs_for = 'lang',
54
+ theta = 10000,
55
+ max_freq = 10,
56
+ num_freqs = 1,
57
+ ):
58
+ super().__init__()
59
+ if custom_freqs:
60
+ freqs = custom_freqs
61
+ elif freqs_for == 'lang':
62
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
63
+ elif freqs_for == 'pixel':
64
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
65
+ elif freqs_for == 'constant':
66
+ freqs = torch.ones(num_freqs).float()
67
+ else:
68
+ raise ValueError(f'unknown modality {freqs_for}')
69
+
70
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
71
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
72
+
73
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
74
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
75
+
76
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
77
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
78
+
79
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
80
+
81
+ self.register_buffer("freqs_cos", freqs.cos())
82
+ self.register_buffer("freqs_sin", freqs.sin())
83
+
84
+ print('======== shape of rope freq', self.freqs_cos.shape, '========')
85
+
86
+ def forward(self, t, start_index = 0):
87
+ rot_dim = self.freqs_cos.shape[-1]
88
+ end_index = start_index + rot_dim
89
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
90
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
91
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
92
+ return torch.cat((t_left, t, t_right), dim = -1)
93
+
94
+
95
+
96
+ class VisionRotaryEmbeddingFast(nn.Module):
97
+ def __init__(
98
+ self,
99
+ dim,
100
+ pt_seq_len=16,
101
+ ft_seq_len=None,
102
+ custom_freqs = None,
103
+ freqs_for = 'lang',
104
+ theta = 10000,
105
+ max_freq = 10,
106
+ num_freqs = 1,
107
+ ):
108
+ super().__init__()
109
+ if custom_freqs:
110
+ freqs = custom_freqs
111
+ elif freqs_for == 'lang':
112
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
113
+ elif freqs_for == 'pixel':
114
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
115
+ elif freqs_for == 'constant':
116
+ freqs = torch.ones(num_freqs).float()
117
+ else:
118
+ raise ValueError(f'unknown modality {freqs_for}')
119
+
120
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
121
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
122
+
123
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
124
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
125
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
126
+
127
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
128
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
129
+
130
+ self.register_buffer("freqs_cos", freqs_cos)
131
+ self.register_buffer("freqs_sin", freqs_sin)
132
+
133
+ print('======== shape of rope freq', self.freqs_cos.shape, '========')
134
+
135
+ def forward(self, t):
136
+ if t.shape[1] % 2 != 0:
137
+ t_spatial = t[:, 1:, :]
138
+ t_spatial = t_spatial * self.freqs_cos + rotate_half(t_spatial) * self.freqs_sin
139
+ return torch.cat((t[:, :1, :], t_spatial), dim=1)
140
+ else:
141
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin