diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b023a229..f050f5409 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,8 @@ Remember to align the itemized text with the first line of an item within a list * `jax.numpy.in1d` has been deprecated. Use `jax.numpy.isin` instead. * `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated, following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead. + * `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11. + Use the built-in `math.prod` instead. * Internal deprecations: * The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype` diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 1c6b24c43..111233d72 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -377,17 +377,3 @@ from jax.lax import linalg as linalg from jax._src.pjit import with_sharding_constraint as with_sharding_constraint from jax._src.pjit import sharding_constraint_p as sharding_constraint_p from jax._src.dispatch import device_put_p as device_put_p - -from math import prod as _prod - -_deprecations = { - # Added May 23, 2023: - "prod": ( - "jax.lax.prod is deprecated. Use math.prod instead.", - _prod, - ), -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr, _prod