fix: raise RuntimeError when checkpoint step >= config.steps#4022
fix: raise RuntimeError when checkpoint step >= config.steps#4022Dr-Left wants to merge 1 commit into
Conversation
When a user sets steps=x and there is already a checkpoint saved at step x, the job should fail with a clear error message instead of performing no computation or failing with a confusing profiling error. We add an early check in setup_train_loop (train_utils.py) to fail fast before loading the checkpoint or initializing TPU. We also add a fallback check in train_loop (train.py) to catch cases where the early check might have been bypassed (e.g. loading from a full state path with checkpointing disabled). The error message guides the user on how to proceed (either increase steps or disable checkpoint loading). TAG=agy CONV=5b89af3a-6c91-4082-9e23-dfe4390461f3
|
🤖 Hi @Dr-Left, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR introduces explicit checks to ensure that training does not proceed if the target number of steps has already been reached or exceeded by an existing checkpoint. This provides clearer error messages compared to the previous behavior of either doing nothing or failing with cryptic profiling errors.
🔍 General Feedback
- The implementation follows a "fail-fast" approach by adding an early check in
setup_train_loopand a fallback check intrain_loopto handle various checkpoint loading scenarios. - The logic for identifying finished training based on checkpoint steps is correct and consistent across the two checks.
- A minor improvement is suggested for the error message in cases where
start_stepis 0 to avoid referring to a non-existent "checkpoint at step -1". - Consolidating the error message string into a shared constant or utility would improve maintainability.
|
|
||
| start_step = get_first_step(model, state) # this is the start_step for training | ||
| if start_step >= config.steps: | ||
| raise RuntimeError( | ||
| f"Requested training up to step {config.steps}, but a checkpoint already exists at step {start_step - 1} " | ||
| f"(which means {start_step} steps have been completed). " | ||
| f"Did you mean to continue training past step {start_step} (you should set steps > {start_step}) " |
There was a problem hiding this comment.
| start_step = get_first_step(model, state) # this is the start_step for training | |
| if start_step >= config.steps: | |
| raise RuntimeError( | |
| f"Requested training up to step {config.steps}, but a checkpoint already exists at step {start_step - 1} " | |
| f"(which means {start_step} steps have been completed). " | |
| f"Did you mean to continue training past step {start_step} (you should set steps > {start_step}) " | |
| if start_step >= config.steps: | |
| if start_step > 0: | |
| raise RuntimeError( | |
| f"Requested training up to step {config.steps}, but a checkpoint already exists at step {start_step - 1} " | |
| f"(which means {start_step} steps have been completed). " | |
| f"Did you mean to continue training past step {start_step} (you should set steps > {start_step}) " | |
| f"or to not load the checkpoint (use enable_checkpointing=False?)" | |
| ) | |
| else: | |
| raise RuntimeError( | |
| f"Requested training up to step {config.steps}, but start_step is {start_step}. " | |
| f"Please set steps > {start_step}." | |
| ) |
There was a problem hiding this comment.
I assume customers might not train for zero steps?
| if checkpoint_step is not None and checkpoint_step + 1 >= config.steps: | ||
| raise RuntimeError( | ||
| f"Requested training up to step {config.steps}, but a checkpoint already exists at step {checkpoint_step} " | ||
| f"(which means {checkpoint_step + 1} steps have been completed). " | ||
| f"Did you mean to continue training past step {checkpoint_step + 1} (you should set steps > {checkpoint_step + 1}) " | ||
| f"or to not load the checkpoint (use enable_checkpointing=False?)" |
There was a problem hiding this comment.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Description
When a user sets steps=x and there is already a checkpoint saved at step x, the job should fail with a clear error message instead of performing no computation or failing with a confusing profiling error.
We add an early check in setup_train_loop (train_utils.py) to fail fast before loading the checkpoint or initializing TPU. We also add a fallback check in train_loop (train.py) to catch cases where the early check might have been bypassed (e.g. loading from a full state path with checkpointing disabled).
The error message guides the user on how to proceed (either increase steps or disable checkpoint loading).
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/474108002
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.