-
Notifications
You must be signed in to change notification settings - Fork 734
Expand file tree
/
Copy pathtest_ptbxl_mi_classification.py
More file actions
63 lines (51 loc) · 1.92 KB
/
test_ptbxl_mi_classification.py
File metadata and controls
63 lines (51 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import tempfile
import unittest
from unittest.mock import patch
import numpy as np
import pandas as pd
import polars as pl
from pyhealth.tasks.ptbxl_mi_classification import PTBXLMIClassificationTask
from pyhealth.data import Patient
class TestPTBXLTask(unittest.TestCase):
@patch.object(PTBXLMIClassificationTask, "_load_ecg_signal")
def test_mi_label_extraction(self, mock_load_signal):
mock_load_signal.return_value = np.zeros((12, 1000), dtype=np.float32)
with tempfile.TemporaryDirectory() as tmpdir:
scp_path = os.path.join(tmpdir, "scp_statements.csv")
# minimal synthetic SCP mapping
scp_df = pd.DataFrame(
{
"diagnostic_class": ["MI", "NORM"],
},
index=["IMI", "NORM"],
)
scp_df.to_csv(scp_path)
df = pd.DataFrame(
{
"patient_id": ["1", "1"],
"event_type": ["ptbxl", "ptbxl"],
"timestamp": [None, None],
"ptbxl/ecg_id": [100, 101],
"ptbxl/record_path": [
"records100/00000/00001_lr",
"records100/00000/00002_lr",
],
"ptbxl/scp_codes": [
"{'IMI': 1}",
"{'NORM': 1}",
],
}
)
patient = Patient(
patient_id="1",
data_source=pl.from_pandas(df),
)
task = PTBXLMIClassificationTask(root=tmpdir)
samples = task(patient)
self.assertEqual(len(samples), 2)
self.assertEqual(samples[0]["label"], 1)
self.assertEqual(samples[1]["label"], 0)
self.assertEqual(np.array(samples[0]["signal"]).shape, (12, 1000))
if __name__ == "__main__":
unittest.main()