-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathtest_accessors.py
More file actions
63 lines (49 loc) · 1.75 KB
/
test_accessors.py
File metadata and controls
63 lines (49 loc) · 1.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Tests for cupy-xarray accessors"""
import cupy as cp
import numpy as np
import pytest
import xarray as xr
from xarray.tests import requires_dask, requires_pint
import cupy_xarray # noqa: F401 pylint:disable=unused-import
da = xr.DataArray(np.random.rand(2, 3), attrs={"units": "candle"})
ds = xr.Dataset({"a": da})
@pytest.mark.parametrize("obj", [da, ds])
def test_numpy(obj):
"""Test is_cupy property in cupy xarray accessor"""
assert not da.cupy.is_cupy
cpda = da.cupy.as_cupy()
assert cpda.cupy.is_cupy
as_numpy = cpda.as_numpy()
assert not as_numpy.cupy.is_cupy
if isinstance(as_numpy, xr.DataArray):
assert isinstance(as_numpy.data, np.ndarray)
@requires_dask
@pytest.mark.parametrize("obj", [da, ds])
def test_dask(obj):
"""Test is_cupy property in cupy xarray accessor"""
as_dask = obj.chunk()
assert not as_dask.cupy.is_cupy
cpda = as_dask.cupy.as_cupy()
assert cpda.cupy.is_cupy
if isinstance(cpda, xr.DataArray):
assert isinstance(cpda.data._meta, cp.ndarray)
@requires_pint
@pytest.mark.parametrize("obj", [da, ds])
def test_pint(obj):
import pint
import pint_xarray # noqa
as_pint = obj.pint.quantify()
assert not as_pint.cupy.is_cupy
cpda = as_pint.cupy.as_cupy()
if isinstance(cpda, xr.DataArray):
assert isinstance(cpda.data, pint.Quantity)
assert cpda.cupy.is_cupy
as_dask = as_pint.chunk()
if isinstance(as_dask, xr.DataArray):
assert isinstance(as_dask.data, pint.Quantity)
assert isinstance(as_dask.data.magnitude._meta, np.ndarray)
assert not as_dask.cupy.is_cupy
cpda = as_dask.cupy.as_cupy()
assert cpda.cupy.is_cupy
if isinstance(cpda, xr.DataArray):
assert isinstance(cpda.data._meta, cp.ndarray)