From 7dae23a03c3089fb5a42191ce174d78515422ed5 Mon Sep 17 00:00:00 2001 From: Teresa Huang Date: Fri, 29 May 2026 12:53:11 -0700 Subject: [PATCH] Optimize semi-distributed sampling: move graph loading into shared initializer, add GCS retry logic, and benchmark on Dataflow. PiperOrigin-RevId: 923557771 --- dgf/src/io/graph_in_beam.py | 15 +++++++- dgf/src/io/graph_in_memory.py | 3 ++ dgf/src/io/parquet.py | 14 ++++++- .../beam_semi_distributed_sampler_v2.py | 28 ++++++++------ dgf/src/util/filesystem.py | 37 +++++++++++++++++++ 5 files changed, 81 insertions(+), 16 deletions(-) diff --git a/dgf/src/io/graph_in_beam.py b/dgf/src/io/graph_in_beam.py index 13500a7..57f9be4 100644 --- a/dgf/src/io/graph_in_beam.py +++ b/dgf/src/io/graph_in_beam.py @@ -252,6 +252,9 @@ def write_graph( graph: distributed_graph_lib.Graph, path: str, beam_namespace: str = "", + num_node_shards: int = 0, + num_edge_shards: int = 0, + compression: str = "snappy", ) -> beam.pvalue.PDone: """Writes a GF Graph from a distributed graph (beam). @@ -305,6 +308,8 @@ def write_graph( file_path_prefix=file_path_prefix, file_name_suffix=PARQUET_EXTENSION, schema=_node_schema_to_parquet_schema(nodeset_schema), + codec=compression, + num_shards=num_node_shards, ) ) write_results.append(write_result) @@ -324,6 +329,8 @@ def write_graph( file_path_prefix=file_path_prefix, file_name_suffix=PARQUET_EXTENSION, schema=_edge_schema_to_parquet_schema(edgeset_schema, graph.schema), + codec=compression, + num_shards=num_edge_shards, ) ) write_results.append(write_result) @@ -398,10 +405,14 @@ def _node_to_raw( node: distributed_graph_lib.Node, schema: schema_lib.NodeSchema ) -> Dict[str, Any]: """Converts a Node to a raw dictionary for Parquet writing.""" + primary_key = schema_analyse_lib.primary_feature_or_none("", schema) raw_dict = {} for feature_name in schema.features: - feature_values = node.features[feature_name] - raw_dict[feature_name] = feature_values.tolist() + if feature_name == primary_key: + raw_dict[feature_name] = node.id + else: + feature_values = node.features[feature_name] + raw_dict[feature_name] = feature_values.tolist() return raw_dict diff --git a/dgf/src/io/graph_in_memory.py b/dgf/src/io/graph_in_memory.py index 2ef6b62..edaab1c 100644 --- a/dgf/src/io/graph_in_memory.py +++ b/dgf/src/io/graph_in_memory.py @@ -336,6 +336,7 @@ def write_graph( path: str, verbose: bool = False, max_num_shards: Optional[int] = None, + compression: str = "snappy", ): """Writes an in-memory graph and schema to a GF Graph directory. @@ -397,6 +398,7 @@ def write_graph( nodeset_schema.features, num_shards, verbose, + compression=compression, ) # Write Edge Sets @@ -447,6 +449,7 @@ def write_graph( features_schema, num_shards, verbose, + compression=compression, ) end_time = time.monotonic() diff --git a/dgf/src/io/parquet.py b/dgf/src/io/parquet.py index 5dd0ae7..27eab45 100644 --- a/dgf/src/io/parquet.py +++ b/dgf/src/io/parquet.py @@ -222,6 +222,7 @@ def _write_single_shard( start_row: int, end_row: int, write_specs: Dict[str, WriteParquetSpec], + compression: str = "snappy", ): def to_pa_array(key, np_values, spec): @@ -249,7 +250,7 @@ def to_pa_array(key, np_values, spec): } table = pa.Table.from_pydict(shard_data) with filesystem.open_write(shard_path, binary=True) as f: - pq.write_table(table, f) + pq.write_table(table, f, compression=compression) def write_numpy_dict_to_parquet( @@ -259,6 +260,7 @@ def write_numpy_dict_to_parquet( schema: schema_lib.FeatureSetSchema, num_shards: int = 1, verbose: bool = False, + compression: str = "snappy", ): """Writes a dictionary of numpy arrays to sharded Parquet files. @@ -269,6 +271,7 @@ def write_numpy_dict_to_parquet( schema: The schema defining the features to write. num_shards: The number of shards to write. verbose: If True, print progress information. + compression: The parquet compression codec to use. """ if not data: raise ValueError("Input data dictionary is empty.") @@ -307,7 +310,14 @@ def _write_shard_task(shard_index: int): if verbose: print(f"Writing shard {shard_index + 1}/{num_shards} to {shard_path}") - _write_single_shard(data, shard_path, start_row, end_row, write_specs) + _write_single_shard( + data, + shard_path, + start_row, + end_row, + write_specs, + compression=compression, + ) with futures.ThreadPoolExecutor( max_workers=max(1, min(num_shards, 20)) diff --git a/dgf/src/sampling/beam_semi_distributed_sampler_v2.py b/dgf/src/sampling/beam_semi_distributed_sampler_v2.py index 37bf977..6f6733b 100644 --- a/dgf/src/sampling/beam_semi_distributed_sampler_v2.py +++ b/dgf/src/sampling/beam_semi_distributed_sampler_v2.py @@ -16,6 +16,7 @@ import logging import os +import threading import time from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union import apache_beam as beam @@ -156,10 +157,12 @@ def sample_with_beam_semi_distributed_sampler_v2( return samples, schema -class RawSamplerV2Cache: +class SharedSampler: + """A wrapper to allow weak references to the shared sampler and mapper.""" - def __init__(self): - self.by_path = {} + def __init__(self, sampler, mapper): + self.sampler = sampler + self.mapper = mapper class RawSamplerV2(beam.DoFn): @@ -186,14 +189,11 @@ def __init__( def setup(self): def initializer(): - return RawSamplerV2Cache() - - self.cache = RawSamplerV2.shared_in_memory_samplers.acquire(initializer) - - if self.graph_path not in self.cache.by_path: - # TODO(gbm): Don't load and return the feature values. start_time = time.time() - logging.info("Load graph in memory") + logging.info( + "Thread %s: Start loading shared graph", + threading.current_thread().name, + ) read_schema = self.schema @@ -231,9 +231,13 @@ def initializer(): ].features[KEY_ID] mapper = {id.item(): idx for idx, id in enumerate(seed_node_ids)} - self.cache.by_path[self.graph_path] = (sampler, mapper) + return SharedSampler(sampler, mapper) - self.sampler, self.mapper = self.cache.by_path[self.graph_path] + shared_sampler = RawSamplerV2.shared_in_memory_samplers.acquire( + initializer, tag=self.graph_path + ) + self.sampler = shared_sampler.sampler + self.mapper = shared_sampler.mapper def process( self, seeds: Sequence[distributed_graph.NodeId] diff --git a/dgf/src/util/filesystem.py b/dgf/src/util/filesystem.py index 54bb391..aa66e21 100644 --- a/dgf/src/util/filesystem.py +++ b/dgf/src/util/filesystem.py @@ -16,6 +16,7 @@ """ import concurrent.futures +import time from typing import List, Optional, Sequence from absl import logging from etils import epath @@ -58,6 +59,42 @@ def open_read(path: str, binary: bool = False): Returns: A file-like object for reading. """ + if is_gcs_path(path): + # TODO(gbm): Check other possible fixes + # Direct GCS blob open for GCS reads (epath unreachable from GCP Dataflow runner). + # Note: blob.open() returns a streaming BlobReader which downloads in chunks, + # so it does NOT load the entire file into memory at once (safe for large files). + # We also add exponential-backoff retries to handle transient GCS network timeouts. + client = storage.Client() + + # Parse "gs://bucket/path/to/blob" into bucket name and blob path. + gcs_path = path.replace("gs://", "") + bucket_name = gcs_path.split("/")[0] + blob_path = "/".join(gcs_path.split("/")[1:]) + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_path) + + # Open the GCS blob as a streaming file-like object. + max_retries = 5 + for attempt in range(max_retries): + try: + return blob.open("rb" if binary else "r") + except Exception as e: # pylint: disable=broad-except + if attempt < max_retries - 1: + wait = 2**attempt # Exponential backoff: 1s, 2s, 4s, 8s. + logging.warning( + "GCS open attempt %d/%d failed for %s: %s. Retrying in %ds.", + attempt + 1, + max_retries, + path, + e, + wait, + ) + time.sleep(wait) + else: + raise # All retries exhausted; propagate the exception. + + # For non-GCS paths (local, CNS, etc.), delegate to etils.epath. return epath.Path(path).open("rb" if binary else "r")