Skip to content

Commit ac38f0f

Browse files
committed
Add databricks support
1 parent c2e2dc3 commit ac38f0f

3 files changed

Lines changed: 60 additions & 8 deletions

File tree

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import typing as t
55
from functools import partial
66

7-
from sqlglot import exp
7+
from sqlglot import exp, parse_one
88
from sqlmesh.core.dialect import to_schema
99
from sqlmesh.core.engine_adapter.shared import (
1010
CatalogSupport,
@@ -16,6 +16,7 @@
1616
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
1717
from sqlmesh.core.node import IntervalUnit
1818
from sqlmesh.core.schema_diff import NestedSupport
19+
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
1920
from sqlmesh.engines.spark.db_api.spark_session import connection, SparkSessionConnection
2021
from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError
2122

@@ -34,6 +35,7 @@ class DatabricksEngineAdapter(SparkEngineAdapter):
3435
SUPPORTS_CLONING = True
3536
SUPPORTS_MATERIALIZED_VIEWS = True
3637
SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
38+
SUPPORTS_QUERY_EXECUTION_TRACKING = True
3739
SCHEMA_DIFFER_KWARGS = {
3840
"support_positional_add": True,
3941
"nested_support": NestedSupport.ALL,
@@ -363,3 +365,52 @@ def _build_table_properties_exp(
363365
expressions.append(clustered_by_exp)
364366
properties = exp.Properties(expressions=expressions)
365367
return properties
368+
369+
def _record_execution_stats(
370+
self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None
371+
) -> None:
372+
parsed = parse_one(sql, dialect=self.dialect)
373+
table = parsed.find(exp.Table)
374+
table_name = table.sql(dialect=self.dialect) if table else None
375+
376+
if table_name:
377+
try:
378+
self.cursor.execute(f"DESCRIBE HISTORY {table_name}")
379+
except:
380+
return
381+
382+
history = self.cursor.fetchall_arrow()
383+
if history.num_rows:
384+
history_df = history.to_pandas()
385+
write_df = history_df[history_df["operation"] == "WRITE"]
386+
write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()]
387+
if not write_df.empty:
388+
metrics = write_df["operationMetrics"][0]
389+
if metrics:
390+
rowcount = None
391+
rowcount_str = [
392+
metric[1] for metric in metrics if metric[0] == "numOutputRows"
393+
]
394+
if rowcount_str:
395+
try:
396+
rowcount = int(rowcount_str[0])
397+
except (TypeError, ValueError):
398+
pass
399+
400+
bytes_processed = None
401+
bytes_str = [
402+
metric[1] for metric in metrics if metric[0] == "numOutputBytes"
403+
]
404+
if bytes_str:
405+
try:
406+
bytes_processed = int(bytes_str[0])
407+
except (TypeError, ValueError):
408+
pass
409+
410+
if rowcount is not None or bytes_processed is not None:
411+
# if no rows were written, df contains 0 for bytes but no value for rows
412+
rowcount = (
413+
0 if rowcount is None and bytes_processed is not None else rowcount
414+
)
415+
416+
QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed)

sqlmesh/core/snapshot/execution_tracker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __post_init__(self) -> None:
4141
def add_execution(
4242
self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int]
4343
) -> None:
44-
if row_count is not None:
44+
if row_count is not None and row_count >= 0:
4545
if self.stats.total_rows_processed is None:
4646
self.stats.total_rows_processed = row_count
4747
else:

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2465,11 +2465,12 @@ def capture_execution_stats(
24652465
# seed rows aren't tracked
24662466
assert actual_execution_stats["seed_model"].total_rows_processed is None
24672467

2468-
if ctx.mark.startswith("bigquery"):
2469-
assert actual_execution_stats["incremental_model"].total_bytes_processed
2470-
assert actual_execution_stats["full_model"].total_bytes_processed
2468+
if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"):
2469+
assert actual_execution_stats["incremental_model"].total_bytes_processed is not None
2470+
assert actual_execution_stats["full_model"].total_bytes_processed is not None
24712471

24722472
# run that loads 0 rows in incremental model
2473+
actual_execution_stats = {}
24732474
with patch.object(
24742475
context.console, "update_snapshot_evaluation_progress", capture_execution_stats
24752476
):
@@ -2483,9 +2484,9 @@ def capture_execution_stats(
24832484
None if ctx.mark.startswith("snowflake") else 3
24842485
)
24852486

2486-
if ctx.mark.startswith("bigquery"):
2487-
assert actual_execution_stats["incremental_model"].total_bytes_processed
2488-
assert actual_execution_stats["full_model"].total_bytes_processed
2487+
if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"):
2488+
assert actual_execution_stats["incremental_model"].total_bytes_processed is not None
2489+
assert actual_execution_stats["full_model"].total_bytes_processed is not None
24892490

24902491
# make and validate unmodified dev environment
24912492
no_change_plan: Plan = context.plan_builder(

0 commit comments

Comments
 (0)