mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
c1aeb8b3fe
commit
0a6b715cd4
@ -58,7 +58,7 @@ from .lax_numpy import (
|
||||
tan, tanh, tensordot, tile, trace, trapz, transpose, tri, tril, tril_indices, tril_indices_from,
|
||||
triu, triu_indices, triu_indices_from, true_divide, trunc, uint16, uint32, uint64, uint8, unique,
|
||||
unpackbits, unravel_index, unsignedinteger, unwrap, vander, var, vdot, vsplit,
|
||||
vstack, where, zeros, zeros_like)
|
||||
vstack, where, zeros, zeros_like, _NOT_IMPLEMENTED)
|
||||
|
||||
from .polynomial import roots
|
||||
from .vectorize import vectorize
|
||||
@ -73,6 +73,7 @@ def _init():
|
||||
# Builds a set of all unimplemented NumPy functions.
|
||||
for name, func in util.get_module_functions(np).items():
|
||||
if name not in globals():
|
||||
_NOT_IMPLEMENTED.append(name)
|
||||
globals()[name] = lax_numpy._not_implemented(func)
|
||||
|
||||
_init()
|
||||
|
@ -236,6 +236,8 @@ def ifftshift(x, axes=None):
|
||||
return jnp.roll(x, shift, axes)
|
||||
|
||||
|
||||
_NOT_IMPLEMENTED = []
|
||||
for name, func in get_module_functions(np.fft).items():
|
||||
if name not in globals():
|
||||
_NOT_IMPLEMENTED.append(name)
|
||||
globals()[name] = _not_implemented(func)
|
||||
|
@ -4326,6 +4326,7 @@ _diff_methods = ["clip", "conj", "conjugate", "cumprod", "cumsum",
|
||||
# _not_implemented implementations of them here rather than in __init__.py.
|
||||
# TODO(phawkins): implement these.
|
||||
argpartition = _not_implemented(np.argpartition)
|
||||
_NOT_IMPLEMENTED = ['argpartition']
|
||||
|
||||
# Set up operator, method, and property forwarding on Tracer instances containing
|
||||
# ShapedArray avals by following the forwarding conventions for Tracer.
|
||||
|
@ -479,11 +479,6 @@ def solve(a, b):
|
||||
return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
|
||||
|
||||
|
||||
for name, func in get_module_functions(np.linalg).items():
|
||||
if name not in globals():
|
||||
globals()[name] = _not_implemented(func)
|
||||
|
||||
|
||||
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
|
||||
It has two important differences:
|
||||
|
||||
@ -535,3 +530,10 @@ def lstsq(a, b, rcond=None, *, numpy_resid=False):
|
||||
if b_orig_ndim == 1:
|
||||
x = x.ravel()
|
||||
return x, resid, rank, s
|
||||
|
||||
|
||||
_NOT_IMPLEMENTED = []
|
||||
for name, func in get_module_functions(np.linalg).items():
|
||||
if name not in globals():
|
||||
_NOT_IMPLEMENTED.append(name)
|
||||
globals()[name] = _not_implemented(func)
|
||||
|
@ -19,8 +19,10 @@ from . import lax_numpy as jnp
|
||||
|
||||
from jax import jit
|
||||
from ._util import _wraps
|
||||
from .lax_numpy import _not_implemented
|
||||
from .linalg import eigvals as _eigvals
|
||||
from .. import ops as jaxops
|
||||
from ..util import get_module_functions
|
||||
|
||||
|
||||
def _to_inexact_type(type):
|
||||
@ -102,3 +104,10 @@ def roots(p, *, strip_zeros=True):
|
||||
# combine roots and zero roots
|
||||
roots = jnp.hstack((roots, jnp.zeros(trailing_zeros, p.dtype)))
|
||||
return roots
|
||||
|
||||
|
||||
_NOT_IMPLEMENTED = []
|
||||
for name, func in get_module_functions(np.polynomial).items():
|
||||
if name not in globals():
|
||||
_NOT_IMPLEMENTED.append(name)
|
||||
globals()[name] = _not_implemented(func)
|
||||
|
@ -85,6 +85,12 @@ def _zero_for_irfft(z, axes):
|
||||
|
||||
class FftTest(jtu.JaxTestCase):
|
||||
|
||||
def testNotImplemented(self):
|
||||
for name in jnp.fft._NOT_IMPLEMENTED:
|
||||
func = getattr(jnp.fft, name)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inverse={}_real={}_shape={}_axes={}".format(
|
||||
inverse, real, jtu.format_shape_dtype_string(shape, dtype), axes),
|
||||
|
@ -445,6 +445,12 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for a in out]
|
||||
return f
|
||||
|
||||
def testNotImplemented(self):
|
||||
for name in jnp._NOT_IMPLEMENTED:
|
||||
func = getattr(jnp, name)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
|
||||
|
@ -51,6 +51,12 @@ def _skip_if_unsupported_type(dtype):
|
||||
|
||||
class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
def testNotImplemented(self):
|
||||
for name in jnp.linalg._NOT_IMPLEMENTED:
|
||||
func = getattr(jnp.linalg, name)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
|
@ -33,6 +33,12 @@ all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
|
||||
|
||||
class TestPolynomial(jtu.JaxTestCase):
|
||||
|
||||
def testNotImplemented(self):
|
||||
for name in jnp.polynomial._NOT_IMPLEMENTED:
|
||||
func = getattr(jnp.polynomial, name)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}_leading={}_trailing={}".format(
|
||||
jtu.format_shape_dtype_string((length+leading+trailing,), dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user