diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index fad56976..aec36d90 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -21,9 +21,30 @@ # SOFTWARE. import galsim as _galsim import jax.numpy as jnp +import numpy as np from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import cast_to_float, ensure_hashable, implements +from jax_galsim.core.utils import ( + cast_to_float, + ensure_hashable, + has_tracers, + implements, +) + +NON_COMPLEX_TYPES = ( + float, + int, + np.int16, + np.int32, + np.int64, + np.float32, + np.float64, + jnp.int16, + jnp.int32, + jnp.int64, + jnp.float32, + jnp.float64, +) @implements(_galsim.AngleUnit) @@ -178,6 +199,10 @@ def __sub__(self, other): return _Angle(self._rad - other._rad) def __mul__(self, other): + if not (has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES)): + raise TypeError( + "Cannot multiply Angle by %s of type %s" % (other, type(other)) + ) return _Angle(self._rad * other) __rmul__ = __mul__ @@ -185,8 +210,12 @@ def __mul__(self, other): def __div__(self, other): if isinstance(other, AngleUnit): return self._rad / other.value - else: + elif has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES): return _Angle(self._rad / other) + else: + raise TypeError( + "Cannot divide Angle by %s of type %s" % (other, type(other)) + ) __truediv__ = __div__ diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 10442ad4..21b6ba8a 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,34 +1,24 @@ import galsim as _galsim import jax import jax.numpy as jnp -import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( + CONST_TYPES, cast_to_float, cast_to_int, + cast_to_python_float, + check_is_int_then_cast, ensure_hashable, has_tracers, implements, ) from jax_galsim.position import Position, PositionD, PositionI -CONST_TYPES = (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64) -CONST_TYPES_WITH_JAX = CONST_TYPES + ( - jax.Array, - jnp.array, - jnp.int32, - jnp.int64, - jnp.float32, - jnp.float64, -) - -# TODO: write extra docs for JAX changes BOUNDS_LAX_DESCR = """\ The JAX implementation - will not always test whether the bounds are valid -- will not always test whether BoundsI is initialized with integers Further, the JAX implementation adds a new method, ``isStatic`` to the ``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance @@ -525,31 +515,27 @@ def __init__(self, *args, **kwargs): f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}." ) + self.deltax = cast_to_python_float(self.deltax) + self.deltay = cast_to_python_float(self.deltay) + if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): + raise TypeError("BoundsI must be initialized with integer values") self.deltax = int(cast_to_int(self.deltax)) self.deltay = int(cast_to_int(self.deltay)) - if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): - raise TypeError("BoundsI must be initialized with integer values") + if has_tracers(self._xmin) or has_tracers(self._ymin): + self._isstatic = False + + # validate inputs are ints + self._xmin = check_is_int_then_cast( + self._xmin, "BoundsI must be initialized with integer values" + ) + self._ymin = check_is_int_then_cast( + self._ymin, "BoundsI must be initialized with integer values" + ) if self.deltax < 1 and self.deltay < 1: self._isdefined = False - # for simple inputs, we can check if the bounds are valid ints - if isinstance(self._xmin, CONST_TYPES) and self._xmin != int(self._xmin): - raise TypeError("BoundsI must be initialized with integer values") - - if isinstance(self._ymin, CONST_TYPES) and self._ymin != int(self._ymin): - raise TypeError("BoundsI must be initialized with integer values") - - if not has_tracers(self._xmin) and not has_tracers(self._ymin): - self._isstatic = True - self._xmin = int(np.trunc(self._xmin)) - self._ymin = int(np.trunc(self._ymin)) - else: - self._isstatic = False - self._xmin = cast_to_float(jnp.trunc(self._xmin)) - self._ymin = cast_to_float(jnp.trunc(self._ymin)) - if force_static and not self._isstatic: raise RuntimeError( "BoundsI initialized with non-static " diff --git a/jax_galsim/celestial.py b/jax_galsim/celestial.py index 1b6e992f..5645eda3 100644 --- a/jax_galsim/celestial.py +++ b/jax_galsim/celestial.py @@ -23,9 +23,11 @@ from functools import partial import coord as _coord +import equinox import galsim as _galsim import jax import jax.numpy as jnp +import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.angle import Angle, _Angle, arcsec, degrees, radians @@ -74,6 +76,16 @@ def __init__(self, ra, dec=None): elif not isinstance(dec, Angle): raise TypeError("dec must be a galsim.Angle") else: + if isinstance(dec._rad, (float, int)): + if dec._rad < -np.pi / 2 or dec._rad > np.pi / 2: + raise ValueError("dec must be between -90 deg and +90 deg.") + else: + dec._rad = equinox.error_if( + jnp.array(dec._rad), + jnp.any((dec._rad < -jnp.pi / 2) | (dec._rad > jnp.pi / 2)), + "dec must be between -90 deg and +90 deg.", + ) + # Normal case self._ra = ra self._dec = dec @@ -121,15 +133,14 @@ def get_xyz(self): @staticmethod @jax.jit - @implements( - _galsim.celestial.CelestialCoord.from_xyz, - lax_description=( - "The JAX version of this static method does not check that the norm of the input " - "vector is non-zero." - ), - ) + @implements(_galsim.celestial.CelestialCoord.from_xyz) def from_xyz(x, y, z): norm = jnp.sqrt(x * x + y * y + z * z) + norm = equinox.error_if( + norm, + jnp.any(norm == 0), + "CelestialCoord for position (0,0,0) is undefined.", + ) ret = CelestialCoord.__new__(CelestialCoord) ret._x = x / norm ret._y = y / norm @@ -236,13 +247,7 @@ def distanceTo(self, coord2): return _Angle(theta) - @implements( - _galsim.celestial.CelestialCoord.greatCirclePoint, - lax_description=( - "The JAX version of this method does not check that coord2 defines a unique great " - "circle with the current coord at angle theta." - ), - ) + @implements(_galsim.celestial.CelestialCoord.greatCirclePoint) @jax.jit def greatCirclePoint(self, coord2, theta): aux = self._get_aux() @@ -280,8 +285,11 @@ def greatCirclePoint(self, coord2, theta): # Normalize wr = (wx**2 + wy**2 + wz**2) ** 0.5 - # if wr == 0.: - # raise ValueError("coord2 does not define a unique great circle with self.") + wr = equinox.error_if( + wr, + jnp.any(wr == 0), + "coord2 does not define a unique great circle with self.", + ) wx /= wr wy /= wr wz /= wr diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 3fcbf46d..f6785ceb 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -3,11 +3,60 @@ from functools import partial from typing import NamedTuple +import equinox import jax import jax.numpy as jnp import numpy as np from jax.tree_util import tree_flatten +CONST_TYPES = ( + float, + int, + np.ndarray, + np.int8, + np.int16, + np.int32, + np.int64, + np.float16, + np.float32, + np.float64, + np.complex64, + np.complex128, +) +CONST_TYPES_WITH_JAX = CONST_TYPES + ( + jax.Array, + jnp.ndarray, + jnp.int8, + jnp.int16, + jnp.int32, + jnp.int64, + jnp.float32, + jnp.float64, + jnp.complex64, + jnp.complex128, +) + + +def check_is_int_then_cast(val, msg): + """Check if `val` is an integer, raise if not, otherwise cast to int.""" + # for simple inputs, we can check direct in python + if isinstance(val, CONST_TYPES) and not has_tracers(val): + val = cast_to_python_float(val) + if val != int(val): + raise TypeError(msg) + val = int(val) + else: + # otherwise we use more opaque checking upon jit via equinox + val = jnp.array(val) + val = equinox.error_if( + val, + np.any(val != jnp.trunc(val)), + msg, + ) + val = val.astype(int) + + return val + def cast_numpy_array_to_native_byte_order(arr): """Cast an array to native byte order.""" diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index 232801d6..d64564ab 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -1,6 +1,7 @@ import copy import os +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -1094,12 +1095,10 @@ def _step(i, args): unroll=True, )[0:4] - x, y = jax.lax.cond( - jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12, - lambda x, y: (x * jnp.nan, y * jnp.nan), - lambda x, y: (x, y), - x, - y, + x, y = equinox.error_if( + (x, y), + jnp.any(jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12), + "Unable to solve for image_pos (max iter reached).", ) return x, y diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 3d72dbab..a5cf51b7 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1,6 +1,7 @@ from collections import namedtuple from functools import partial +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -574,12 +575,13 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): lax_description="""\ The JAX-GalSim version of ``drawImage`` -- does not do extensive (any?) checking of the input settings. - uses a default of ``n_photons=None`` instead of ``n_photons=0`` to indicate that the number of photons should be determined from the flux and gain +- uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` - requires that the ``maxN`` option be a constant since PhotonArrays are allocated with ``maxN`` photons when this option is used and arrays in JAX must have static sizes. +- raises a generic ``Exception`` instead of a more specific exception for some invalid inputs """, ) def drawImage( @@ -601,7 +603,7 @@ def drawImage( offset=None, n_photons=None, rng=None, - max_extra_noise=0.0, + max_extra_noise=None, poisson_flux=None, sensor=None, photon_ops=(), @@ -626,6 +628,16 @@ def drawImage( if image is not None and not isinstance(image, Image): raise TypeError("image is not an Image instance", image) + # Make sure (gain, area, exptime) have valid values: + gain = jnp.array(gain) + gain = equinox.error_if(gain, jnp.any(gain <= 0.0), "Invalid gain <= 0.") + area = jnp.array(area) + area = equinox.error_if(area, jnp.any(area <= 0.0), "Invalid area <= 0.") + exptime = jnp.array(exptime) + exptime = equinox.error_if( + exptime, jnp.any(exptime <= 0.0), "Invalid exptime <= 0." + ) + if method == "phot" and save_photons and maxN is not None: raise GalSimIncompatibleValuesError( "Setting maxN is incompatible with save_photons=True" @@ -659,6 +671,13 @@ def drawImage( sensor=sensor, n_photons=n_photons, ) + if max_extra_noise is not None: + raise GalSimIncompatibleValuesError( + "max_extra_noise is only relevant for method='phot'", + method=method, + sensor=sensor, + max_extra_noise=max_extra_noise, + ) if poisson_flux is not None: raise GalSimIncompatibleValuesError( "poisson_flux is only relevant for method='phot'", @@ -1078,6 +1097,8 @@ def _drawKImage( @implements(_galsim.GSObject._calculate_nphotons) def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): + if max_extra_noise is None: + max_extra_noise = 0.0 n_photons, g, _rng = calculate_n_photons( self.flux, self._flux_per_photon, @@ -1096,17 +1117,17 @@ def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): lax_description="""\ The JAX-GalSim version of ``makePhot`` -- does little to no error checking on the inputs - uses a default of ``n_photons=None`` instead of ``n_photons=0`` - to indicate that the number of photons should be determined - from the flux and gain +- uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` + to indicate no limit on the extra noise +- raises a generic ``Exception`` instead of a more specific exception for some invalid inputs """, ) def makePhot( self, n_photons=None, rng=None, - max_extra_noise=0.0, + max_extra_noise=None, poisson_flux=None, photon_ops=(), local_wcs=None, @@ -1168,6 +1189,8 @@ def makePhot( - uses a default of ``n_photons=None`` instead of ``n_photons=0`` to indicate that the number of photons should be determined from the flux and gain +- uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` +- raises a generic ``Exception`` instead of a more specific exception for some invalid inputs - requires that the ``maxN`` option must be a constant """, ) @@ -1178,7 +1201,7 @@ def drawPhot( add_to_image=False, n_photons=None, rng=None, - max_extra_noise=0.0, + max_extra_noise=None, poisson_flux=None, sensor=None, photon_ops=(), @@ -1208,6 +1231,9 @@ def drawPhot( elif not isinstance(sensor, Sensor): raise TypeError("The sensor provided is not a Sensor instance") + gain = jnp.array(gain) + gain = equinox.error_if(gain, jnp.any(gain <= 0.0), "Invalid gain <= 0.") + if n_photons is not None: # n_photons is the length of an array so it is a python int and # and thus a constant wrt to JIT diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 3e8c0e69..d7eac192 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -103,6 +103,13 @@ def __init__(self, *args, **kwargs): else: if "array" in kwargs: array = kwargs.pop("array") + if has_tracers(array) or isinstance(array, jnp.ndarray): + pass + elif isinstance(array, np.ndarray): + array = jnp.array(cast_numpy_array_to_native_byte_order(array)) + else: + raise TypeError("Unable to parse %s as an array." % array) + array, xmin, ymin = self._get_xmin_ymin( array, kwargs, check_bounds=_check_bounds ) @@ -326,7 +333,7 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): def __repr__(self): s = "galsim.Image(bounds=%r" % self.bounds - if self.bounds.isDefined(): + if self.bounds.isDefined() and not has_tracers(self.array): s += ", array=\n%r" % (ensure_hashable(np.array(self.array)),) s += ", wcs=%r" % self.wcs if self.isconst: diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index 19ad5c4b..20397be2 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -1,5 +1,6 @@ from functools import partial +import equinox import galsim as _galsim import jax.lax import jax.numpy as jnp @@ -17,8 +18,8 @@ - This implementation is different than the one in GalSim and lacks some features that greatly enhance galsim's accuracy. -- The JAX-GalSim implementation returns NaN on error/non-convergence instead of - rasing an exception. +- The JAX-GalSim implementation raises a generic ``Exception`` on error/non-convergence + instead of rasing a ``galsim.GalSimError`` exception. """ ), ) @@ -72,8 +73,8 @@ def _base_integration(): _base_integration, ) - return jax.lax.cond( - status == 0, - lambda: val, - lambda: jnp.nan, + return equinox.error_if( + val, + jnp.any(status != 0), + "`jax_galsim.int1d` failed to converge!", ) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 0424e9d8..9564bd36 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -2,6 +2,7 @@ import math from functools import partial +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -59,6 +60,8 @@ def __dir__(cls): - the pad_image options - depixelize - most of the bounds checks, type checks, and dtype casts done by galsim +- raises a generic ``Exception`` instead of a more specific one for some + initialization errors """ @@ -118,6 +121,11 @@ def __init__( elif not isinstance(image, Image): raise TypeError("Supplied image must be an Image or file name") + if not (image.dtype == jnp.float32 or image.dtype == jnp.float64): + raise GalSimValueError( + "Interpolated images must use a float-type image.", image.dtype + ) + self._jax_children = ( image, dict( @@ -506,6 +514,15 @@ def __init__( image=self._jax_children[0], ) + if calculate_stepk or calculate_maxk or flux is not None: + image.array = jnp.array(image.array) + image.array = equinox.error_if( + image.array, + jnp.any(image.array.sum() == 0.0), + "This input image has zero total flux. It does not define a " + "valid surface brightness profile.", + ) + @doc_inherit def withGSParams(self, gsparams=None, **kwargs): if gsparams == self.gsparams: diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 2a9b312b..932c48b8 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -1,5 +1,6 @@ from functools import partial +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -31,7 +32,7 @@ def _Knu(nu, x): lax_description="""\ The JAX-GalSim version of the Moffat profile -- does not support truncation or beta < 1.1 +- does not support truncation or beta <= 1.1 - does not support gsparams.maxk_thresholds > 0.1 - does not support autodiff with respect to the `beta` parameter for Fourier-space evaluations @@ -67,6 +68,19 @@ def __init__( f"(got trunc={repr(trunc)}, always pass the constant 0.0)!" ) + if isinstance(beta, (float, int)): + if beta <= self._beta_thr: + raise ValueError( + f"JAX-GalSim does not support Moffat beta values <= {self._beta_thr}." + ) + else: + beta = jnp.array(beta) + beta = equinox.error_if( + beta, + jnp.any(beta <= self._beta_thr), + f"JAX-GalSim does not support Moffat beta values <= {self._beta_thr}.", + ) + # Parse the radius options if half_light_radius is not None: if scale_radius is not None or fwhm is not None: diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 336c28a7..967f55c3 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -561,6 +561,24 @@ def copyFrom( do_flux=True, do_other=True, ): + # jax naturally checks the other error cases in the test suite with the `.at` + # syntax, but it does not check out of bounds inds like ints so we do that here + if isinstance(target_indices, int) and ( + target_indices < -self._nokeep.shape[0] + or target_indices >= self._nokeep.shape[0] + ): + raise ValueError( + f"target_indices is invalid for the target PhotonArray. Got {target_indices!r}" + ) + + if isinstance(source_indices, int) and ( + source_indices < -rhs._nokeep.shape[0] + or source_indices >= rhs._nokeep.shape[0] + ): + raise ValueError( + f"source_indices is invalid for the source PhotonArray. Got {source_indices!r}" + ) + return self._copyFrom( rhs, target_indices, source_indices, do_xy, do_flux, do_other ) diff --git a/jax_galsim/position.py b/jax_galsim/position.py index 822797b8..cf36dba8 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -5,7 +5,7 @@ from jax_galsim.core.utils import ( cast_to_float, - cast_to_int, + check_is_int_then_cast, ensure_hashable, implements, ) @@ -208,15 +208,26 @@ def _check_scalar(self, other, op): raise TypeError("Can only %s a PositionD by float values" % op) -@implements(_galsim.PositionI) +@implements( + _galsim.PositionI, + lax_description=( + "The ``jax_galsim.PositionI`` class will raise generic " + "``Exception``s instead of a more specific exception for invalid " + "inputs." + ), +) @register_pytree_node_class class PositionI(Position): def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) - # inputs must be ints - self.x = cast_to_int(self.x) - self.y = cast_to_int(self.y) + # validate input is int + self.x = check_is_int_then_cast( + self.x, "PositionI must be initialized with integer values" + ) + self.y = check_is_int_then_cast( + self.y, "PositionI must be initialized with integer values" + ) def _check_scalar(self, other, op): try: diff --git a/jax_galsim/random.py b/jax_galsim/random.py index db027ed7..789ec692 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -1,10 +1,12 @@ import secrets from functools import partial +import equinox import galsim as _galsim import jax import jax.numpy as jnp import jax.random as jrandom +import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import implements @@ -122,9 +124,19 @@ def reset(self, seed=None): self._state = _DeviateState( wrap_key_data(jnp.array(seed, dtype=jnp.uint32)) ) - else: + elif ( + isinstance( + seed, (int, jnp.ndarray, jax.Array, np.ndarray, np.integer, jnp.integer) + ) + or seed is None + ): _initial_seed = seed or secrets.randbelow(2**31) self._state = _DeviateState(jrandom.key(_initial_seed)) + else: + raise TypeError( + "Seeds for BaseDeviate must be an int-like, str, tuple, or another BaseDeviate. " + f"Got seed {seed!r}." + ) @property def _key(self): @@ -295,6 +307,20 @@ def __str__(self): class GaussianDeviate(BaseDeviate): def __init__(self, seed=None, mean=0.0, sigma=1.0): super().__init__(seed=seed) + + if isinstance(sigma, (int, float)): + if sigma < 0: + raise ValueError( + f"Gaussian deviates must have a non-negative sigma. Got {sigma!r}." + ) + else: + sigma = jnp.array(sigma) + sigma = equinox.error_if( + sigma, + jnp.any(sigma < 0), + f"Gaussian deviates must have a non-negative sigma. Got {sigma!r}.", + ) + self._params["mean"] = mean self._params["sigma"] = sigma @@ -435,6 +461,20 @@ def __str__(self): class PoissonDeviate(BaseDeviate): def __init__(self, seed=None, mean=1.0): super().__init__(seed=seed) + + if isinstance(mean, (int, float)): + if mean < 0: + raise ValueError( + f"Poisson deviates must have a non-negative mean. Got {mean!r}." + ) + else: + mean = jnp.array(mean) + mean = equinox.error_if( + mean, + jnp.any(mean < 0), + f"Poisson deviates must have a non-negative mean. Got {mean!r}.", + ) + self._params["mean"] = mean @property @@ -484,6 +524,11 @@ def _generate_one(key, mean): @implements(_galsim.PoissonDeviate.generate_from_expectation) def generate_from_expectation(self, array): + array = equinox.error_if( + jnp.array(array), + jnp.any(jnp.array(array) < 0), + "Poission deviates must have a non-negative mean.", + ) self._key, _array = self.__class__._generate_from_exp(self._key, array) return _array @@ -706,6 +751,10 @@ def __str__(self): ) def permute(rng, *args): rng = BaseDeviate(rng) + if len(args) == 0: + raise TypeError( + f"`galsim.random.permute` must be called with at least one array. Got {args!r}" + ) arrs = [] for arr in args: arrs.append(jrandom.permutation(rng._key, arr)) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 074a762e..b0663890 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -1,3 +1,4 @@ +import equinox import galsim as _galsim import jax.numpy as jnp from galsim.errors import GalSimIncompatibleValuesError @@ -10,9 +11,11 @@ @register_pytree_node_class @implements( _galsim.Shear, - lax_description="""\ -The jax_galsim implementation of ``Shear`` does not perform range checking of the \ -shear (e.g., ``|g| <= 1``) upon construction.""", + lax_description=( + "While the JAX-GalSim implementation of ``Shear`` will raise exceptions for " + "invalid shear values (e.g., |g| > 1), it raises a generic ``Exception`` " + "instead of a ``galsim.GalSimRangeError`` exception." + ), ) class Shear(object): def __init__(self, *args, **kwargs): @@ -45,15 +48,25 @@ def __init__(self, *args, **kwargs): # g1,g2 elif "g1" in kwargs or "g2" in kwargs: - g1 = kwargs.pop("g1", 0.0) - g2 = kwargs.pop("g2", 0.0) + g1 = jnp.array(kwargs.pop("g1", 0.0)) + g2 = jnp.array(kwargs.pop("g2", 0.0)) self._g = g1 + 1j * g2 + self._g = equinox.error_if( + self._g, + jnp.any(jnp.abs(self._g) > 1.0), + "Requested shear exceeds 1.", + ) # e1,e2 elif "e1" in kwargs or "e2" in kwargs: - e1 = kwargs.pop("e1", 0.0) - e2 = kwargs.pop("e2", 0.0) + e1 = jnp.array(kwargs.pop("e1", 0.0)) + e2 = jnp.array(kwargs.pop("e2", 0.0)) absesq = e1**2 + e2**2 + absesq = equinox.error_if( + absesq, + jnp.any(absesq > 1.0), + "Requested distortion exceeds 1.", + ) self._g = (e1 + 1j * e2) * self._e2g(absesq) # eta1,eta2 @@ -75,7 +88,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - g = kwargs.pop("g") + g = jnp.array(kwargs.pop("g")) + g = equinox.error_if( + g, + jnp.any((g > 1) | (g < 0)), + "Requested |shear| is outside [0,1].", + ) self._g = g * jnp.exp(2j * beta.rad) # e,beta @@ -89,7 +107,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - e = kwargs.pop("e") + e = jnp.array(kwargs.pop("e")) + e = equinox.error_if( + e, + jnp.any((e > 1) | (e < 0)), + "Requested distortion is outside [0,1].", + ) self._g = self._e2g(e**2) * e * jnp.exp(2j * beta.rad) # eta,beta @@ -103,7 +126,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - eta = kwargs.pop("eta") + eta = jnp.array(kwargs.pop("eta")) + eta = equinox.error_if( + eta, + jnp.any(eta < 0), + "Requested eta is below 0.", + ) self._g = self._eta2g(eta) * eta * jnp.exp(2j * beta.rad) # q,beta @@ -117,7 +145,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - q = kwargs.pop("q") + q = jnp.array(kwargs.pop("q")) + q = equinox.error_if( + q, + jnp.any((q <= 0) | (q > 1)), + "Cannot use requested axis ratio.", + ) eta = -jnp.log(q) self._g = self._eta2g(eta) * eta * jnp.exp(2j * beta.rad) diff --git a/pyproject.toml b/pyproject.toml index 2ad2d9aa..f8ca4441 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ description = "The modular galaxy image simulation toolkit, but in JAX" dynamic = ["version"] license = { file = "LICENSE" } readme = "README.md" -dependencies = ["numpy >=1.18.0", "galsim >=2.7.0", "jax >=0.6.0", "astropy >=2.0", "quadax"] +dependencies = ["numpy >=1.18.0", "galsim >=2.7.0", "jax >=0.6.0", "astropy >=2.0", "quadax", "equinox"] [project.optional-dependencies] dev = ["pytest", "pytest-codspeed"] diff --git a/tests/GalSim b/tests/GalSim index a5afbf51..549616e8 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit a5afbf510dc747f5667f61c742b9dd3630643988 +Subproject commit 549616e8ca4bb84142fae6cdb0a006669f92454b diff --git a/tests/conftest.py b/tests/conftest.py index 447c1678..53d2e1c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -98,6 +98,9 @@ def pytest_collection_modifyitems(config, items): ): item.add_marker(skip) + if any([t in item.nodeid for t in test_config["skipped_tests"]["coord"]]): + item.add_marker(skip) + @lru_cache(maxsize=128) def _infile(val, fname): diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 46a64c18..428572a4 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -28,6 +28,13 @@ enabled_tests: - test_astropy.py - test_celestial.py +# tests to explicitly skip +# applied on top of the enabled set above +skipped_tests: + coord: + - "tests/Coord/tests/test_celestial.py::test_xyz_raises" + - "tests/Coord/tests/test_celestial.py::test_greatcircle_raises" + # This documents which error messages will be allowed # without being reported as an error. These typically # correspond to features that are not implemented yet @@ -87,10 +94,6 @@ allowed_failures: - "module 'jax_galsim' has no attribute 'InterpolatedKImage'" - "module 'jax_galsim' has no attribute 'CorrelatedNoise'" - "CelestialCoord.precess is too slow" # cannot get jax to warmup but once it does it passes - - "ValueError not raised by from_xyz" - - "ValueError not raised by greatCirclePoint" - - "TypeError not raised by __mul__" - - "ValueError not raised by CelestialCoord" - "module 'jax_galsim' has no attribute 'BaseCorrelatedNoise'" - "module 'jax_galsim' has no attribute 'fft'" - "Transform does not support callable arguments." diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index 303631bd..5c2f7dc6 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -280,9 +280,9 @@ def _run_benchmark_invert_ab_noraise(u, v, ab): @pytest.mark.parametrize("kind", ["run"]) def test_benchmark_invert_ab_noraise(benchmark, kind): - u = jnp.arange(1000).astype(jnp.float64) - v = jnp.arange(1000).astype(jnp.float64) - ab = jnp.array([[[-0.5, 0.3], [-0.1, 2.0]], [[-1.0, 0.3], [-0.1, 4.0]]]) + u = jnp.arange(1000).astype(jnp.float64) / 1000.0 + v = jnp.arange(1000).astype(jnp.float64) / 1000.0 + ab = jnp.array([[[0.6, 0.04], [-0.03, 0.5]], [[0.4, -0.02], [0.01, 0.7]]]) dt = _run_benchmarks( benchmark, kind, diff --git a/tests/jax/test_celestial_jax.py b/tests/jax/test_celestial_jax.py index d4210ac9..8a9cc63e 100644 --- a/tests/jax/test_celestial_jax.py +++ b/tests/jax/test_celestial_jax.py @@ -1,4 +1,5 @@ import galsim as _galsim +import jax.numpy as jnp import numpy as np import pytest @@ -118,3 +119,22 @@ def test_celestial_jax_ecliptic_obliquity(): ecliptic_obliquity(epoch).rad, _ecliptic_obliquity(epoch).rad, ) + + +def test_celestial_jax_xyz_raises(): + np.testing.assert_raises( + Exception, jax_galsim.CelestialCoord.from_xyz, 0.0, 0.0, 0.0 + ) + + +def test_celestial_jax_greatcircle_raises(): + theta = 50 * jax_galsim.radians + eq1 = jax_galsim.CelestialCoord( + 0 * jax_galsim.radians, 0 * jax_galsim.radians + ) # point on the equator + eq2 = jax_galsim.CelestialCoord( + jnp.array(1) * jax_galsim.radians, 0 * jax_galsim.radians + ) # 1 radian along equator + + np.testing.assert_raises(Exception, eq1.greatCirclePoint, eq1, theta) + np.testing.assert_raises(Exception, eq2.greatCirclePoint, eq2, theta) diff --git a/tests/jax/test_moffat_jax.py b/tests/jax/test_moffat_jax.py new file mode 100644 index 00000000..4810a105 --- /dev/null +++ b/tests/jax/test_moffat_jax.py @@ -0,0 +1,18 @@ +import jax +import jax.numpy as jnp +import pytest + +import jax_galsim + + +def test_moffat_jax_beta_raises(): + + @jax.jit + def make_moffat(beta): + return jax_galsim.Moffat(beta, fwhm=1.0) + + with pytest.raises(Exception): + make_moffat(jnp.array(1.1)) + + with pytest.raises(Exception): + make_moffat(0.9) diff --git a/tests/jax/test_position_jax.py b/tests/jax/test_position_jax.py new file mode 100644 index 00000000..e73937e7 --- /dev/null +++ b/tests/jax/test_position_jax.py @@ -0,0 +1,21 @@ +import jax +import pytest + +import jax_galsim + + +def test_position_jax_int_raises_in_jit(): + + @jax.jit + def _make_pos(x, y): + return jax_galsim.PositionI(x, y) + + with pytest.raises(Exception): + _make_pos(1.2, 23) + + with pytest.raises(Exception): + _make_pos(12, 2.3) + + pos = _make_pos(1, 2) + assert pos.x == 1 + assert pos.y == 2 diff --git a/tests/jax/test_random_jax.py b/tests/jax/test_random_jax.py new file mode 100644 index 00000000..c9309eb6 --- /dev/null +++ b/tests/jax/test_random_jax.py @@ -0,0 +1,30 @@ +import jax +import pytest + +import jax_galsim + + +def test_random_jax_gaussian_pos_sigma_jit(): + @jax.jit + def _make_gauss(sigma): + return jax_galsim.GaussianDeviate(seed=10, sigma=sigma) + + with pytest.raises(Exception): + _make_gauss(-1.0) + + @jax.jit + def _make_gauss_again(sigma): + return jax_galsim.GaussianDeviate(seed=10, sigma=sigma) + + _make_gauss_again(1.0) + + with pytest.raises(Exception): + _make_gauss_again(-1) + + def _make_gauss_again_again(sigma): + return jax_galsim.GaussianDeviate(seed=10, sigma=sigma) + + _make_gauss_again_again(1.0) + + with pytest.raises(Exception): + jax.jit(_make_gauss_again_again)(-1)