diff --git a/dgf/src/analyse/schema.py b/dgf/src/analyse/schema.py index b952db5..845e951 100644 --- a/dgf/src/analyse/schema.py +++ b/dgf/src/analyse/schema.py @@ -14,6 +14,7 @@ """Analaysis / data extraction from a schema.""" +from typing import Tuple, Union from dgf.src.data import schema as schema_lib from dgf.src.util import log @@ -142,6 +143,7 @@ def primary_feature_or_none( def fix_schema( schema: schema_lib.GraphSchema, create_pound_id_as_fall_back: bool = False, + fix_shape: bool = True, ): """Tries to fix broken/invalid schemas by inferring and setting primary keys. @@ -151,9 +153,41 @@ def fix_schema( one called "#id". This should only be used to consume old GraphAI datasets. Note that create_pound_id_as_fall_back=True does not check that the feaure "#id" is actually present in the data. + fix_shapes: If true, fixes the extra None dimension added to all the shapes. + This is a common issue with TF-GNN schemaes. """ + def shape_is_suspicious(shape: schema_lib.Shape): + return ( + shape is not None + and len(shape) == 2 + and shape[0] is None + and shape[1] is not None + ) + + def fix_suspicious_shape( + feature_name, + shape: schema_lib.Shape, + ) -> schema_lib.Shape: + log.info("Fix suspicious shape of feature '%s'", feature_name) + assert shape_is_suspicious(shape) + if shape[1] == 1: + return tuple() + else: + return shape[1:] + for nodeset_name, nodeset_def in schema.node_sets.items(): + all_shape_are_suspicious = True + for _, feature_schema in nodeset_def.features.items(): + all_shape_are_suspicious = ( + all_shape_are_suspicious and shape_is_suspicious(feature_schema.shape) + ) + if fix_shape and all_shape_are_suspicious: + for feature_name, feature_schema in nodeset_def.features.items(): + feature_schema.shape = fix_suspicious_shape( + feature_name, feature_schema.shape + ) + if primary_feature_or_none(nodeset_name, nodeset_def) is not None: continue # This nodeset has not primary key. @@ -178,6 +212,17 @@ def fix_schema( ) for edgeset_name, edgeset_def in schema.edge_sets.items(): + all_shape_are_suspicious = True + for _, feature_schema in edgeset_def.features.items(): + all_shape_are_suspicious = ( + all_shape_are_suspicious and shape_is_suspicious(feature_schema.shape) + ) + if fix_shape and all_shape_are_suspicious: + for feature_name, feature_schema in edgeset_def.features.items(): + feature_schema.shape = fix_suspicious_shape( + feature_name, feature_schema.shape + ) + if primary_feature_or_none(edgeset_name, edgeset_def) is not None: continue # This nodeset has not primary key. @@ -197,3 +242,84 @@ def fix_schema( edgeset_name, primary_key, ) + + +def infer_schema_semantic( + schema: schema_lib.GraphSchema, + raise_on_error: bool = True, +) -> schema_lib.GraphSchema: + """Automatically detects the semantic of features with UNKNOWN semantic. + + Usage example: + + ```python + schema = dgf.analyse.infer_schema_semantic(schema) + ``` + + The logic to infer the semantic is as follows: + - bytes and booleans are considered categorical. + - numerical values (float, integers) are considered numerical. + - integer values where the name starts with "is_" are considered categorical. + - features where the name starts with # are ignored. + + Args: + schema: The GraphSchema to infer semantics for. + raise_on_error: If True, raises a ValueError if a feature's semantic cannot + be inferred. If False, logs a warning instead. + + Returns: + The modified GraphSchema. + """ + for nodeset_name, nodeset_def in schema.node_sets.items(): + _infer_features_semantic( + nodeset_def.features, + container_name=f"nodeset '{nodeset_name}'", + raise_on_error=raise_on_error, + ) + + for edgeset_name, edgeset_def in schema.edge_sets.items(): + _infer_features_semantic( + edgeset_def.features, + container_name=f"edgeset '{edgeset_name}'", + raise_on_error=raise_on_error, + ) + + return schema + + +def _infer_features_semantic( + features: schema_lib.FeatureSetSchema, + container_name: str, + raise_on_error: bool = True, +): + for feature_name, feature_schema in features.items(): + if feature_name.startswith("#"): + continue + if feature_schema.semantic != schema_lib.FeatureSemantic.UNKNOWN: + continue + + fmt = feature_schema.format + inferred = False + if fmt in (schema_lib.FeatureFormat.BYTES, schema_lib.FeatureFormat.BOOL): + feature_schema.semantic = schema_lib.FeatureSemantic.CATEGORICAL + inferred = True + elif fmt.is_numerical(): + if fmt.is_integer() and feature_name.startswith("is_"): + feature_schema.semantic = schema_lib.FeatureSemantic.CATEGORICAL + else: + feature_schema.semantic = schema_lib.FeatureSemantic.NUMERICAL + inferred = True + + if not inferred: + msg = ( + f"Could not infer semantic for feature {feature_name!r} in" + f" {container_name} with format {fmt!r}. Please specify the semantic" + " manually in the schema." + ) + if raise_on_error: + raise ValueError( + f"{msg} To disable this error and print a warning instead, set" + " `raise_on_error=False` in `infer_schema_semantic`." + ) + else: + log.warning(msg) diff --git a/dgf/src/analyse/schema_test.py b/dgf/src/analyse/schema_test.py index 8a02d4b..3b47997 100644 --- a/dgf/src/analyse/schema_test.py +++ b/dgf/src/analyse/schema_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + from absl.testing import absltest from dgf.src.analyse import schema as schema_lib from dgf.src.data.schema import EdgeSchema @@ -253,6 +255,193 @@ def test_fix_schema_nodeset_no_candidate(self): ): schema_lib.fix_schema(schema) + def test_fix_schema_create_pound_id_fallback(self): + schema = GraphSchema( + node_sets={ + "nodes": NodeSchema( + features={ + "feat": FeatureSchema( + format=FeatureFormat.FLOAT_32, + semantic=FeatureSemantic.UNKNOWN, + ) + } + ) + }, + edge_sets={}, + ) + schema_lib.fix_schema(schema, create_pound_id_as_fall_back=True) + self.assertIn("#id", schema.node_sets["nodes"].features) + self.assertEqual( + schema.node_sets["nodes"].features["#id"].semantic, + FeatureSemantic.PRIMARY_ID, + ) + self.assertEqual( + schema.node_sets["nodes"].features["#id"].format, + FeatureFormat.BYTES, + ) + + def test_fix_schema_fix_suspicious_shape(self): + schema = GraphSchema( + node_sets={ + "nodes": NodeSchema( + features={ + "id": FeatureSchema( + format=FeatureFormat.BYTES, + semantic=FeatureSemantic.PRIMARY_ID, + shape=(None, 1), + ), + "feat": FeatureSchema( + format=FeatureFormat.FLOAT_32, + semantic=FeatureSemantic.NUMERICAL, + shape=(None, 10), + ), + } + ) + }, + edge_sets={}, + ) + schema_lib.fix_schema(schema, fix_shape=True) + self.assertEqual(schema.node_sets["nodes"].features["id"].shape, ()) + self.assertEqual(schema.node_sets["nodes"].features["feat"].shape, (10,)) + + def test_infer_schema_semantic(self): + + schema = GraphSchema( + node_sets={ + "nodes": NodeSchema( + features={ + # Bytes -> Categorical + "f_bytes": FeatureSchema( + format=FeatureFormat.BYTES, + semantic=FeatureSemantic.UNKNOWN, + ), + # Bool -> Categorical + "f_bool": FeatureSchema( + format=FeatureFormat.BOOL, + semantic=FeatureSemantic.UNKNOWN, + ), + # Float -> Numerical + "f_float": FeatureSchema( + format=FeatureFormat.FLOAT_32, + semantic=FeatureSemantic.UNKNOWN, + ), + # Int -> Numerical + "f_int": FeatureSchema( + format=FeatureFormat.INTEGER_64, + semantic=FeatureSemantic.UNKNOWN, + ), + # Int with "is_" -> Categorical + "is_active": FeatureSchema( + format=FeatureFormat.INTEGER_32, + semantic=FeatureSemantic.UNKNOWN, + ), + # Starts with "#" -> Ignored + "#id": FeatureSchema( + format=FeatureFormat.BYTES, + semantic=FeatureSemantic.UNKNOWN, + ), + # Already set -> Ignored + "already_set": FeatureSchema( + format=FeatureFormat.FLOAT_32, + semantic=FeatureSemantic.EMBEDDING, + ), + } + ) + }, + edge_sets={ + "edges": EdgeSchema( + source="nodes", + target="nodes", + features={ + "f_bytes_edge": FeatureSchema( + format=FeatureFormat.BYTES, + semantic=FeatureSemantic.UNKNOWN, + ), + }, + ) + }, + ) + + inferred_schema = schema_lib.infer_schema_semantic(schema) + + node_features = inferred_schema.node_sets["nodes"].features + self.assertEqual( + node_features["f_bytes"].semantic, FeatureSemantic.CATEGORICAL + ) + self.assertEqual( + node_features["f_bool"].semantic, FeatureSemantic.CATEGORICAL + ) + self.assertEqual( + node_features["f_float"].semantic, FeatureSemantic.NUMERICAL + ) + self.assertEqual(node_features["f_int"].semantic, FeatureSemantic.NUMERICAL) + self.assertEqual( + node_features["is_active"].semantic, FeatureSemantic.CATEGORICAL + ) + self.assertEqual(node_features["#id"].semantic, FeatureSemantic.UNKNOWN) + self.assertEqual( + node_features["already_set"].semantic, FeatureSemantic.EMBEDDING + ) + + edge_features = inferred_schema.edge_sets["edges"].features + self.assertEqual( + edge_features["f_bytes_edge"].semantic, FeatureSemantic.CATEGORICAL + ) + + def test_infer_schema_semantic_cannot_infer_raise(self): + class FakeFormat: + + def is_numerical(self): + return False + + def is_integer(self): + return False + + schema = GraphSchema( + node_sets={ + "nodes": NodeSchema( + features={ + "f_weird": FeatureSchema( + format=FakeFormat(), + semantic=FeatureSemantic.UNKNOWN, + ), + } + ) + }, + edge_sets={}, + ) + with self.assertRaisesRegex( + ValueError, "Could not infer semantic for feature 'f_weird'" + ): + schema_lib.infer_schema_semantic(schema, raise_on_error=True) + + @mock.patch("dgf.src.analyse.schema.log.warning") + def test_infer_schema_semantic_cannot_infer_warn(self, mock_warning): + class FakeFormat: + + def is_numerical(self): + return False + + def is_integer(self): + return False + + schema = GraphSchema( + node_sets={ + "nodes": NodeSchema( + features={ + "f_weird": FeatureSchema( + format=FakeFormat(), + semantic=FeatureSemantic.UNKNOWN, + ), + } + ) + }, + edge_sets={}, + ) + schema_lib.infer_schema_semantic(schema, raise_on_error=False) + mock_warning.assert_called_once() + self.assertIn("Could not infer semantic", mock_warning.call_args[0][0]) + if __name__ == "__main__": absltest.main() diff --git a/dgf/src/api/BUILD b/dgf/src/api/BUILD index be412f3..c87d8fc 100644 --- a/dgf/src/api/BUILD +++ b/dgf/src/api/BUILD @@ -110,6 +110,7 @@ py_library( "//dgf/src/analyse:in_process_feature_statistics", "//dgf/src/analyse:padding", "//dgf/src/analyse:print_schema", + "//dgf/src/analyse:schema", "//dgf/src/analyse/reports:data_model", "//dgf/src/analyse/reports:reporter", "//dgf/src/analyse/topology:global_graph_topology", diff --git a/dgf/src/api/analyse.py b/dgf/src/api/analyse.py index 056a2b5..6202eda 100644 --- a/dgf/src/api/analyse.py +++ b/dgf/src/api/analyse.py @@ -23,6 +23,8 @@ # TODO: Use third_party/py/dgf/src/api/print.py version instead. from dgf.src.analyse.print_schema import print_schema +from dgf.src.analyse.schema import infer_schema_semantic + from dgf.src.analyse.reports import data_model as reports_data_model from dgf.src.analyse.reports import reporter diff --git a/dgf/src/data/schema.py b/dgf/src/data/schema.py index b36f647..8c60b89 100644 --- a/dgf/src/data/schema.py +++ b/dgf/src/data/schema.py @@ -19,6 +19,8 @@ from typing import Callable, Dict, Optional, Tuple import dataclasses_json +Shape = Optional[Tuple[Optional[int], ...]] + class FeatureFormat(enum.Enum): """How a value is represented / stored.""" @@ -85,15 +87,14 @@ class FeatureSchema: num_categorical_values: The number of possible categories for CATEGORICAL features. `None` for other semantic types or if the numeber of possible categories is unknown. - is_utf8_string: Whether the feature is a UTF-8 string. This is only - relevant when feature_format is BYTES, to distinguish between Spanner - STRING (True) and Spanner BYTES (False). - + is_utf8_string: Whether the feature is a UTF-8 string. This is only relevant + when feature_format is BYTES, to distinguish between Spanner STRING (True) + and Spanner BYTES (False). """ format: FeatureFormat semantic: FeatureSemantic = FeatureSemantic.UNKNOWN - shape: Optional[Tuple[Optional[int], ...]] = None + shape: Shape = None num_categorical_values: Optional[int] = None is_utf8_string: Optional[bool] = False diff --git a/dgf/src/io/hgraph_in_memory.py b/dgf/src/io/hgraph_in_memory.py index 078b2db..d27c825 100644 --- a/dgf/src/io/hgraph_in_memory.py +++ b/dgf/src/io/hgraph_in_memory.py @@ -63,11 +63,13 @@ def get_extension(container_type: HGraphContainerType) -> str: def tfgnn_schema_to_schema( tfgnn_schema: "tf_gnn_proto.GraphSchema", + fix_shape: bool = True, ) -> schema_lib.GraphSchema: """Converts a TF-GNN schema proto into a GraphSchema object. Args: tfgnn_schema: A TF-GNN schema proto. + fix_shapes: If true, fixes the extra None dimension added to all the shapes. Returns: A GraphSchema object. @@ -102,7 +104,9 @@ def convert_feature(gnn_feature) -> schema_lib.FeatureSchema: ) schema = schema_lib.GraphSchema(node_sets=node_sets, edge_sets=edge_sets) - analyse_schema_lib.fix_schema(schema, create_pound_id_as_fall_back=True) + analyse_schema_lib.fix_schema( + schema, create_pound_id_as_fall_back=True, fix_shape=fix_shape + ) return schema