-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathaccessors.py
More file actions
168 lines (136 loc) · 4.92 KB
/
accessors.py
File metadata and controls
168 lines (136 loc) · 4.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import cupy as cp
from xarray import (
DataArray,
Dataset,
register_dataarray_accessor,
register_dataset_accessor,
)
from xarray.core.pycompat import DuckArrayModule
dask_array_type = DuckArrayModule("dask").type
pint_array_type = DuckArrayModule("pint").type
@register_dataarray_accessor("cupy")
class CupyDataArrayAccessor:
"""
Access methods for DataArrays using Cupy.
Methods and attributes can be accessed through the `.cupy` attribute.
"""
def __init__(self, da):
self.da = da
@property
def is_cupy(self):
"""bool: The underlying data is a cupy array."""
if isinstance(self.da.data, dask_array_type):
return isinstance(self.da.data._meta, cp.ndarray)
if isinstance(self.da.data, pint_array_type):
return isinstance(self.da.data.magnitude, cp.ndarray)
return isinstance(self.da.data, cp.ndarray)
def as_cupy(self):
"""
Converts the DataArray's underlying array type to cupy.
For DataArrays which are initially backed by numpy the data
will be immediately cast to cupy and moved to the GPU. In the case
that the data was originally a Dask array each chunk will be moved
to the GPU when the task graph is computed.
Returns
-------
cupy_da: DataArray
DataArray with underlying data cast to cupy.
Examples
--------
>>> import xarray as xr
>>> da = xr.tutorial.load_dataset("air_temperature").air
>>> gda = da.cupy.as_cupy()
>>> type(gda.data)
<class 'cupy.core.core.ndarray'>
"""
if isinstance(self.da.data, dask_array_type):
return DataArray(
data=self.da.data.map_blocks(cp.asarray),
coords=self.da.coords,
dims=self.da.dims,
name=self.da.name,
attrs=self.da.attrs,
)
if isinstance(self.da.data, pint_array_type):
return self.da.pint.dequantify().cupy.as_cupy().pint.quantify()
return DataArray(
data=cp.asarray(self.da.data),
coords=self.da.coords,
dims=self.da.dims,
name=self.da.name,
attrs=self.da.attrs,
)
def as_numpy(self):
"""
Converts the DataArray's underlying array type from cupy to numpy.
Returns
-------
da: DataArray
DataArray with underlying data cast to numpy.
"""
if self.is_cupy:
if isinstance(self.da.data, dask_array_type):
return DataArray(
data=self.da.data.map_blocks(
lambda block: block.get(), dtype=self.da.data._meta.dtype
),
coords=self.da.coords,
dims=self.da.dims,
name=self.da.name,
attrs=self.da.attrs,
)
return DataArray(
data=self.da.data.get(),
coords=self.da.coords,
dims=self.da.dims,
name=self.da.name,
attrs=self.da.attrs,
)
return self.da.as_numpy()
def get(self):
return self.da.data.get()
@register_dataset_accessor("cupy")
class CupyDatasetAccessor:
"""
Access methods for DataArrays using Cupy.
Methods and attributes can be accessed through the `.cupy` attribute.
"""
def __init__(self, ds):
self.ds = ds
@property
def is_cupy(self):
return all([da.cupy.is_cupy for da in self.ds.data_vars.values()])
def as_cupy(self):
data_vars = {var: da.as_cupy() for var, da in self.ds.data_vars.items()}
return Dataset(data_vars=data_vars, coords=self.ds.coords, attrs=self.ds.attrs)
def as_numpy(self):
if self.is_cupy:
data_vars = {var: da.cupy.as_numpy() for var, da in self.ds.data_vars.items()}
return Dataset(
data_vars=data_vars,
coords=self.ds.coords,
attrs=self.ds.attrs,
)
else:
return self.ds.as_numpy()
# Attach the `as_cupy` methods to the top level `Dataset` and `Dataarray` objects.
# Would be good to replace this with a less hacky API upstream at some stage where
# libraries like this could register new ``as_`` methods for dispatch.
@register_dataarray_accessor("as_cupy")
def _(da):
"""
Converts the DataArray's underlying array type to cupy.
See :meth:`cupy_xarray.CupyDataArrayAccessor.as_cupy`.
"""
def as_cupy(*args, **kwargs):
return da.cupy.as_cupy(*args, **kwargs)
return as_cupy
@register_dataset_accessor("as_cupy")
def _(ds):
"""
Converts the Dataset's underlying Dataarray's array type to cupy.
See :meth:`cupy_xarray.CupyDatasetAccessor.as_cupy`.
"""
def as_cupy(*args, **kwargs):
return ds.cupy.as_cupy(*args, **kwargs)
return as_cupy