jnp.roots: better support for computation under JIT

This commit is contained in:
Jake VanderPlas 2022-06-23 14:48:53 -07:00
parent 2744404809
commit f6476f7a03
3 changed files with 127 additions and 143 deletions

View File

@ -9,11 +9,15 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
-->
## jax 0.3.15 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.14...main).
* Changes
* {func}`jax.numpy.roots` is now better behaved when `strip_zeros=False` when
coefficients have leading zeros ({jax-issue}`#11215`).
## jaxlib 0.3.15 (Unreleased)
## jax 0.3.14 (June 21, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...main).
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...jax-v0.3.14).
* Breaking changes
* {func}`jax.experimental.compilation_cache.initialize_cache` does not support
`max_cache_size_ bytes` anymore and will not get that as an input.

View File

@ -19,86 +19,81 @@ import operator
from jax import core
from jax import jit
from jax import lax
from jax._src import dtypes
from jax._src.numpy.lax_numpy import (
all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, diag, dot, finfo,
full, hstack, maximum, ones, outer, sqrt, trim_zeros, trim_zeros_tol, true_divide, vander, zeros)
all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, diag, dot,
finfo, full, maximum, ones, outer, roll, sqrt, trim_zeros, trim_zeros_tol, true_divide,
vander, zeros)
from jax._src.numpy import linalg
from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _wraps
from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _where, _wraps
import numpy as np
@jit
def _roots_no_zeros(p):
# assume: p does not have leading zeros and has length > 1
p, = _promote_dtypes_inexact(p)
# build companion matrix and find its eigenvalues (the roots)
if p.size < 2:
return array([], dtype=dtypes._to_complex_dtype(p.dtype))
A = diag(ones((p.size - 2,), p.dtype), -1)
A = A.at[0, :].set(-p[1:] / p[0])
roots = linalg.eigvals(A)
return roots
return linalg.eigvals(A)
@jit
def _nonzero_range(arr):
# return start and end s.t. arr[:start] = 0 = arr[end:] padding zeros
is_zero = arr == 0
start = argmin(is_zero)
end = is_zero.size - argmin(is_zero[::-1])
return start, end
def _roots_with_zeros(p, num_leading_zeros):
# Avoid lapack errors when p is all zero
p = _where(len(p) == num_leading_zeros, 1.0, p)
# Roll any leading zeros to the end & compute the roots
roots = _roots_no_zeros(roll(p, -num_leading_zeros))
# Sort zero roots to the end.
roots = lax.sort_key_val(roots == 0, roots)[1]
# Set roots associated with num_leading_zeros to NaN
return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan))
@_wraps(np.roots, lax_description="""\
If the input polynomial coefficients of length n do not start with zero,
the polynomial is of degree n - 1 leading to n - 1 roots.
If the coefficients do have leading zeros, the polynomial they define
has a smaller degree and the number of roots (and thus the output shape)
is value dependent.
Unlike the numpy version of this function, the JAX version returns the roots in
a complex array regardless of the values of the roots. Additionally, the jax
version of this function adds the ``strip_zeros`` function which must be set to
False for the function to be compatible with JIT and other JAX transformations.
With ``strip_zeros=False``, if your coefficients have leading zeros, the
roots will be padded with NaN values:
The general implementation can therefore not be transformed with jit.
If the coefficients are guaranteed to have no leading zeros, use the
keyword argument `strip_zeros=False` to get a jit-compatible variant:
>>> coeffs = jnp.array([0, 1, 2])
>>> from functools import partial
>>> roots_unsafe = jax.jit(partial(jnp.roots, strip_zeros=False))
>>> roots_unsafe([1, 2]) # ok
DeviceArray([-2.+0.j], dtype=complex64)
>>> roots_unsafe([0, 1, 2]) # problem
DeviceArray([nan+nanj, nan+nanj], dtype=complex64)
>>> jnp.roots([0, 1, 2]) # use the no-jit version instead
# The default behavior matches numpy and strips leading zeros:
>>> jnp.roots(coeffs)
DeviceArray([-2.+0.j], dtype=complex64)
# With strip_zeros=False, extra roots are set to NaN:
>>> jnp.roots(coeffs, strip_zeros=False)
DeviceArray([-2. +0.j, nan+nanj], dtype=complex64)
""",
extra_params="""
strip_zeros : bool, default=True
If set to True, then leading zeros in the coefficients will be stripped, similar
to :func:`numpy.roots`. If set to False, leading zeros will not be stripped, and
undefined roots will be represented by NaN values in the function output.
``strip_zeros`` must be set to ``False`` for the function to be compatible with
:func:`jax.jit` and other JAX transformations.
""")
def roots(p, *, strip_zeros=True):
# ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251
p = atleast_1d(p)
_check_arraylike("roots", p)
p = atleast_1d(*_promote_dtypes_inexact(p))
if p.ndim != 1:
raise ValueError("Input must be a rank-1 array.")
# strip_zeros=False is unsafe because leading zeros aren't removed
if not strip_zeros:
if p.size > 1:
return _roots_no_zeros(p)
else:
return array([])
if all(p == 0):
return array([])
# factor out trivial roots
start, end = _nonzero_range(p)
# number of trailing zeros = number of roots at 0
trailing_zeros = p.size - end
# strip leading and trailing zeros
p = p[start:end]
if p.size < 2:
return zeros(trailing_zeros, p.dtype)
return array([], dtype=dtypes._to_complex_dtype(p.dtype))
num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0))
if strip_zeros:
num_leading_zeros = core.concrete_or_error(int, num_leading_zeros,
"The error occurred in the jnp.roots() function. To use this within a "
"JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
"will be result in some returned roots being set to NaN.")
return _roots_no_zeros(p[num_leading_zeros:])
else:
roots = _roots_no_zeros(p)
# combine roots and zero roots
roots = hstack((roots, zeros(trailing_zeros, roots.dtype)))
return roots
return _roots_with_zeros(p, num_leading_zeros)
_POLYFIT_DOC = """\

View File

@ -13,13 +13,14 @@
# limitations under the License.
from functools import partial
import numpy as np
import unittest
from scipy.sparse import csgraph, csr_matrix
from absl.testing import absltest
from absl.testing import parameterized
from jax import jit
from jax._src import dtypes
from jax import numpy as jnp
from jax._src import test_util as jtu
@ -35,114 +36,98 @@ all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
class TestPolynomial(jtu.JaxTestCase):
def assertSetsAllClose(self, x, y, rtol=None, atol=None, check_dtypes=True):
"""Assert that x and y contain permutations of the same approximate set of values.
For non-complex inputs, this is accomplished by comparing the sorted inputs.
For complex, such an approach can be confounded by numerical errors. In this case,
we compute the structural rank of the pairwise comparison matrix: if the structural
rank is full, it implies that the matrix can be permuted so that the diagonal is
non-zero, which implies a one-to-one approximate match between the permuted sets.
"""
x = np.asarray(x).ravel()
y = np.asarray(y).ravel()
atol = max(jtu.tolerance(x.dtype, atol), jtu.tolerance(y.dtype, atol))
rtol = max(jtu.tolerance(x.dtype, rtol), jtu.tolerance(y.dtype, rtol))
if not (np.issubdtype(x.dtype, np.complexfloating) or
np.issubdtype(y.dtype, np.complexfloating)):
return self.assertAllClose(np.sort(x), np.sort(y), atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
if check_dtypes:
self.assertEqual(x.dtype, y.dtype)
self.assertEqual(x.size, y.size)
pairwise = np.isclose(x[:, None], x[None, :],
atol=atol, rtol=rtol, equal_nan=True)
rank = csgraph.structural_rank(csr_matrix(pairwise))
self.assertEqual(rank, x.size)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}_leading={}_trailing={}".format(
jtu.format_shape_dtype_string((length+leading+trailing,), dtype),
leading, trailing),
"dtype": dtype, "length": length, "leading": leading, "trailing": trailing}
for dtype in all_dtypes
for length in [0, 3, 9, 10, 17]
for leading in [0, 1, 2, 3, 5, 7, 10]
for trailing in [0, 1, 2, 3, 5, 7, 10]))
for length in [0, 3, 5]
for leading in [0, 2]
for trailing in [0, 2]))
# TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
@jtu.skip_on_devices("gpu", "tpu")
def testRoots(self, dtype, length, leading, trailing):
# rng = jtu.rand_default(self.rng())
# This test is very fragile and breaks unless a "good" random seed is chosen.
rng = jtu.rand_default(self.rng())
rng = jtu.rand_some_zero(self.rng())
def args_maker():
p = rng((length,), dtype)
return jnp.concatenate(
[jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)]),
return [jnp.concatenate(
[jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)])]
jnp_fn = lambda arg: jnp.sort(jnp.roots(arg))
np_fn = lambda arg: np.sort(np.roots(arg))
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
tol=3e-6)
jnp_fun = jnp.roots
def np_fun(arg):
return np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))
# Note: outputs have no defined order, so we need to use a special comparator.
args = args_maker()
np_roots = np_fun(*args)
jnp_roots = jnp_fun(*args)
self.assertSetsAllClose(np_roots, jnp_roots)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}_trailing={}".format(
jtu.format_shape_dtype_string((length+trailing,), dtype), trailing),
"dtype": dtype, "length": length, "trailing": trailing}
{"testcase_name": "_dtype={}_leading={}_trailing={}".format(
jtu.format_shape_dtype_string((length+leading+trailing,), dtype),
leading, trailing),
"dtype": dtype, "length": length, "leading": leading, "trailing": trailing}
for dtype in all_dtypes
for length in [0, 1, 3, 10]
for trailing in [0, 1, 3, 7]))
for length in [0, 3, 5]
for leading in [0, 2]
for trailing in [0, 2]))
# TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
@jtu.skip_on_devices("gpu", "tpu")
def testRootsNostrip(self, length, dtype, trailing):
# rng = jtu.rand_default(self.rng())
# This test is very fragile and breaks unless a "good" random seed is chosen.
rng = jtu.rand_default(np.random.RandomState(0))
def testRootsNoStrip(self, dtype, length, leading, trailing):
rng = jtu.rand_some_zero(self.rng())
def args_maker():
p = rng((length,), dtype)
if length != 0:
return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]),
else:
# adding trailing would make input invalid (start with zeros)
return p,
return [jnp.concatenate(
[jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)])]
jnp_fn = lambda arg: jnp.sort(jnp.roots(arg, strip_zeros=False))
np_fn = lambda arg: np.sort(np.roots(arg))
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker,
check_dtypes=False, tol=1e-6)
jnp_fun = partial(jnp.roots, strip_zeros=False)
def np_fun(arg):
roots = np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))
if len(roots) < len(arg) - 1:
roots = np.pad(roots, (0, len(arg) - len(roots) - 1),
constant_values=complex(np.nan, np.nan))
return roots
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}_trailing={}".format(
jtu.format_shape_dtype_string((length + trailing,), dtype), trailing),
"dtype": dtype, "length": length, "trailing": trailing}
for dtype in all_dtypes
for length in [0, 1, 3, 10]
for trailing in [0, 1, 3, 7]))
# TODO: enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
def testRootsJit(self, length, dtype, trailing):
# rng = jtu.rand_default(self.rng())
# This test is very fragile and breaks unless a "good" random seed is chosen.
rng = jtu.rand_default(np.random.RandomState(0))
def args_maker():
p = rng((length,), dtype)
if length != 0:
return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]),
else:
# adding trailing would make input invalid (start with zeros)
return p,
roots_compiled = jit(partial(jnp.roots, strip_zeros=False))
jnp_fn = lambda arg: jnp.sort(roots_compiled(arg))
np_fn = lambda arg: np.sort(np.roots(arg))
# Using strip_zeros=False makes the algorithm less efficient
# and leads to slightly different values compared ot numpy
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker,
check_dtypes=False, tol=1e-6)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}_zeros={}_nonzeros={}".format(
jtu.format_shape_dtype_string((zeros+nonzeros,), dtype),
zeros, nonzeros),
"zeros": zeros, "nonzeros": nonzeros, "dtype": dtype}
for dtype in all_dtypes
for zeros in [1, 2, 5]
for nonzeros in [0, 3]))
@jtu.skip_on_devices("gpu")
@unittest.skip("getting segfaults on MKL") # TODO(#3711)
def testRootsInvalid(self, zeros, nonzeros, dtype):
rng = jtu.rand_default(self.rng())
# The polynomial coefficients here start with zero and would have to
# be stripped before computing eigenvalues of the companion matrix.
# Setting strip_zeros=False skips this check,
# allowing jit transformation but yielding nan's for these inputs.
p = jnp.concatenate([jnp.zeros(zeros, dtype), rng((nonzeros,), dtype)])
if p.size == 1:
# polynomial = const has no roots
self.assertTrue(jnp.roots(p, strip_zeros=False).size == 0)
else:
self.assertTrue(jnp.any(jnp.isnan(jnp.roots(p, strip_zeros=False))))
# Note: outputs have no defined order, so we need to use a special comparator.
args = args_maker()
np_roots = np_fun(*args)
jnp_roots = jnp_fun(*args)
self.assertSetsAllClose(np_roots, jnp_roots)
self._CompileAndCheck(jnp_fun, args_maker)
if __name__ == "__main__":