File tree Expand file tree Collapse file tree
sdks/python/apache_beam/yaml Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -286,11 +286,11 @@ def inference_output_type(self):
286286class HuggingFacePipelineProvider (ModelHandlerProvider ):
287287 def __init__ (
288288 self ,
289- task : str = "" ,
290- model : str = "" ,
289+ task : Optional [ str ] = None ,
290+ model : Optional [ str ] = None ,
291291 preprocess : Optional [dict [str , str ]] = None ,
292292 postprocess : Optional [dict [str , str ]] = None ,
293- device : Optional [str ] = None ,
293+ device : Optional [Any ] = None ,
294294 inference_fn : Optional [dict [str , str ]] = None ,
295295 load_pipeline_args : Optional [dict [str , Any ]] = None ,
296296 ** kwargs ):
@@ -322,7 +322,11 @@ def __init__(
322322
323323 @staticmethod
324324 def validate (model_handler_spec ):
325- pass
325+ config = model_handler_spec .get ('config' , {})
326+ if not config .get ('task' ) and not config .get ('model' ):
327+ raise ValueError (
328+ "HuggingFacePipeline requires either 'task' or "
329+ "'model' to be specified." )
326330
327331 def inference_output_type (self ):
328332 return Any
You can’t perform that action at this time.
0 commit comments