diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 4609e0134..4b480a5c7 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -162,6 +162,7 @@ Not every function in NumPy is implemented; contributions are welcome! max maximum mean + median meshgrid min minimum @@ -228,7 +229,9 @@ Not every function in NumPy is implemented; contributions are welcome! transpose tri tril + tril_indices triu + triu_indices true_divide vander var diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index d296e3230..215dcc071 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -34,6 +34,7 @@ jax.scipy.special :toctree: _autosummary digamma + entr erf erfc erfinv @@ -42,6 +43,7 @@ jax.scipy.special log_ndtr logit logsumexp + multigammaln ndtr ndtri xlog1py diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 3fc2d795b..3f1c9face 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1745,8 +1745,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list( - {"testcase_name": "_shape={}_dtype={}_axis={}_ddof={}_keepdims={}" - .format(shape, dtype, axis, ddof, keepdims), + {"testcase_name": + "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}" + .format(shape, dtype, out_dtype, axis, ddof, keepdims), "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis, "ddof": ddof, "keepdims": keepdims, "rng": rng} for shape in [(5,), (10, 5)]