Skip to content

Commit d4afa2c

Browse files
authored
feat(bigframes): implement ai.similarity (#16771)
Fixes b/497837587
1 parent f6e916c commit d4afa2c

14 files changed

Lines changed: 262 additions & 0 deletions

File tree

packages/bigframes/bigframes/bigquery/_operations/ai.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,87 @@ def score(
976976
return series_list[0]._apply_nary_op(operator, series_list[1:])
977977

978978

979+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
980+
def similarity(
981+
content1: str | series.Series | pd.Series,
982+
content2: str | series.Series | pd.Series,
983+
*,
984+
endpoint: str | None = None,
985+
model: str | None = None,
986+
model_params: Mapping[Any, Any] | None = None,
987+
connection_id: str | None = None,
988+
) -> series.Series:
989+
"""
990+
Returns a FLOAT64 value that represents the cosine similarity between the two inputs.
991+
992+
**Examples:**
993+
994+
>>> import bigframes.pandas as bpd
995+
>>> import bigframes.bigquery as bbq
996+
>>> df = bpd.DataFrame({'word': ['happy', 'sad']})
997+
>>> bbq.ai.similarity(df['word'], 'glad', endpoint='text-embedding-005') # doctest: +SKIP
998+
0 0.916601
999+
1 0.660579
1000+
1001+
Args:
1002+
content1 (str | Series):
1003+
A string or series that provides the first value to compare. Both a BigFrames Series or a pandas Series are allowed.
1004+
content2 (str | Series):
1005+
A string or series that provides the second value to compare. Both a BigFrames Series or a pandas Series are allowed.
1006+
endpoint (str, optional):
1007+
Specifies the Vertex AI endpoint to use for the text embedding model.
1008+
If you specify the model name, such as `'text-embedding-005'`, rather than a URL, then BigQuery ML automatically identifies the model and uses the model's full endpoint.
1009+
model (str, optional):
1010+
Specifies a built-in text embedding model. The only supported value is the embeddinggemma-300m model.
1011+
If you specify this parameter, you can't specify the `endpoint`, `model_params`, or `connection_id` parameters.
1012+
model_params (Mapping[Any, Any], optional):
1013+
Provides additional parameters to the model. You can use any of the parameters object fields.
1014+
One of these fields, `outputDimensionality`, lets you specify the number of dimensions to use when generating embeddings.
1015+
connection_id (str, optional):
1016+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
1017+
1018+
Returns:
1019+
bigframes.series.Series: A new series of FLOAT64 values representing the cosine similarity.
1020+
"""
1021+
if model is not None:
1022+
if any(x is not None for x in [endpoint, model_params, connection_id]):
1023+
raise ValueError(
1024+
"If 'model' is specified, you cannot specify 'endpoint', 'model_params', or 'connection_id'."
1025+
)
1026+
elif endpoint is None:
1027+
raise ValueError("You must specify either 'model' or 'endpoint'.")
1028+
1029+
operator = ai_ops.AISimilarity(
1030+
endpoint=endpoint,
1031+
model=model,
1032+
model_params=json.dumps(model_params) if model_params else None,
1033+
connection_id=connection_id,
1034+
)
1035+
1036+
# Find a unifying session for the subsequent operations.
1037+
bf_session = None
1038+
if isinstance(content1, series.Series):
1039+
bf_session = content1._session
1040+
elif isinstance(content2, series.Series):
1041+
bf_session = content2._session
1042+
1043+
if isinstance(content1, str) and isinstance(content2, str):
1044+
content1 = series.Series([content1], session=bf_session)
1045+
return content1._apply_binary_op(content2, operator)
1046+
elif isinstance(content1, str):
1047+
# content2 must be a series
1048+
content2 = convert.to_bf_series(
1049+
content2, default_index=None, session=bf_session
1050+
)
1051+
return content2._apply_binary_op(content1, operator)
1052+
else:
1053+
# content1 must be a series.
1054+
content1 = convert.to_bf_series(
1055+
content1, default_index=None, session=bf_session
1056+
)
1057+
return content1._apply_binary_op(content2, operator)
1058+
1059+
9791060
@log_adapter.method_logger(custom_base_name="bigquery_ai")
9801061
def forecast(
9811062
df: dataframe.DataFrame | pd.DataFrame,

packages/bigframes/bigframes/bigquery/ai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
generate_text,
7070
if_,
7171
score,
72+
similarity,
7273
)
7374

7475
__all__ = [
@@ -84,4 +85,5 @@
8485
"generate_text",
8586
"if_",
8687
"score",
88+
"similarity",
8789
]

packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,6 +2005,20 @@ def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructVal
20052005
).to_expr()
20062006

20072007

2008+
@scalar_op_compiler.register_binary_op(ops.AISimilarity, pass_op=True)
2009+
def ai_similarity(
2010+
content1: ibis_types.Value, content2: ibis_types.Value, op: ops.AISimilarity
2011+
) -> ibis_types.Value:
2012+
return ai_ops.AISimilarity(
2013+
content1, # type: ignore
2014+
content2, # type: ignore
2015+
op.endpoint, # type: ignore
2016+
op.model, # type: ignore
2017+
op.model_params, # type: ignore
2018+
op.connection_id, # type: ignore
2019+
).to_expr()
2020+
2021+
20082022
def _construct_prompt(
20092023
col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None]
20102024
) -> ibis_types.StructValue:

packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2525

2626
register_nary_op = expression_compiler.expression_compiler.register_nary_op
27+
register_binary_op = expression_compiler.expression_compiler.register_binary_op
2728
register_unary_op = expression_compiler.expression_compiler.register_unary_op
2829

2930

@@ -85,6 +86,16 @@ def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:
8586
return sge.func("AI.SCORE", *args)
8687

8788

89+
@register_binary_op(ops.AISimilarity, pass_op=True)
90+
def _(content1: TypedExpr, content2: TypedExpr, op: ops.AISimilarity) -> sge.Expression:
91+
args = [
92+
sge.Kwarg(this="content1", expression=content1.expr),
93+
sge.Kwarg(this="content2", expression=content2.expr),
94+
] + _construct_named_args(op)
95+
96+
return sge.func("AI.SIMILARITY", *args)
97+
98+
8899
def _construct_prompt(
89100
exprs: tuple[TypedExpr, ...],
90101
prompt_context: tuple[str | None, ...],

packages/bigframes/bigframes/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AIGenerateInt,
2424
AIIf,
2525
AIScore,
26+
AISimilarity,
2627
)
2728
from bigframes.operations.array_ops import (
2829
ArrayIndexOp,
@@ -438,6 +439,7 @@
438439
"AIEmbed",
439440
"AIIf",
440441
"AIScore",
442+
"AISimilarity",
441443
# Numpy ops mapping
442444
"NUMPY_TO_BINOP",
443445
"NUMPY_TO_OP",

packages/bigframes/bigframes/operations/ai_ops.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,16 @@ class AIScore(base_ops.NaryOp):
172172

173173
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
174174
return dtypes.FLOAT_DTYPE
175+
176+
177+
@dataclasses.dataclass(frozen=True)
178+
class AISimilarity(base_ops.BinaryOp):
179+
name: ClassVar[str] = "ai_similarity"
180+
181+
endpoint: str | None
182+
model: str | None
183+
model_params: str | None
184+
connection_id: str | None
185+
186+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
187+
return dtypes.FLOAT_DTYPE

packages/bigframes/tests/system/small/bigquery/test_ai.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,5 +433,53 @@ def test_forecast_w_params(time_series_df_default_index: dataframe.DataFrame):
433433
)
434434

435435

436+
def test_ai_similarity(session):
437+
s1 = bpd.Series(["happy", "sad"], session=session)
438+
s2 = pd.Series(["glad", "angry"])
439+
440+
result = bbq.ai.similarity(s1, s2, endpoint="text-embedding-005")
441+
442+
assert _contains_no_nulls(result)
443+
assert result.dtype == dtypes.FLOAT_DTYPE
444+
445+
446+
def test_ai_similarity_one_content_is_string_literal(session):
447+
s1 = "happy"
448+
s2 = bpd.Series(["glad", "angry"], session=session)
449+
450+
result = bbq.ai.similarity(s1, s2, model="embeddinggemma-300m")
451+
452+
assert _contains_no_nulls(result)
453+
assert result.dtype == dtypes.FLOAT_DTYPE
454+
455+
456+
def test_ai_similarity_both_contents_are_string_literals(session):
457+
s1 = "happy"
458+
s2 = "glad"
459+
460+
result = bbq.ai.similarity(s1, s2, endpoint="text-embedding-005")
461+
462+
assert _contains_no_nulls(result)
463+
assert result.dtype == dtypes.FLOAT_DTYPE
464+
465+
466+
def test_ai_similarity_no_endpoint_or_model__raises_error(session):
467+
s1 = bpd.Series(["happy", "sad"], session=session)
468+
s2 = bpd.Series(["glad", "angry"], session=session)
469+
470+
with pytest.raises(ValueError):
471+
bbq.ai.similarity(s1, s2)
472+
473+
474+
def test_ai_similarity_both_endpoint_and_model__raises_error(session):
475+
s1 = "happy"
476+
s2 = "glad"
477+
478+
with pytest.raises(ValueError):
479+
bbq.ai.similarity(
480+
s1, s2, endpoint="text-embedding-005", model="embeddinggemma-300m"
481+
)
482+
483+
436484
def _contains_no_nulls(s: series.Series) -> bool:
437485
return len(s) == s.count()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.SIMILARITY(content1 => `string_col`, content2 => `string_col`, endpoint => 'text-embedding-005') AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
SELECT
2+
AI.SIMILARITY(
3+
content1 => `string_col`,
4+
content2 => `string_col`,
5+
endpoint => 'text-embedding-005',
6+
connection_id => 'bigframes-dev.us.bigframes-default-connection'
7+
) AS `result`
8+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.SIMILARITY(content1 => `string_col`, content2 => `string_col`, model => 'embeddinggemma-300m') AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

0 commit comments

Comments
 (0)