@@ -282,6 +282,52 @@ def inference_output_type(self):
282282 ('model_id' , Optional [str ])])
283283
284284
285+ @ModelHandlerProvider .register_handler_type ('HuggingFacePipeline' )
286+ class HuggingFacePipelineProvider (ModelHandlerProvider ):
287+ def __init__ (
288+ self ,
289+ task : str = "" ,
290+ model : str = "" ,
291+ preprocess : Optional [dict [str , str ]] = None ,
292+ postprocess : Optional [dict [str , str ]] = None ,
293+ device : Optional [str ] = None ,
294+ inference_fn : Optional [dict [str , str ]] = None ,
295+ load_pipeline_args : Optional [dict [str , Any ]] = None ,
296+ ** kwargs ):
297+ try :
298+ from apache_beam .ml .inference .huggingface_inference import HuggingFacePipelineModelHandler
299+ except ImportError :
300+ raise ValueError (
301+ 'Unable to import HuggingFacePipelineModelHandler. Please '
302+ 'install transformers dependencies.' )
303+
304+ kwargs = {k : v for k , v in kwargs .items () if not k .startswith ('_' )}
305+
306+ inference_fn_obj = self .parse_processing_transform (
307+ inference_fn , 'inference_fn' ) if inference_fn else None
308+
309+ handler_kwargs = {}
310+ if inference_fn_obj :
311+ handler_kwargs ['inference_fn' ] = inference_fn_obj
312+
313+ _handler = HuggingFacePipelineModelHandler (
314+ task = task ,
315+ model = model ,
316+ device = device ,
317+ load_pipeline_args = load_pipeline_args ,
318+ ** handler_kwargs ,
319+ ** kwargs )
320+
321+ super ().__init__ (_handler , preprocess , postprocess )
322+
323+ @staticmethod
324+ def validate (model_handler_spec ):
325+ pass
326+
327+ def inference_output_type (self ):
328+ return Any
329+
330+
285331@beam .ptransform .ptransform_fn
286332def run_inference (
287333 pcoll ,
0 commit comments