@@ -71,6 +71,11 @@ def submit(
7171 raise ValueError (
7272 "TrainingQueue requires using a ModelTrainer with Mode.SAGEMAKER_TRAINING_JOB"
7373 )
74+
75+ if share_identifier != None and quota_share_name != None :
76+ raise ValueError (
77+ "Either share_identifier or quota_share_name can be specified, but not both"
78+ )
7479 training_payload = training_job ._create_training_job_args (
7580 input_data_config = inputs , boto3 = True
7681 )
@@ -108,6 +113,7 @@ def map(
108113 share_identifier : Optional [str ] = None ,
109114 timeout : Optional [Dict ] = None ,
110115 tags : Optional [Dict ] = None ,
116+ quota_share_name : Optional [str ] = None ,
111117 ) -> List [TrainingQueuedJob ]:
112118 """Submit queued jobs to the provided estimator and return a list of TrainingQueuedJob objects.
113119
@@ -120,6 +126,7 @@ def map(
120126 share_identifier: Share identifier for the Batch jobs.
121127 timeout: Timeout configuration for the Batch jobs.
122128 tags: Tags apply to Batch job. These tags are for Batch job only.
129+ quota_share_name: Quota share name for the Batch jobs.
123130
124131 Returns: a list of TrainingQueuedJob objects with each Batch job ARN and job name.
125132
@@ -144,6 +151,7 @@ def map(
144151 share_identifier ,
145152 timeout ,
146153 tags ,
154+ quota_share_name ,
147155 )
148156 queued_batch_job_list .append (queued_batch_job )
149157
@@ -171,7 +179,7 @@ def list_jobs(
171179 for job_result in job_result_dict .get ("jobSummaryList" , []):
172180 if "jobArn" in job_result and "jobName" in job_result :
173181 jobs_to_return .append (
174- TrainingQueuedJob (job_result ["jobArn" ], job_result ["jobName" ], job_result .get ("shareIdentifier" , None ))
182+ TrainingQueuedJob (job_result ["jobArn" ], job_result ["jobName" ], job_result .get ("shareIdentifier" , None ), job_result . get ( "quotaShareName" , None ) )
175183 )
176184 else :
177185 logging .warning ("Missing JobArn or JobName in Batch ListJobs API" )
@@ -182,27 +190,35 @@ def list_jobs_by_share(
182190 self ,
183191 status : Optional [str ] = JOB_STATUS_RUNNING ,
184192 share_identifier : Optional [str ] = None ,
193+ quota_share_name : Optional [str ] = None ,
185194 ) -> List [TrainingQueuedJob ]:
186195 """List Batch jobs according to status and share.
187196
188197 Args:
189198 status: Batch job status.
190199 share_identifier: Batch fairshare share identifier.
200+ quota_share_name: Batch quota management share name.
191201
192202 Returns: A list of QueuedJob.
193203
194204 """
195205 filters = None
206+ if share_identifier != None and quota_share_name != None :
207+ raise ValueError (
208+ "Either share_identifier or quota_share_name can be specified, but not both"
209+ )
196210 if share_identifier :
197211 filters = [{"name" : "SHARE_IDENTIFIER" , "values" : [share_identifier ]}]
212+ elif quota_share_name :
213+ filters = [{"name" : "QUOTA_SHARE_NAME" , "values" : [quota_share_name ]}]
198214
199215 jobs_to_return = []
200216 next_token = None
201217 for job_result_dict in _list_service_job (self .queue_name , status , filters , next_token ):
202218 for job_result in job_result_dict .get ("jobSummaryList" , []):
203219 if "jobArn" in job_result and "jobName" in job_result :
204220 jobs_to_return .append (
205- TrainingQueuedJob (job_result ["jobArn" ], job_result ["jobName" ], job_result .get ("shareIdentifier" , None ))
221+ TrainingQueuedJob (job_result ["jobArn" ], job_result ["jobName" ], job_result .get ("shareIdentifier" , None ), job_result . get ( "quotaShareName" , None ) )
206222 )
207223 else :
208224 logging .warning ("Missing JobArn or JobName in Batch ListJobs API" )
0 commit comments