mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
jnp.roots: better support for computation under JIT
This commit is contained in:
parent
2744404809
commit
f6476f7a03
@ -9,11 +9,15 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
-->
|
||||
|
||||
## jax 0.3.15 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.14...main).
|
||||
* Changes
|
||||
* {func}`jax.numpy.roots` is now better behaved when `strip_zeros=False` when
|
||||
coefficients have leading zeros ({jax-issue}`#11215`).
|
||||
|
||||
## jaxlib 0.3.15 (Unreleased)
|
||||
|
||||
## jax 0.3.14 (June 21, 2022)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...main).
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...jax-v0.3.14).
|
||||
* Breaking changes
|
||||
* {func}`jax.experimental.compilation_cache.initialize_cache` does not support
|
||||
`max_cache_size_ bytes` anymore and will not get that as an input.
|
||||
|
@ -19,86 +19,81 @@ import operator
|
||||
from jax import core
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
from jax._src import dtypes
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, diag, dot, finfo,
|
||||
full, hstack, maximum, ones, outer, sqrt, trim_zeros, trim_zeros_tol, true_divide, vander, zeros)
|
||||
all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, diag, dot,
|
||||
finfo, full, maximum, ones, outer, roll, sqrt, trim_zeros, trim_zeros_tol, true_divide,
|
||||
vander, zeros)
|
||||
from jax._src.numpy import linalg
|
||||
from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _wraps
|
||||
from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _where, _wraps
|
||||
import numpy as np
|
||||
|
||||
|
||||
@jit
|
||||
def _roots_no_zeros(p):
|
||||
# assume: p does not have leading zeros and has length > 1
|
||||
p, = _promote_dtypes_inexact(p)
|
||||
|
||||
# build companion matrix and find its eigenvalues (the roots)
|
||||
if p.size < 2:
|
||||
return array([], dtype=dtypes._to_complex_dtype(p.dtype))
|
||||
A = diag(ones((p.size - 2,), p.dtype), -1)
|
||||
A = A.at[0, :].set(-p[1:] / p[0])
|
||||
roots = linalg.eigvals(A)
|
||||
return roots
|
||||
return linalg.eigvals(A)
|
||||
|
||||
|
||||
@jit
|
||||
def _nonzero_range(arr):
|
||||
# return start and end s.t. arr[:start] = 0 = arr[end:] padding zeros
|
||||
is_zero = arr == 0
|
||||
start = argmin(is_zero)
|
||||
end = is_zero.size - argmin(is_zero[::-1])
|
||||
return start, end
|
||||
def _roots_with_zeros(p, num_leading_zeros):
|
||||
# Avoid lapack errors when p is all zero
|
||||
p = _where(len(p) == num_leading_zeros, 1.0, p)
|
||||
# Roll any leading zeros to the end & compute the roots
|
||||
roots = _roots_no_zeros(roll(p, -num_leading_zeros))
|
||||
# Sort zero roots to the end.
|
||||
roots = lax.sort_key_val(roots == 0, roots)[1]
|
||||
# Set roots associated with num_leading_zeros to NaN
|
||||
return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan))
|
||||
|
||||
|
||||
@_wraps(np.roots, lax_description="""\
|
||||
If the input polynomial coefficients of length n do not start with zero,
|
||||
the polynomial is of degree n - 1 leading to n - 1 roots.
|
||||
If the coefficients do have leading zeros, the polynomial they define
|
||||
has a smaller degree and the number of roots (and thus the output shape)
|
||||
is value dependent.
|
||||
Unlike the numpy version of this function, the JAX version returns the roots in
|
||||
a complex array regardless of the values of the roots. Additionally, the jax
|
||||
version of this function adds the ``strip_zeros`` function which must be set to
|
||||
False for the function to be compatible with JIT and other JAX transformations.
|
||||
With ``strip_zeros=False``, if your coefficients have leading zeros, the
|
||||
roots will be padded with NaN values:
|
||||
|
||||
The general implementation can therefore not be transformed with jit.
|
||||
If the coefficients are guaranteed to have no leading zeros, use the
|
||||
keyword argument `strip_zeros=False` to get a jit-compatible variant:
|
||||
>>> coeffs = jnp.array([0, 1, 2])
|
||||
|
||||
>>> from functools import partial
|
||||
>>> roots_unsafe = jax.jit(partial(jnp.roots, strip_zeros=False))
|
||||
>>> roots_unsafe([1, 2]) # ok
|
||||
DeviceArray([-2.+0.j], dtype=complex64)
|
||||
>>> roots_unsafe([0, 1, 2]) # problem
|
||||
DeviceArray([nan+nanj, nan+nanj], dtype=complex64)
|
||||
>>> jnp.roots([0, 1, 2]) # use the no-jit version instead
|
||||
# The default behavior matches numpy and strips leading zeros:
|
||||
>>> jnp.roots(coeffs)
|
||||
DeviceArray([-2.+0.j], dtype=complex64)
|
||||
|
||||
# With strip_zeros=False, extra roots are set to NaN:
|
||||
>>> jnp.roots(coeffs, strip_zeros=False)
|
||||
DeviceArray([-2. +0.j, nan+nanj], dtype=complex64)
|
||||
""",
|
||||
extra_params="""
|
||||
strip_zeros : bool, default=True
|
||||
If set to True, then leading zeros in the coefficients will be stripped, similar
|
||||
to :func:`numpy.roots`. If set to False, leading zeros will not be stripped, and
|
||||
undefined roots will be represented by NaN values in the function output.
|
||||
``strip_zeros`` must be set to ``False`` for the function to be compatible with
|
||||
:func:`jax.jit` and other JAX transformations.
|
||||
""")
|
||||
def roots(p, *, strip_zeros=True):
|
||||
# ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251
|
||||
p = atleast_1d(p)
|
||||
_check_arraylike("roots", p)
|
||||
p = atleast_1d(*_promote_dtypes_inexact(p))
|
||||
if p.ndim != 1:
|
||||
raise ValueError("Input must be a rank-1 array.")
|
||||
|
||||
# strip_zeros=False is unsafe because leading zeros aren't removed
|
||||
if not strip_zeros:
|
||||
if p.size > 1:
|
||||
return _roots_no_zeros(p)
|
||||
else:
|
||||
return array([])
|
||||
|
||||
if all(p == 0):
|
||||
return array([])
|
||||
|
||||
# factor out trivial roots
|
||||
start, end = _nonzero_range(p)
|
||||
# number of trailing zeros = number of roots at 0
|
||||
trailing_zeros = p.size - end
|
||||
|
||||
# strip leading and trailing zeros
|
||||
p = p[start:end]
|
||||
|
||||
if p.size < 2:
|
||||
return zeros(trailing_zeros, p.dtype)
|
||||
return array([], dtype=dtypes._to_complex_dtype(p.dtype))
|
||||
num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0))
|
||||
|
||||
if strip_zeros:
|
||||
num_leading_zeros = core.concrete_or_error(int, num_leading_zeros,
|
||||
"The error occurred in the jnp.roots() function. To use this within a "
|
||||
"JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
|
||||
"will be result in some returned roots being set to NaN.")
|
||||
return _roots_no_zeros(p[num_leading_zeros:])
|
||||
else:
|
||||
roots = _roots_no_zeros(p)
|
||||
# combine roots and zero roots
|
||||
roots = hstack((roots, zeros(trailing_zeros, roots.dtype)))
|
||||
return roots
|
||||
return _roots_with_zeros(p, num_leading_zeros)
|
||||
|
||||
|
||||
_POLYFIT_DOC = """\
|
||||
|
@ -13,13 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
from scipy.sparse import csgraph, csr_matrix
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from jax import jit
|
||||
from jax._src import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
@ -35,114 +36,98 @@ all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
|
||||
|
||||
class TestPolynomial(jtu.JaxTestCase):
|
||||
|
||||
def assertSetsAllClose(self, x, y, rtol=None, atol=None, check_dtypes=True):
|
||||
"""Assert that x and y contain permutations of the same approximate set of values.
|
||||
|
||||
For non-complex inputs, this is accomplished by comparing the sorted inputs.
|
||||
For complex, such an approach can be confounded by numerical errors. In this case,
|
||||
we compute the structural rank of the pairwise comparison matrix: if the structural
|
||||
rank is full, it implies that the matrix can be permuted so that the diagonal is
|
||||
non-zero, which implies a one-to-one approximate match between the permuted sets.
|
||||
"""
|
||||
x = np.asarray(x).ravel()
|
||||
y = np.asarray(y).ravel()
|
||||
|
||||
atol = max(jtu.tolerance(x.dtype, atol), jtu.tolerance(y.dtype, atol))
|
||||
rtol = max(jtu.tolerance(x.dtype, rtol), jtu.tolerance(y.dtype, rtol))
|
||||
|
||||
if not (np.issubdtype(x.dtype, np.complexfloating) or
|
||||
np.issubdtype(y.dtype, np.complexfloating)):
|
||||
return self.assertAllClose(np.sort(x), np.sort(y), atol=atol, rtol=rtol,
|
||||
check_dtypes=check_dtypes)
|
||||
|
||||
if check_dtypes:
|
||||
self.assertEqual(x.dtype, y.dtype)
|
||||
self.assertEqual(x.size, y.size)
|
||||
|
||||
pairwise = np.isclose(x[:, None], x[None, :],
|
||||
atol=atol, rtol=rtol, equal_nan=True)
|
||||
rank = csgraph.structural_rank(csr_matrix(pairwise))
|
||||
self.assertEqual(rank, x.size)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}_leading={}_trailing={}".format(
|
||||
jtu.format_shape_dtype_string((length+leading+trailing,), dtype),
|
||||
leading, trailing),
|
||||
"dtype": dtype, "length": length, "leading": leading, "trailing": trailing}
|
||||
for dtype in all_dtypes
|
||||
for length in [0, 3, 9, 10, 17]
|
||||
for leading in [0, 1, 2, 3, 5, 7, 10]
|
||||
for trailing in [0, 1, 2, 3, 5, 7, 10]))
|
||||
for length in [0, 3, 5]
|
||||
for leading in [0, 2]
|
||||
for trailing in [0, 2]))
|
||||
# TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testRoots(self, dtype, length, leading, trailing):
|
||||
# rng = jtu.rand_default(self.rng())
|
||||
# This test is very fragile and breaks unless a "good" random seed is chosen.
|
||||
rng = jtu.rand_default(self.rng())
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
|
||||
def args_maker():
|
||||
p = rng((length,), dtype)
|
||||
return jnp.concatenate(
|
||||
[jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)]),
|
||||
return [jnp.concatenate(
|
||||
[jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)])]
|
||||
|
||||
jnp_fn = lambda arg: jnp.sort(jnp.roots(arg))
|
||||
np_fn = lambda arg: np.sort(np.roots(arg))
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
|
||||
tol=3e-6)
|
||||
jnp_fun = jnp.roots
|
||||
def np_fun(arg):
|
||||
return np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))
|
||||
|
||||
# Note: outputs have no defined order, so we need to use a special comparator.
|
||||
args = args_maker()
|
||||
np_roots = np_fun(*args)
|
||||
jnp_roots = jnp_fun(*args)
|
||||
self.assertSetsAllClose(np_roots, jnp_roots)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}_trailing={}".format(
|
||||
jtu.format_shape_dtype_string((length+trailing,), dtype), trailing),
|
||||
"dtype": dtype, "length": length, "trailing": trailing}
|
||||
{"testcase_name": "_dtype={}_leading={}_trailing={}".format(
|
||||
jtu.format_shape_dtype_string((length+leading+trailing,), dtype),
|
||||
leading, trailing),
|
||||
"dtype": dtype, "length": length, "leading": leading, "trailing": trailing}
|
||||
for dtype in all_dtypes
|
||||
for length in [0, 1, 3, 10]
|
||||
for trailing in [0, 1, 3, 7]))
|
||||
for length in [0, 3, 5]
|
||||
for leading in [0, 2]
|
||||
for trailing in [0, 2]))
|
||||
# TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testRootsNostrip(self, length, dtype, trailing):
|
||||
# rng = jtu.rand_default(self.rng())
|
||||
# This test is very fragile and breaks unless a "good" random seed is chosen.
|
||||
rng = jtu.rand_default(np.random.RandomState(0))
|
||||
def testRootsNoStrip(self, dtype, length, leading, trailing):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
|
||||
def args_maker():
|
||||
p = rng((length,), dtype)
|
||||
if length != 0:
|
||||
return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]),
|
||||
else:
|
||||
# adding trailing would make input invalid (start with zeros)
|
||||
return p,
|
||||
return [jnp.concatenate(
|
||||
[jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)])]
|
||||
|
||||
jnp_fn = lambda arg: jnp.sort(jnp.roots(arg, strip_zeros=False))
|
||||
np_fn = lambda arg: np.sort(np.roots(arg))
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker,
|
||||
check_dtypes=False, tol=1e-6)
|
||||
jnp_fun = partial(jnp.roots, strip_zeros=False)
|
||||
def np_fun(arg):
|
||||
roots = np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))
|
||||
if len(roots) < len(arg) - 1:
|
||||
roots = np.pad(roots, (0, len(arg) - len(roots) - 1),
|
||||
constant_values=complex(np.nan, np.nan))
|
||||
return roots
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}_trailing={}".format(
|
||||
jtu.format_shape_dtype_string((length + trailing,), dtype), trailing),
|
||||
"dtype": dtype, "length": length, "trailing": trailing}
|
||||
for dtype in all_dtypes
|
||||
for length in [0, 1, 3, 10]
|
||||
for trailing in [0, 1, 3, 7]))
|
||||
# TODO: enable when there is an eigendecomposition implementation
|
||||
# for GPU/TPU.
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testRootsJit(self, length, dtype, trailing):
|
||||
# rng = jtu.rand_default(self.rng())
|
||||
# This test is very fragile and breaks unless a "good" random seed is chosen.
|
||||
rng = jtu.rand_default(np.random.RandomState(0))
|
||||
|
||||
def args_maker():
|
||||
p = rng((length,), dtype)
|
||||
if length != 0:
|
||||
return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]),
|
||||
else:
|
||||
# adding trailing would make input invalid (start with zeros)
|
||||
return p,
|
||||
|
||||
roots_compiled = jit(partial(jnp.roots, strip_zeros=False))
|
||||
jnp_fn = lambda arg: jnp.sort(roots_compiled(arg))
|
||||
np_fn = lambda arg: np.sort(np.roots(arg))
|
||||
# Using strip_zeros=False makes the algorithm less efficient
|
||||
# and leads to slightly different values compared ot numpy
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker,
|
||||
check_dtypes=False, tol=1e-6)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}_zeros={}_nonzeros={}".format(
|
||||
jtu.format_shape_dtype_string((zeros+nonzeros,), dtype),
|
||||
zeros, nonzeros),
|
||||
"zeros": zeros, "nonzeros": nonzeros, "dtype": dtype}
|
||||
for dtype in all_dtypes
|
||||
for zeros in [1, 2, 5]
|
||||
for nonzeros in [0, 3]))
|
||||
@jtu.skip_on_devices("gpu")
|
||||
@unittest.skip("getting segfaults on MKL") # TODO(#3711)
|
||||
def testRootsInvalid(self, zeros, nonzeros, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
||||
# The polynomial coefficients here start with zero and would have to
|
||||
# be stripped before computing eigenvalues of the companion matrix.
|
||||
# Setting strip_zeros=False skips this check,
|
||||
# allowing jit transformation but yielding nan's for these inputs.
|
||||
p = jnp.concatenate([jnp.zeros(zeros, dtype), rng((nonzeros,), dtype)])
|
||||
|
||||
if p.size == 1:
|
||||
# polynomial = const has no roots
|
||||
self.assertTrue(jnp.roots(p, strip_zeros=False).size == 0)
|
||||
else:
|
||||
self.assertTrue(jnp.any(jnp.isnan(jnp.roots(p, strip_zeros=False))))
|
||||
# Note: outputs have no defined order, so we need to use a special comparator.
|
||||
args = args_maker()
|
||||
np_roots = np_fun(*args)
|
||||
jnp_roots = jnp_fun(*args)
|
||||
self.assertSetsAllClose(np_roots, jnp_roots)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user