1313from ..types import ServerlessTrainResult , TrainConfig
1414
1515if TYPE_CHECKING :
16+ import wandb
17+
1618 from ..model import Model , TrainableModel
1719
1820
21+ def _extract_step_from_wandb_artifact (artifact : "wandb.Artifact" ) -> int | None :
22+ """Extract step number from a W&B artifact's aliases."""
23+ for alias in artifact .aliases :
24+ if alias .startswith ("step" ):
25+ try :
26+ return int (alias [4 :])
27+ except ValueError :
28+ pass
29+ return None
30+
31+
1932class ServerlessBackend (Backend ):
2033 def __init__ (
2134 self , * , api_key : str | None = None , base_url : str | None = None
@@ -417,7 +430,58 @@ async def _experimental_push_to_s3(
417430 verbose : bool = False ,
418431 delete : bool = False ,
419432 ) -> None :
420- raise NotImplementedError
433+ """Push model checkpoints from W&B artifacts to S3.
434+
435+ Downloads checkpoint(s) from W&B and uploads them to S3.
436+
437+ Args:
438+ model: The model whose checkpoints to push.
439+ s3_bucket: S3 bucket name. If None, uses BACKUP_BUCKET env var.
440+ prefix: Optional S3 prefix path.
441+ verbose: Whether to print verbose output.
442+ delete: Whether to delete files from S3 that don't exist in source.
443+ """
444+ from art .utils .s3 import build_s3_path , ensure_bucket_exists , s3_sync
445+
446+ assert model .id is not None , "Model ID is required"
447+
448+ # Get all checkpoint steps
449+ steps : list [int ] = []
450+ async for checkpoint in self ._client .models .checkpoints .list ( # ty:ignore[possibly-missing-attribute]
451+ model_id = model .id , order = "asc"
452+ ):
453+ steps .append (checkpoint .step )
454+
455+ if not steps :
456+ if verbose :
457+ print ("No checkpoints found to push." )
458+ return
459+
460+ await ensure_bucket_exists (s3_bucket )
461+
462+ for step in steps :
463+ if verbose :
464+ print (f"Pushing checkpoint step { step } to S3..." )
465+
466+ # Pull from W&B to local temp dir
467+ checkpoint_dir = await self ._experimental_pull_model_checkpoint (
468+ model , # type: ignore[arg-type]
469+ step = step ,
470+ verbose = verbose ,
471+ )
472+
473+ # Push to S3
474+ s3_path = build_s3_path (
475+ model_name = model .name ,
476+ project = model .project ,
477+ step = step ,
478+ s3_bucket = s3_bucket ,
479+ prefix = prefix ,
480+ )
481+ await s3_sync (checkpoint_dir , s3_path , verbose = verbose , delete = delete )
482+
483+ if verbose :
484+ print (f"Successfully pushed { len (steps )} checkpoint(s) to S3." )
421485
422486 async def _experimental_fork_checkpoint (
423487 self ,
@@ -429,4 +493,154 @@ async def _experimental_fork_checkpoint(
429493 verbose : bool = False ,
430494 prefix : str | None = None ,
431495 ) -> None :
432- raise NotImplementedError
496+ """Fork a checkpoint from another model to initialize this model.
497+
498+ Pulls the source checkpoint from W&B artifacts (or S3 if from_s3_bucket
499+ is provided) and uploads it as a W&B artifact for the destination model.
500+
501+ Note: This uploads the artifact directly to W&B. The ServerlessBackend's
502+ checkpoint tracking may not immediately reflect the forked checkpoint
503+ until the next training step.
504+
505+ Args:
506+ model: The destination model to fork to.
507+ from_model: The name of the source model to fork from.
508+ from_project: The project of the source model. Defaults to model.project.
509+ from_s3_bucket: Optional S3 bucket to pull the checkpoint from.
510+ not_after_step: If provided, uses the latest checkpoint <= this step.
511+ verbose: Whether to print verbose output.
512+ prefix: Optional S3 prefix for bucket operations.
513+ """
514+ import os
515+ import tempfile
516+
517+ import wandb
518+
519+ from_project = from_project or model .project
520+
521+ if from_s3_bucket is not None :
522+ # Pull from S3
523+ from art .utils .s3 import build_s3_path , ensure_bucket_exists , s3_sync
524+ from art .utils .s3_checkpoint_utils import (
525+ get_checkpoint_step_not_after_from_s3 ,
526+ get_latest_checkpoint_step_from_s3 ,
527+ )
528+
529+ if not_after_step is None :
530+ target_step = await get_latest_checkpoint_step_from_s3 (
531+ model_name = from_model ,
532+ project = from_project ,
533+ s3_bucket = from_s3_bucket ,
534+ prefix = prefix ,
535+ )
536+ else :
537+ target_step = await get_checkpoint_step_not_after_from_s3 (
538+ model_name = from_model ,
539+ project = from_project ,
540+ not_after_step = not_after_step ,
541+ s3_bucket = from_s3_bucket ,
542+ prefix = prefix ,
543+ )
544+
545+ if target_step is None :
546+ raise ValueError (
547+ f"No suitable checkpoint found in S3 for model { from_model } "
548+ )
549+
550+ if verbose :
551+ print (f"Pulling checkpoint step { target_step } from S3..." )
552+
553+ checkpoint_dir = os .path .join (
554+ tempfile .gettempdir (),
555+ "art_fork_checkpoints" ,
556+ from_project ,
557+ from_model ,
558+ f"{ target_step :04d} " ,
559+ )
560+ os .makedirs (checkpoint_dir , exist_ok = True )
561+
562+ s3_path = build_s3_path (
563+ model_name = from_model ,
564+ project = from_project ,
565+ step = target_step ,
566+ s3_bucket = from_s3_bucket ,
567+ prefix = prefix ,
568+ )
569+ await ensure_bucket_exists (from_s3_bucket )
570+ await s3_sync (s3_path , checkpoint_dir , verbose = verbose )
571+ selected_step = target_step
572+ else :
573+ # Pull from W&B artifacts
574+ api = wandb .Api (api_key = self ._client .api_key ) # ty:ignore[possibly-missing-attribute]
575+ from_entity = model .entity or api .default_entity
576+
577+ # Iterate all artifact versions to find the best step.
578+ # We avoid relying on the W&B `:latest` alias because it
579+ # may not correspond to the highest training step.
580+ collection_path = f"{ from_entity } /{ from_project } /{ from_model } "
581+ versions = api .artifacts ("lora" , collection_path )
582+
583+ best_step : int | None = None
584+ best_artifact = None
585+ for version in versions :
586+ step_num = _extract_step_from_wandb_artifact (version )
587+ if step_num is None :
588+ continue
589+ if not_after_step is not None and step_num > not_after_step :
590+ continue
591+ if best_step is None or step_num > best_step :
592+ best_step = step_num
593+ best_artifact = version
594+
595+ if best_step is None or best_artifact is None :
596+ if not_after_step is not None :
597+ raise ValueError (
598+ f"No checkpoints found not after step { not_after_step } "
599+ f"for model { from_model } "
600+ )
601+ raise ValueError (f"No checkpoints found for model { from_model } " )
602+ selected_step = best_step
603+ artifact = best_artifact
604+
605+ checkpoint_dir = os .path .join (
606+ tempfile .gettempdir (),
607+ "art_fork_checkpoints" ,
608+ from_project ,
609+ from_model ,
610+ f"{ selected_step :04d} " if selected_step is not None else "latest" ,
611+ )
612+ os .makedirs (checkpoint_dir , exist_ok = True )
613+ artifact .download (root = checkpoint_dir )
614+
615+ if verbose :
616+ print (f"Downloaded source checkpoint step { selected_step } from W&B" )
617+
618+ # Upload as W&B artifact for the destination model
619+ assert model .entity is not None , "Model entity is required"
620+
621+ if verbose :
622+ print (f"Uploading forked checkpoint as W&B artifact for { model .name } ..." )
623+
624+ wandb .login (key = self ._client .api_key ) # ty:ignore[possibly-missing-attribute]
625+ run = wandb .init (
626+ project = model .project ,
627+ entity = model .entity ,
628+ job_type = "checkpoint-fork" ,
629+ name = f"fork-{ from_model } -to-{ model .name } " ,
630+ settings = wandb .Settings (silent = True ),
631+ )
632+ assert run is not None
633+
634+ dest_artifact = wandb .Artifact (name = model .name , type = "lora" )
635+ dest_artifact .add_dir (checkpoint_dir )
636+ aliases = ["latest" ]
637+ if selected_step is not None :
638+ aliases .insert (0 , f"step{ selected_step } " )
639+ run .log_artifact (dest_artifact , aliases = aliases )
640+ run .finish ()
641+
642+ if verbose :
643+ print (
644+ f"Successfully forked checkpoint from { from_model } "
645+ f"(step { selected_step } ) to { model .name } "
646+ )
0 commit comments