Skip to content

Commit 8f14d67

Browse files
ilan-goldd-v-b
andauthored
fix: numpy default NaT handling (#3863)
* fix: numpy default NaT handling * fix: timedelta64 handling * fix: use isnat --------- Co-authored-by: Davis Bennett <davis.v.bennett@gmail.com>
1 parent 1bbc826 commit 8f14d67

2 files changed

Lines changed: 10 additions & 2 deletions

File tree

src/zarr/core/dtype/npy/time.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,8 @@ def cast_scalar(self, data: object) -> np.timedelta64:
545545
raise a TypeError.
546546
"""
547547
if self._check_scalar(data):
548+
if isinstance(data, np.timedelta64) and np.isnat(data):
549+
return np.timedelta64("NaT", self.unit)
548550
return self._cast_scalar_unchecked(data)
549551
msg = (
550552
f"Cannot convert object {data!r} with type {type(data)} to a scalar compatible with the "
@@ -559,7 +561,7 @@ def default_scalar(self) -> np.timedelta64:
559561
This method provides a default value for the timedelta64 scalar, which is
560562
a 'Not-a-Time' (NaT) value.
561563
"""
562-
return np.timedelta64("NaT")
564+
return np.timedelta64("NaT", self.unit)
563565

564566
def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> np.timedelta64:
565567
"""

tests/test_dtype/test_npy/test_time.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class TestTimeDelta64(_TestTimeBase):
115115

116116
cast_value_params = (
117117
(TimeDelta64(unit="ns", scale_factor=1), "1", np.timedelta64(1, "ns")),
118-
(TimeDelta64(unit="ns", scale_factor=1), "NaT", np.timedelta64("NaT")),
118+
(TimeDelta64(unit="ns", scale_factor=1), "NaT", np.timedelta64("NaT", "ns")),
119119
)
120120
invalid_scalar_params = (
121121
(TimeDelta64(unit="Y", scale_factor=1), 1.3),
@@ -148,6 +148,12 @@ def test_time_scale_factor_too_low() -> None:
148148
TimeDelta64(scale_factor=scale_factor)
149149

150150

151+
def test_default_is_NaT() -> None:
152+
np.testing.assert_equal(
153+
TimeDelta64(unit="ns", scale_factor=1).default_scalar(), np.timedelta64("NaT", "ns")
154+
)
155+
156+
151157
def test_time_scale_factor_too_high() -> None:
152158
"""
153159
Test that an invalid unit raises a ValueError.

0 commit comments

Comments
 (0)