From f6476f7a03f8390627c1a8e2a2ec8702d8a320e5 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 23 Jun 2022 14:48:53 -0700 Subject: [PATCH] jnp.roots: better support for computation under JIT --- CHANGELOG.md | 6 +- jax/_src/numpy/polynomial.py | 105 +++++++++++------------ tests/polynomial_test.py | 159 ++++++++++++++++------------------- 3 files changed, 127 insertions(+), 143 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f34188a2..b1c6620ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,11 +9,15 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. --> ## jax 0.3.15 (Unreleased) +* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.14...main). +* Changes + * {func}`jax.numpy.roots` is now better behaved when `strip_zeros=False` when + coefficients have leading zeros ({jax-issue}`#11215`). ## jaxlib 0.3.15 (Unreleased) ## jax 0.3.14 (June 21, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...main). +* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...jax-v0.3.14). * Breaking changes * {func}`jax.experimental.compilation_cache.initialize_cache` does not support `max_cache_size_ bytes` anymore and will not get that as an input. diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 9279a8908..1931c85d7 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -19,86 +19,81 @@ import operator from jax import core from jax import jit from jax import lax +from jax._src import dtypes from jax._src.numpy.lax_numpy import ( - all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, diag, dot, finfo, - full, hstack, maximum, ones, outer, sqrt, trim_zeros, trim_zeros_tol, true_divide, vander, zeros) + all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, diag, dot, + finfo, full, maximum, ones, outer, roll, sqrt, trim_zeros, trim_zeros_tol, true_divide, + vander, zeros) from jax._src.numpy import linalg -from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _wraps +from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _where, _wraps import numpy as np @jit def _roots_no_zeros(p): - # assume: p does not have leading zeros and has length > 1 - p, = _promote_dtypes_inexact(p) - # build companion matrix and find its eigenvalues (the roots) + if p.size < 2: + return array([], dtype=dtypes._to_complex_dtype(p.dtype)) A = diag(ones((p.size - 2,), p.dtype), -1) A = A.at[0, :].set(-p[1:] / p[0]) - roots = linalg.eigvals(A) - return roots + return linalg.eigvals(A) @jit -def _nonzero_range(arr): - # return start and end s.t. arr[:start] = 0 = arr[end:] padding zeros - is_zero = arr == 0 - start = argmin(is_zero) - end = is_zero.size - argmin(is_zero[::-1]) - return start, end +def _roots_with_zeros(p, num_leading_zeros): + # Avoid lapack errors when p is all zero + p = _where(len(p) == num_leading_zeros, 1.0, p) + # Roll any leading zeros to the end & compute the roots + roots = _roots_no_zeros(roll(p, -num_leading_zeros)) + # Sort zero roots to the end. + roots = lax.sort_key_val(roots == 0, roots)[1] + # Set roots associated with num_leading_zeros to NaN + return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan)) @_wraps(np.roots, lax_description="""\ -If the input polynomial coefficients of length n do not start with zero, -the polynomial is of degree n - 1 leading to n - 1 roots. -If the coefficients do have leading zeros, the polynomial they define -has a smaller degree and the number of roots (and thus the output shape) -is value dependent. +Unlike the numpy version of this function, the JAX version returns the roots in +a complex array regardless of the values of the roots. Additionally, the jax +version of this function adds the ``strip_zeros`` function which must be set to +False for the function to be compatible with JIT and other JAX transformations. +With ``strip_zeros=False``, if your coefficients have leading zeros, the +roots will be padded with NaN values: -The general implementation can therefore not be transformed with jit. -If the coefficients are guaranteed to have no leading zeros, use the -keyword argument `strip_zeros=False` to get a jit-compatible variant: +>>> coeffs = jnp.array([0, 1, 2]) ->>> from functools import partial ->>> roots_unsafe = jax.jit(partial(jnp.roots, strip_zeros=False)) ->>> roots_unsafe([1, 2]) # ok -DeviceArray([-2.+0.j], dtype=complex64) ->>> roots_unsafe([0, 1, 2]) # problem -DeviceArray([nan+nanj, nan+nanj], dtype=complex64) ->>> jnp.roots([0, 1, 2]) # use the no-jit version instead +# The default behavior matches numpy and strips leading zeros: +>>> jnp.roots(coeffs) DeviceArray([-2.+0.j], dtype=complex64) + +# With strip_zeros=False, extra roots are set to NaN: +>>> jnp.roots(coeffs, strip_zeros=False) +DeviceArray([-2. +0.j, nan+nanj], dtype=complex64) +""", +extra_params=""" +strip_zeros : bool, default=True + If set to True, then leading zeros in the coefficients will be stripped, similar + to :func:`numpy.roots`. If set to False, leading zeros will not be stripped, and + undefined roots will be represented by NaN values in the function output. + ``strip_zeros`` must be set to ``False`` for the function to be compatible with + :func:`jax.jit` and other JAX transformations. """) def roots(p, *, strip_zeros=True): - # ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251 - p = atleast_1d(p) + _check_arraylike("roots", p) + p = atleast_1d(*_promote_dtypes_inexact(p)) if p.ndim != 1: raise ValueError("Input must be a rank-1 array.") - - # strip_zeros=False is unsafe because leading zeros aren't removed - if not strip_zeros: - if p.size > 1: - return _roots_no_zeros(p) - else: - return array([]) - - if all(p == 0): - return array([]) - - # factor out trivial roots - start, end = _nonzero_range(p) - # number of trailing zeros = number of roots at 0 - trailing_zeros = p.size - end - - # strip leading and trailing zeros - p = p[start:end] - if p.size < 2: - return zeros(trailing_zeros, p.dtype) + return array([], dtype=dtypes._to_complex_dtype(p.dtype)) + num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0)) + + if strip_zeros: + num_leading_zeros = core.concrete_or_error(int, num_leading_zeros, + "The error occurred in the jnp.roots() function. To use this within a " + "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros " + "will be result in some returned roots being set to NaN.") + return _roots_no_zeros(p[num_leading_zeros:]) else: - roots = _roots_no_zeros(p) - # combine roots and zero roots - roots = hstack((roots, zeros(trailing_zeros, roots.dtype))) - return roots + return _roots_with_zeros(p, num_leading_zeros) _POLYFIT_DOC = """\ diff --git a/tests/polynomial_test.py b/tests/polynomial_test.py index 7d9807e11..ba0eec230 100644 --- a/tests/polynomial_test.py +++ b/tests/polynomial_test.py @@ -13,13 +13,14 @@ # limitations under the License. from functools import partial + import numpy as np -import unittest +from scipy.sparse import csgraph, csr_matrix from absl.testing import absltest from absl.testing import parameterized -from jax import jit +from jax._src import dtypes from jax import numpy as jnp from jax._src import test_util as jtu @@ -35,114 +36,98 @@ all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex class TestPolynomial(jtu.JaxTestCase): + def assertSetsAllClose(self, x, y, rtol=None, atol=None, check_dtypes=True): + """Assert that x and y contain permutations of the same approximate set of values. + + For non-complex inputs, this is accomplished by comparing the sorted inputs. + For complex, such an approach can be confounded by numerical errors. In this case, + we compute the structural rank of the pairwise comparison matrix: if the structural + rank is full, it implies that the matrix can be permuted so that the diagonal is + non-zero, which implies a one-to-one approximate match between the permuted sets. + """ + x = np.asarray(x).ravel() + y = np.asarray(y).ravel() + + atol = max(jtu.tolerance(x.dtype, atol), jtu.tolerance(y.dtype, atol)) + rtol = max(jtu.tolerance(x.dtype, rtol), jtu.tolerance(y.dtype, rtol)) + + if not (np.issubdtype(x.dtype, np.complexfloating) or + np.issubdtype(y.dtype, np.complexfloating)): + return self.assertAllClose(np.sort(x), np.sort(y), atol=atol, rtol=rtol, + check_dtypes=check_dtypes) + + if check_dtypes: + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.size, y.size) + + pairwise = np.isclose(x[:, None], x[None, :], + atol=atol, rtol=rtol, equal_nan=True) + rank = csgraph.structural_rank(csr_matrix(pairwise)) + self.assertEqual(rank, x.size) + + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_leading={}_trailing={}".format( jtu.format_shape_dtype_string((length+leading+trailing,), dtype), leading, trailing), "dtype": dtype, "length": length, "leading": leading, "trailing": trailing} for dtype in all_dtypes - for length in [0, 3, 9, 10, 17] - for leading in [0, 1, 2, 3, 5, 7, 10] - for trailing in [0, 1, 2, 3, 5, 7, 10])) + for length in [0, 3, 5] + for leading in [0, 2] + for trailing in [0, 2])) # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU. @jtu.skip_on_devices("gpu", "tpu") def testRoots(self, dtype, length, leading, trailing): - # rng = jtu.rand_default(self.rng()) - # This test is very fragile and breaks unless a "good" random seed is chosen. - rng = jtu.rand_default(self.rng()) + rng = jtu.rand_some_zero(self.rng()) def args_maker(): p = rng((length,), dtype) - return jnp.concatenate( - [jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)]), + return [jnp.concatenate( + [jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)])] - jnp_fn = lambda arg: jnp.sort(jnp.roots(arg)) - np_fn = lambda arg: np.sort(np.roots(arg)) - self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, - tol=3e-6) + jnp_fun = jnp.roots + def np_fun(arg): + return np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype)) + + # Note: outputs have no defined order, so we need to use a special comparator. + args = args_maker() + np_roots = np_fun(*args) + jnp_roots = jnp_fun(*args) + self.assertSetsAllClose(np_roots, jnp_roots) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_dtype={}_trailing={}".format( - jtu.format_shape_dtype_string((length+trailing,), dtype), trailing), - "dtype": dtype, "length": length, "trailing": trailing} + {"testcase_name": "_dtype={}_leading={}_trailing={}".format( + jtu.format_shape_dtype_string((length+leading+trailing,), dtype), + leading, trailing), + "dtype": dtype, "length": length, "leading": leading, "trailing": trailing} for dtype in all_dtypes - for length in [0, 1, 3, 10] - for trailing in [0, 1, 3, 7])) + for length in [0, 3, 5] + for leading in [0, 2] + for trailing in [0, 2])) # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU. @jtu.skip_on_devices("gpu", "tpu") - def testRootsNostrip(self, length, dtype, trailing): - # rng = jtu.rand_default(self.rng()) - # This test is very fragile and breaks unless a "good" random seed is chosen. - rng = jtu.rand_default(np.random.RandomState(0)) + def testRootsNoStrip(self, dtype, length, leading, trailing): + rng = jtu.rand_some_zero(self.rng()) def args_maker(): p = rng((length,), dtype) - if length != 0: - return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]), - else: - # adding trailing would make input invalid (start with zeros) - return p, + return [jnp.concatenate( + [jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)])] - jnp_fn = lambda arg: jnp.sort(jnp.roots(arg, strip_zeros=False)) - np_fn = lambda arg: np.sort(np.roots(arg)) - self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, - check_dtypes=False, tol=1e-6) + jnp_fun = partial(jnp.roots, strip_zeros=False) + def np_fun(arg): + roots = np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype)) + if len(roots) < len(arg) - 1: + roots = np.pad(roots, (0, len(arg) - len(roots) - 1), + constant_values=complex(np.nan, np.nan)) + return roots - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_dtype={}_trailing={}".format( - jtu.format_shape_dtype_string((length + trailing,), dtype), trailing), - "dtype": dtype, "length": length, "trailing": trailing} - for dtype in all_dtypes - for length in [0, 1, 3, 10] - for trailing in [0, 1, 3, 7])) - # TODO: enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.skip_on_devices("gpu", "tpu") - def testRootsJit(self, length, dtype, trailing): - # rng = jtu.rand_default(self.rng()) - # This test is very fragile and breaks unless a "good" random seed is chosen. - rng = jtu.rand_default(np.random.RandomState(0)) - - def args_maker(): - p = rng((length,), dtype) - if length != 0: - return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]), - else: - # adding trailing would make input invalid (start with zeros) - return p, - - roots_compiled = jit(partial(jnp.roots, strip_zeros=False)) - jnp_fn = lambda arg: jnp.sort(roots_compiled(arg)) - np_fn = lambda arg: np.sort(np.roots(arg)) - # Using strip_zeros=False makes the algorithm less efficient - # and leads to slightly different values compared ot numpy - self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, - check_dtypes=False, tol=1e-6) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_dtype={}_zeros={}_nonzeros={}".format( - jtu.format_shape_dtype_string((zeros+nonzeros,), dtype), - zeros, nonzeros), - "zeros": zeros, "nonzeros": nonzeros, "dtype": dtype} - for dtype in all_dtypes - for zeros in [1, 2, 5] - for nonzeros in [0, 3])) - @jtu.skip_on_devices("gpu") - @unittest.skip("getting segfaults on MKL") # TODO(#3711) - def testRootsInvalid(self, zeros, nonzeros, dtype): - rng = jtu.rand_default(self.rng()) - - # The polynomial coefficients here start with zero and would have to - # be stripped before computing eigenvalues of the companion matrix. - # Setting strip_zeros=False skips this check, - # allowing jit transformation but yielding nan's for these inputs. - p = jnp.concatenate([jnp.zeros(zeros, dtype), rng((nonzeros,), dtype)]) - - if p.size == 1: - # polynomial = const has no roots - self.assertTrue(jnp.roots(p, strip_zeros=False).size == 0) - else: - self.assertTrue(jnp.any(jnp.isnan(jnp.roots(p, strip_zeros=False)))) + # Note: outputs have no defined order, so we need to use a special comparator. + args = args_maker() + np_roots = np_fun(*args) + jnp_roots = jnp_fun(*args) + self.assertSetsAllClose(np_roots, jnp_roots) + self._CompileAndCheck(jnp_fun, args_maker) if __name__ == "__main__":