@@ -615,9 +615,31 @@ def draws_pd(
615615 else :
616616 cols = list (self .column_names )
617617
618+ draws = self .draws (inc_warmup = inc_warmup )
619+ # add long-form columns for chain, iteration, draw
620+ n_draws , n_chains , _ = draws .shape
621+ chains_col = (
622+ np .repeat (np .arange (1 , n_chains + 1 ), n_draws )
623+ .reshape (1 , n_chains , n_draws )
624+ .T
625+ )
626+ iter_col = (
627+ np .tile (np .arange (1 , n_draws + 1 ), n_chains )
628+ .reshape (1 , n_chains , n_draws )
629+ .T
630+ )
631+ draw_col = (
632+ np .arange (1 , (n_draws * n_chains ) + 1 )
633+ .reshape (1 , n_chains , n_draws )
634+ .T
635+ )
636+ draws = np .concatenate ([chains_col , iter_col , draw_col , draws ], axis = 2 )
637+
638+ cols = ['chain' , 'iter' , 'draw' ] + cols
639+
618640 return pd .DataFrame (
619- data = flatten_chains (self . draws ( inc_warmup = inc_warmup ) ),
620- columns = self .column_names ,
641+ data = flatten_chains (draws ),
642+ columns = [ 'chain' , 'iter' , 'draw' ] + list ( self .column_names ) ,
621643 )[cols ]
622644
623645 def draws_xr (
0 commit comments