@@ -89,6 +89,45 @@ def test_from_csv_files(caplog: pytest.LogCaptureFixture) -> None:
8989 )
9090
9191
92+ def test_pd_xr_agreement ():
93+ # fitted_params sample - list of filenames
94+ goodfiles_path = os .path .join (DATAFILES_PATH , 'runset-good' , 'bern' )
95+ csv_files = []
96+ for i in range (4 ):
97+ csv_files .append ('{}-{}.csv' .format (goodfiles_path , i + 1 ))
98+
99+ # gq_model
100+ stan = os .path .join (DATAFILES_PATH , 'bernoulli_ppc.stan' )
101+ model = CmdStanModel (stan_file = stan )
102+ jdata = os .path .join (DATAFILES_PATH , 'bernoulli.data.json' )
103+
104+ bern_gqs = model .generate_quantities (data = jdata , previous_fit = csv_files )
105+
106+ draws_pd = bern_gqs .draws_pd (inc_sample = True )
107+ draws_xr = bern_gqs .draws_xr (inc_sample = True )
108+
109+ # check that the indexing is the same between the two
110+ np .testing .assert_equal (
111+ draws_pd [draws_pd ['chain__' ] == 2 ]['y_rep[1]' ],
112+ draws_xr .y_rep .sel (chain = 2 ).isel (y_rep_dim_0 = 0 ).values ,
113+ )
114+ # "draw" is 0-indexed in xarray, equiv. "iter__" is 1-indexed in pandas
115+ np .testing .assert_equal (
116+ draws_pd [draws_pd ['iter__' ] == 100 ]['y_rep[1]' ],
117+ draws_xr .y_rep .sel (draw = 99 ).isel (y_rep_dim_0 = 0 ).values ,
118+ )
119+
120+ # check for included sample as well
121+ np .testing .assert_equal (
122+ draws_pd [draws_pd ['chain__' ] == 2 ]['theta' ],
123+ draws_xr .theta .sel (chain = 2 ).values ,
124+ )
125+ np .testing .assert_equal (
126+ draws_pd [draws_pd ['iter__' ] == 100 ]['theta' ],
127+ draws_xr .theta .sel (draw = 99 ).values ,
128+ )
129+
130+
92131def test_from_csv_files_bad () -> None :
93132 # gq model
94133 stan = os .path .join (DATAFILES_PATH , 'bernoulli_ppc.stan' )
0 commit comments