mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.
Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API. PiperOrigin-RevId: 679163106
This commit is contained in:
parent
5cef547eab
commit
7b53c2f39d
@ -15,6 +15,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* New Functionality
|
||||
* This release includes wheels for Python 3.13. Free-threading mode is not yet
|
||||
supported.
|
||||
* `jax.errors.JaxRuntimeError` has been added as a public alias for the
|
||||
formerly private `XlaRuntimeError` type.
|
||||
|
||||
* Breaking changes
|
||||
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
|
||||
@ -32,6 +34,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
in an error.
|
||||
* Internal pretty-printing tools `jax.core.pp_*` have been removed, after
|
||||
being deprecated in JAX v0.4.30.
|
||||
* `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use
|
||||
`jax.errors.JaxRuntimeError` instead.
|
||||
|
||||
* Deletion:
|
||||
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
|
||||
|
@ -9,6 +9,7 @@ along with representative examples of how one might fix them.
|
||||
.. currentmodule:: jax.errors
|
||||
.. autoclass:: ConcretizationTypeError
|
||||
.. autoclass:: KeyReuseError
|
||||
.. autoclass:: JaxRuntimeError
|
||||
.. autoclass:: NonConcreteBooleanIndexError
|
||||
.. autoclass:: TracerArrayConversionError
|
||||
.. autoclass:: TracerBoolConversionError
|
||||
|
@ -26,4 +26,9 @@ from jax._src.errors import (
|
||||
UnexpectedTracerError as UnexpectedTracerError,
|
||||
KeyReuseError as KeyReuseError,
|
||||
)
|
||||
|
||||
from jax._src.lib import xla_client as _xc
|
||||
JaxRuntimeError = _xc.XlaRuntimeError
|
||||
del _xc
|
||||
|
||||
from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback
|
||||
|
@ -37,26 +37,36 @@ Shape = _xc.Shape
|
||||
Traceback = _xc.Traceback
|
||||
XlaBuilder = _xc.XlaBuilder
|
||||
XlaComputation = _xc.XlaComputation
|
||||
XlaRuntimeError = _xc.XlaRuntimeError
|
||||
|
||||
_deprecations = {
|
||||
# Added Aug 5 2024
|
||||
"_xla" : (
|
||||
"jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.",
|
||||
_xc._xla
|
||||
),
|
||||
"bfloat16" : (
|
||||
"jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.",
|
||||
_xc.bfloat16
|
||||
),
|
||||
# Added Aug 5 2024
|
||||
"_xla": (
|
||||
"jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.",
|
||||
_xc._xla,
|
||||
),
|
||||
"bfloat16": (
|
||||
"jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.",
|
||||
_xc.bfloat16,
|
||||
),
|
||||
# Added Sep 26 2024
|
||||
"XlaRuntimeError": (
|
||||
(
|
||||
"jax.lib.xla_client.XlaRuntimeError is deprecated; use"
|
||||
" jax.errors.JaxRuntimeError."
|
||||
),
|
||||
_xc.XlaRuntimeError,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
|
||||
if _typing.TYPE_CHECKING:
|
||||
_xla = _xc._xla
|
||||
bfloat16 = _xc.bfloat16
|
||||
XlaRuntimeError = _xc.XlaRuntimeError
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del _typing
|
||||
|
@ -394,11 +394,19 @@ class UserContextTracebackTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
class CustomErrorsTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
errorclass=[
|
||||
errorclass for errorclass in dir(jax.errors)
|
||||
if errorclass.endswith('Error') and errorclass not in ['JaxIndexError', 'JAXTypeError']
|
||||
],
|
||||
errorclass=[
|
||||
errorclass
|
||||
for errorclass in dir(jax.errors)
|
||||
if errorclass.endswith('Error')
|
||||
and errorclass
|
||||
not in [
|
||||
'JaxIndexError',
|
||||
'JAXTypeError',
|
||||
'JaxRuntimeError',
|
||||
]
|
||||
],
|
||||
)
|
||||
def testErrorsURL(self, errorclass):
|
||||
class FakeTracer(core.Tracer):
|
||||
|
@ -31,7 +31,7 @@ class PackageStructureTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
# TODO(jakevdp): expand test to other public modules.
|
||||
_mod("jax.errors"),
|
||||
_mod("jax.errors", exclude=["JaxRuntimeError"]),
|
||||
_mod("jax.nn.initializers"),
|
||||
_mod(
|
||||
"jax.tree_util",
|
||||
|
Loading…
x
Reference in New Issue
Block a user