Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.

Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API.

PiperOrigin-RevId: 679163106
This commit is contained in:
Peter Hawkins 2024-09-26 08:38:46 -07:00 committed by jax authors
parent 5cef547eab
commit 7b53c2f39d
6 changed files with 43 additions and 15 deletions

View File

@ -15,6 +15,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* New Functionality
* This release includes wheels for Python 3.13. Free-threading mode is not yet
supported.
* `jax.errors.JaxRuntimeError` has been added as a public alias for the
formerly private `XlaRuntimeError` type.
* Breaking changes
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
@ -32,6 +34,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
in an error.
* Internal pretty-printing tools `jax.core.pp_*` have been removed, after
being deprecated in JAX v0.4.30.
* `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use
`jax.errors.JaxRuntimeError` instead.
* Deletion:
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation

View File

@ -9,6 +9,7 @@ along with representative examples of how one might fix them.
.. currentmodule:: jax.errors
.. autoclass:: ConcretizationTypeError
.. autoclass:: KeyReuseError
.. autoclass:: JaxRuntimeError
.. autoclass:: NonConcreteBooleanIndexError
.. autoclass:: TracerArrayConversionError
.. autoclass:: TracerBoolConversionError

View File

@ -26,4 +26,9 @@ from jax._src.errors import (
UnexpectedTracerError as UnexpectedTracerError,
KeyReuseError as KeyReuseError,
)
from jax._src.lib import xla_client as _xc
JaxRuntimeError = _xc.XlaRuntimeError
del _xc
from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback

View File

@ -37,26 +37,36 @@ Shape = _xc.Shape
Traceback = _xc.Traceback
XlaBuilder = _xc.XlaBuilder
XlaComputation = _xc.XlaComputation
XlaRuntimeError = _xc.XlaRuntimeError
_deprecations = {
# Added Aug 5 2024
"_xla" : (
"jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.",
_xc._xla
),
"bfloat16" : (
"jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.",
_xc.bfloat16
),
# Added Aug 5 2024
"_xla": (
"jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.",
_xc._xla,
),
"bfloat16": (
"jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.",
_xc.bfloat16,
),
# Added Sep 26 2024
"XlaRuntimeError": (
(
"jax.lib.xla_client.XlaRuntimeError is deprecated; use"
" jax.errors.JaxRuntimeError."
),
_xc.XlaRuntimeError,
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
_xla = _xc._xla
bfloat16 = _xc.bfloat16
XlaRuntimeError = _xc.XlaRuntimeError
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing

View File

@ -394,11 +394,19 @@ class UserContextTracebackTest(jtu.JaxTestCase):
class CustomErrorsTest(jtu.JaxTestCase):
@jtu.sample_product(
errorclass=[
errorclass for errorclass in dir(jax.errors)
if errorclass.endswith('Error') and errorclass not in ['JaxIndexError', 'JAXTypeError']
],
errorclass=[
errorclass
for errorclass in dir(jax.errors)
if errorclass.endswith('Error')
and errorclass
not in [
'JaxIndexError',
'JAXTypeError',
'JaxRuntimeError',
]
],
)
def testErrorsURL(self, errorclass):
class FakeTracer(core.Tracer):

View File

@ -31,7 +31,7 @@ class PackageStructureTest(jtu.JaxTestCase):
@parameterized.parameters([
# TODO(jakevdp): expand test to other public modules.
_mod("jax.errors"),
_mod("jax.errors", exclude=["JaxRuntimeError"]),
_mod("jax.nn.initializers"),
_mod(
"jax.tree_util",