@@ -745,18 +745,24 @@ def test_validate_good_run() -> None:
745745 draws_pd = fit .draws_pd ()
746746 assert draws_pd .shape == (
747747 fit .runset .chains * fit .num_draws_sampling ,
748- len (fit .column_names ),
748+ len (fit .column_names ) + 3 ,
749749 )
750- assert fit .draws_pd (vars = ['theta' ]).shape == (400 , 1 )
751- assert fit .draws_pd (vars = ['lp__' , 'theta' ]).shape == (400 , 2 )
752- assert fit .draws_pd (vars = ['theta' , 'lp__' ]).shape == (400 , 2 )
753- assert fit .draws_pd (vars = 'theta' ).shape == (400 , 1 )
750+ assert fit .draws_pd (vars = ['theta' ]).shape == (400 , 4 )
751+ assert fit .draws_pd (vars = ['lp__' , 'theta' ]).shape == (400 , 5 )
752+ assert fit .draws_pd (vars = ['theta' , 'lp__' ]).shape == (400 , 5 )
753+ assert fit .draws_pd (vars = 'theta' ).shape == (400 , 4 )
754754
755755 assert list (fit .draws_pd (vars = ['theta' , 'lp__' ]).columns ) == [
756+ 'chain__' ,
757+ 'iter__' ,
758+ 'draw__' ,
756759 'theta' ,
757760 'lp__' ,
758761 ]
759762 assert list (fit .draws_pd (vars = ['lp__' , 'theta' ]).columns ) == [
763+ 'chain__' ,
764+ 'iter__' ,
765+ 'draw__' ,
760766 'lp__' ,
761767 'theta' ,
762768 ]
@@ -817,7 +823,7 @@ def test_validate_big_run() -> None:
817823 assert fit .step_size .shape == (2 ,)
818824 assert fit .metric .shape == (2 , 2095 )
819825 assert fit .draws ().shape == (1000 , 2 , 2102 )
820- assert fit .draws_pd (vars = ['phi' ]).shape == (2000 , 2095 )
826+ assert fit .draws_pd (vars = ['phi' ]).shape == (2000 , 2098 )
821827 with raises_nested (ValueError , r'Unknown variable: gamma' ):
822828 fit .draws_pd (vars = ['gamma' ])
823829
@@ -828,14 +834,14 @@ def test_instantiate_from_csvfiles() -> None:
828834 draws_pd = bern_fit .draws_pd ()
829835 assert draws_pd .shape == (
830836 bern_fit .runset .chains * bern_fit .num_draws_sampling ,
831- len (bern_fit .column_names ),
837+ len (bern_fit .column_names ) + 3 ,
832838 )
833839 csvfiles_path = os .path .join (DATAFILES_PATH , 'runset-big' )
834840 big_fit = from_csv (path = csvfiles_path )
835841 draws_pd = big_fit .draws_pd ()
836842 assert draws_pd .shape == (
837843 big_fit .runset .chains * big_fit .num_draws_sampling ,
838- len (big_fit .column_names ),
844+ len (big_fit .column_names ) + 3 ,
839845 )
840846 # list
841847 csvfiles_path = os .path .join (DATAFILES_PATH , 'runset-good' )
@@ -848,22 +854,22 @@ def test_instantiate_from_csvfiles() -> None:
848854 draws_pd = bern_fit .draws_pd ()
849855 assert draws_pd .shape == (
850856 bern_fit .runset .chains * bern_fit .num_draws_sampling ,
851- len (bern_fit .column_names ),
857+ len (bern_fit .column_names ) + 3 ,
852858 )
853859 # single csvfile
854860 bern_fit = from_csv (path = csvfiles [0 ])
855861 draws_pd = bern_fit .draws_pd ()
856862 assert draws_pd .shape == (
857863 bern_fit .num_draws_sampling ,
858- len (bern_fit .column_names ),
864+ len (bern_fit .column_names ) + 3 ,
859865 )
860866 # glob
861867 csvfiles_path = os .path .join (csvfiles_path , '*.csv' )
862868 big_fit = from_csv (path = csvfiles_path )
863869 draws_pd = big_fit .draws_pd ()
864870 assert draws_pd .shape == (
865871 big_fit .runset .chains * big_fit .num_draws_sampling ,
866- len (big_fit .column_names ),
872+ len (big_fit .column_names ) + 3 ,
867873 )
868874
869875
@@ -930,7 +936,7 @@ def test_instantiate_from_csvfiles_fail(
930936def test_from_csv_fixed_param () -> None :
931937 csv_path = os .path .join (DATAFILES_PATH , 'fixed_param_sample.csv' )
932938 fixed_param_sample = from_csv (path = csv_path )
933- assert fixed_param_sample .draws_pd ().shape == (100 , 85 )
939+ assert fixed_param_sample .draws_pd ().shape == (100 , 88 )
934940
935941
936942def test_custom_metric () -> None :
@@ -1292,14 +1298,14 @@ def test_save_warmup() -> None:
12921298 len (BERNOULLI_COLS ),
12931299 )
12941300
1295- assert bern_fit .draws_pd ().shape == (200 , len (BERNOULLI_COLS ))
1301+ assert bern_fit .draws_pd ().shape == (200 , len (BERNOULLI_COLS ) + 3 )
12961302 assert bern_fit .draws_pd (inc_warmup = False ).shape == (
12971303 200 ,
1298- len (BERNOULLI_COLS ),
1304+ len (BERNOULLI_COLS ) + 3 ,
12991305 )
13001306 assert bern_fit .draws_pd (inc_warmup = True ).shape == (
13011307 600 ,
1302- len (BERNOULLI_COLS ),
1308+ len (BERNOULLI_COLS ) + 3 ,
13031309 )
13041310
13051311
@@ -1371,7 +1377,7 @@ def test_dont_save_warmup(caplog: pytest.LogCaptureFixture) -> None:
13711377 with caplog .at_level (logging .WARNING ):
13721378 assert bern_fit .draws_pd (inc_warmup = True ).shape == (
13731379 200 ,
1374- len (BERNOULLI_COLS ),
1380+ len (BERNOULLI_COLS ) + 3 ,
13751381 )
13761382 check_present (
13771383 caplog ,
0 commit comments