jax.core: deprecate a number of APIs

This commit is contained in:
Jake VanderPlas 2024-12-10 11:11:32 -08:00
parent 263d4d1462
commit 6541a62099
20 changed files with 84 additions and 40 deletions

View File

@ -12,6 +12,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## jax 0.4.38
* Deprecations
* a number of APIs in the internal `jax.core` namespace have been deprecated, including
`ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,
`Literal`, `Primitive`, `raise_to_shaped`, `Token`, `Var`. Most can be replaced by
APIs of the same name in {mod}`jax.extend.core`; see the documentation for
{mod}`jax.extend` for information on the compatibility guarantees of these
semi-public extensions.
## jax 0.4.37 (Dec 9, 2024)
This is a patch release of jax 0.4.36. Only "jax" was released at this version.

View File

@ -25,4 +25,3 @@ some of JAX's (extensible) internals.
autodidax
jep/index
jax_internal_api

18
docs/jax.extend.core.rst Normal file
View File

@ -0,0 +1,18 @@
``jax.extend.core`` module
==========================
.. automodule:: jax.extend.core
.. autosummary::
:toctree: _autosummary
ClosedJaxpr
Jaxpr
JaxprEqn
Literal
Primitive
Token
Var
array_types
jaxpr_as_fun
primitives

View File

@ -11,6 +11,7 @@ Modules
.. toctree::
:maxdepth: 1
jax.extend.core
jax.extend.ffi
jax.extend.linear_util
jax.extend.mlir

View File

@ -1,14 +0,0 @@
Internal API reference
======================
core
----
.. currentmodule:: jax.core
.. automodule:: jax.core
.. autosummary::
:toctree: _autosummary
Jaxpr
ClosedJaxpr

View File

@ -18,8 +18,8 @@ import json
import math
import jax
from jax import core
from jax import dtypes
from jax._src import core
from jax._src import dispatch
from jax._src.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching

View File

@ -14,7 +14,7 @@
import functools
import jax
from jax import core as jax_core
from jax._src import core as jax_core
from jax.interpreters import mlir
from jax.interpreters.mlir import hlo
from jax.interpreters.mlir import ir

View File

@ -21,10 +21,9 @@ import tempfile
from typing import Any
import jax
from jax import core as jax_core
from jax import dtypes
from jax._src import config
from jax._src import core as jax_src_core
from jax._src import core as jax_core
from jax._src import sharding_impls
from jax._src import tpu_custom_call
from jax._src.interpreters import mlir
@ -189,7 +188,7 @@ def pallas_call_tpu_lowering_rule(
# Replace in_avals to physical avals.
# This step is required for mapping logical types to physical types.
# (e.g. PRNG key -> uint32[2])
physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in]
physical_avals = [jax_core.physical_aval(aval) for aval in ctx.avals_in]
ctx = ctx.replace(avals_in=physical_avals)
# Booleans are loaded into the kernel as integers.

View File

@ -44,6 +44,7 @@ pytype_strict_library(
deps = [
":lowering",
"//jax",
"//jax:core",
"//jax:mlir",
"//jax:mosaic_gpu",
"//jax/_src/pallas",

View File

@ -23,7 +23,7 @@ from typing import Any
import warnings
import jax
from jax import core as jax_core
from jax._src import core as jax_core
from jax._src.interpreters import mlir
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import lowering

View File

@ -76,6 +76,7 @@ pytype_strict_library(
":lowering",
"//jax",
"//jax:config",
"//jax:core",
"//jax:mlir",
"//jax:util",
"//jax/_src/lib",

View File

@ -19,7 +19,7 @@ from __future__ import annotations
import io
from typing import Any
from jax import core as jax_core
import jax._src.core as jax_core
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.pallas import core as pallas_core

View File

@ -19,7 +19,7 @@ from __future__ import annotations
from collections.abc import Sequence
import jax
from jax import core as jax_core
from jax._src import core as jax_core
from jax._src.lib.mlir.dialects import gpu as gpu_dialect
from jax._src.lib.triton import dialect as tt_dialect
from jax._src.pallas.triton import lowering

View File

@ -29,7 +29,7 @@ import time
from typing import Any
import jax
from jax import core
from jax._src import core
from jax._src import config
from jax._src import sharding_impls
from jax._src.interpreters import mlir

View File

@ -23,7 +23,6 @@ from jax._src.core import (
AxisSize as AxisSize,
AxisName as AxisName,
CallPrimitive as CallPrimitive,
ClosedJaxpr as ClosedJaxpr,
ConcretizationTypeError as ConcretizationTypeError,
DShapedArray as DShapedArray,
DropVar as DropVar,
@ -34,23 +33,18 @@ from jax._src.core import (
InDBIdx as InDBIdx,
InconclusiveDimensionOperation as InconclusiveDimensionOperation,
InputType as InputType,
Jaxpr as Jaxpr,
JaxprDebugInfo as JaxprDebugInfo,
JaxprEqn as JaxprEqn,
JaxprPpContext as JaxprPpContext,
JaxprPpSettings as JaxprPpSettings,
JaxprTypeError as JaxprTypeError,
Literal as Literal,
MapPrimitive as MapPrimitive,
nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401
OpaqueTraceState as OpaqueTraceState,
OutDBIdx as OutDBIdx,
OutputType as OutputType,
ParamDict as ParamDict,
Primitive as Primitive,
ShapedArray as ShapedArray,
TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING,
Token as Token,
Trace as Trace,
Tracer as Tracer,
unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401
@ -59,7 +53,6 @@ from jax._src.core import (
unsafe_get_current_trace as unsafe_get_current_trace_DO_NOT_USE, # noqa: F401
UnshapedArray as UnshapedArray,
Value as Value,
Var as Var,
abstract_token as abstract_token,
aval_mapping_handlers as aval_mapping_handlers,
call as call,
@ -78,7 +71,6 @@ from jax._src.core import (
eval_jaxpr as eval_jaxpr,
extend_axis_env_nd as extend_axis_env_nd,
find_top_trace as find_top_trace,
full_lower as full_lower,
gensym as gensym,
get_aval as get_aval,
get_type as get_type,
@ -86,10 +78,8 @@ from jax._src.core import (
is_concrete as is_concrete,
is_constant_dim as is_constant_dim,
is_constant_shape as is_constant_shape,
jaxpr_as_fun as jaxpr_as_fun,
jaxprs_in_params as jaxprs_in_params,
join_effects as join_effects,
lattice_join as lattice_join,
leaked_tracer_error as leaked_tracer_error,
literalable_types as literalable_types,
mapped_aval as mapped_aval,
@ -101,7 +91,6 @@ from jax._src.core import (
no_effects as no_effects,
primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype,
pytype_aval_mappings as pytype_aval_mappings,
raise_to_shaped as raise_to_shaped,
raise_to_shaped_mappings as raise_to_shaped_mappings,
reset_trace_state as reset_trace_state,
set_current_trace as set_current_trace,
@ -124,6 +113,37 @@ from jax._src.core import (
from jax._src import core as _src_core
_deprecations = {
# Added 2024-12-10
"ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.ClosedJaxpr),
"full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.",
_src_core.full_lower),
"Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.Jaxpr),
"JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.JaxprEqn),
"jaxpr_as_fun": ("jax.core.jaxpr_as_fun is deprecated. Use jax.extend.core.jaxpr_as_fun instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.jaxpr_as_fun),
"lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.",
_src_core.lattice_join),
"Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.Literal),
"Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.Primitive),
"raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.",
_src_core.raise_to_shaped),
"Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.Token),
"Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, "
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
_src_core.Var),
# Added 2024-08-14
"check_eqn": ("jax.core.check_eqn is deprecated.", _src_core.check_eqn),
"check_type": ("jax.core.check_type is deprecated.", _src_core.check_type),
@ -152,10 +172,21 @@ _deprecations = {
import typing
if typing.TYPE_CHECKING:
ClosedJaxpr = _src_core.ClosedJaxpr
Jaxpr = _src_core.Jaxpr
JaxprEqn = _src_core.JaxprEqn
Literal = _src_core.Literal
Primitive = _src_core.Primitive
Token = _src_core.Token
Var = _src_core.Var
check_eqn = _src_core.check_eqn
check_type = _src_core.check_type
check_valid_jaxtype = _src_core.check_valid_jaxtype
full_lower = _src_core.full_lower
jaxpr_as_fun = _src_core.jaxpr_as_fun
lattice_join = _src_core.lattice_join
non_negative_dim = _src_core.non_negative_dim
raise_to_shaped = _src_core.raise_to_shaped
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)

View File

@ -77,7 +77,7 @@ if RUNTIME_PATH and RUNTIME_PATH.exists():
os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH)
mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p")
mosaic_gpu_p = jax._src.core.Primitive("mosaic_gpu_p")
mosaic_gpu_p.multiple_results = True

View File

@ -19,7 +19,7 @@ are used internally in GPU translation rules of higher-level primitives.
from functools import partial
from jax import core
from jax._src import core
from jax._src import dispatch
from jax._src.interpreters import mlir
from jax._src.lib import gpu_sparse

View File

@ -14,7 +14,7 @@
"""N:M-sparsity associated primitives."""
from jax import core
from jax._src import core
from jax._src import dispatch
from jax._src.lax.lax import DotDimensionNumbers
from jax._src.lib import gpu_sparse

View File

@ -18,8 +18,8 @@ import operator
import numpy as np
import jax
from jax import core
import jax.numpy as jnp
from jax._src import core
from jax._src import prng
from jax._src import random
from jax._src import test_util as jtu

View File

@ -3837,7 +3837,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
constant_values= ((0.0, 0.0), (0.0, 0.0)))
jaxpr = jax.make_jaxpr(trace_to_jaxpr)(x)
jax.core.jaxpr_as_fun(jaxpr)(x)
jax._src.core.jaxpr_as_fun(jaxpr)(x)
jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap')
jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap') # doesn't crash