1+ import re
2+ import os
3+ import subprocess
4+ from statistics import mean
5+
6+ precision_sample_size = 4
7+ precisions = ['float16' , 'float32' ]
8+
9+ folder = os .path .dirname (os .path .dirname (__file__ ))
10+ script = os .path .join (folder , 'mnist' , 'batch_eth_mnist.py' )
11+ data = {}
12+ for precision in precisions :
13+ for _ in range (precision_sample_size ):
14+ result = subprocess .run (
15+ f"python { script } --n_train 100 --batch_size 50 --n_test 10 --n_updates 1 --w_dtype { precision } " ,
16+ shell = True , capture_output = True , text = True
17+ )
18+ output = result .stdout
19+ time_match = re .search (r'Progress: 1 / 1 \((\d+\.\d+) seconds\)' , output )
20+ memory_match = re .search (r'Memory consumption: (\d+)mb' , output )
21+ data .setdefault (precision , []).append ([
22+ time_match .groups ()[0 ],
23+ memory_match .groups ()[0 ]
24+ ])
25+ print ("+" )
26+
27+
28+ def print_table (data ):
29+ column_widths = [max (len (str (item )) for item in col ) for col in zip (* data )]
30+ for row in data :
31+ formatted_row = " | " .join (f"{ str (item ):<{column_widths [i ]}} " for i , item in enumerate (row ))
32+ print (formatted_row )
33+
34+
35+ average_time = {}
36+ average_memory = {}
37+ for precision , rows in data .items ():
38+ print (f"precision: { precision } " )
39+ table = [
40+ ['Time (sec)' , 'GPU memory (Mb)' ]
41+ ] + rows
42+ avg_time = mean (map (lambda i : float (i [0 ]), rows ))
43+ avg_memory = mean (map (lambda i : float (i [1 ]), rows ))
44+ print_table (table )
45+ print (f"Average time: { avg_time } " )
46+ print (f"Average memory: { avg_memory } " )
47+ average_memory [precision ] = avg_memory
48+ average_time [precision ] = avg_time
49+ print ('' )
0 commit comments