@@ -778,24 +778,19 @@ def test_validate_good_run() -> None:
778778 fit .runset .chains * fit .num_draws_sampling ,
779779 len (fit .column_names ) + 3 ,
780780 )
781- assert fit .draws_pd (vars = ['theta' ]).shape == (400 , 4 )
782- assert fit .draws_pd (vars = ['lp__' , 'theta' ]).shape == (400 , 5 )
783- assert fit .draws_pd (vars = ['theta' , 'lp__' ]).shape == (400 , 5 )
784- assert fit .draws_pd (vars = 'theta' ).shape == (400 , 4 )
781+ assert fit .draws_pd (vars = ['theta' ]).shape == (400 , 1 )
782+ assert fit .draws_pd (vars = ['lp__' , 'theta' ]).shape == (400 , 2 )
783+ assert fit .draws_pd (vars = ['theta' , 'lp__' ]).shape == (400 , 2 )
784+ assert fit .draws_pd (vars = 'theta' ).shape == (400 , 1 )
785785
786786 assert list (fit .draws_pd (vars = ['theta' , 'lp__' ]).columns ) == [
787- 'chain__' ,
788- 'iter__' ,
789- 'draw__' ,
790787 'theta' ,
791788 'lp__' ,
792789 ]
793- assert list (fit .draws_pd (vars = ['lp__' , 'theta' ]).columns ) == [
794- 'chain__' ,
795- 'iter__' ,
796- 'draw__' ,
790+ assert list (fit .draws_pd (vars = ['lp__' , 'theta' , 'iter__' ]).columns ) == [
797791 'lp__' ,
798792 'theta' ,
793+ 'iter__' ,
799794 ]
800795
801796 summary = fit .summary ()
@@ -854,7 +849,7 @@ def test_validate_big_run() -> None:
854849 assert fit .step_size .shape == (2 ,)
855850 assert fit .metric .shape == (2 , 2095 )
856851 assert fit .draws ().shape == (1000 , 2 , 2102 )
857- assert fit .draws_pd (vars = ['phi' ]).shape == (2000 , 2098 )
852+ assert fit .draws_pd (vars = ['phi' ]).shape == (2000 , 2095 )
858853 with raises_nested (ValueError , r'Unknown variable: gamma' ):
859854 fit .draws_pd (vars = ['gamma' ])
860855
0 commit comments