Skip to content

Commit c5743cf

Browse files
Add implicit query builder conversions from bool to BoolExpr (#4547)
# Description of Changes Adds implicit query builder conversions from `bool` to `BoolExpr` so that you can write: ```rust ctx.from.user().r#where(|u| u.online) ``` instead of ```rust ctx.from.user().r#where(|u| u.online.eq(true)) ``` Also removes `NullableCol` and `NullableIxCol` types from C# query builder. # API and ABI breaking changes None # Expected complexity level and risk 1 # Testing Unit and smoketests
1 parent 292bda8 commit c5743cf

20 files changed

Lines changed: 241 additions & 571 deletions

File tree

crates/bindings-csharp/BSATN.Runtime/QueryBuilder.cs

Lines changed: 79 additions & 406 deletions
Large diffs are not rendered by default.

crates/bindings-csharp/Codegen.Tests/fixtures/diag/snapshots/Module#FFI.verified.cs

Lines changed: 9 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/bindings-csharp/Codegen.Tests/fixtures/server/snapshots/Module#FFI.verified.cs

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -248,14 +248,8 @@ public readonly struct PublicTableCols
248248
global::PublicTable,
249249
System.Collections.Generic.List<int>
250250
> ListField;
251-
public readonly global::SpacetimeDB.NullableCol<
252-
global::PublicTable,
253-
int
254-
> NullableValueField;
255-
public readonly global::SpacetimeDB.NullableCol<
256-
global::PublicTable,
257-
string
258-
> NullableReferenceField;
251+
public readonly global::SpacetimeDB.Col<global::PublicTable, int> NullableValueField;
252+
public readonly global::SpacetimeDB.Col<global::PublicTable, string> NullableReferenceField;
259253

260254
internal PublicTableCols(string tableName)
261255
{
@@ -357,14 +351,14 @@ internal PublicTableCols(string tableName)
357351
global::PublicTable,
358352
System.Collections.Generic.List<int>
359353
>(tableName, "ListField");
360-
NullableValueField = new global::SpacetimeDB.NullableCol<global::PublicTable, int>(
354+
NullableValueField = new global::SpacetimeDB.Col<global::PublicTable, int>(
361355
tableName,
362356
"NullableValueField"
363357
);
364-
NullableReferenceField = new global::SpacetimeDB.NullableCol<
365-
global::PublicTable,
366-
string
367-
>(tableName, "NullableReferenceField");
358+
NullableReferenceField = new global::SpacetimeDB.Col<global::PublicTable, string>(
359+
tableName,
360+
"NullableReferenceField"
361+
);
368362
}
369363
}
370364

crates/bindings-csharp/Codegen/Module.cs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -904,9 +904,7 @@ string ColDecl(ColumnDeclaration col)
904904
var typeName = col.Type.Name;
905905
var isNullable = typeName.EndsWith("?", StringComparison.Ordinal);
906906
var valueTypeName = isNullable ? typeName[..^1] : typeName;
907-
var colType = isNullable
908-
? "global::SpacetimeDB.NullableCol"
909-
: "global::SpacetimeDB.Col";
907+
var colType = isNullable ? "global::SpacetimeDB.Col" : "global::SpacetimeDB.Col";
910908
return $"public readonly {colType}<{globalRowName}, {valueTypeName}> {col.Name};";
911909
}
912910

@@ -915,9 +913,7 @@ string ColInit(ColumnDeclaration col)
915913
var typeName = col.Type.Name;
916914
var isNullable = typeName.EndsWith("?", StringComparison.Ordinal);
917915
var valueTypeName = isNullable ? typeName[..^1] : typeName;
918-
var colType = isNullable
919-
? "global::SpacetimeDB.NullableCol"
920-
: "global::SpacetimeDB.Col";
916+
var colType = isNullable ? "global::SpacetimeDB.Col" : "global::SpacetimeDB.Col";
921917
return $"{col.Name} = new {colType}<{globalRowName}, {valueTypeName}>(tableName, \"{col.Name}\");";
922918
}
923919

@@ -950,7 +946,7 @@ string IxColDecl(ColumnDeclaration col)
950946
var isNullable = typeName.EndsWith("?", StringComparison.Ordinal);
951947
var valueTypeName = isNullable ? typeName[..^1] : typeName;
952948
var colType = isNullable
953-
? "global::SpacetimeDB.NullableIxCol"
949+
? "global::SpacetimeDB.IxCol"
954950
: "global::SpacetimeDB.IxCol";
955951
return $"public readonly {colType}<{globalRowName}, {valueTypeName}> {col.Name};";
956952
}
@@ -961,7 +957,7 @@ string IxColInit(ColumnDeclaration col)
961957
var isNullable = typeName.EndsWith("?", StringComparison.Ordinal);
962958
var valueTypeName = isNullable ? typeName[..^1] : typeName;
963959
var colType = isNullable
964-
? "global::SpacetimeDB.NullableIxCol"
960+
? "global::SpacetimeDB.IxCol"
965961
: "global::SpacetimeDB.IxCol";
966962
return $"{col.Name} = new {colType}<{globalRowName}, {valueTypeName}>(tableName, \"{col.Name}\");";
967963
}

crates/bindings-typescript/src/lib/query.ts

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import type {
1111
TypeBuilder,
1212
} from './type_builders';
1313
import type { Values } from './type_util';
14+
import type { Bool as SatsBool } from './algebraic_type_variants';
1415

1516
/**
1617
* Helper to get the set of table names.
@@ -65,7 +66,7 @@ type From<TableDef extends TypedTableDef> = RowTypedQuery<
6566
Readonly<{
6667
toSql(): string;
6768
where(
68-
predicate: (row: RowExpr<TableDef>) => BooleanExpr<TableDef>
69+
predicate: (row: RowExpr<TableDef>) => PredicateExpr<TableDef>
6970
): From<TableDef>;
7071
rightSemijoin<RightTable extends TypedTableDef>(
7172
other: TableRef<RightTable>,
@@ -93,7 +94,7 @@ type SemijoinBuilder<TableDef extends TypedTableDef> = RowTypedQuery<
9394
Readonly<{
9495
toSql(): string;
9596
where(
96-
predicate: (row: RowExpr<TableDef>) => BooleanExpr<TableDef>
97+
predicate: (row: RowExpr<TableDef>) => PredicateExpr<TableDef>
9798
): SemijoinBuilder<TableDef>;
9899
/** @deprecated No longer needed — builder is already a valid query. */
99100
build(): Query<TableDef>;
@@ -120,7 +121,7 @@ class SemijoinImpl<TableDef extends TypedTableDef>
120121
}
121122

122123
where(
123-
predicate: (row: RowExpr<TableDef>) => BooleanExpr<TableDef>
124+
predicate: (row: RowExpr<TableDef>) => PredicateExpr<TableDef>
124125
): SemijoinImpl<TableDef> {
125126
const nextSourceQuery = this.sourceQuery.where(predicate);
126127
return new SemijoinImpl<TableDef>(
@@ -167,9 +168,9 @@ class FromBuilder<TableDef extends TypedTableDef>
167168
) {}
168169

169170
where(
170-
predicate: (row: RowExpr<TableDef>) => BooleanExpr<TableDef>
171+
predicate: (row: RowExpr<TableDef>) => PredicateExpr<TableDef>
171172
): FromBuilder<TableDef> {
172-
const newCondition = predicate(this.table.cols);
173+
const newCondition = normalizePredicateExpr(predicate(this.table.cols));
173174
const nextWhere = this.whereClause
174175
? this.whereClause.and(newCondition)
175176
: newCondition;
@@ -308,7 +309,7 @@ class TableRefImpl<TableDef extends TypedTableDef>
308309
}
309310

310311
where(
311-
predicate: (row: RowExpr<TableDef>) => BooleanExpr<TableDef>
312+
predicate: (row: RowExpr<TableDef>) => PredicateExpr<TableDef>
312313
): FromBuilder<TableDef> {
313314
return this.asFrom().where(predicate);
314315
}
@@ -628,6 +629,11 @@ export type ValueExpr<TableDef extends TypedTableDef, Value> =
628629
| LiteralExpr<Value & LiteralValue>
629630
| ColumnExprForValue<TableDef, Value>;
630631

632+
type PredicateExpr<TableDef extends TypedTableDef> =
633+
| BooleanExpr<TableDef>
634+
| ColumnExprForValue<TableDef, SatsBool>
635+
| boolean;
636+
631637
type LiteralExpr<Value> = {
632638
type: 'literal';
633639
value: Value;
@@ -654,6 +660,24 @@ function normalizeValue(val: ValueInput<any>): ValueExpr<any, any> {
654660
return literal(val as LiteralValue);
655661
}
656662

663+
function normalizePredicateExpr<TableDef extends TypedTableDef>(
664+
value: PredicateExpr<TableDef>
665+
): BooleanExpr<TableDef> {
666+
if (value instanceof BooleanExpr) return value;
667+
if (typeof value === 'boolean') {
668+
return new BooleanExpr({
669+
type: 'eq',
670+
left: literal(value),
671+
right: literal(true),
672+
});
673+
}
674+
return new BooleanExpr({
675+
type: 'eq',
676+
left: value as ValueExpr<TableDef, any>,
677+
right: literal(true),
678+
});
679+
}
680+
657681
type EqExpr<Table extends TypedTableDef = any> = BooleanExpr<Table>;
658682

659683
type BooleanExprData<Table extends TypedTableDef> = (

crates/bindings-typescript/tests/query.test.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ const personTable = table(
3030
id: t.identity(),
3131
name: t.string(),
3232
age: t.u32(),
33+
active: t.bool(),
3334
}
3435
);
3536

@@ -141,6 +142,13 @@ describe('TableScan.toSql', () => {
141142
);
142143
});
143144

145+
it('accepts boolean columns directly as where predicates', () => {
146+
const qb = makeQueryBuilder(schemaDef);
147+
const sql = toSql(qb.person.where(row => row.active).build());
148+
149+
expect(sql).toBe(`SELECT * FROM "person" WHERE "person"."active" = TRUE`);
150+
});
151+
144152
it('renders Identity literals using their hex form', () => {
145153
const qb = makeQueryBuilder(schemaDef);
146154
const identity = new Identity(

crates/codegen/src/csharp.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ impl Lang for Csharp<'_> {
659659
for (field_name, field_type) in &product_type.elements {
660660
let prop = field_name.deref().to_case(Case::Pascal);
661661
let (col_ty, ty) = match field_type {
662-
AlgebraicTypeUse::Option(inner) => ("NullableCol", ty_fmt(module, inner).to_string()),
662+
AlgebraicTypeUse::Option(inner) => ("Col", ty_fmt(module, inner).to_string()),
663663
_ => ("Col", ty_fmt(module, field_type).to_string()),
664664
};
665665
writeln!(
@@ -673,7 +673,7 @@ impl Lang for Csharp<'_> {
673673
for (field_name, field_type) in &product_type.elements {
674674
let prop = field_name.deref().to_case(Case::Pascal);
675675
let (col_ty, ty) = match field_type {
676-
AlgebraicTypeUse::Option(inner) => ("NullableCol", ty_fmt(module, inner).to_string()),
676+
AlgebraicTypeUse::Option(inner) => ("Col", ty_fmt(module, inner).to_string()),
677677
_ => ("Col", ty_fmt(module, field_type).to_string()),
678678
};
679679
let col_name = field_name.deref();
@@ -694,7 +694,7 @@ impl Lang for Csharp<'_> {
694694
}
695695
let prop = field_name.deref().to_case(Case::Pascal);
696696
let (col_ty, ty) = match field_type {
697-
AlgebraicTypeUse::Option(inner) => ("NullableIxCol", ty_fmt(module, inner).to_string()),
697+
AlgebraicTypeUse::Option(inner) => ("IxCol", ty_fmt(module, inner).to_string()),
698698
_ => ("IxCol", ty_fmt(module, field_type).to_string()),
699699
};
700700
writeln!(
@@ -711,7 +711,7 @@ impl Lang for Csharp<'_> {
711711
}
712712
let prop = field_name.deref().to_case(Case::Pascal);
713713
let (col_ty, ty) = match field_type {
714-
AlgebraicTypeUse::Option(inner) => ("NullableIxCol", ty_fmt(module, inner).to_string()),
714+
AlgebraicTypeUse::Option(inner) => ("IxCol", ty_fmt(module, inner).to_string()),
715715
_ => ("IxCol", ty_fmt(module, field_type).to_string()),
716716
};
717717
let col_name = field_name.deref();

crates/codegen/tests/snapshots/codegen__codegen_csharp.snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,11 +2035,11 @@ namespace SpacetimeDB
20352035

20362036
public sealed class TestDCols
20372037
{
2038-
public global::SpacetimeDB.NullableCol<TestD, NamespaceTestC> TestC { get; }
2038+
public global::SpacetimeDB.Col<TestD, NamespaceTestC> TestC { get; }
20392039

20402040
public TestDCols(string tableName)
20412041
{
2042-
TestC = new global::SpacetimeDB.NullableCol<TestD, NamespaceTestC>(tableName, "test_c");
2042+
TestC = new global::SpacetimeDB.Col<TestD, NamespaceTestC>(tableName, "test_c");
20432043
}
20442044
}
20452045

crates/query-builder/src/expr.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,28 @@ impl<T> BoolExpr<T> {
3737
}
3838
}
3939

40+
impl<T> From<Col<T, bool>> for BoolExpr<T> {
41+
fn from(col: Col<T, bool>) -> Self {
42+
col.eq(true)
43+
}
44+
}
45+
46+
impl<T> From<bool> for BoolExpr<T> {
47+
fn from(value: bool) -> Self {
48+
if value {
49+
BoolExpr::Eq(
50+
Operand::Literal(LiteralValue("TRUE".to_string())),
51+
Operand::Literal(LiteralValue("TRUE".to_string())),
52+
)
53+
} else {
54+
BoolExpr::Eq(
55+
Operand::Literal(LiteralValue("FALSE".to_string())),
56+
Operand::Literal(LiteralValue("TRUE".to_string())),
57+
)
58+
}
59+
}
60+
}
61+
4062
/// Trait for types that can be used as the right-hand side of a comparison with a column of type V
4163
/// in table T.
4264
///

crates/query-builder/src/join.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,12 @@ impl<R: HasCols, L: HasCols> Query<R> for RightSemiJoin<R, L> {
141141

142142
// LeftSemiJoin where() operates on L
143143
impl<L: HasCols> LeftSemiJoin<L> {
144-
pub fn r#where<F>(self, f: F) -> Self
144+
pub fn r#where<F, E>(self, f: F) -> Self
145145
where
146-
F: Fn(&L::Cols) -> BoolExpr<L>,
146+
F: Fn(&L::Cols) -> E,
147+
E: Into<BoolExpr<L>>,
147148
{
148-
let extra = f(&L::cols(self.left_col.table_name()));
149+
let extra = f(&L::cols(self.left_col.table_name())).into();
149150
let new = match self.where_expr {
150151
Some(existing) => Some(existing.and(extra)),
151152
None => Some(extra),
@@ -159,9 +160,10 @@ impl<L: HasCols> LeftSemiJoin<L> {
159160
}
160161

161162
// Filter is an alias for where
162-
pub fn filter<F>(self, f: F) -> Self
163+
pub fn filter<F, E>(self, f: F) -> Self
163164
where
164-
F: Fn(&L::Cols) -> BoolExpr<L>,
165+
F: Fn(&L::Cols) -> E,
166+
E: Into<BoolExpr<L>>,
165167
{
166168
self.r#where(f)
167169
}
@@ -189,11 +191,12 @@ impl<L: HasCols> LeftSemiJoin<L> {
189191

190192
// RightSemiJoin where() operates on R
191193
impl<R: HasCols, L: HasCols> RightSemiJoin<R, L> {
192-
pub fn r#where<F>(self, f: F) -> Self
194+
pub fn r#where<F, E>(self, f: F) -> Self
193195
where
194-
F: Fn(&R::Cols) -> BoolExpr<R>,
196+
F: Fn(&R::Cols) -> E,
197+
E: Into<BoolExpr<R>>,
195198
{
196-
let extra = f(&R::cols(self.right_col.table_name()));
199+
let extra = f(&R::cols(self.right_col.table_name())).into();
197200
let new = match self.right_where_expr {
198201
Some(existing) => Some(existing.and(extra)),
199202
None => Some(extra),
@@ -208,9 +211,10 @@ impl<R: HasCols, L: HasCols> RightSemiJoin<R, L> {
208211
}
209212

210213
// Filter is an alias for where
211-
pub fn filter<F>(self, f: F) -> Self
214+
pub fn filter<F, E>(self, f: F) -> Self
212215
where
213-
F: Fn(&R::Cols) -> BoolExpr<R>,
216+
F: Fn(&R::Cols) -> E,
217+
E: Into<BoolExpr<R>>,
214218
{
215219
self.r#where(f)
216220
}

0 commit comments

Comments
 (0)