temporarily un-deprecate several jax.core APIs.

These were causing excessive log-spam for some users; I'll work to migrate
them to jax.extend before re-deprecating these.
This commit is contained in:
Jake VanderPlas 2024-12-12 13:15:58 -08:00
parent 97459ba9aa
commit d3406768f0
2 changed files with 40 additions and 34 deletions

View File

@ -13,12 +13,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## jax 0.4.38
* Deprecations
* a number of APIs in the internal `jax.core` namespace have been deprecated, including
`ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,
`Literal`, `Primitive`, `raise_to_shaped`, `Token`, `Var`. Most can be replaced by
APIs of the same name in {mod}`jax.extend.core`; see the documentation for
{mod}`jax.extend` for information on the compatibility guarantees of these
semi-public extensions.
* a number of APIs in the internal `jax.core` namespace have been deprecated.
Most were no-ops, were little-used, or can be replaced by APIs of the same
name in {mod}`jax.extend.core`; see the documentation for {mod}`jax.extend`
for information on the compatibility guarantees of these semi-public extensions.
* Several previously-deprecated APIs have been removed, including:
* from {mod}`jax.core`: `check_eqn`, `check_type`, `check_valid_jaxtype`, and
`non_negative_dim`.

View File

@ -20,22 +20,29 @@ from jax._src.core import (
AbstractValue as AbstractValue,
Atom as Atom,
CallPrimitive as CallPrimitive,
ClosedJaxpr as ClosedJaxpr,
DShapedArray as DShapedArray,
DropVar as DropVar,
Effect as Effect,
Effects as Effects,
get_opaque_trace_state as get_opaque_trace_state,
InconclusiveDimensionOperation as InconclusiveDimensionOperation,
Jaxpr as Jaxpr,
JaxprDebugInfo as JaxprDebugInfo,
JaxprEqn as JaxprEqn,
JaxprPpContext as JaxprPpContext,
JaxprPpSettings as JaxprPpSettings,
JaxprTypeError as JaxprTypeError,
nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401
Literal as Literal,
OutputType as OutputType,
ParamDict as ParamDict,
Primitive as Primitive,
ShapedArray as ShapedArray,
Token as Token,
Trace as Trace,
Tracer as Tracer,
Var as Var,
unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401
unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401
unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401
@ -81,6 +88,28 @@ from jax._src.core import (
from jax._src import core as _src_core
_deprecations = {
# TODO(jakevdp): re-deprecate these after migrating some downstream uses.
# "ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, "
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
# _src_core.ClosedJaxpr),
# "Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, "
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
# _src_core.Jaxpr),
# "JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, "
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
# _src_core.JaxprEqn),
# "Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, "
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
# _src_core.Literal),
# "Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, "
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
# _src_core.Primitive),
# "Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, "
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
# _src_core.Token),
# "Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, "
# "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
# _src_core.Var),
# Added 2024-12-11
"axis_frame": ("jax.core.axis_frame is deprecated.", _src_core.axis_frame),
"AxisName": ("jax.core.AxisName is deprecated.", _src_core.AxisName),
@ -129,36 +158,15 @@ _deprecations = {
"used_axis_names_jaxpr": ("jax.core.used_axis_names_jaxpr is deprecated.",
_src_core.used_axis_names_jaxpr),
# Added 2024-12-10
"ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.ClosedJaxpr),
"full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.",
_src_core.full_lower),
"Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.Jaxpr),
"JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.JaxprEqn),
"jaxpr_as_fun": ("jax.core.jaxpr_as_fun is deprecated. Use jax.extend.core.jaxpr_as_fun instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.jaxpr_as_fun),
"lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.",
_src_core.lattice_join),
"Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.Literal),
"Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.Primitive),
"raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.",
_src_core.raise_to_shaped),
"Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.Token),
"Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.Var),
# Finalized 2024-12-11; remove after 2025-3-11
"check_eqn": ("jax.core.check_eqn was removed in JAX v0.4.38.", None),
"check_type": ("jax.core.check_type was removed in JAX v0.4.38.", None),
@ -188,21 +196,21 @@ import typing
if typing.TYPE_CHECKING:
AxisName = _src_core.AxisName
AxisSize = _src_core.AxisSize
ClosedJaxpr = _src_core.ClosedJaxpr
# ClosedJaxpr = _src_core.ClosedJaxpr
ConcretizationTypeError = _src_core.ConcretizationTypeError
EvalTrace = _src_core.EvalTrace
InDBIdx = _src_core.InDBIdx
InputType = _src_core.InputType
Jaxpr = _src_core.Jaxpr
JaxprEqn = _src_core.JaxprEqn
Literal = _src_core.Literal
# Jaxpr = _src_core.Jaxpr
# JaxprEqn = _src_core.JaxprEqn
# Literal = _src_core.Literal
MapPrimitive = _src_core.MapPrimitive
OpaqueTraceState = _src_core.OpaqueTraceState
OutDBIdx = _src_core.OutDBIdx
Primitive = _src_core.Primitive
Token = _src_core.Token
# Primitive = _src_core.Primitive
# Token = _src_core.Token
TRACER_LEAK_DEBUGGER_WARNING = _src_core.TRACER_LEAK_DEBUGGER_WARNING
Var = _src_core.Var
# Var = _src_core.Var
axis_frame = _src_core.axis_frame
call_p = _src_core.call_p
closed_call_p = _src_core.closed_call_p