From 3d11630787dac29abd91c4d45c8dd58c8254831e Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 05:57:52 -0500 Subject: [PATCH 1/6] fix: raise errors for invalid shears and PixelScale WCS inits --- jax_galsim/shear.py | 46 +++++++++++++++++++++++++++++++++++++-------- pyproject.toml | 2 +- tests/GalSim | 2 +- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 074a762e..55b639a5 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -1,3 +1,4 @@ +import equinox import galsim as _galsim import jax.numpy as jnp from galsim.errors import GalSimIncompatibleValuesError @@ -45,15 +46,24 @@ def __init__(self, *args, **kwargs): # g1,g2 elif "g1" in kwargs or "g2" in kwargs: - g1 = kwargs.pop("g1", 0.0) - g2 = kwargs.pop("g2", 0.0) + g1 = jnp.array(kwargs.pop("g1", 0.0)) + g2 = jnp.array(kwargs.pop("g2", 0.0)) self._g = g1 + 1j * g2 + self._g = equinox.error_if( + self._g, jnp.abs(self._g) > 1., + "Requested shear exceeds 1.", + ) # e1,e2 elif "e1" in kwargs or "e2" in kwargs: - e1 = kwargs.pop("e1", 0.0) - e2 = kwargs.pop("e2", 0.0) + e1 = jnp.array(kwargs.pop("e1", 0.0)) + e2 = jnp.array(kwargs.pop("e2", 0.0)) absesq = e1**2 + e2**2 + absesq = equinox.error_if( + absesq, + absesq > 1., + "Requested distortion exceeds 1.", + ) self._g = (e1 + 1j * e2) * self._e2g(absesq) # eta1,eta2 @@ -75,7 +85,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - g = kwargs.pop("g") + g = jnp.array(kwargs.pop("g")) + g = equinox.error_if( + g, + g > 1 or g < 0, + "Requested |shear| is outside [0,1].", + ) self._g = g * jnp.exp(2j * beta.rad) # e,beta @@ -89,7 +104,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - e = kwargs.pop("e") + e = jnp.array(kwargs.pop("e")) + e = equinox.error_if( + e, + (e > 1) | (e < 0), + "Requested distortion is outside [0,1].", + ) self._g = self._e2g(e**2) * e * jnp.exp(2j * beta.rad) # eta,beta @@ -103,7 +123,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - eta = kwargs.pop("eta") + eta = jnp.array(kwargs.pop("eta")) + eta = equinox.error_if( + eta, + eta < 0, + "Requested eta is below 0.", + ) self._g = self._eta2g(eta) * eta * jnp.exp(2j * beta.rad) # q,beta @@ -117,7 +142,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - q = kwargs.pop("q") + q = jnp.array(kwargs.pop("q")) + q = equinox.error_if( + q, + (q <= 0) | (q > 1), + "Cannot use requested axis ratio.", + ) eta = -jnp.log(q) self._g = self._eta2g(eta) * eta * jnp.exp(2j * beta.rad) diff --git a/pyproject.toml b/pyproject.toml index 2ad2d9aa..f8ca4441 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ description = "The modular galaxy image simulation toolkit, but in JAX" dynamic = ["version"] license = { file = "LICENSE" } readme = "README.md" -dependencies = ["numpy >=1.18.0", "galsim >=2.7.0", "jax >=0.6.0", "astropy >=2.0", "quadax"] +dependencies = ["numpy >=1.18.0", "galsim >=2.7.0", "jax >=0.6.0", "astropy >=2.0", "quadax", "equinox"] [project.optional-dependencies] dev = ["pytest", "pytest-codspeed"] diff --git a/tests/GalSim b/tests/GalSim index a5afbf51..11c473b4 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit a5afbf510dc747f5667f61c742b9dd3630643988 +Subproject commit 11c473b4fde8b8b730af654a47b96e7894862d57 From bd0e282c71dd6de8bfc3e502bf0477099c83224e Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 05:58:40 -0500 Subject: [PATCH 2/6] please the dog --- jax_galsim/shear.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 55b639a5..92d57693 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -50,7 +50,8 @@ def __init__(self, *args, **kwargs): g2 = jnp.array(kwargs.pop("g2", 0.0)) self._g = g1 + 1j * g2 self._g = equinox.error_if( - self._g, jnp.abs(self._g) > 1., + self._g, + jnp.abs(self._g) > 1.0, "Requested shear exceeds 1.", ) @@ -61,7 +62,7 @@ def __init__(self, *args, **kwargs): absesq = e1**2 + e2**2 absesq = equinox.error_if( absesq, - absesq > 1., + absesq > 1.0, "Requested distortion exceeds 1.", ) self._g = (e1 + 1j * e2) * self._e2g(absesq) From a3b7ba4c68c4e806cfc95c5420e3af844de1642a Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 06:31:39 -0500 Subject: [PATCH 3/6] fix: mock up equinox --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 11c473b4..062c9ed0 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 11c473b4fde8b8b730af654a47b96e7894862d57 +Subproject commit 062c9ed06ae309b1a47885ee8abee3b7860760ac From 5a43c922a80dd1d234e44c9e55055e1a7262991b Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 06:35:42 -0500 Subject: [PATCH 4/6] test: more array equals --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 062c9ed0..e5ee4016 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 062c9ed06ae309b1a47885ee8abee3b7860760ac +Subproject commit e5ee401606efcc43b6a8f6ca5a204f5d95befc94 From ff189007dfc0e6bd0a4872eeb89f882d134a7714 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 06:49:07 -0500 Subject: [PATCH 5/6] doc: update docs for shears --- jax_galsim/shear.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 92d57693..60e5b306 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -11,9 +11,11 @@ @register_pytree_node_class @implements( _galsim.Shear, - lax_description="""\ -The jax_galsim implementation of ``Shear`` does not perform range checking of the \ -shear (e.g., ``|g| <= 1``) upon construction.""", + lax_description=( + "While the JAX-GalSim implementation of ``Shear`` will do range checking of " + "the shear upon construction, it raises ``equinox.EquinoxRuntimeError`` exceptions " + "instead of ``galsim.GalSimRangeError`` exceptions." + ), ) class Shear(object): def __init__(self, *args, **kwargs): From 85569db75d31cb70ad11b1d4290b4ddabcdc91ff Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 06:51:39 -0500 Subject: [PATCH 6/6] fix: clarify docs --- jax_galsim/shear.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 60e5b306..59ef2cca 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -12,9 +12,9 @@ @implements( _galsim.Shear, lax_description=( - "While the JAX-GalSim implementation of ``Shear`` will do range checking of " - "the shear upon construction, it raises ``equinox.EquinoxRuntimeError`` exceptions " - "instead of ``galsim.GalSimRangeError`` exceptions." + "While the JAX-GalSim implementation of ``Shear`` will raise exceptions for " + "invalid shear values (e.g., |g| > 1), it raises ``equinox.EquinoxRuntimeError`` " + "exceptions instead of ``galsim.GalSimRangeError`` exceptions." ), ) class Shear(object):