Skip to content

Commit 1393d57

Browse files
committed
Allow passing device_id=None to determine CUDA device id automatically
Make `device_id` optional (default to None), so that it can be inferred automatically by `cupy.cuda.runtime.getDevice`. Also added a unit test using xr.open_mfdataset.
1 parent 2012579 commit 1393d57

2 files changed

Lines changed: 33 additions & 4 deletions

File tree

cupy_xarray/cog3pio.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Cog3pioBackendEntrypoint(BackendEntrypoint):
1717
"""
1818
Xarray backend to read GeoTIFF files using 'cog3pio' engine.
1919
20-
When using :py:func:`xarray.open_dataarray` with ``engine="cog3pio"``, the
20+
When using :py:func:`xarray.open_dataarray` with ``engine="cog3pio"``, the optional
2121
``device_id`` parameter can be set to the CUDA GPU id to do the decoding on.
2222
2323
Examples
@@ -47,7 +47,7 @@ def open_dataset( # type: ignore[override]
4747
filename_or_obj: str,
4848
*,
4949
drop_variables: str | Iterable[str] | None = None,
50-
device_id: int,
50+
device_id: int | None = None,
5151
# other backend specific keyword arguments
5252
# `chunks` and `cache` DO NOT go here, they are handled by xarray
5353
mask_and_scale=None,
@@ -60,14 +60,18 @@ def open_dataset( # type: ignore[override]
6060
filename_or_obj : str
6161
File path or url to a TIFF (.tif) image file that can be read by the
6262
nvTIFF or image-tiff backend library.
63-
device_id : int
64-
CUDA device ID on which to place the created cupy array.
63+
device_id : int | None
64+
CUDA device ID on which to place the created cupy array. Default is None,
65+
which means device_id will be inferred via
66+
:py:func:`cupy.cuda.runtime.getDevice`.
6567
6668
Returns
6769
-------
6870
xarray.Dataset
6971
7072
"""
73+
if device_id is None:
74+
device_id: int = cp.cuda.runtime.getDevice()
7175

7276
with cp.cuda.Stream(ptds=True):
7377
cog = CudaCogReader(path=filename_or_obj, device_id=device_id)

cupy_xarray/tests/test_cog3pio.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
cog3pio = pytest.importorskip("cog3pio")
1212

1313

14+
# %%
1415
def test_entrypoint():
1516
assert "cog3pio" in xr.backends.list_engines()
1617

@@ -32,3 +33,27 @@ def test_xarray_backend_open_dataarray():
3233
assert da.y.min() == 3490250.0
3334
assert da.y.max() == 3599950.0
3435
assert da.dtype == "uint8"
36+
37+
38+
def test_xarray_backend_open_mfdataset():
39+
"""
40+
Ensure that passing engine='cog3pio' to xarray.open_mfdataset works to read multiple
41+
Cloud-optimized GeoTIFF files from http urls. Also testing that `device_id=None`
42+
works.
43+
"""
44+
ds: xr.Dataset = xr.open_mfdataset(
45+
paths=[
46+
"https://github.com/developmentseed/titiler/raw/1.2.0/src/titiler/mosaic/tests/fixtures/B01.tif",
47+
"https://github.com/developmentseed/titiler/raw/1.2.0/src/titiler/mosaic/tests/fixtures/B09.tif",
48+
],
49+
engine=Cog3pioBackendEntrypoint,
50+
concat_dim="band",
51+
combine="nested",
52+
device_id=None,
53+
)
54+
assert ds.sizes == {"band": 2, "y": 183, "x": 183}
55+
assert ds.x.min() == 700260.0
56+
assert ds.x.max() == 809460.0
57+
assert ds.y.min() == 3490500.0
58+
assert ds.y.max() == 3599700.0
59+
assert ds.raster.dtype == "uint16"

0 commit comments

Comments
 (0)