mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 05:16:08 +00:00
114 lines
3.5 KiB
Python
114 lines
3.5 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import numpy as np
|
|
from .. import lax
|
|
from . import lax_numpy as jnp
|
|
|
|
from jax import jit
|
|
from ._util import _wraps
|
|
from .lax_numpy import _not_implemented
|
|
from .linalg import eigvals as _eigvals
|
|
from .. import ops as jaxops
|
|
from ..util import get_module_functions
|
|
|
|
|
|
def _to_inexact_type(type):
|
|
return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
|
|
|
|
|
|
def _promote_inexact(arr):
|
|
return lax.convert_element_type(arr, _to_inexact_type(arr.dtype))
|
|
|
|
|
|
@jit
|
|
def _roots_no_zeros(p):
|
|
# assume: p does not have leading zeros and has length > 1
|
|
p = _promote_inexact(p)
|
|
|
|
# build companion matrix and find its eigenvalues (the roots)
|
|
A = jnp.diag(jnp.ones((p.size - 2,), p.dtype), -1)
|
|
A = jaxops.index_update(A, jaxops.index[0, :], -p[1:] / p[0])
|
|
roots = _eigvals(A)
|
|
return roots
|
|
|
|
|
|
@jit
|
|
def _nonzero_range(arr):
|
|
# return start and end s.t. arr[:start] = 0 = arr[end:] padding zeros
|
|
is_zero = arr == 0
|
|
start = jnp.argmin(is_zero)
|
|
end = is_zero.size - jnp.argmin(is_zero[::-1])
|
|
return start, end
|
|
|
|
|
|
@_wraps(np.roots, lax_description="""\
|
|
If the input polynomial coefficients of length n do not start with zero,
|
|
the polynomial is of degree n - 1 leading to n - 1 roots.
|
|
If the coefficients do have leading zeros, the polynomial they define
|
|
has a smaller degree and the number of roots (and thus the output shape)
|
|
is value dependent.
|
|
|
|
The general implementation can therefore not be transformed with jit.
|
|
If the coefficients are guaranteed to have no leading zeros, use the
|
|
keyword argument `strip_zeros=False` to get a jit-compatible variant::
|
|
|
|
>>> roots_unsafe = jax.jit(functools.partial(jnp.roots, strip_zeros=False))
|
|
>>> roots_unsafe([1, 2]) # ok
|
|
DeviceArray([-2.+0.j], dtype=complex64)
|
|
>>> roots_unsafe([0, 1, 2]) # problem
|
|
DeviceArray([nan+nanj, nan+nanj], dtype=complex64)
|
|
>>> jnp.roots([0, 1, 2]) # use the no-jit version instead
|
|
DeviceArray([-2.+0.j], dtype=complex64)
|
|
""")
|
|
def roots(p, *, strip_zeros=True):
|
|
# ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251
|
|
p = jnp.atleast_1d(p)
|
|
if p.ndim != 1:
|
|
raise ValueError("Input must be a rank-1 array.")
|
|
|
|
# strip_zeros=False is unsafe because leading zeros aren't removed
|
|
if not strip_zeros:
|
|
if p.size > 1:
|
|
return _roots_no_zeros(p)
|
|
else:
|
|
return jnp.array([])
|
|
|
|
if jnp.all(p == 0):
|
|
return jnp.array([])
|
|
|
|
# factor out trivial roots
|
|
start, end = _nonzero_range(p)
|
|
# number of trailing zeros = number of roots at 0
|
|
trailing_zeros = p.size - end
|
|
|
|
# strip leading and trailing zeros
|
|
p = p[start:end]
|
|
|
|
if p.size < 2:
|
|
return jnp.zeros(trailing_zeros, p.dtype)
|
|
else:
|
|
roots = _roots_no_zeros(p)
|
|
# combine roots and zero roots
|
|
roots = jnp.hstack((roots, jnp.zeros(trailing_zeros, p.dtype)))
|
|
return roots
|
|
|
|
|
|
_NOT_IMPLEMENTED = []
|
|
for name, func in get_module_functions(np.polynomial).items():
|
|
if name not in globals():
|
|
_NOT_IMPLEMENTED.append(name)
|
|
globals()[name] = _not_implemented(func)
|