Skip to content

Commit befac2d

Browse files
committed
add huggingface support
1 parent e0df4d8 commit befac2d

1 file changed

Lines changed: 46 additions & 0 deletions

File tree

sdks/python/apache_beam/yaml/yaml_ml.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,52 @@ def inference_output_type(self):
282282
('model_id', Optional[str])])
283283

284284

285+
@ModelHandlerProvider.register_handler_type('HuggingFacePipeline')
286+
class HuggingFacePipelineProvider(ModelHandlerProvider):
287+
def __init__(
288+
self,
289+
task: str = "",
290+
model: str = "",
291+
preprocess: Optional[dict[str, str]] = None,
292+
postprocess: Optional[dict[str, str]] = None,
293+
device: Optional[str] = None,
294+
inference_fn: Optional[dict[str, str]] = None,
295+
load_pipeline_args: Optional[dict[str, Any]] = None,
296+
**kwargs):
297+
try:
298+
from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler
299+
except ImportError:
300+
raise ValueError(
301+
'Unable to import HuggingFacePipelineModelHandler. Please '
302+
'install transformers dependencies.')
303+
304+
kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_')}
305+
306+
inference_fn_obj = self.parse_processing_transform(
307+
inference_fn, 'inference_fn') if inference_fn else None
308+
309+
handler_kwargs = {}
310+
if inference_fn_obj:
311+
handler_kwargs['inference_fn'] = inference_fn_obj
312+
313+
_handler = HuggingFacePipelineModelHandler(
314+
task=task,
315+
model=model,
316+
device=device,
317+
load_pipeline_args=load_pipeline_args,
318+
**handler_kwargs,
319+
**kwargs)
320+
321+
super().__init__(_handler, preprocess, postprocess)
322+
323+
@staticmethod
324+
def validate(model_handler_spec):
325+
pass
326+
327+
def inference_output_type(self):
328+
return Any
329+
330+
285331
@beam.ptransform.ptransform_fn
286332
def run_inference(
287333
pcoll,

0 commit comments

Comments
 (0)