Skip to content

Commit 0dad77b

Browse files
committed
address gemini
1 parent e44b41a commit 0dad77b

1 file changed

Lines changed: 8 additions & 4 deletions

File tree

sdks/python/apache_beam/yaml/yaml_ml.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,11 @@ def inference_output_type(self):
286286
class HuggingFacePipelineProvider(ModelHandlerProvider):
287287
def __init__(
288288
self,
289-
task: str = "",
290-
model: str = "",
289+
task: Optional[str] = None,
290+
model: Optional[str] = None,
291291
preprocess: Optional[dict[str, str]] = None,
292292
postprocess: Optional[dict[str, str]] = None,
293-
device: Optional[str] = None,
293+
device: Optional[Any] = None,
294294
inference_fn: Optional[dict[str, str]] = None,
295295
load_pipeline_args: Optional[dict[str, Any]] = None,
296296
**kwargs):
@@ -322,7 +322,11 @@ def __init__(
322322

323323
@staticmethod
324324
def validate(model_handler_spec):
325-
pass
325+
config = model_handler_spec.get('config', {})
326+
if not config.get('task') and not config.get('model'):
327+
raise ValueError(
328+
"HuggingFacePipeline requires either 'task' or "
329+
"'model' to be specified.")
326330

327331
def inference_output_type(self):
328332
return Any

0 commit comments

Comments
 (0)