Skip to content

fix: raise RuntimeError when checkpoint step >= config.steps#4022

Open
Dr-Left wants to merge 1 commit into
mainfrom
chris/fix/explicit-checkpoint-exiting-error
Open

fix: raise RuntimeError when checkpoint step >= config.steps#4022
Dr-Left wants to merge 1 commit into
mainfrom
chris/fix/explicit-checkpoint-exiting-error

Conversation

@Dr-Left
Copy link
Copy Markdown
Collaborator

@Dr-Left Dr-Left commented May 29, 2026

Description

When a user sets steps=x and there is already a checkpoint saved at step x, the job should fail with a clear error message instead of performing no computation or failing with a confusing profiling error.

We add an early check in setup_train_loop (train_utils.py) to fail fast before loading the checkpoint or initializing TPU. We also add a fallback check in train_loop (train.py) to catch cases where the early check might have been bypassed (e.g. loading from a full state path with checkpointing disabled).

The error message guides the user on how to proceed (either increase steps or disable checkpoint loading).

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/474108002

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

When a user sets steps=x and there is already a checkpoint saved at step x,
the job should fail with a clear error message instead of performing no
computation or failing with a confusing profiling error.

We add an early check in setup_train_loop (train_utils.py) to fail fast before
loading the checkpoint or initializing TPU. We also add a fallback check in
train_loop (train.py) to catch cases where the early check might have been
bypassed (e.g. loading from a full state path with checkpointing disabled).

The error message guides the user on how to proceed (either increase steps
or disable checkpoint loading).

TAG=agy
CONV=5b89af3a-6c91-4082-9e23-dfe4390461f3
@github-actions
Copy link
Copy Markdown

🤖 Hi @Dr-Left, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR introduces explicit checks to ensure that training does not proceed if the target number of steps has already been reached or exceeded by an existing checkpoint. This provides clearer error messages compared to the previous behavior of either doing nothing or failing with cryptic profiling errors.

🔍 General Feedback

  • The implementation follows a "fail-fast" approach by adding an early check in setup_train_loop and a fallback check in train_loop to handle various checkpoint loading scenarios.
  • The logic for identifying finished training based on checkpoint steps is correct and consistent across the two checks.
  • A minor improvement is suggested for the error message in cases where start_step is 0 to avoid referring to a non-existent "checkpoint at step -1".
  • Consolidating the error message string into a shared constant or utility would improve maintainability.

Comment on lines 685 to +691

start_step = get_first_step(model, state) # this is the start_step for training
if start_step >= config.steps:
raise RuntimeError(
f"Requested training up to step {config.steps}, but a checkpoint already exists at step {start_step - 1} "
f"(which means {start_step} steps have been completed). "
f"Did you mean to continue training past step {start_step} (you should set steps > {start_step}) "
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The error message is slightly misleading when `start_step` is 0 (e.g., if `config.steps` is set to 0), as it refers to a "checkpoint already exists at step -1". It would be better to handle the `start_step == 0` case specifically.
Suggested change
start_step = get_first_step(model, state) # this is the start_step for training
if start_step >= config.steps:
raise RuntimeError(
f"Requested training up to step {config.steps}, but a checkpoint already exists at step {start_step - 1} "
f"(which means {start_step} steps have been completed). "
f"Did you mean to continue training past step {start_step} (you should set steps > {start_step}) "
if start_step >= config.steps:
if start_step > 0:
raise RuntimeError(
f"Requested training up to step {config.steps}, but a checkpoint already exists at step {start_step - 1} "
f"(which means {start_step} steps have been completed). "
f"Did you mean to continue training past step {start_step} (you should set steps > {start_step}) "
f"or to not load the checkpoint (use enable_checkpointing=False?)"
)
else:
raise RuntimeError(
f"Requested training up to step {config.steps}, but start_step is {start_step}. "
f"Please set steps > {start_step}."
)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume customers might not train for zero steps?

Comment on lines +245 to +250
if checkpoint_step is not None and checkpoint_step + 1 >= config.steps:
raise RuntimeError(
f"Requested training up to step {config.steps}, but a checkpoint already exists at step {checkpoint_step} "
f"(which means {checkpoint_step + 1} steps have been completed). "
f"Did you mean to continue training past step {checkpoint_step + 1} (you should set steps > {checkpoint_step + 1}) "
f"or to not load the checkpoint (use enable_checkpointing=False?)"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 This long error message is almost identical to the one in `train.py`. To improve maintainability and avoid potential inconsistencies, consider defining a shared helper function or constant in `train_utils.py` for this message.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 29, 2026

Codecov Report

❌ Patch coverage is 33.33333% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/pre_train/train.py 0.00% 1 Missing and 1 partial ⚠️
src/maxtext/utils/train_utils.py 50.00% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working gemini-review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants