Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions dgf/src/io/graph_in_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def write_graph(
graph: distributed_graph_lib.Graph,
path: str,
beam_namespace: str = "",
num_node_shards: int = 0,
num_edge_shards: int = 0,
compression: str = "snappy",
) -> beam.pvalue.PDone:
"""Writes a GF Graph from a distributed graph (beam).

Expand Down Expand Up @@ -305,6 +308,8 @@ def write_graph(
file_path_prefix=file_path_prefix,
file_name_suffix=PARQUET_EXTENSION,
schema=_node_schema_to_parquet_schema(nodeset_schema),
codec=compression,
num_shards=num_node_shards,
)
)
write_results.append(write_result)
Expand All @@ -324,6 +329,8 @@ def write_graph(
file_path_prefix=file_path_prefix,
file_name_suffix=PARQUET_EXTENSION,
schema=_edge_schema_to_parquet_schema(edgeset_schema, graph.schema),
codec=compression,
num_shards=num_edge_shards,
)
)
write_results.append(write_result)
Expand Down Expand Up @@ -398,10 +405,14 @@ def _node_to_raw(
node: distributed_graph_lib.Node, schema: schema_lib.NodeSchema
) -> Dict[str, Any]:
"""Converts a Node to a raw dictionary for Parquet writing."""
primary_key = schema_analyse_lib.primary_feature_or_none("", schema)
raw_dict = {}
for feature_name in schema.features:
feature_values = node.features[feature_name]
raw_dict[feature_name] = feature_values.tolist()
if feature_name == primary_key:
raw_dict[feature_name] = node.id
else:
feature_values = node.features[feature_name]
raw_dict[feature_name] = feature_values.tolist()
return raw_dict


Expand Down
3 changes: 3 additions & 0 deletions dgf/src/io/graph_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def write_graph(
path: str,
verbose: bool = False,
max_num_shards: Optional[int] = None,
compression: str = "snappy",
):
"""Writes an in-memory graph and schema to a GF Graph directory.

Expand Down Expand Up @@ -397,6 +398,7 @@ def write_graph(
nodeset_schema.features,
num_shards,
verbose,
compression=compression,
)

# Write Edge Sets
Expand Down Expand Up @@ -447,6 +449,7 @@ def write_graph(
features_schema,
num_shards,
verbose,
compression=compression,
)

end_time = time.monotonic()
Expand Down
14 changes: 12 additions & 2 deletions dgf/src/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def _write_single_shard(
start_row: int,
end_row: int,
write_specs: Dict[str, WriteParquetSpec],
compression: str = "snappy",
):

def to_pa_array(key, np_values, spec):
Expand Down Expand Up @@ -249,7 +250,7 @@ def to_pa_array(key, np_values, spec):
}
table = pa.Table.from_pydict(shard_data)
with filesystem.open_write(shard_path, binary=True) as f:
pq.write_table(table, f)
pq.write_table(table, f, compression=compression)


def write_numpy_dict_to_parquet(
Expand All @@ -259,6 +260,7 @@ def write_numpy_dict_to_parquet(
schema: schema_lib.FeatureSetSchema,
num_shards: int = 1,
verbose: bool = False,
compression: str = "snappy",
):
"""Writes a dictionary of numpy arrays to sharded Parquet files.

Expand All @@ -269,6 +271,7 @@ def write_numpy_dict_to_parquet(
schema: The schema defining the features to write.
num_shards: The number of shards to write.
verbose: If True, print progress information.
compression: The parquet compression codec to use.
"""
if not data:
raise ValueError("Input data dictionary is empty.")
Expand Down Expand Up @@ -307,7 +310,14 @@ def _write_shard_task(shard_index: int):
if verbose:
print(f"Writing shard {shard_index + 1}/{num_shards} to {shard_path}")

_write_single_shard(data, shard_path, start_row, end_row, write_specs)
_write_single_shard(
data,
shard_path,
start_row,
end_row,
write_specs,
compression=compression,
)

with futures.ThreadPoolExecutor(
max_workers=max(1, min(num_shards, 20))
Expand Down
28 changes: 16 additions & 12 deletions dgf/src/sampling/beam_semi_distributed_sampler_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
import os
import threading
import time
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union
import apache_beam as beam
Expand Down Expand Up @@ -156,10 +157,12 @@ def sample_with_beam_semi_distributed_sampler_v2(
return samples, schema


class RawSamplerV2Cache:
class SharedSampler:
"""A wrapper to allow weak references to the shared sampler and mapper."""

def __init__(self):
self.by_path = {}
def __init__(self, sampler, mapper):
self.sampler = sampler
self.mapper = mapper


class RawSamplerV2(beam.DoFn):
Expand All @@ -186,14 +189,11 @@ def __init__(

def setup(self):
def initializer():
return RawSamplerV2Cache()

self.cache = RawSamplerV2.shared_in_memory_samplers.acquire(initializer)

if self.graph_path not in self.cache.by_path:
# TODO(gbm): Don't load and return the feature values.
start_time = time.time()
logging.info("Load graph in memory")
logging.info(
"Thread %s: Start loading shared graph",
threading.current_thread().name,
)

read_schema = self.schema

Expand Down Expand Up @@ -231,9 +231,13 @@ def initializer():
].features[KEY_ID]
mapper = {id.item(): idx for idx, id in enumerate(seed_node_ids)}

self.cache.by_path[self.graph_path] = (sampler, mapper)
return SharedSampler(sampler, mapper)

self.sampler, self.mapper = self.cache.by_path[self.graph_path]
shared_sampler = RawSamplerV2.shared_in_memory_samplers.acquire(
initializer, tag=self.graph_path
)
self.sampler = shared_sampler.sampler
self.mapper = shared_sampler.mapper

def process(
self, seeds: Sequence[distributed_graph.NodeId]
Expand Down
37 changes: 37 additions & 0 deletions dgf/src/util/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

import concurrent.futures
import time
from typing import List, Optional, Sequence
from absl import logging
from etils import epath
Expand Down Expand Up @@ -58,6 +59,42 @@ def open_read(path: str, binary: bool = False):
Returns:
A file-like object for reading.
"""
if is_gcs_path(path):
# TODO(gbm): Check other possible fixes
# Direct GCS blob open for GCS reads (epath unreachable from GCP Dataflow runner).
# Note: blob.open() returns a streaming BlobReader which downloads in chunks,
# so it does NOT load the entire file into memory at once (safe for large files).
# We also add exponential-backoff retries to handle transient GCS network timeouts.
client = storage.Client()

# Parse "gs://bucket/path/to/blob" into bucket name and blob path.
gcs_path = path.replace("gs://", "")
bucket_name = gcs_path.split("/")[0]
blob_path = "/".join(gcs_path.split("/")[1:])
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_path)

# Open the GCS blob as a streaming file-like object.
max_retries = 5
for attempt in range(max_retries):
try:
return blob.open("rb" if binary else "r")
except Exception as e: # pylint: disable=broad-except
if attempt < max_retries - 1:
wait = 2**attempt # Exponential backoff: 1s, 2s, 4s, 8s.
logging.warning(
"GCS open attempt %d/%d failed for %s: %s. Retrying in %ds.",
attempt + 1,
max_retries,
path,
e,
wait,
)
time.sleep(wait)
else:
raise # All retries exhausted; propagate the exception.

# For non-GCS paths (local, CNS, etc.), delegate to etils.epath.
return epath.Path(path).open("rb" if binary else "r")


Expand Down