mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
remove deprecated jax.lax.prod function
PiperOrigin-RevId: 559787522
This commit is contained in:
parent
26643aa96e
commit
665b176c2c
@ -46,6 +46,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* `jax.numpy.in1d` has been deprecated. Use `jax.numpy.isin` instead.
|
||||
* `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated,
|
||||
following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead.
|
||||
* `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11.
|
||||
Use the built-in `math.prod` instead.
|
||||
|
||||
* Internal deprecations:
|
||||
* The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype`
|
||||
|
@ -377,17 +377,3 @@ from jax.lax import linalg as linalg
|
||||
from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
|
||||
from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
|
||||
from jax._src.dispatch import device_put_p as device_put_p
|
||||
|
||||
from math import prod as _prod
|
||||
|
||||
_deprecations = {
|
||||
# Added May 23, 2023:
|
||||
"prod": (
|
||||
"jax.lax.prod is deprecated. Use math.prod instead.",
|
||||
_prod,
|
||||
),
|
||||
}
|
||||
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr, _prod
|
||||
|
Loading…
x
Reference in New Issue
Block a user