@@ -745,18 +745,24 @@ def test_validate_good_run() -> None:
745745 draws_pd = fit .draws_pd ()
746746 assert draws_pd .shape == (
747747 fit .runset .chains * fit .num_draws_sampling ,
748- len (fit .column_names ),
748+ len (fit .column_names ) + 3 ,
749749 )
750- assert fit .draws_pd (vars = ['theta' ]).shape == (400 , 1 )
751- assert fit .draws_pd (vars = ['lp__' , 'theta' ]).shape == (400 , 2 )
752- assert fit .draws_pd (vars = ['theta' , 'lp__' ]).shape == (400 , 2 )
753- assert fit .draws_pd (vars = 'theta' ).shape == (400 , 1 )
750+ assert fit .draws_pd (vars = ['theta' ]).shape == (400 , 4 )
751+ assert fit .draws_pd (vars = ['lp__' , 'theta' ]).shape == (400 , 5 )
752+ assert fit .draws_pd (vars = ['theta' , 'lp__' ]).shape == (400 , 5 )
753+ assert fit .draws_pd (vars = 'theta' ).shape == (400 , 4 )
754754
755755 assert list (fit .draws_pd (vars = ['theta' , 'lp__' ]).columns ) == [
756+ 'chain__' ,
757+ 'iter__' ,
758+ 'draw__' ,
756759 'theta' ,
757760 'lp__' ,
758761 ]
759762 assert list (fit .draws_pd (vars = ['lp__' , 'theta' ]).columns ) == [
763+ 'chain__' ,
764+ 'iter__' ,
765+ 'draw__' ,
760766 'lp__' ,
761767 'theta' ,
762768 ]
@@ -817,7 +823,7 @@ def test_validate_big_run() -> None:
817823 assert fit .step_size .shape == (2 ,)
818824 assert fit .metric .shape == (2 , 2095 )
819825 assert fit .draws ().shape == (1000 , 2 , 2102 )
820- assert fit .draws_pd (vars = ['phi' ]).shape == (2000 , 2095 )
826+ assert fit .draws_pd (vars = ['phi' ]).shape == (2000 , 2098 )
821827 with raises_nested (ValueError , r'Unknown variable: gamma' ):
822828 fit .draws_pd (vars = ['gamma' ])
823829
@@ -828,14 +834,14 @@ def test_instantiate_from_csvfiles() -> None:
828834 draws_pd = bern_fit .draws_pd ()
829835 assert draws_pd .shape == (
830836 bern_fit .runset .chains * bern_fit .num_draws_sampling ,
831- len (bern_fit .column_names ),
837+ len (bern_fit .column_names ) + 3 ,
832838 )
833839 csvfiles_path = os .path .join (DATAFILES_PATH , 'runset-big' )
834840 big_fit = from_csv (path = csvfiles_path )
835841 draws_pd = big_fit .draws_pd ()
836842 assert draws_pd .shape == (
837843 big_fit .runset .chains * big_fit .num_draws_sampling ,
838- len (big_fit .column_names ),
844+ len (big_fit .column_names ) + 3 ,
839845 )
840846 # list
841847 csvfiles_path = os .path .join (DATAFILES_PATH , 'runset-good' )
@@ -848,22 +854,41 @@ def test_instantiate_from_csvfiles() -> None:
848854 draws_pd = bern_fit .draws_pd ()
849855 assert draws_pd .shape == (
850856 bern_fit .runset .chains * bern_fit .num_draws_sampling ,
851- len (bern_fit .column_names ),
857+ len (bern_fit .column_names ) + 3 ,
852858 )
853859 # single csvfile
854860 bern_fit = from_csv (path = csvfiles [0 ])
855861 draws_pd = bern_fit .draws_pd ()
856862 assert draws_pd .shape == (
857863 bern_fit .num_draws_sampling ,
858- len (bern_fit .column_names ),
864+ len (bern_fit .column_names ) + 3 ,
859865 )
860866 # glob
861867 csvfiles_path = os .path .join (csvfiles_path , '*.csv' )
862868 big_fit = from_csv (path = csvfiles_path )
863869 draws_pd = big_fit .draws_pd ()
864870 assert draws_pd .shape == (
865871 big_fit .runset .chains * big_fit .num_draws_sampling ,
866- len (big_fit .column_names ),
872+ len (big_fit .column_names ) + 3 ,
873+ )
874+
875+
876+ def test_pd_xr_agreement ():
877+ csvfiles_path = os .path .join (DATAFILES_PATH , 'runset-good' , '*.csv' )
878+ bern_fit = from_csv (path = csvfiles_path )
879+
880+ draws_pd = bern_fit .draws_pd ()
881+ draws_xr = bern_fit .draws_xr ()
882+
883+ # check that the indexing is the same between the two
884+ np .testing .assert_equal (
885+ draws_pd [draws_pd ['chain__' ] == 2 ]['theta' ],
886+ draws_xr .theta .sel (chain = 2 ).values ,
887+ )
888+ # "draw" is 0-indexed in xarray, equiv. "iter__" is 1-indexed in pandas
889+ np .testing .assert_equal (
890+ draws_pd [draws_pd ['iter__' ] == 100 ]['theta' ],
891+ draws_xr .theta .sel (draw = 99 ).values ,
867892 )
868893
869894
@@ -930,7 +955,7 @@ def test_instantiate_from_csvfiles_fail(
930955def test_from_csv_fixed_param () -> None :
931956 csv_path = os .path .join (DATAFILES_PATH , 'fixed_param_sample.csv' )
932957 fixed_param_sample = from_csv (path = csv_path )
933- assert fixed_param_sample .draws_pd ().shape == (100 , 85 )
958+ assert fixed_param_sample .draws_pd ().shape == (100 , 88 )
934959
935960
936961def test_custom_metric () -> None :
@@ -1292,14 +1317,14 @@ def test_save_warmup() -> None:
12921317 len (BERNOULLI_COLS ),
12931318 )
12941319
1295- assert bern_fit .draws_pd ().shape == (200 , len (BERNOULLI_COLS ))
1320+ assert bern_fit .draws_pd ().shape == (200 , len (BERNOULLI_COLS ) + 3 )
12961321 assert bern_fit .draws_pd (inc_warmup = False ).shape == (
12971322 200 ,
1298- len (BERNOULLI_COLS ),
1323+ len (BERNOULLI_COLS ) + 3 ,
12991324 )
13001325 assert bern_fit .draws_pd (inc_warmup = True ).shape == (
13011326 600 ,
1302- len (BERNOULLI_COLS ),
1327+ len (BERNOULLI_COLS ) + 3 ,
13031328 )
13041329
13051330
@@ -1371,7 +1396,7 @@ def test_dont_save_warmup(caplog: pytest.LogCaptureFixture) -> None:
13711396 with caplog .at_level (logging .WARNING ):
13721397 assert bern_fit .draws_pd (inc_warmup = True ).shape == (
13731398 200 ,
1374- len (BERNOULLI_COLS ),
1399+ len (BERNOULLI_COLS ) + 3 ,
13751400 )
13761401 check_present (
13771402 caplog ,
0 commit comments