mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #17182 from hawkinsp:tril
PiperOrigin-RevId: 558247039
This commit is contained in:
commit
97af33c4d1
@ -41,6 +41,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* `jax.numpy.NINF` has been deprecated. Use `-jax.numpy.inf` instead.
|
||||
* `jax.numpy.PZERO` has been deprecated. Use `0.0` instead.
|
||||
* `jax.numpy.NZERO` has been deprecated. Use `-0.0` instead.
|
||||
* `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated,
|
||||
following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead.
|
||||
|
||||
* Internal deprecations:
|
||||
* The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype`
|
||||
|
@ -39,10 +39,36 @@ from jax._src.scipy.linalg import (
|
||||
solve_triangular as solve_triangular,
|
||||
svd as svd,
|
||||
toeplitz as toeplitz,
|
||||
tril as tril,
|
||||
triu as triu,
|
||||
)
|
||||
|
||||
from jax._src.third_party.scipy.linalg import (
|
||||
funm as funm,
|
||||
)
|
||||
|
||||
# Deprecations
|
||||
from jax._src.scipy.linalg import (
|
||||
tril as _deprecated_tril,
|
||||
triu as _deprecated_triu,
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
# Added August 18, 2023:
|
||||
"tril": (
|
||||
"jax.scipy.linalg.tril is deprecated. Use jax.numpy.tril instead.",
|
||||
_deprecated_tril,
|
||||
),
|
||||
"triu": (
|
||||
"jax.scipy.linalg.triu is deprecated. Use jax.numpy.triu instead.",
|
||||
_deprecated_triu,
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
tril = _deprecated_tril
|
||||
triu = _deprecated_triu
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
|
Loading…
x
Reference in New Issue
Block a user