Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions dgf/src/analyse/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Analaysis / data extraction from a schema."""

from typing import Tuple, Union
from dgf.src.data import schema as schema_lib
from dgf.src.util import log

Expand Down Expand Up @@ -142,6 +143,7 @@ def primary_feature_or_none(
def fix_schema(
schema: schema_lib.GraphSchema,
create_pound_id_as_fall_back: bool = False,
fix_shape: bool = True,
):
"""Tries to fix broken/invalid schemas by inferring and setting primary keys.

Expand All @@ -151,9 +153,41 @@ def fix_schema(
one called "#id". This should only be used to consume old GraphAI
datasets. Note that create_pound_id_as_fall_back=True does not check that
the feaure "#id" is actually present in the data.
fix_shapes: If true, fixes the extra None dimension added to all the shapes.
This is a common issue with TF-GNN schemaes.
"""

def shape_is_suspicious(shape: schema_lib.Shape):
return (
shape is not None
and len(shape) == 2
and shape[0] is None
and shape[1] is not None
)

def fix_suspicious_shape(
feature_name,
shape: schema_lib.Shape,
) -> schema_lib.Shape:
log.info("Fix suspicious shape of feature '%s'", feature_name)
assert shape_is_suspicious(shape)
if shape[1] == 1:
return tuple()
else:
return shape[1:]

for nodeset_name, nodeset_def in schema.node_sets.items():
all_shape_are_suspicious = True
for _, feature_schema in nodeset_def.features.items():
all_shape_are_suspicious = (
all_shape_are_suspicious and shape_is_suspicious(feature_schema.shape)
)
if fix_shape and all_shape_are_suspicious:
for feature_name, feature_schema in nodeset_def.features.items():
feature_schema.shape = fix_suspicious_shape(
feature_name, feature_schema.shape
)

if primary_feature_or_none(nodeset_name, nodeset_def) is not None:
continue
# This nodeset has not primary key.
Expand All @@ -178,6 +212,17 @@ def fix_schema(
)

for edgeset_name, edgeset_def in schema.edge_sets.items():
all_shape_are_suspicious = True
for _, feature_schema in edgeset_def.features.items():
all_shape_are_suspicious = (
all_shape_are_suspicious and shape_is_suspicious(feature_schema.shape)
)
if fix_shape and all_shape_are_suspicious:
for feature_name, feature_schema in edgeset_def.features.items():
feature_schema.shape = fix_suspicious_shape(
feature_name, feature_schema.shape
)

if primary_feature_or_none(edgeset_name, edgeset_def) is not None:
continue
# This nodeset has not primary key.
Expand All @@ -197,3 +242,84 @@ def fix_schema(
edgeset_name,
primary_key,
)


def infer_schema_semantic(
schema: schema_lib.GraphSchema,
raise_on_error: bool = True,
) -> schema_lib.GraphSchema:
"""Automatically detects the semantic of features with UNKNOWN semantic.

Usage example:

```python
schema = dgf.analyse.infer_schema_semantic(schema)
```

The logic to infer the semantic is as follows:
- bytes and booleans are considered categorical.
- numerical values (float, integers) are considered numerical.
- integer values where the name starts with "is_" are considered categorical.
- features where the name starts with # are ignored.

Args:
schema: The GraphSchema to infer semantics for.
raise_on_error: If True, raises a ValueError if a feature's semantic cannot
be inferred. If False, logs a warning instead.

Returns:
The modified GraphSchema.
"""
for nodeset_name, nodeset_def in schema.node_sets.items():
_infer_features_semantic(
nodeset_def.features,
container_name=f"nodeset '{nodeset_name}'",
raise_on_error=raise_on_error,
)

for edgeset_name, edgeset_def in schema.edge_sets.items():
_infer_features_semantic(
edgeset_def.features,
container_name=f"edgeset '{edgeset_name}'",
raise_on_error=raise_on_error,
)

return schema


def _infer_features_semantic(
features: schema_lib.FeatureSetSchema,
container_name: str,
raise_on_error: bool = True,
):
for feature_name, feature_schema in features.items():
if feature_name.startswith("#"):
continue
if feature_schema.semantic != schema_lib.FeatureSemantic.UNKNOWN:
continue

fmt = feature_schema.format
inferred = False
if fmt in (schema_lib.FeatureFormat.BYTES, schema_lib.FeatureFormat.BOOL):
feature_schema.semantic = schema_lib.FeatureSemantic.CATEGORICAL
inferred = True
elif fmt.is_numerical():
if fmt.is_integer() and feature_name.startswith("is_"):
feature_schema.semantic = schema_lib.FeatureSemantic.CATEGORICAL
else:
feature_schema.semantic = schema_lib.FeatureSemantic.NUMERICAL
inferred = True

if not inferred:
msg = (
f"Could not infer semantic for feature {feature_name!r} in"
f" {container_name} with format {fmt!r}. Please specify the semantic"
" manually in the schema."
)
if raise_on_error:
raise ValueError(
f"{msg} To disable this error and print a warning instead, set"
" `raise_on_error=False` in `infer_schema_semantic`."
)
else:
log.warning(msg)
189 changes: 189 additions & 0 deletions dgf/src/analyse/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

from absl.testing import absltest
from dgf.src.analyse import schema as schema_lib
from dgf.src.data.schema import EdgeSchema
Expand Down Expand Up @@ -253,6 +255,193 @@ def test_fix_schema_nodeset_no_candidate(self):
):
schema_lib.fix_schema(schema)

def test_fix_schema_create_pound_id_fallback(self):
schema = GraphSchema(
node_sets={
"nodes": NodeSchema(
features={
"feat": FeatureSchema(
format=FeatureFormat.FLOAT_32,
semantic=FeatureSemantic.UNKNOWN,
)
}
)
},
edge_sets={},
)
schema_lib.fix_schema(schema, create_pound_id_as_fall_back=True)
self.assertIn("#id", schema.node_sets["nodes"].features)
self.assertEqual(
schema.node_sets["nodes"].features["#id"].semantic,
FeatureSemantic.PRIMARY_ID,
)
self.assertEqual(
schema.node_sets["nodes"].features["#id"].format,
FeatureFormat.BYTES,
)

def test_fix_schema_fix_suspicious_shape(self):
schema = GraphSchema(
node_sets={
"nodes": NodeSchema(
features={
"id": FeatureSchema(
format=FeatureFormat.BYTES,
semantic=FeatureSemantic.PRIMARY_ID,
shape=(None, 1),
),
"feat": FeatureSchema(
format=FeatureFormat.FLOAT_32,
semantic=FeatureSemantic.NUMERICAL,
shape=(None, 10),
),
}
)
},
edge_sets={},
)
schema_lib.fix_schema(schema, fix_shape=True)
self.assertEqual(schema.node_sets["nodes"].features["id"].shape, ())
self.assertEqual(schema.node_sets["nodes"].features["feat"].shape, (10,))

def test_infer_schema_semantic(self):

schema = GraphSchema(
node_sets={
"nodes": NodeSchema(
features={
# Bytes -> Categorical
"f_bytes": FeatureSchema(
format=FeatureFormat.BYTES,
semantic=FeatureSemantic.UNKNOWN,
),
# Bool -> Categorical
"f_bool": FeatureSchema(
format=FeatureFormat.BOOL,
semantic=FeatureSemantic.UNKNOWN,
),
# Float -> Numerical
"f_float": FeatureSchema(
format=FeatureFormat.FLOAT_32,
semantic=FeatureSemantic.UNKNOWN,
),
# Int -> Numerical
"f_int": FeatureSchema(
format=FeatureFormat.INTEGER_64,
semantic=FeatureSemantic.UNKNOWN,
),
# Int with "is_" -> Categorical
"is_active": FeatureSchema(
format=FeatureFormat.INTEGER_32,
semantic=FeatureSemantic.UNKNOWN,
),
# Starts with "#" -> Ignored
"#id": FeatureSchema(
format=FeatureFormat.BYTES,
semantic=FeatureSemantic.UNKNOWN,
),
# Already set -> Ignored
"already_set": FeatureSchema(
format=FeatureFormat.FLOAT_32,
semantic=FeatureSemantic.EMBEDDING,
),
}
)
},
edge_sets={
"edges": EdgeSchema(
source="nodes",
target="nodes",
features={
"f_bytes_edge": FeatureSchema(
format=FeatureFormat.BYTES,
semantic=FeatureSemantic.UNKNOWN,
),
},
)
},
)

inferred_schema = schema_lib.infer_schema_semantic(schema)

node_features = inferred_schema.node_sets["nodes"].features
self.assertEqual(
node_features["f_bytes"].semantic, FeatureSemantic.CATEGORICAL
)
self.assertEqual(
node_features["f_bool"].semantic, FeatureSemantic.CATEGORICAL
)
self.assertEqual(
node_features["f_float"].semantic, FeatureSemantic.NUMERICAL
)
self.assertEqual(node_features["f_int"].semantic, FeatureSemantic.NUMERICAL)
self.assertEqual(
node_features["is_active"].semantic, FeatureSemantic.CATEGORICAL
)
self.assertEqual(node_features["#id"].semantic, FeatureSemantic.UNKNOWN)
self.assertEqual(
node_features["already_set"].semantic, FeatureSemantic.EMBEDDING
)

edge_features = inferred_schema.edge_sets["edges"].features
self.assertEqual(
edge_features["f_bytes_edge"].semantic, FeatureSemantic.CATEGORICAL
)

def test_infer_schema_semantic_cannot_infer_raise(self):
class FakeFormat:

def is_numerical(self):
return False

def is_integer(self):
return False

schema = GraphSchema(
node_sets={
"nodes": NodeSchema(
features={
"f_weird": FeatureSchema(
format=FakeFormat(),
semantic=FeatureSemantic.UNKNOWN,
),
}
)
},
edge_sets={},
)
with self.assertRaisesRegex(
ValueError, "Could not infer semantic for feature 'f_weird'"
):
schema_lib.infer_schema_semantic(schema, raise_on_error=True)

@mock.patch("dgf.src.analyse.schema.log.warning")
def test_infer_schema_semantic_cannot_infer_warn(self, mock_warning):
class FakeFormat:

def is_numerical(self):
return False

def is_integer(self):
return False

schema = GraphSchema(
node_sets={
"nodes": NodeSchema(
features={
"f_weird": FeatureSchema(
format=FakeFormat(),
semantic=FeatureSemantic.UNKNOWN,
),
}
)
},
edge_sets={},
)
schema_lib.infer_schema_semantic(schema, raise_on_error=False)
mock_warning.assert_called_once()
self.assertIn("Could not infer semantic", mock_warning.call_args[0][0])


if __name__ == "__main__":
absltest.main()
1 change: 1 addition & 0 deletions dgf/src/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ py_library(
"//dgf/src/analyse:in_process_feature_statistics",
"//dgf/src/analyse:padding",
"//dgf/src/analyse:print_schema",
"//dgf/src/analyse:schema",
"//dgf/src/analyse/reports:data_model",
"//dgf/src/analyse/reports:reporter",
"//dgf/src/analyse/topology:global_graph_topology",
Expand Down
2 changes: 2 additions & 0 deletions dgf/src/api/analyse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

# TODO: Use third_party/py/dgf/src/api/print.py version instead.
from dgf.src.analyse.print_schema import print_schema
from dgf.src.analyse.schema import infer_schema_semantic


from dgf.src.analyse.reports import data_model as reports_data_model
from dgf.src.analyse.reports import reporter
Expand Down
Loading