diff --git a/CHANGELOG.md b/CHANGELOG.md index e351f64d0..d8bb1478a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.38 +* Deprecations + * a number of APIs in the internal `jax.core` namespace have been deprecated, including + `ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`, + `Literal`, `Primitive`, `raise_to_shaped`, `Token`, `Var`. Most can be replaced by + APIs of the same name in {mod}`jax.extend.core`; see the documentation for + {mod}`jax.extend` for information on the compatibility guarantees of these + semi-public extensions. + ## jax 0.4.37 (Dec 9, 2024) This is a patch release of jax 0.4.36. Only "jax" was released at this version. diff --git a/docs/contributor_guide.rst b/docs/contributor_guide.rst index 55094fc88..f89122f94 100644 --- a/docs/contributor_guide.rst +++ b/docs/contributor_guide.rst @@ -25,4 +25,3 @@ some of JAX's (extensible) internals. autodidax jep/index - jax_internal_api diff --git a/docs/jax.extend.core.rst b/docs/jax.extend.core.rst new file mode 100644 index 000000000..5f3ff0558 --- /dev/null +++ b/docs/jax.extend.core.rst @@ -0,0 +1,18 @@ +``jax.extend.core`` module +========================== + +.. automodule:: jax.extend.core + +.. autosummary:: + :toctree: _autosummary + + ClosedJaxpr + Jaxpr + JaxprEqn + Literal + Primitive + Token + Var + array_types + jaxpr_as_fun + primitives diff --git a/docs/jax.extend.rst b/docs/jax.extend.rst index 9cbee08e8..0d68013c9 100644 --- a/docs/jax.extend.rst +++ b/docs/jax.extend.rst @@ -11,6 +11,7 @@ Modules .. toctree:: :maxdepth: 1 + jax.extend.core jax.extend.ffi jax.extend.linear_util jax.extend.mlir diff --git a/docs/jax_internal_api.rst b/docs/jax_internal_api.rst deleted file mode 100644 index 1ece596d8..000000000 --- a/docs/jax_internal_api.rst +++ /dev/null @@ -1,14 +0,0 @@ -Internal API reference -====================== - -core ----- - -.. currentmodule:: jax.core -.. automodule:: jax.core - -.. autosummary:: - :toctree: _autosummary - - Jaxpr - ClosedJaxpr diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index ef4e33ad0..c45bb8a9e 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -18,8 +18,8 @@ import json import math import jax -from jax import core from jax import dtypes +from jax._src import core from jax._src import dispatch from jax._src.custom_partitioning import custom_partitioning from jax._src.interpreters import batching diff --git a/jax/_src/cudnn/fusion.py b/jax/_src/cudnn/fusion.py index 8a13399e3..f32067246 100644 --- a/jax/_src/cudnn/fusion.py +++ b/jax/_src/cudnn/fusion.py @@ -14,7 +14,7 @@ import functools import jax -from jax import core as jax_core +from jax._src import core as jax_core from jax.interpreters import mlir from jax.interpreters.mlir import hlo from jax.interpreters.mlir import ir diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 4382cea91..ec9500c67 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -21,10 +21,9 @@ import tempfile from typing import Any import jax -from jax import core as jax_core from jax import dtypes from jax._src import config -from jax._src import core as jax_src_core +from jax._src import core as jax_core from jax._src import sharding_impls from jax._src import tpu_custom_call from jax._src.interpreters import mlir @@ -189,7 +188,7 @@ def pallas_call_tpu_lowering_rule( # Replace in_avals to physical avals. # This step is required for mapping logical types to physical types. # (e.g. PRNG key -> uint32[2]) - physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in] + physical_avals = [jax_core.physical_aval(aval) for aval in ctx.avals_in] ctx = ctx.replace(avals_in=physical_avals) # Booleans are loaded into the kernel as integers. diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 3d6e82d44..e9461a5ce 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -44,6 +44,7 @@ pytype_strict_library( deps = [ ":lowering", "//jax", + "//jax:core", "//jax:mlir", "//jax:mosaic_gpu", "//jax/_src/pallas", diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 05785cb51..18d8baf6e 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -23,7 +23,7 @@ from typing import Any import warnings import jax -from jax import core as jax_core +from jax._src import core as jax_core from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import lowering diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index a9babcba0..84fae3913 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -76,6 +76,7 @@ pytype_strict_library( ":lowering", "//jax", "//jax:config", + "//jax:core", "//jax:mlir", "//jax:util", "//jax/_src/lib", diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 67b0bd326..1805f8c09 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -19,7 +19,7 @@ from __future__ import annotations import io from typing import Any -from jax import core as jax_core +import jax._src.core as jax_core from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.pallas import core as pallas_core diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 23fce50dc..b845a4079 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -19,7 +19,7 @@ from __future__ import annotations from collections.abc import Sequence import jax -from jax import core as jax_core +from jax._src import core as jax_core from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.triton import dialect as tt_dialect from jax._src.pallas.triton import lowering diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index ccd77af5b..9e54f62d9 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -29,7 +29,7 @@ import time from typing import Any import jax -from jax import core +from jax._src import core from jax._src import config from jax._src import sharding_impls from jax._src.interpreters import mlir diff --git a/jax/core.py b/jax/core.py index 8d7c546f0..4d1742bc2 100644 --- a/jax/core.py +++ b/jax/core.py @@ -23,7 +23,6 @@ from jax._src.core import ( AxisSize as AxisSize, AxisName as AxisName, CallPrimitive as CallPrimitive, - ClosedJaxpr as ClosedJaxpr, ConcretizationTypeError as ConcretizationTypeError, DShapedArray as DShapedArray, DropVar as DropVar, @@ -34,23 +33,18 @@ from jax._src.core import ( InDBIdx as InDBIdx, InconclusiveDimensionOperation as InconclusiveDimensionOperation, InputType as InputType, - Jaxpr as Jaxpr, JaxprDebugInfo as JaxprDebugInfo, - JaxprEqn as JaxprEqn, JaxprPpContext as JaxprPpContext, JaxprPpSettings as JaxprPpSettings, JaxprTypeError as JaxprTypeError, - Literal as Literal, MapPrimitive as MapPrimitive, nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401 OpaqueTraceState as OpaqueTraceState, OutDBIdx as OutDBIdx, OutputType as OutputType, ParamDict as ParamDict, - Primitive as Primitive, ShapedArray as ShapedArray, TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING, - Token as Token, Trace as Trace, Tracer as Tracer, unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401 @@ -59,7 +53,6 @@ from jax._src.core import ( unsafe_get_current_trace as unsafe_get_current_trace_DO_NOT_USE, # noqa: F401 UnshapedArray as UnshapedArray, Value as Value, - Var as Var, abstract_token as abstract_token, aval_mapping_handlers as aval_mapping_handlers, call as call, @@ -78,7 +71,6 @@ from jax._src.core import ( eval_jaxpr as eval_jaxpr, extend_axis_env_nd as extend_axis_env_nd, find_top_trace as find_top_trace, - full_lower as full_lower, gensym as gensym, get_aval as get_aval, get_type as get_type, @@ -86,10 +78,8 @@ from jax._src.core import ( is_concrete as is_concrete, is_constant_dim as is_constant_dim, is_constant_shape as is_constant_shape, - jaxpr_as_fun as jaxpr_as_fun, jaxprs_in_params as jaxprs_in_params, join_effects as join_effects, - lattice_join as lattice_join, leaked_tracer_error as leaked_tracer_error, literalable_types as literalable_types, mapped_aval as mapped_aval, @@ -101,7 +91,6 @@ from jax._src.core import ( no_effects as no_effects, primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, pytype_aval_mappings as pytype_aval_mappings, - raise_to_shaped as raise_to_shaped, raise_to_shaped_mappings as raise_to_shaped_mappings, reset_trace_state as reset_trace_state, set_current_trace as set_current_trace, @@ -124,6 +113,37 @@ from jax._src.core import ( from jax._src import core as _src_core _deprecations = { + # Added 2024-12-10 + "ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.ClosedJaxpr), + "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", + _src_core.full_lower), + "Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Jaxpr), + "JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.JaxprEqn), + "jaxpr_as_fun": ("jax.core.jaxpr_as_fun is deprecated. Use jax.extend.core.jaxpr_as_fun instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.jaxpr_as_fun), + "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", + _src_core.lattice_join), + "Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Literal), + "Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Primitive), + "raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.", + _src_core.raise_to_shaped), + "Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Token), + "Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Var), # Added 2024-08-14 "check_eqn": ("jax.core.check_eqn is deprecated.", _src_core.check_eqn), "check_type": ("jax.core.check_type is deprecated.", _src_core.check_type), @@ -152,10 +172,21 @@ _deprecations = { import typing if typing.TYPE_CHECKING: + ClosedJaxpr = _src_core.ClosedJaxpr + Jaxpr = _src_core.Jaxpr + JaxprEqn = _src_core.JaxprEqn + Literal = _src_core.Literal + Primitive = _src_core.Primitive + Token = _src_core.Token + Var = _src_core.Var check_eqn = _src_core.check_eqn check_type = _src_core.check_type check_valid_jaxtype = _src_core.check_valid_jaxtype + full_lower = _src_core.full_lower + jaxpr_as_fun = _src_core.jaxpr_as_fun + lattice_join = _src_core.lattice_join non_negative_dim = _src_core.non_negative_dim + raise_to_shaped = _src_core.raise_to_shaped else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index d8774e932..b03c3a5b5 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -77,7 +77,7 @@ if RUNTIME_PATH and RUNTIME_PATH.exists(): os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) -mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p") +mosaic_gpu_p = jax._src.core.Primitive("mosaic_gpu_p") mosaic_gpu_p.multiple_results = True diff --git a/jax/experimental/sparse/_lowerings.py b/jax/experimental/sparse/_lowerings.py index f4fe0b904..6962ef78b 100644 --- a/jax/experimental/sparse/_lowerings.py +++ b/jax/experimental/sparse/_lowerings.py @@ -19,7 +19,7 @@ are used internally in GPU translation rules of higher-level primitives. from functools import partial -from jax import core +from jax._src import core from jax._src import dispatch from jax._src.interpreters import mlir from jax._src.lib import gpu_sparse diff --git a/jax/experimental/sparse/nm.py b/jax/experimental/sparse/nm.py index 6c827325b..f9d28f5ff 100644 --- a/jax/experimental/sparse/nm.py +++ b/jax/experimental/sparse/nm.py @@ -14,7 +14,7 @@ """N:M-sparsity associated primitives.""" -from jax import core +from jax._src import core from jax._src import dispatch from jax._src.lax.lax import DotDimensionNumbers from jax._src.lib import gpu_sparse diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 286088eeb..3364c9be9 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -18,8 +18,8 @@ import operator import numpy as np import jax -from jax import core import jax.numpy as jnp +from jax._src import core from jax._src import prng from jax._src import random from jax._src import test_util as jtu diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5ca87aae6..5bb090435 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3837,7 +3837,7 @@ class ArrayPjitTest(jtu.JaxTestCase): constant_values= ((0.0, 0.0), (0.0, 0.0))) jaxpr = jax.make_jaxpr(trace_to_jaxpr)(x) - jax.core.jaxpr_as_fun(jaxpr)(x) + jax._src.core.jaxpr_as_fun(jaxpr)(x) jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap') jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap') # doesn't crash