diff --git a/src/maxtext/models/deepseek_batchsplit.py b/src/maxtext/models/deepseek_batchsplit.py index 395af865b4..aba66b6df6 100644 --- a/src/maxtext/models/deepseek_batchsplit.py +++ b/src/maxtext/models/deepseek_batchsplit.py @@ -1167,7 +1167,7 @@ def batch_split_schedule_bwd( ): """Performs the backward pass for a single layer.""" norm_mla_ws, moe_ws = weights - mla_out, mla_bwds = mla_with_norms_remat( + _, mla_bwds = mla_with_norms_remat( residuals, norm_mla_ws, positions, @@ -1187,7 +1187,6 @@ def batch_split_schedule_bwd( dtype=cfg.dtype, activation_pspec=activation_pspec, ) - residuals["mla_out"] = mla_out attn_out_grad, moe_ws_grad = moe_bwd( residuals, outputs_grad, @@ -1286,7 +1285,8 @@ def fn(args): # Prevent fusion with MoE ops, especially the RMS norm. # Unfortunately, this seems to be needed to avoid slight numerical differences # between the fwd pass and remat. - return jax.lax.optimization_barrier(mla_out + x), mla_res + out = jax.lax.optimization_barrier(mla_out + x) + return out, {"mla_out": out, **mla_res} return staggered_call(fn, list(zip(inputs, yarn_freqs))) @@ -1349,11 +1349,11 @@ def remat_fn(args): mesh=mesh, activation_pspec=activation_pspec, ) - out = x + mla_out # Prevent fusion with MoE ops, especially the RMS norm. # Unfortunately, this seems to be needed to avoid slight numerical differences # between the fwd pass and remat. - return jax.lax.optimization_barrier(out), (pre_attn_rms_norm_bwd, mla_bwds) + out = jax.lax.optimization_barrier(x + mla_out) + return out, (pre_attn_rms_norm_bwd, mla_bwds) bwds = [None] * len(xs) for i, x in enumerate(zip(xs, yarn_freqs, residuals.pop("attn_out"), residuals.pop("lse"))):