Add _NOT_IMPLEMENTED attribute to jax.numpy (fixes #3689) (#3698)

This commit is contained in:
Jake Vanderplas 2020-07-09 16:31:08 -07:00 committed by GitHub
parent c1aeb8b3fe
commit 0a6b715cd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 45 additions and 6 deletions

View File

@ -58,7 +58,7 @@ from .lax_numpy import (
tan, tanh, tensordot, tile, trace, trapz, transpose, tri, tril, tril_indices, tril_indices_from,
triu, triu_indices, triu_indices_from, true_divide, trunc, uint16, uint32, uint64, uint8, unique,
unpackbits, unravel_index, unsignedinteger, unwrap, vander, var, vdot, vsplit,
vstack, where, zeros, zeros_like)
vstack, where, zeros, zeros_like, _NOT_IMPLEMENTED)
from .polynomial import roots
from .vectorize import vectorize
@ -73,6 +73,7 @@ def _init():
# Builds a set of all unimplemented NumPy functions.
for name, func in util.get_module_functions(np).items():
if name not in globals():
_NOT_IMPLEMENTED.append(name)
globals()[name] = lax_numpy._not_implemented(func)
_init()

View File

@ -236,6 +236,8 @@ def ifftshift(x, axes=None):
return jnp.roll(x, shift, axes)
_NOT_IMPLEMENTED = []
for name, func in get_module_functions(np.fft).items():
if name not in globals():
_NOT_IMPLEMENTED.append(name)
globals()[name] = _not_implemented(func)

View File

@ -4326,6 +4326,7 @@ _diff_methods = ["clip", "conj", "conjugate", "cumprod", "cumsum",
# _not_implemented implementations of them here rather than in __init__.py.
# TODO(phawkins): implement these.
argpartition = _not_implemented(np.argpartition)
_NOT_IMPLEMENTED = ['argpartition']
# Set up operator, method, and property forwarding on Tracer instances containing
# ShapedArray avals by following the forwarding conventions for Tracer.

View File

@ -479,11 +479,6 @@ def solve(a, b):
return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
for name, func in get_module_functions(np.linalg).items():
if name not in globals():
globals()[name] = _not_implemented(func)
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
It has two important differences:
@ -535,3 +530,10 @@ def lstsq(a, b, rcond=None, *, numpy_resid=False):
if b_orig_ndim == 1:
x = x.ravel()
return x, resid, rank, s
_NOT_IMPLEMENTED = []
for name, func in get_module_functions(np.linalg).items():
if name not in globals():
_NOT_IMPLEMENTED.append(name)
globals()[name] = _not_implemented(func)

View File

@ -19,8 +19,10 @@ from . import lax_numpy as jnp
from jax import jit
from ._util import _wraps
from .lax_numpy import _not_implemented
from .linalg import eigvals as _eigvals
from .. import ops as jaxops
from ..util import get_module_functions
def _to_inexact_type(type):
@ -102,3 +104,10 @@ def roots(p, *, strip_zeros=True):
# combine roots and zero roots
roots = jnp.hstack((roots, jnp.zeros(trailing_zeros, p.dtype)))
return roots
_NOT_IMPLEMENTED = []
for name, func in get_module_functions(np.polynomial).items():
if name not in globals():
_NOT_IMPLEMENTED.append(name)
globals()[name] = _not_implemented(func)

View File

@ -85,6 +85,12 @@ def _zero_for_irfft(z, axes):
class FftTest(jtu.JaxTestCase):
def testNotImplemented(self):
for name in jnp.fft._NOT_IMPLEMENTED:
func = getattr(jnp.fft, name)
with self.assertRaises(NotImplementedError):
func()
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inverse={}_real={}_shape={}_axes={}".format(
inverse, real, jtu.format_shape_dtype_string(shape, dtype), axes),

View File

@ -445,6 +445,12 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for a in out]
return f
def testNotImplemented(self):
for name in jnp._NOT_IMPLEMENTED:
func = getattr(jnp, name)
with self.assertRaises(NotImplementedError):
func()
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,

View File

@ -51,6 +51,12 @@ def _skip_if_unsupported_type(dtype):
class NumpyLinalgTest(jtu.JaxTestCase):
def testNotImplemented(self):
for name in jnp.linalg._NOT_IMPLEMENTED:
func = getattr(jnp.linalg, name)
with self.assertRaises(NotImplementedError):
func()
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),

View File

@ -33,6 +33,12 @@ all_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
class TestPolynomial(jtu.JaxTestCase):
def testNotImplemented(self):
for name in jnp.polynomial._NOT_IMPLEMENTED:
func = getattr(jnp.polynomial, name)
with self.assertRaises(NotImplementedError):
func()
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}_leading={}_trailing={}".format(
jtu.format_shape_dtype_string((length+leading+trailing,), dtype),