Skip to content

Commit fcb4579

Browse files
authored
feat(bigframes): implement ai.embed (#16759)
Fixes b/497836685 πŸ¦•
1 parent ef3940a commit fcb4579

15 files changed

Lines changed: 347 additions & 1 deletion

File tree

β€Žpackages/bigframes/bigframes/bigquery/_operations/ai.pyβ€Ž

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,113 @@ def generate_table(
705705
return session.read_gbq_query(query)
706706

707707

708+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
709+
def embed(
710+
content: str | series.Series | pd.Series,
711+
*,
712+
endpoint: str | None = None,
713+
model: str | None = None,
714+
task_type: (
715+
Literal[
716+
"retrieval_query",
717+
"retrieval_document",
718+
"semantic_similarity",
719+
"classification",
720+
"clustering",
721+
"question_answering",
722+
"fact_verification",
723+
"code_retrieval_query",
724+
]
725+
| None
726+
) = None,
727+
title: str | None = None,
728+
model_params: Mapping[Any, Any] | None = None,
729+
connection_id: str | None = None,
730+
) -> series.Series:
731+
"""
732+
Creates embeddings from text or image data in BigQuery.
733+
734+
**Examples:**
735+
736+
>>> import bigframes.pandas as bpd
737+
>>> import bigframes.bigquery as bbq
738+
>>> bbq.ai.embed("dog", endpoint="text-embedding-005") # doctest: +SKIP
739+
0 {'result': array([ 1.78243860e-03, -1.10658340...
740+
741+
>>> s = bpd.Series(['dog']) # doctest: +SKIP
742+
>>> bbq.ai.embed(s, endpoint='text-embedding-005') # doctest: +SKIP
743+
0 {'result': array([ 1.78243860e-03, -1.10658340...
744+
745+
Args:
746+
content (str | Series):
747+
A string literal or a Series (either BigFrames series or pandas Series) that provides the text or image to embed.
748+
endpoint (str, optional):
749+
A string value that specifies a supported Vertex AI embedding model endpoint to use.
750+
The endpoint value that you specify must include the model version, for example,
751+
`"text-embedding-005"`. If you specify this parameter, you can't specify the
752+
`model` parameter.
753+
model (str, optional):
754+
A string value that specifies a built-in embedding model. The only supported value is
755+
`"embeddinggemma-300m"`. If you specify this parameter, you can't specify the `endpoint`,
756+
`title`, `model_params`, or `connection_id` parameters.
757+
task_type (str, optional):
758+
A string literal that specifies the intended downstream application to help the model
759+
produce better quality embeddings. Accepts `"retrieval_query"`, `"retrieval_document"`,
760+
`"semantic_similarity"`, `"classification"`, `"clustering"`, `"question_answering"`,
761+
`"fact_verification"`, `"code_retrieval_query"`.
762+
title (str, optional):
763+
A string value that specifies the document title, which the model uses to improve
764+
embedding quality. You can only use this parameter if you specify `"retrieval_document"`
765+
for the `task_type` value.
766+
model_params (Mapping[Any, Any], optional):
767+
A JSON literal that provides additional parameters to the model. For example,
768+
`{"outputDimensionality": 768}` lets you specify the number of dimensions to use when
769+
generating embeddings.
770+
connection_id (str, optional):
771+
A STRING value specifying the connection to use to communicate with the model, in the
772+
format `PROJECT_ID.LOCATION.CONNECTION_ID`. For example, `myproject.us.myconnection`.
773+
If not provided, the query uses your end-user credential.
774+
775+
Returns:
776+
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
777+
* "result": an ARRAY<FLOAT64> value containing the generated embeddings.
778+
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
779+
"""
780+
781+
if model is not None:
782+
if any(x is not None for x in [endpoint, title, model_params, connection_id]):
783+
raise ValueError(
784+
"You cannot specify endpoint, title, model_params, or connection_id when the model is set."
785+
)
786+
elif endpoint is None:
787+
raise ValueError(
788+
"You must specify exactly one of 'endpoint' or 'model' argument."
789+
)
790+
791+
if title is not None and task_type != "retrieval_document":
792+
raise ValueError(
793+
"You can only use 'title' parameter if you specify retrieval_document for the task_type value."
794+
)
795+
796+
operator = ai_ops.AIEmbed(
797+
endpoint=endpoint,
798+
model=model,
799+
task_type=task_type,
800+
title=title,
801+
model_params=json.dumps(model_params) if model_params else None,
802+
connection_id=connection_id,
803+
)
804+
805+
if isinstance(content, str):
806+
return series.Series([content])._apply_unary_op(operator)
807+
elif isinstance(content, pd.Series):
808+
return series.Series(content)._apply_unary_op(operator)
809+
elif isinstance(content, series.Series):
810+
return content._apply_unary_op(operator)
811+
else:
812+
raise ValueError(f"Unsupported 'content' parameter type: {type(content)}")
813+
814+
708815
@log_adapter.method_logger(custom_base_name="bigquery_ai")
709816
def if_(
710817
prompt: PROMPT_TYPE,

β€Žpackages/bigframes/bigframes/bigquery/ai.pyβ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858

5959
from bigframes.bigquery._operations.ai import (
6060
classify,
61+
embed,
6162
forecast,
6263
generate,
6364
generate_bool,
@@ -72,6 +73,7 @@
7273

7374
__all__ = [
7475
"classify",
76+
"embed",
7577
"forecast",
7678
"generate",
7779
"generate_bool",

β€Žpackages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.pyβ€Ž

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,6 +1965,19 @@ def ai_generate_double(
19651965
).to_expr()
19661966

19671967

1968+
@scalar_op_compiler.register_unary_op(ops.AIEmbed, pass_op=True)
1969+
def ai_embed(value: ibis_types.Value, op: ops.AIEmbed) -> ibis_types.StructValue:
1970+
return ai_ops.AIEmbed(
1971+
value, # type: ignore
1972+
connection_id=op.connection_id, # type: ignore
1973+
endpoint=op.endpoint, # type: ignore
1974+
model=op.model, # type: ignore
1975+
task_type=op.task_type.upper() if op.task_type is not None else None, # type: ignore
1976+
title=op.title, # type: ignore
1977+
model_params=op.model_params, # type: ignore
1978+
).to_expr()
1979+
1980+
19681981
@scalar_op_compiler.register_nary_op(ops.AIIf, pass_op=True)
19691982
def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue:
19701983
return ai_ops.AIIf(

β€Žpackages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.pyβ€Ž

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from dataclasses import asdict
18+
from typing import Any
1819

1920
import bigframes_vendored.sqlglot.expressions as sge
2021

@@ -23,6 +24,7 @@
2324
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2425

2526
register_nary_op = expression_compiler.expression_compiler.register_nary_op
27+
register_unary_op = expression_compiler.expression_compiler.register_unary_op
2628

2729

2830
@register_nary_op(ops.AIGenerate, pass_op=True)
@@ -53,6 +55,13 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression:
5355
return sge.func("AI.GENERATE_DOUBLE", *args)
5456

5557

58+
@register_unary_op(ops.AIEmbed, pass_op=True)
59+
def _(expr: TypedExpr, op: ops.AIEmbed) -> sge.Expression:
60+
args: list[Any] = [expr.expr] + _construct_named_args(op)
61+
62+
return sge.func("AI.EMBED", *args)
63+
64+
5665
@register_nary_op(ops.AIIf, pass_op=True)
5766
def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression:
5867
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
@@ -94,7 +103,7 @@ def _construct_prompt(
94103
return sge.Kwarg(this=param_name, expression=sge.Tuple(expressions=prompt))
95104

96105

97-
def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
106+
def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
98107
args = []
99108

100109
op_args = asdict(op)

β€Žpackages/bigframes/bigframes/operations/__init__.pyβ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from bigframes.operations.ai_ops import (
1818
AIClassify,
19+
AIEmbed,
1920
AIGenerate,
2021
AIGenerateBool,
2122
AIGenerateDouble,
@@ -434,6 +435,7 @@
434435
"AIGenerateBool",
435436
"AIGenerateDouble",
436437
"AIGenerateInt",
438+
"AIEmbed",
437439
"AIIf",
438440
"AIScore",
439441
# Numpy ops mapping

β€Žpackages/bigframes/bigframes/operations/ai_ops.pyβ€Ž

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,28 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
118118
)
119119

120120

121+
@dataclasses.dataclass(frozen=True)
122+
class AIEmbed(base_ops.UnaryOp):
123+
name: ClassVar[str] = "ai_embed"
124+
125+
endpoint: str | None
126+
model: str | None
127+
task_type: str | None
128+
title: str | None
129+
model_params: str | None
130+
connection_id: str | None
131+
132+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
133+
return pd.ArrowDtype(
134+
pa.struct(
135+
(
136+
pa.field("result", pa.list_(pa.float64())),
137+
pa.field("status", pa.string()),
138+
)
139+
)
140+
)
141+
142+
121143
@dataclasses.dataclass(frozen=True)
122144
class AIIf(base_ops.NaryOp):
123145
name: ClassVar[str] = "ai_if"

β€Žpackages/bigframes/tests/system/small/bigquery/test_ai.pyβ€Ž

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,69 @@ def test_ai_generate_double_multi_model(session):
255255
)
256256

257257

258+
def test_ai_embed_series_content(session):
259+
content = bpd.Series(["dog"], session=session)
260+
261+
result = bbq.ai.embed(content, endpoint="text-embedding-005")
262+
263+
assert _contains_no_nulls(result)
264+
assert result.dtype == pd.ArrowDtype(
265+
pa.struct(
266+
(
267+
pa.field("result", pa.list_(pa.float64())),
268+
pa.field("status", pa.string()),
269+
)
270+
)
271+
)
272+
273+
274+
def test_ai_embed_string_content(session):
275+
with mock.patch(
276+
"bigframes.core.global_session.get_global_session"
277+
) as mock_get_session:
278+
mock_get_session.return_value = session
279+
280+
result = bbq.ai.embed("dog", endpoint="text-embedding-005")
281+
282+
assert _contains_no_nulls(result)
283+
assert result.dtype == pd.ArrowDtype(
284+
pa.struct(
285+
(
286+
pa.field("result", pa.list_(pa.float64())),
287+
pa.field("status", pa.string()),
288+
)
289+
)
290+
)
291+
292+
293+
def test_ai_embed_no_endpoint_or_model_raises_error(session):
294+
content = bpd.Series(["dog"], session=session)
295+
296+
with pytest.raises(ValueError):
297+
bbq.ai.embed(content)
298+
299+
300+
def test_ai_embed_both_model_and_endpoint_are_set_raises_error(session):
301+
content = bpd.Series(["dog"], session=session)
302+
303+
with pytest.raises(ValueError):
304+
bbq.ai.embed(
305+
content, endpoint="text-embedding-005", model="embeddinggemma-300m model"
306+
)
307+
308+
309+
def test_ai_embed_title_and_task_type_mismatch_raises_error(session):
310+
content = bpd.Series(["dog"], session=session)
311+
312+
with pytest.raises(ValueError):
313+
bbq.ai.embed(
314+
content,
315+
endpoint="text-embedding-005",
316+
title="my title",
317+
task_type="text_similarity",
318+
)
319+
320+
258321
def test_ai_if(session):
259322
s1 = bpd.Series(["apple", "bear"], session=session)
260323
s2 = bpd.Series(["fruit", "tree"], session=session)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.EMBED(`string_col`, endpoint => 'text-embedding-005') AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
SELECT
2+
AI.EMBED(
3+
`string_col`,
4+
endpoint => 'text-embedding-005',
5+
connection_id => 'bigframes-dev.us.bigframes-default-connection'
6+
) AS `result`
7+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.EMBED(`string_col`, model => 'embeddinggemma-300m') AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

0 commit comments

Comments
Β (0)