mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
jax.core: deprecate a number of APIs
This commit is contained in:
parent
263d4d1462
commit
6541a62099
@ -12,6 +12,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
|
||||
## jax 0.4.38
|
||||
|
||||
* Deprecations
|
||||
* a number of APIs in the internal `jax.core` namespace have been deprecated, including
|
||||
`ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,
|
||||
`Literal`, `Primitive`, `raise_to_shaped`, `Token`, `Var`. Most can be replaced by
|
||||
APIs of the same name in {mod}`jax.extend.core`; see the documentation for
|
||||
{mod}`jax.extend` for information on the compatibility guarantees of these
|
||||
semi-public extensions.
|
||||
|
||||
## jax 0.4.37 (Dec 9, 2024)
|
||||
|
||||
This is a patch release of jax 0.4.36. Only "jax" was released at this version.
|
||||
|
@ -25,4 +25,3 @@ some of JAX's (extensible) internals.
|
||||
|
||||
autodidax
|
||||
jep/index
|
||||
jax_internal_api
|
||||
|
18
docs/jax.extend.core.rst
Normal file
18
docs/jax.extend.core.rst
Normal file
@ -0,0 +1,18 @@
|
||||
``jax.extend.core`` module
|
||||
==========================
|
||||
|
||||
.. automodule:: jax.extend.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
ClosedJaxpr
|
||||
Jaxpr
|
||||
JaxprEqn
|
||||
Literal
|
||||
Primitive
|
||||
Token
|
||||
Var
|
||||
array_types
|
||||
jaxpr_as_fun
|
||||
primitives
|
@ -11,6 +11,7 @@ Modules
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
jax.extend.core
|
||||
jax.extend.ffi
|
||||
jax.extend.linear_util
|
||||
jax.extend.mlir
|
||||
|
@ -1,14 +0,0 @@
|
||||
Internal API reference
|
||||
======================
|
||||
|
||||
core
|
||||
----
|
||||
|
||||
.. currentmodule:: jax.core
|
||||
.. automodule:: jax.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Jaxpr
|
||||
ClosedJaxpr
|
@ -18,8 +18,8 @@ import json
|
||||
import math
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src.custom_partitioning import custom_partitioning
|
||||
from jax._src.interpreters import batching
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import functools
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax._src import core as jax_core
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters.mlir import hlo
|
||||
from jax.interpreters.mlir import ir
|
||||
|
@ -21,10 +21,9 @@ import tempfile
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax import dtypes
|
||||
from jax._src import config
|
||||
from jax._src import core as jax_src_core
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import tpu_custom_call
|
||||
from jax._src.interpreters import mlir
|
||||
@ -189,7 +188,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
# Replace in_avals to physical avals.
|
||||
# This step is required for mapping logical types to physical types.
|
||||
# (e.g. PRNG key -> uint32[2])
|
||||
physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in]
|
||||
physical_avals = [jax_core.physical_aval(aval) for aval in ctx.avals_in]
|
||||
ctx = ctx.replace(avals_in=physical_avals)
|
||||
|
||||
# Booleans are loaded into the kernel as integers.
|
||||
|
@ -44,6 +44,7 @@ pytype_strict_library(
|
||||
deps = [
|
||||
":lowering",
|
||||
"//jax",
|
||||
"//jax:core",
|
||||
"//jax:mlir",
|
||||
"//jax:mosaic_gpu",
|
||||
"//jax/_src/pallas",
|
||||
|
@ -23,7 +23,7 @@ from typing import Any
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax._src import core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.mosaic_gpu import lowering
|
||||
|
@ -76,6 +76,7 @@ pytype_strict_library(
|
||||
":lowering",
|
||||
"//jax",
|
||||
"//jax:config",
|
||||
"//jax:core",
|
||||
"//jax:mlir",
|
||||
"//jax:util",
|
||||
"//jax/_src/lib",
|
||||
|
@ -19,7 +19,7 @@ from __future__ import annotations
|
||||
import io
|
||||
from typing import Any
|
||||
|
||||
from jax import core as jax_core
|
||||
import jax._src.core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
|
@ -19,7 +19,7 @@ from __future__ import annotations
|
||||
from collections.abc import Sequence
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax._src import core as jax_core
|
||||
from jax._src.lib.mlir.dialects import gpu as gpu_dialect
|
||||
from jax._src.lib.triton import dialect as tt_dialect
|
||||
from jax._src.pallas.triton import lowering
|
||||
|
@ -29,7 +29,7 @@ import time
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import sharding_impls
|
||||
from jax._src.interpreters import mlir
|
||||
|
53
jax/core.py
53
jax/core.py
@ -23,7 +23,6 @@ from jax._src.core import (
|
||||
AxisSize as AxisSize,
|
||||
AxisName as AxisName,
|
||||
CallPrimitive as CallPrimitive,
|
||||
ClosedJaxpr as ClosedJaxpr,
|
||||
ConcretizationTypeError as ConcretizationTypeError,
|
||||
DShapedArray as DShapedArray,
|
||||
DropVar as DropVar,
|
||||
@ -34,23 +33,18 @@ from jax._src.core import (
|
||||
InDBIdx as InDBIdx,
|
||||
InconclusiveDimensionOperation as InconclusiveDimensionOperation,
|
||||
InputType as InputType,
|
||||
Jaxpr as Jaxpr,
|
||||
JaxprDebugInfo as JaxprDebugInfo,
|
||||
JaxprEqn as JaxprEqn,
|
||||
JaxprPpContext as JaxprPpContext,
|
||||
JaxprPpSettings as JaxprPpSettings,
|
||||
JaxprTypeError as JaxprTypeError,
|
||||
Literal as Literal,
|
||||
MapPrimitive as MapPrimitive,
|
||||
nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401
|
||||
OpaqueTraceState as OpaqueTraceState,
|
||||
OutDBIdx as OutDBIdx,
|
||||
OutputType as OutputType,
|
||||
ParamDict as ParamDict,
|
||||
Primitive as Primitive,
|
||||
ShapedArray as ShapedArray,
|
||||
TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING,
|
||||
Token as Token,
|
||||
Trace as Trace,
|
||||
Tracer as Tracer,
|
||||
unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401
|
||||
@ -59,7 +53,6 @@ from jax._src.core import (
|
||||
unsafe_get_current_trace as unsafe_get_current_trace_DO_NOT_USE, # noqa: F401
|
||||
UnshapedArray as UnshapedArray,
|
||||
Value as Value,
|
||||
Var as Var,
|
||||
abstract_token as abstract_token,
|
||||
aval_mapping_handlers as aval_mapping_handlers,
|
||||
call as call,
|
||||
@ -78,7 +71,6 @@ from jax._src.core import (
|
||||
eval_jaxpr as eval_jaxpr,
|
||||
extend_axis_env_nd as extend_axis_env_nd,
|
||||
find_top_trace as find_top_trace,
|
||||
full_lower as full_lower,
|
||||
gensym as gensym,
|
||||
get_aval as get_aval,
|
||||
get_type as get_type,
|
||||
@ -86,10 +78,8 @@ from jax._src.core import (
|
||||
is_concrete as is_concrete,
|
||||
is_constant_dim as is_constant_dim,
|
||||
is_constant_shape as is_constant_shape,
|
||||
jaxpr_as_fun as jaxpr_as_fun,
|
||||
jaxprs_in_params as jaxprs_in_params,
|
||||
join_effects as join_effects,
|
||||
lattice_join as lattice_join,
|
||||
leaked_tracer_error as leaked_tracer_error,
|
||||
literalable_types as literalable_types,
|
||||
mapped_aval as mapped_aval,
|
||||
@ -101,7 +91,6 @@ from jax._src.core import (
|
||||
no_effects as no_effects,
|
||||
primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype,
|
||||
pytype_aval_mappings as pytype_aval_mappings,
|
||||
raise_to_shaped as raise_to_shaped,
|
||||
raise_to_shaped_mappings as raise_to_shaped_mappings,
|
||||
reset_trace_state as reset_trace_state,
|
||||
set_current_trace as set_current_trace,
|
||||
@ -124,6 +113,37 @@ from jax._src.core import (
|
||||
|
||||
from jax._src import core as _src_core
|
||||
_deprecations = {
|
||||
# Added 2024-12-10
|
||||
"ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, "
|
||||
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
|
||||
_src_core.ClosedJaxpr),
|
||||
"full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.",
|
||||
_src_core.full_lower),
|
||||
"Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, "
|
||||
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
|
||||
_src_core.Jaxpr),
|
||||
"JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, "
|
||||
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
|
||||
_src_core.JaxprEqn),
|
||||
"jaxpr_as_fun": ("jax.core.jaxpr_as_fun is deprecated. Use jax.extend.core.jaxpr_as_fun instead, "
|
||||
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
|
||||
_src_core.jaxpr_as_fun),
|
||||
"lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.",
|
||||
_src_core.lattice_join),
|
||||
"Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, "
|
||||
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
|
||||
_src_core.Literal),
|
||||
"Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, "
|
||||
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
|
||||
_src_core.Primitive),
|
||||
"raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.",
|
||||
_src_core.raise_to_shaped),
|
||||
"Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, "
|
||||
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
|
||||
_src_core.Token),
|
||||
"Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, "
|
||||
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
|
||||
_src_core.Var),
|
||||
# Added 2024-08-14
|
||||
"check_eqn": ("jax.core.check_eqn is deprecated.", _src_core.check_eqn),
|
||||
"check_type": ("jax.core.check_type is deprecated.", _src_core.check_type),
|
||||
@ -152,10 +172,21 @@ _deprecations = {
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
ClosedJaxpr = _src_core.ClosedJaxpr
|
||||
Jaxpr = _src_core.Jaxpr
|
||||
JaxprEqn = _src_core.JaxprEqn
|
||||
Literal = _src_core.Literal
|
||||
Primitive = _src_core.Primitive
|
||||
Token = _src_core.Token
|
||||
Var = _src_core.Var
|
||||
check_eqn = _src_core.check_eqn
|
||||
check_type = _src_core.check_type
|
||||
check_valid_jaxtype = _src_core.check_valid_jaxtype
|
||||
full_lower = _src_core.full_lower
|
||||
jaxpr_as_fun = _src_core.jaxpr_as_fun
|
||||
lattice_join = _src_core.lattice_join
|
||||
non_negative_dim = _src_core.non_negative_dim
|
||||
raise_to_shaped = _src_core.raise_to_shaped
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
|
@ -77,7 +77,7 @@ if RUNTIME_PATH and RUNTIME_PATH.exists():
|
||||
os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH)
|
||||
|
||||
|
||||
mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p")
|
||||
mosaic_gpu_p = jax._src.core.Primitive("mosaic_gpu_p")
|
||||
mosaic_gpu_p.multiple_results = True
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ are used internally in GPU translation rules of higher-level primitives.
|
||||
|
||||
from functools import partial
|
||||
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import gpu_sparse
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
"""N:M-sparsity associated primitives."""
|
||||
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src.lax.lax import DotDimensionNumbers
|
||||
from jax._src.lib import gpu_sparse
|
||||
|
@ -18,8 +18,8 @@ import operator
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import core
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import prng
|
||||
from jax._src import random
|
||||
from jax._src import test_util as jtu
|
||||
|
@ -3837,7 +3837,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
constant_values= ((0.0, 0.0), (0.0, 0.0)))
|
||||
|
||||
jaxpr = jax.make_jaxpr(trace_to_jaxpr)(x)
|
||||
jax.core.jaxpr_as_fun(jaxpr)(x)
|
||||
jax._src.core.jaxpr_as_fun(jaxpr)(x)
|
||||
|
||||
jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap')
|
||||
jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap') # doesn't crash
|
||||
|
Loading…
x
Reference in New Issue
Block a user