mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix test breakage at head.
Add new numpy/scipy functions to documentation.
This commit is contained in:
parent
1dfdd8dafe
commit
b45d1ec1dd
@ -162,6 +162,7 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
max
|
||||
maximum
|
||||
mean
|
||||
median
|
||||
meshgrid
|
||||
min
|
||||
minimum
|
||||
@ -228,7 +229,9 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
transpose
|
||||
tri
|
||||
tril
|
||||
tril_indices
|
||||
triu
|
||||
triu_indices
|
||||
true_divide
|
||||
vander
|
||||
var
|
||||
|
@ -34,6 +34,7 @@ jax.scipy.special
|
||||
:toctree: _autosummary
|
||||
|
||||
digamma
|
||||
entr
|
||||
erf
|
||||
erfc
|
||||
erfinv
|
||||
@ -42,6 +43,7 @@ jax.scipy.special
|
||||
log_ndtr
|
||||
logit
|
||||
logsumexp
|
||||
multigammaln
|
||||
ndtr
|
||||
ndtri
|
||||
xlog1py
|
||||
|
@ -1745,8 +1745,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dtype={}_axis={}_ddof={}_keepdims={}"
|
||||
.format(shape, dtype, axis, ddof, keepdims),
|
||||
{"testcase_name":
|
||||
"_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}"
|
||||
.format(shape, dtype, out_dtype, axis, ddof, keepdims),
|
||||
"shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis,
|
||||
"ddof": ddof, "keepdims": keepdims, "rng": rng}
|
||||
for shape in [(5,), (10, 5)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user