-
Notifications
You must be signed in to change notification settings - Fork 734
Expand file tree
/
Copy pathtest_ptbxl_dataset.py
More file actions
36 lines (27 loc) · 1.24 KB
/
test_ptbxl_dataset.py
File metadata and controls
36 lines (27 loc) · 1.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import tempfile
import unittest
from pyhealth.datasets import PTBXLDataset
class TestPTBXLDataset(unittest.TestCase):
def test_load_data_dev_mode(self):
with tempfile.TemporaryDirectory() as tmpdir:
csv_path = os.path.join(tmpdir, "ptbxl_database.csv")
with open(csv_path, "w") as f:
f.write("ecg_id,patient_id,filename_lr,filename_hr,scp_codes\n")
f.write('1,100,records100/00000/00001_lr,records500/00000/00001_hr,"{\'MI\': 1}"\n')
f.write('2,101,records100/00000/00002_lr,records500/00000/00002_hr,"{\'NORM\': 1}"\n')
dataset = PTBXLDataset(
root=tmpdir,
dev=True,
)
df = dataset.load_data().compute()
self.assertEqual(len(df), 2)
self.assertIn("patient_id", df.columns)
self.assertIn("event_type", df.columns)
self.assertIn("ptbxl/ecg_id", df.columns)
self.assertIn("ptbxl/record_path", df.columns)
self.assertIn("ptbxl/scp_codes", df.columns)
self.assertEqual(str(df.iloc[0]["patient_id"]), "100")
self.assertEqual(df.iloc[0]["event_type"], "ptbxl")
if __name__ == "__main__":
unittest.main()