Skip to content

Commit cf12ffd

Browse files
feat(bigframes): Support Expression objects in create_model options (#16606)
Allows users to specify BigFrames `Expression` objects as values within the `options` dictionary when calling `create_model`. These values are compiled natively to their scalar SQL representation in the resulting `CREATE MODEL` BigQuery DDL. Snapshot tests confirm the translation logic. --- *PR created automatically by Jules for task [15193413976404138758](https://jules.google.com/task/15193413976404138758) started by @tswast* --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: tswast <247555+tswast@users.noreply.github.com>
1 parent 3390cf4 commit cf12ffd

4 files changed

Lines changed: 41 additions & 4 deletions

File tree

packages/bigframes/bigframes/bigquery/_operations/ml.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import bigframes.dataframe as dataframe
2626
import bigframes.ml.base
2727
import bigframes.session
28+
import bigframes.core.col as col
2829
from bigframes.bigquery._operations import utils
2930

3031

@@ -50,7 +51,9 @@ def create_model(
5051
input_schema: Optional[Mapping[str, str]] = None,
5152
output_schema: Optional[Mapping[str, str]] = None,
5253
connection_name: Optional[str] = None,
53-
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
54+
options: Optional[
55+
Mapping[str, Union[str, int, float, bool, list, "col.Expression"]]
56+
] = None,
5457
training_data: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
5558
custom_holiday: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
5659
session: Optional[bigframes.session.Session] = None,
@@ -78,7 +81,7 @@ def create_model(
7881
The OUTPUT clause, which specifies the schema of the output data.
7982
connection_name (str, optional):
8083
The connection to use for the model.
81-
options (Mapping[str, Union[str, int, float, bool, list]], optional):
84+
options (Mapping[str, Union[str, int, float, bool, list, bigframes.core.col.Expression]], optional):
8285
The OPTIONS clause, which specifies the model options.
8386
training_data (Union[bigframes.pandas.DataFrame, str], optional):
8487
The query or DataFrame to use for training the model.

packages/bigframes/bigframes/core/sql/ml.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
from typing import Any, Dict, List, Mapping, Optional, Union
1818

19+
import bigframes.core.col as col
1920
from bigframes.core.compile.sqlglot import sql as sg_sql
21+
from bigframes.core.compile.sqlglot.expression_compiler import expression_compiler
2022

2123

2224
def create_model_ddl(
@@ -28,7 +30,9 @@ def create_model_ddl(
2830
input_schema: Optional[Mapping[str, str]] = None,
2931
output_schema: Optional[Mapping[str, str]] = None,
3032
connection_name: Optional[str] = None,
31-
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
33+
options: Optional[
34+
Mapping[str, Union[str, int, float, bool, list, "col.Expression"]]
35+
] = None,
3236
training_data: Optional[str] = None,
3337
custom_holiday: Optional[str] = None,
3438
) -> str:
@@ -70,7 +74,10 @@ def create_model_ddl(
7074
if options:
7175
rendered_options = []
7276
for option_name, option_value in options.items():
73-
if isinstance(option_value, (list, tuple)):
77+
if isinstance(option_value, col.Expression):
78+
sg_expr = expression_compiler.compile_expression(option_value._value)
79+
rendered_val = sg_sql.to_sql(sg_expr)
80+
elif isinstance(option_value, (list, tuple)):
7481
# Handle list options like model_registry="vertex_ai"
7582
# wait, usually options are key=value.
7683
# if value is list, it is [val1, val2]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
CREATE MODEL `my_model`
2+
OPTIONS(l2_reg = 0.1 * 10, booster_type = 'gbtree')
3+
AS SELECT * FROM t

packages/bigframes/tests/unit/core/sql/test_ml.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414

1515
import pytest
1616

17+
import bigframes.core.col as col
18+
import bigframes.core.expression as ex
1719
import bigframes.core.sql.ml
20+
import bigframes.dtypes as dtypes
21+
import bigframes.operations.numeric_ops as numeric_ops
1822

1923
pytest.importorskip("pytest_snapshot")
2024

@@ -97,6 +101,26 @@ def test_create_model_list_option(snapshot):
97101
snapshot.assert_match(sql, "create_model_list_option.sql")
98102

99103

104+
def test_create_model_expression_option(snapshot):
105+
# An expression that calls a function on a literal value
106+
# e.g. 0.1 * 10
107+
literal_expr = ex.ScalarConstantExpression(0.1, dtypes.FLOAT_DTYPE)
108+
multiplier_expr = ex.ScalarConstantExpression(10, dtypes.INT_DTYPE)
109+
math_expr = col.Expression(
110+
ex.OpExpression(op=numeric_ops.mul_op, inputs=(literal_expr, multiplier_expr))
111+
)
112+
113+
sql = bigframes.core.sql.ml.create_model_ddl(
114+
model_name="my_model",
115+
options={
116+
"l2_reg": math_expr,
117+
"booster_type": "gbtree",
118+
},
119+
training_data="SELECT * FROM t",
120+
)
121+
snapshot.assert_match(sql, "create_model_expression_option.sql")
122+
123+
100124
def test_evaluate_model_basic(snapshot):
101125
sql = bigframes.core.sql.ml.evaluate(
102126
model_name="my_project.my_dataset.my_model",

0 commit comments

Comments
 (0)