Merge pull request #17182 from hawkinsp:tril

PiperOrigin-RevId: 558247039
This commit is contained in:
jax authors 2023-08-18 14:29:22 -07:00
commit 97af33c4d1
2 changed files with 30 additions and 2 deletions

View File

@ -41,6 +41,8 @@ Remember to align the itemized text with the first line of an item within a list
* `jax.numpy.NINF` has been deprecated. Use `-jax.numpy.inf` instead.
* `jax.numpy.PZERO` has been deprecated. Use `0.0` instead.
* `jax.numpy.NZERO` has been deprecated. Use `-0.0` instead.
* `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated,
following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead.
* Internal deprecations:
* The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype`

View File

@ -39,10 +39,36 @@ from jax._src.scipy.linalg import (
solve_triangular as solve_triangular,
svd as svd,
toeplitz as toeplitz,
tril as tril,
triu as triu,
)
from jax._src.third_party.scipy.linalg import (
funm as funm,
)
# Deprecations
from jax._src.scipy.linalg import (
tril as _deprecated_tril,
triu as _deprecated_triu,
)
_deprecations = {
# Added August 18, 2023:
"tril": (
"jax.scipy.linalg.tril is deprecated. Use jax.numpy.tril instead.",
_deprecated_tril,
),
"triu": (
"jax.scipy.linalg.triu is deprecated. Use jax.numpy.triu instead.",
_deprecated_triu,
),
}
import typing
if typing.TYPE_CHECKING:
tril = _deprecated_tril
triu = _deprecated_triu
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing