diff --git a/openhexa/sdk/pipelines/parameter/decorator.py b/openhexa/sdk/pipelines/parameter/decorator.py index 301a6269..44503882 100644 --- a/openhexa/sdk/pipelines/parameter/decorator.py +++ b/openhexa/sdk/pipelines/parameter/decorator.py @@ -51,6 +51,8 @@ def __init__( required: bool = True, multiple: bool = False, directory: str | None = None, + disables: typing.Sequence[str] | None = None, + disable_when: bool = True, ): validate_pipeline_parameter_code(code) self.code = code @@ -92,6 +94,16 @@ def __init__( self.widget = widget self.connection = connection self.directory = directory + self.disables = list(dict.fromkeys(disables)) if disables else None + self.disable_when = disable_when + if self.disables and not isinstance(self.type, Boolean): + raise InvalidParameterError( + f"Only boolean parameters can use 'disables'. Parameter '{self.code}' is of type {self.type}." + ) + if not isinstance(self.disable_when, bool): + raise InvalidParameterError( + f"'disable_when' must be a boolean for parameter '{self.code}' (got {disable_when!r})." + ) self._validate_default(default, multiple) self.default = default @@ -117,6 +129,8 @@ def to_dict(self) -> dict[str, typing.Any]: "required": self.required, "multiple": self.multiple, "directory": self.directory, + "disables": self.disables, + "disableWhen": self.disable_when, } if isinstance(self.choices, ChoicesFromFile): d["choicesFromFile"] = self.choices.to_dict() @@ -207,6 +221,24 @@ def validate_parameters(parameters: list[Parameter]): supported_connection_types = {DHIS2ConnectionType, IASOConnectionType} connection_parameters = {p.code for p in parameters if type(p.type) in supported_connection_types} + parameters_by_code = {p.code: p for p in parameters} + controllers = {p.code for p in parameters if p.disables} + for parameter in parameters: + if not parameter.disables: + continue + for target_code in parameter.disables: + if target_code == parameter.code: + raise InvalidParameterError(f"Parameter '{parameter.code}' cannot disable itself.") + if target_code not in parameters_by_code: + raise InvalidParameterError( + f"Parameter '{parameter.code}' disables a non-existing parameter '{target_code}'." + ) + if target_code in controllers: + raise InvalidParameterError( + f"Parameter '{parameter.code}' disables '{target_code}', which is itself a disabling " + f"parameter. Chaining disabling parameters is not supported." + ) + for parameter in parameters: if parameter.connection and parameter.connection not in connection_parameters: raise InvalidParameterError( @@ -251,6 +283,8 @@ def parameter( required: bool = True, multiple: bool = False, directory: str | None = None, + disables: typing.Sequence[str] | None = None, + disable_when: bool = True, ): """Decorate a pipeline function by attaching a parameter to it.. @@ -282,6 +316,14 @@ def parameter( values of the chosen type) directory : str, optional An optional parameter to force file selection to specific directory (only used for parameter type File). If the directory does not exist, it will be ignored. + disables : sequence of str, optional + An optional list of parameter codes to disable when this (boolean) parameter's value matches ``disable_when``. + Disabled parameters are hidden/greyed out in the run form, their required check is skipped, and they are + omitted from the run config (the pipeline function receives their default value). Only boolean parameters can + use this. + disable_when : bool, default=True + The boolean value of this parameter that triggers the disabling of the parameters listed in ``disables``. + Use ``disable_when=False`` for an "enable" toggle (the listed parameters are disabled while it is unticked). Returns ------- @@ -305,6 +347,8 @@ def decorator(fun): connection=connection, multiple=multiple, directory=directory, + disables=disables, + disable_when=disable_when, ), ) diff --git a/openhexa/sdk/pipelines/pipeline.py b/openhexa/sdk/pipelines/pipeline.py index 2a316619..f3f7b537 100644 --- a/openhexa/sdk/pipelines/pipeline.py +++ b/openhexa/sdk/pipelines/pipeline.py @@ -123,9 +123,16 @@ def _validate_config(self, config: dict[str, typing.Any]) -> dict[str, typing.An ParameterValueError If the config contains invalid keys or parameter validation fails. """ + disabled_codes = self._get_disabled_codes(config) + validated_config = {} for parameter in self.parameters: value = config.pop(parameter.code, None) + if parameter.code in disabled_codes: + # Parameter is disabled by an active controller: ignore the (possibly dummy or missing) + # value, skip required/type validation, and fall back to its default. + validated_config[parameter.code] = parameter.default + continue validated_value = parameter.validate(value) validated_config[parameter.code] = validated_value @@ -134,6 +141,22 @@ def _validate_config(self, config: dict[str, typing.Any]) -> dict[str, typing.An return validated_config + def _get_disabled_codes(self, config: dict[str, typing.Any]) -> set[str]: + """Return the codes of parameters disabled by an active controller in the given config. + + A controller is a boolean parameter declaring ``disables=[...]``. It is "active" when its effective + value (from the config, falling back to its default) equals its ``disable_when`` (``True`` by default). + A parameter is disabled if any active controller lists it. + """ + disabled_codes: set[str] = set() + for parameter in self.parameters: + if not parameter.disables: + continue + effective_value = config.get(parameter.code, parameter.default) + if bool(effective_value) == parameter.disable_when: + disabled_codes.update(parameter.disables) + return disabled_codes + def _execute_tasks(self, pool): """Execute all tasks using the provided multiprocessing pool. diff --git a/openhexa/sdk/pipelines/runtime.py b/openhexa/sdk/pipelines/runtime.py index e23888a6..410fd0b6 100644 --- a/openhexa/sdk/pipelines/runtime.py +++ b/openhexa/sdk/pipelines/runtime.py @@ -316,6 +316,8 @@ def get_pipeline(pipeline_path: Path) -> Pipeline: Argument("required", [ast.Constant], default_value=True), Argument("multiple", [ast.Constant], default_value=False), Argument("directory", [ast.Constant]), + Argument("disables", [ast.List]), + Argument("disable_when", [ast.Constant], default_value=True), ), ) diff --git a/tests/test_ast.py b/tests/test_ast.py index 703d36ec..b433b1ed 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -154,6 +154,8 @@ def test_pipeline_with_int_param(self): "help": "Param help", "required": True, "directory": None, + "disables": None, + "disableWhen": True, } ], "timeout": None, @@ -161,6 +163,53 @@ def test_pipeline_with_int_param(self): }, ) + def test_pipeline_with_disables_param(self): + """The @parameter decorator's 'disables' list is parsed from the pipeline code.""" + with tempfile.TemporaryDirectory() as tmpdirname: + with open(f"{tmpdirname}/pipeline.py", "w") as f: + f.write( + "\n".join( + [ + "from openhexa.sdk.pipelines import pipeline, parameter", + "", + "@parameter('run_report_only', type=bool, default=False, disables=['data_input'])", + "@parameter('data_input', type=str)", + "@pipeline('Test pipeline')", + "def test_pipeline():", + " pass", + "", + ] + ) + ) + pipeline = get_pipeline(tmpdirname) + params = {p["code"]: p for p in pipeline.to_dict()["parameters"]} + self.assertEqual(params["run_report_only"]["disables"], ["data_input"]) + self.assertEqual(params["run_report_only"]["disableWhen"], True) + self.assertIsNone(params["data_input"]["disables"]) + + def test_pipeline_with_disable_when_false(self): + """The @parameter decorator's 'disable_when' is parsed from the pipeline code.""" + with tempfile.TemporaryDirectory() as tmpdirname: + with open(f"{tmpdirname}/pipeline.py", "w") as f: + f.write( + "\n".join( + [ + "from openhexa.sdk.pipelines import pipeline, parameter", + "", + "@parameter('enable_advanced', type=bool, default=False, disables=['tuning'], disable_when=False)", + "@parameter('tuning', type=str)", + "@pipeline('Test pipeline')", + "def test_pipeline():", + " pass", + "", + ] + ) + ) + pipeline = get_pipeline(tmpdirname) + params = {p["code"]: p for p in pipeline.to_dict()["parameters"]} + self.assertEqual(params["enable_advanced"]["disables"], ["tuning"]) + self.assertEqual(params["enable_advanced"]["disableWhen"], False) + def test_pipeline_with_multiple_param(self): """The file contains a @pipeline decorator and a @parameter decorator with multiple=True.""" with tempfile.TemporaryDirectory() as tmpdirname: @@ -198,6 +247,8 @@ def test_pipeline_with_multiple_param(self): "help": "Param help", "required": True, "directory": None, + "disables": None, + "disableWhen": True, } ], "timeout": None, @@ -243,6 +294,8 @@ def test_pipeline_with_dataset(self): "help": "Dataset", "required": False, "directory": None, + "disables": None, + "disableWhen": True, } ], "timeout": None, @@ -287,6 +340,8 @@ def test_pipeline_with_choices(self): "help": "Param help", "required": True, "directory": None, + "disables": None, + "disableWhen": True, } ], "timeout": None, @@ -359,6 +414,8 @@ def test_pipeline_with_bool(self): "help": "Param help", "required": True, "directory": None, + "disables": None, + "disableWhen": True, } ], "timeout": None, @@ -404,6 +461,8 @@ def test_pipeline_with_multiple_parameters(self): "help": "Param help", "required": True, "directory": None, + "disables": None, + "disableWhen": True, }, { "choices": ["a", "b"], @@ -417,6 +476,8 @@ def test_pipeline_with_multiple_parameters(self): "help": "Param help 2", "required": True, "directory": None, + "disables": None, + "disableWhen": True, }, ], "timeout": None, @@ -484,6 +545,8 @@ def test_pipeline_with_connection_parameter_for_dhis2(self): "help": None, "required": True, "directory": None, + "disables": None, + "disableWhen": True, }, { "code": "data_element_ids", @@ -497,6 +560,8 @@ def test_pipeline_with_connection_parameter_for_dhis2(self): "help": None, "required": True, "directory": None, + "disables": None, + "disableWhen": True, }, ], "timeout": None, @@ -546,6 +611,8 @@ def test_pipeline_with_connection_parameter_for_iaso(self): "help": None, "required": True, "directory": None, + "disables": None, + "disableWhen": True, }, { "code": "org_units", @@ -559,6 +626,8 @@ def test_pipeline_with_connection_parameter_for_iaso(self): "help": None, "required": True, "directory": None, + "disables": None, + "disableWhen": True, }, { "code": "projects", @@ -572,6 +641,8 @@ def test_pipeline_with_connection_parameter_for_iaso(self): "help": None, "required": True, "directory": None, + "disables": None, + "disableWhen": True, }, { "code": "forms", @@ -585,6 +656,8 @@ def test_pipeline_with_connection_parameter_for_iaso(self): "help": None, "required": True, "directory": None, + "disables": None, + "disableWhen": True, }, ], "timeout": None, diff --git a/tests/test_parameter.py b/tests/test_parameter.py index ea405112..5de6c2ed 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -36,6 +36,7 @@ SecretType, StringType, parameter, + validate_parameters, ) from openhexa.utils import stringcase @@ -422,3 +423,63 @@ def a_function(): assert function_parameters[1].default == ["yo"] assert function_parameters[1].required is False assert function_parameters[1].multiple is True + + +def test_parameter_disables_serialization(): + """The 'disables' option is normalized to a list and serialized in to_dict.""" + no_disables = Parameter("plain", type=str) + assert no_disables.disables is None + assert no_disables.to_dict()["disables"] is None + + controller = Parameter("run_report_only", type=bool, disables=["data_input", "year"]) + assert controller.disables == ["data_input", "year"] + assert controller.to_dict()["disables"] == ["data_input", "year"] + assert controller.to_dict()["disableWhen"] is True + + +def test_parameter_disables_dedup_preserves_order(): + """Duplicate disables targets are removed while keeping declaration order.""" + controller = Parameter("toggle", type=bool, disables=["b", "a", "b", "a"]) + assert controller.disables == ["b", "a"] + + +def test_disable_when_must_be_boolean(): + """'disable_when' must be a boolean — rejected at construction time.""" + with pytest.raises(InvalidParameterError): + Parameter("toggle", type=bool, disables=["x_param"], disable_when="yes") + + +def test_validate_parameters_disables_ok(): + """A valid disabling setup passes validation.""" + controller = Parameter("run_report_only", type=bool, default=False, disables=["data_input"]) + data_input = Parameter("data_input", type=str, required=True) + validate_parameters([controller, data_input]) + + +def test_disables_must_be_boolean(): + """Only boolean parameters can use 'disables' — rejected at construction time.""" + with pytest.raises(InvalidParameterError): + Parameter("mode", type=str, disables=["data_input"]) + + +def test_validate_parameters_disables_unknown_target(): + """Disabling a non-existing parameter raises.""" + controller = Parameter("run_report_only", type=bool, disables=["does_not_exist"]) + with pytest.raises(InvalidParameterError): + validate_parameters([controller]) + + +def test_validate_parameters_disables_self_reference(): + """A parameter cannot disable itself.""" + controller = Parameter("run_report_only", type=bool, disables=["run_report_only"]) + with pytest.raises(InvalidParameterError): + validate_parameters([controller]) + + +def test_validate_parameters_disables_no_chaining(): + """A disabling parameter cannot disable another disabling parameter.""" + controller_a = Parameter("toggle_a", type=bool, disables=["toggle_b"]) + controller_b = Parameter("toggle_b", type=bool, disables=["plain_c"]) + plain_c = Parameter("plain_c", type=str) + with pytest.raises(InvalidParameterError): + validate_parameters([controller_a, controller_b, plain_c]) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index bbff0fc4..f586e652 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -52,6 +52,67 @@ def test_pipeline_run_extra_config(): pipeline.run({"arg1": "ok", "arg2": "extra"}) +def test_pipeline_run_disabled_required_parameter_skipped(): + """A required parameter disabled by an active controller is skipped and receives its default.""" + pipeline_func = Mock() + controller = Parameter("run_report_only", type=bool, default=False, disables=["data_input"]) + data_input = Parameter("data_input", type=str, required=True) + pipeline = Pipeline("pipeline", pipeline_func, [controller, data_input]) + + pipeline.run({"run_report_only": True}) + + pipeline_func.assert_called_once_with(run_report_only=True, data_input=None) + + +def test_pipeline_run_disabled_parameter_value_ignored(): + """A dummy value provided for a disabled parameter is ignored in favor of its default.""" + pipeline_func = Mock() + controller = Parameter("run_report_only", type=bool, default=False, disables=["year"]) + year = Parameter("year", type=int, required=True, default=2024) + pipeline = Pipeline("pipeline", pipeline_func, [controller, year]) + + pipeline.run({"run_report_only": True, "year": 1}) + + pipeline_func.assert_called_once_with(run_report_only=True, year=2024) + + +def test_pipeline_run_inactive_controller_still_validates(): + """When the controller is not active, disabled parameters are still validated as usual.""" + pipeline_func = Mock() + controller = Parameter("run_report_only", type=bool, default=False, disables=["data_input"]) + data_input = Parameter("data_input", type=str, required=True) + pipeline = Pipeline("pipeline", pipeline_func, [controller, data_input]) + + with pytest.raises(ParameterValueError): + pipeline.run({"run_report_only": False}) + + +def test_pipeline_run_disable_when_false_disables_while_off(): + """With disable_when=False, listed params are disabled while the controller is off (default).""" + pipeline_func = Mock() + controller = Parameter("enable_advanced", type=bool, default=False, disables=["tuning"], disable_when=False) + tuning = Parameter("tuning", type=str, required=True) + pipeline = Pipeline("pipeline", pipeline_func, [controller, tuning]) + + pipeline.run({"enable_advanced": False}) + + pipeline_func.assert_called_once_with(enable_advanced=False, tuning=None) + + +def test_pipeline_run_disable_when_false_validates_while_on(): + """With disable_when=False, listed params are required again once the controller is on.""" + pipeline_func = Mock() + controller = Parameter("enable_advanced", type=bool, default=False, disables=["tuning"], disable_when=False) + tuning = Parameter("tuning", type=str, required=True) + pipeline = Pipeline("pipeline", pipeline_func, [controller, tuning]) + + with pytest.raises(ParameterValueError): + pipeline.run({"enable_advanced": True}) + + pipeline.run({"enable_advanced": True, "tuning": "fast"}) + pipeline_func.assert_called_once_with(enable_advanced=True, tuning="fast") + + @patch.dict( os.environ, {"HEXA_SERVER_URL": "https://test.openhexa.org"},