1212from io import StringIO
1313from multiprocessing import cpu_count
1414from pathlib import Path
15- from typing import Any , Callable , Dict , List , Mapping , Optional , Union
15+ from typing import Any , Callable , Dict , Iterable , List , Mapping , Optional , Union
1616
1717import ujson as json
1818from tqdm .auto import tqdm
@@ -300,7 +300,7 @@ def src_info(self) -> Dict[str, Any]:
300300 def format_model (
301301 self ,
302302 save : bool = False ,
303- canonicalize : Union [bool , str , List [str ]] = False ,
303+ canonicalize : Union [bool , str , Iterable [str ]] = False ,
304304 * ,
305305 unsafe : bool = False ,
306306 ) -> None :
@@ -330,10 +330,10 @@ def format_model(
330330 )
331331
332332 if canonicalize :
333- if isinstance (canonicalize , list ):
334- cmd .append ('--canonicalize=' + ',' .join (canonicalize ))
335- elif isinstance (canonicalize , str ):
333+ if isinstance (canonicalize , str ):
336334 cmd .append ('--canonicalize=' + canonicalize )
335+ elif isinstance (canonicalize , Iterable ):
336+ cmd .append ('--canonicalize=' + ',' .join (canonicalize ))
337337 else :
338338 cmd .append ('--print-canonical' )
339339
@@ -350,8 +350,8 @@ def format_model(
350350 if result :
351351 if not unsafe :
352352 shutil .copyfile (self .stan_file , self .stan_file + '.bak' )
353- with (open (self .stan_file , 'w' )) as file :
354- file .write (result )
353+ with (open (self .stan_file , 'w' )) as file_handle :
354+ file_handle .write (result )
355355 else :
356356 print (result )
357357
0 commit comments