Skip to content

Commit

Permalink
Merge pull request #244 from astro-informatics/mmg/healpix-gradient-fix
Browse files Browse the repository at this point in the history
Correct `healpix_forward` derivatives and add support for forward and higher order autodiff
  • Loading branch information
matt-graham authored Nov 26, 2024
2 parents 5210481 + fe82eaa commit dc9b2bc
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 80 deletions.
242 changes: 176 additions & 66 deletions s2fft/transforms/c_backend_spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

# C backend functions for which to provide JAX frontend.
import pyssht
from jax import custom_vjp
from jax import core, custom_vjp
from jax.interpreters import ad

from s2fft.sampling import reindex
from s2fft.utils import quadrature_jax
Expand Down Expand Up @@ -241,83 +242,181 @@ def _ssht_forward_bwd(res, flm):
return f, None, None, None, None, None


@custom_vjp
def healpy_inverse(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
r"""
Compute the inverse scalar real spherical harmonic transform (HEALPix JAX).
# Link JAX gradients for C backend functions
ssht_inverse.defvjp(_ssht_inverse_fwd, _ssht_inverse_bwd)
ssht_forward.defvjp(_ssht_forward_fwd, _ssht_forward_bwd)

HEALPix is a C++ library which implements the scalar spherical harmonic transform
outlined in [1]. We make use of their healpy python bindings for which we provide
custom JAX frontends, hence providing support for automatic differentiation. Currently
these transforms can only be deployed on CPU, which is a limitation of the C++ library.

Args:
flm (jnp.ndarray): Spherical harmonic coefficients.
def _complex_dtype(real_dtype):
"""
Get complex datatype corresponding to a given real datatype.
L (int): Harmonic band-limit.
Derived from https://github.com/jax-ml/jax/blob/1471702adc28/jax/_src/lax/fft.py#L92
nside (int, optional): HEALPix Nside resolution parameter. Only required
if sampling="healpix". Defaults to None.
Original license:
Returns:
jnp.ndarray: Signal on the sphere.
Copyright 2019 The JAX Authors.
Note:
[1] Gorski, Krzysztof M., et al. "HEALPix: A framework for high-resolution
discretization and fast analysis of data distributed on the sphere." The
Astrophysical Journal 622.2 (2005): 759
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
flm = reindex.flm_2d_to_hp_fast(flm, L)
f = jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside))
return f
return (np.zeros((), real_dtype) + np.zeros((), np.complex64)).dtype


def _real_dtype(complex_dtype):
"""
Get real datatype corresponding to a given complex datatype.
Derived from https://github.com/jax-ml/jax/blob/1471702adc28/jax/_src/lax/fft.py#L93
Original license:
Copyright 2019 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
return np.finfo(complex_dtype).dtype


def _healpy_map2alm_impl(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
return jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=0))

def _healpy_inverse_fwd(flm: jnp.ndarray, L: int, nside: int):
"""Private function which implements the forward pass for inverse jax_healpy."""
res = ([], L, nside)
return healpy_inverse(flm, L, nside), res

def _healpy_map2alm_abstract_eval(
f: core.ShapedArray, L: int, nside: int
) -> core.ShapedArray:
return core.ShapedArray(shape=(L * (L + 1) // 2,), dtype=_complex_dtype(f.dtype))

def _healpy_inverse_bwd(res, f):
"""Private function which implements the backward pass for inverse jax_healpy."""
_, L, nside = res
f_new = f * (12 * nside**2) / (4 * jnp.pi)
flm_out = jnp.array(
np.conj(healpy.map2alm(np.conj(np.array(f_new)), lmax=L - 1, iter=0))

def _healpy_map2alm_transpose(dflm: jnp.ndarray, L: int, nside: int):
scale_factors = (
jnp.concatenate((jnp.ones(L), 2 * jnp.ones(L * (L - 1) // 2)))
* (3 * nside**2)
/ jnp.pi
)
# iter MUST be zero otherwise gradient propagation is incorrect (JDM).
flm_out = reindex.flm_hp_to_2d_fast(flm_out, L)
m_conj = (-1) ** (jnp.arange(1, L) % 2)
flm_out = flm_out.at[..., L:].add(
jnp.flip(m_conj * jnp.conj(flm_out[..., : L - 1]), axis=-1)
return (jnp.conj(healpy_alm2map(jnp.conj(dflm) / scale_factors, L, nside)),)


_healpy_map2alm_p = core.Primitive("healpy_map2alm")
_healpy_map2alm_p.def_impl(_healpy_map2alm_impl)
_healpy_map2alm_p.def_abstract_eval(_healpy_map2alm_abstract_eval)
ad.deflinear(_healpy_map2alm_p, _healpy_map2alm_transpose)


def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
"""
JAX wrapper for healpy map2alm function (forward spherical harmonic transform).
This wrapper will return the spherical harmonic coefficients as a one dimensional
array using HEALPix (ring-ordered) indexing. To instead return a two-dimensional
array of harmonic coefficients use :py:func:`healpy_forward`.
Args:
f (jnp.ndarray): Signal on the sphere.
L (int): Harmonic band-limit. Equivalent to `lmax + 1` in healpy.
nside (int): HEALPix Nside resolution parameter.
Returns:
jnp.ndarray: Harmonic coefficients of signal f.
"""
return _healpy_map2alm_p.bind(f, L=L, nside=nside)


def _healpy_alm2map_impl(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
return jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside))


def _healpy_alm2map_abstract_eval(
flm: core.ShapedArray, L: int, nside: int
) -> core.ShapedArray:
return core.ShapedArray(shape=(12 * nside**2,), dtype=_real_dtype(flm.dtype))


def _healpy_alm2map_transpose(df: jnp.ndarray, L: int, nside: int) -> tuple:
scale_factors = (
jnp.concatenate((jnp.ones(L), 2 * jnp.ones(L * (L - 1) // 2)))
* (3 * nside**2)
/ jnp.pi
)
flm_out = flm_out.at[..., : L - 1].set(0)
# Scale factor above includes the inverse quadrature weight given by
# (12 * nside**2) / (4 * jnp.pi) = (3 * nside**2) / jnp.pi
# and also a factor of 2 for m>0 to account for the negative m.
# See explanation in this issue comment:
# https://github.com/astro-informatics/s2fft/issues/243#issuecomment-2500951488
return (scale_factors * jnp.conj(healpy_map2alm(jnp.conj(df), L, nside)),)

return flm_out, None, None

_healpy_alm2map_p = core.Primitive("healpy_alm2map")
_healpy_alm2map_p.def_impl(_healpy_alm2map_impl)
_healpy_alm2map_p.def_abstract_eval(_healpy_alm2map_abstract_eval)
ad.deflinear(_healpy_alm2map_p, _healpy_alm2map_transpose)


def healpy_alm2map(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
"""
JAX wrapper for healpy alm2map function (inverse spherical harmonic transform).
This wrapper assumes the passed spherical harmonic coefficients are a one
dimensional array using HEALPix (ring-ordered) indexing. To instead pass a
two-dimensional array of harmonic coefficients use :py:func:`healpy_inverse`.
Args:
flm (jnp.ndarray): Spherical harmonic coefficients.
L (int): Harmonic band-limit. Equivalent to `lmax + 1` in healpy.
nside (int): HEALPix Nside resolution parameter.
Returns:
jnp.ndarray: Signal on the sphere.
"""
return _healpy_alm2map_p.bind(flm, L=L, nside=nside)


@custom_vjp
def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.ndarray:
r"""
Compute the forward scalar spherical harmonic transform (healpy JAX).
HEALPix is a C++ library which implements the scalar spherical harmonic transform
outlined in [1]. We make use of their healpy python bindings for which we provide
custom JAX frontends, hence providing support for automatic differentiation. Currently
these transforms can only be deployed on CPU, which is a limitation of the C++ library.
custom JAX frontends, hence providing support for automatic differentiation.
Currently these transforms can only be deployed on CPU, which is a limitation of the
C++ library.
Args:
f (jnp.ndarray): Signal on the sphere.
L (int): Harmonic band-limit.
nside (int, optional): HEALPix Nside resolution parameter. Only required
if sampling="healpix". Defaults to None.
nside (int): HEALPix Nside resolution parameter.
iter (int, optional): Number of subiterations for healpy. Note that iterations
increase the precision of the forward transform, but reduce the accuracy of
the gradient pass. Between 2 and 3 iterations is a good compromise.
iter (int, optional): Number of subiterations (iterative refinement steps) for
healpy. Note that iterations increase the precision of the forward transform
as an inverse of inverse transform, but with a linear increase in
computational cost. Between 2 and 3 iterations is a good compromise.
Returns:
jnp.ndarray: Harmonic coefficients of signal f.
Expand All @@ -328,28 +427,39 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
Astrophysical Journal 622.2 (2005): 759
"""
flm = jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=iter))
flm = healpy_map2alm(f, L, nside)
for _ in range(iter):
f_recov = healpy_alm2map(flm, L, nside)
f_error = f - f_recov
flm += healpy_map2alm(f_error, L, nside)
return reindex.flm_hp_to_2d_fast(flm, L)


def _healpy_forward_fwd(f: jnp.ndarray, L: int, nside: int, iter: int = 3):
"""Private function which implements the forward pass for forward jax_healpy."""
res = ([], L, nside, iter)
return healpy_forward(f, L, nside, iter), res
def healpy_inverse(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
r"""
Compute the inverse scalar real spherical harmonic transform (HEALPix JAX).
HEALPix is a C++ library which implements the scalar spherical harmonic transform
outlined in [1]. We make use of their healpy python bindings for which we provide
custom JAX frontends, hence providing support for automatic differentiation.
Currently these transforms can only be deployed on CPU, which is a limitation of the
C++ library.
def _healpy_forward_bwd(res, flm):
"""Private function which implements the backward pass for forward jax_healpy."""
_, L, nside, _ = res
flm_new = reindex.flm_2d_to_hp_fast(flm, L)
f = jnp.array(
np.conj(healpy.alm2map(np.conj(np.array(flm_new)), lmax=L - 1, nside=nside))
)
return f * (4 * jnp.pi) / (12 * nside**2), None, None, None
Args:
flm (jnp.ndarray): Spherical harmonic coefficients.
L (int): Harmonic band-limit.
# Link JAX gradients for C backend functions
ssht_inverse.defvjp(_ssht_inverse_fwd, _ssht_inverse_bwd)
ssht_forward.defvjp(_ssht_forward_fwd, _ssht_forward_bwd)
healpy_inverse.defvjp(_healpy_inverse_fwd, _healpy_inverse_bwd)
healpy_forward.defvjp(_healpy_forward_fwd, _healpy_forward_bwd)
nside (int): HEALPix Nside resolution parameter.
Returns:
jnp.ndarray: Signal on the sphere.
Note:
[1] Gorski, Krzysztof M., et al. "HEALPix: A framework for high-resolution
discretization and fast analysis of data distributed on the sphere." The
Astrophysical Journal 622.2 (2005): 759
"""
flm = reindex.flm_2d_to_hp_fast(flm, L)
return healpy_alm2map(flm, L, nside)
18 changes: 4 additions & 14 deletions tests/test_spherical_custom_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,22 +307,16 @@ def func(f):
@pytest.mark.parametrize("nside", nside_to_test)
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_healpix_c_backend_inverse_custom_gradients(flm_generator, nside: int):
sampling = "healpix"
L = 2 * nside
reality = True
flm = flm_generator(L=L, reality=reality)
flm_target = flm_generator(L=L, reality=reality)
f_target = spherical.inverse_jax(
flm_target, L, nside=nside, sampling=sampling, reality=reality
)

def func(flm):
f = spherical.inverse(
return spherical.inverse(
flm, L, 0, nside, sampling="healpix", method="jax_healpy", reality=True
)
return jnp.sum(jnp.abs(f - f_target) ** 2)

check_grads(func, (flm,), order=1, modes=("rev"))
check_grads(func, (flm,), order=2, modes=("fwd", "rev"))


@pytest.mark.parametrize("nside", nside_to_test)
Expand All @@ -334,16 +328,12 @@ def test_healpix_c_backend_forward_custom_gradients(
sampling = "healpix"
L = 2 * nside
reality = True
flm_target = flm_generator(L=L, reality=reality)
flm = flm_generator(L=L, reality=reality)
f = spherical.inverse_jax(flm, L, nside=nside, sampling=sampling, reality=reality)

def func(f):
flm = spherical.forward(
return spherical.forward(
f, L, nside=nside, sampling="healpix", method="jax_healpy", iter=iter
)
return jnp.sum(jnp.abs(flm - flm_target) ** 2)

rtol = [1e-6, 1e-2, 5e-2, 1e-2][iter]

check_grads(func, (f,), order=1, modes=("rev"), rtol=rtol)
check_grads(func, (f,), order=2, modes=("fwd", "rev"))

0 comments on commit dc9b2bc

Please sign in to comment.