From c1f65fc8b21bd8a238af2b7c9ff4c90196277be3 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 4 Apr 2023 11:41:00 -0700 Subject: [PATCH] Avoid imports from the public jax.* namespace in more places internally. This change is in preparation for more cycle breaking in the Bazel dependency graph. PiperOrigin-RevId: 521822756 --- jax/__init__.py | 6 +- jax/_src/ad_checkpoint.py | 90 ++++++- jax/_src/api.py | 156 ++---------- jax/_src/array.py | 12 +- jax/_src/callback.py | 63 ++++- jax/_src/checkify.py | 27 +- jax/_src/custom_batching.py | 13 +- jax/_src/custom_derivatives.py | 13 +- jax/_src/custom_transpose.py | 4 +- jax/_src/debugging.py | 5 +- jax/_src/dispatch.py | 14 +- jax/_src/flatten_util.py | 8 +- jax/_src/lax/control_flow/solves.py | 2 +- jax/_src/maps.py | 36 +-- jax/_src/numpy/ufuncs.py | 3 +- jax/_src/pjit.py | 43 ++-- jax/_src/prng.py | 52 ++-- jax/_src/public_test_util.py | 6 +- jax/_src/random.py | 6 +- jax/_src/scipy/special.py | 12 +- jax/_src/sharding_impls.py | 9 +- jax/_src/state/discharge.py | 7 +- jax/_src/state/primitives.py | 9 +- tests/api_test.py | 370 ++++++++++++++-------------- 24 files changed, 486 insertions(+), 480 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index 74b786fd0..6a363ece5 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -76,7 +76,7 @@ del _xc from jax._src.api import effects_barrier as effects_barrier from jax._src.api import block_until_ready as block_until_ready -from jax._src.api import checkpoint as checkpoint +from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies from jax._src.api import clear_backends as clear_backends from jax._src.custom_derivatives import closure_convert as closure_convert @@ -116,8 +116,8 @@ from jax._src.api import named_scope as named_scope from jax._src.api import pmap as pmap from jax._src.xla_bridge import process_count as process_count from jax._src.xla_bridge import process_index as process_index -from jax._src.api import pure_callback as pure_callback -from jax._src.api import remat as remat +from jax._src.callback import pure_callback_api as pure_callback +from jax._src.ad_checkpoint import checkpoint_wrapper as remat from jax._src.core import ShapedArray as _deprecated_ShapedArray from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct from jax._src.api import value_and_grad as value_and_grad diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index c2405eb02..1ef301078 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools from functools import partial import logging from typing import (Any, Callable, FrozenSet, List, Optional, Sequence, Tuple, @@ -20,9 +21,8 @@ import types import numpy as np -import jax -from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr from jax._src import ad_util +from jax._src import api from jax._src import core from jax._src import dispatch from jax._src import linear_util as lu @@ -31,6 +31,7 @@ from jax._src import source_info_util from jax._src import traceback_util from jax._src import util from jax._src.api_util import flatten_fun, shaped_abstractify +from jax._src.config import config from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -39,6 +40,7 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import convolution as lax_convolution from jax._src.lib.mlir.dialects import hlo from jax._src.traceback_util import api_boundary +from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map, safe_zip, merge_lists, weakref_lru_cache) @@ -389,7 +391,7 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]: args, kwargs = tree_unflatten(in_tree, args) return f(*args, **kwargs) - out = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1], + out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1], return_shape=True)(*in_leaves) assert isinstance(out, tuple) jaxpr_, out_shape = out @@ -522,7 +524,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params): res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:]) res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:] body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None) - logger.log(logging.WARNING if jax.config.jax_log_checkpoint_residuals + logger.log(logging.WARNING if config.jax_log_checkpoint_residuals else logging.DEBUG, 'remat-decorated function ' + 'saving inputs with shapes:\n' * bool(res_invars) + @@ -652,7 +654,7 @@ def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool, assert not jaxpr.constvars if differentiated and prevent_cse: - if jax.config.jax_remat_opt_barrier: + if config.jax_remat_opt_barrier: translation_rule = _remat_translation_using_opt_barrier elif is_gpu_platform: translation_rule = _remat_translation_using_while @@ -661,7 +663,7 @@ def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool, else: translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args) - return jax.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr) + return api.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr) def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): args = _optimization_barrier(args) @@ -670,9 +672,9 @@ def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): # TODO(mattjj): add core utility for 'create dummy value for this type'? def _dummy_like(aval: core.AbstractValue) -> Any: if aval is core.abstract_token: - return jax.lax.create_token() + return lax_internal.create_token() elif isinstance(aval, (core.ShapedArray, core.DShapedArray)): - return jax.lax.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore + return lax_internal.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore else: raise ValueError(aval) @@ -682,11 +684,13 @@ def _remat_translation_using_while(*args, jaxpr: core.Jaxpr): # result = eval_jaxpr(*args) # } # The loop carry is a tuple: (counter, result, args) + from jax._src.lax import control_flow as lax_control_flow + avals_out = tuple(v.aval for v in jaxpr.outvars) carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args) def cond(carry): counter, _, _ = carry - unif = jax.lax.rng_uniform(np.int32(1), np.int32(2), shape=()) + unif = lax_internal.rng_uniform(np.int32(1), np.int32(2), shape=()) return counter < unif def body(carry): @@ -694,7 +698,7 @@ def _remat_translation_using_while(*args, jaxpr: core.Jaxpr): results = core.eval_jaxpr(jaxpr, (), *args) return (counter + 1, tuple(results), args) - carry_res = jax.lax.while_loop(cond, body, carry_init) + carry_res = lax_control_flow.while_loop(cond, body, carry_init) return carry_res[1] def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr): @@ -703,6 +707,8 @@ def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr): # return eval_jaxpr(*args) # else: # return 0 + from jax._src.lax import control_flow as lax_control_flow + avals_out = tuple(v.aval for v in jaxpr.outvars) def remat_comp(*args): @@ -710,8 +716,8 @@ def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr): def dummy_comp(*args): return tuple(map(_dummy_like, avals_out)) - unif = jax.lax.rng_uniform(np.float32(0), np.float32(1), shape=()) - return jax.lax.cond(unif < np.float32(2), remat_comp, dummy_comp, *args) + unif = lax_internal.rng_uniform(np.float32(0), np.float32(1), shape=()) + return lax_control_flow.cond(unif < np.float32(2), remat_comp, dummy_comp, *args) mlir.register_lowering( remat_p, mlir.lower_fun(remat_lowering, multiple_results=True)) @@ -760,3 +766,63 @@ def name_batcher(args, dims, *, name): (x,), (d,) = args, dims return name_p.bind(x, name=name), d batching.primitive_batchers[name_p] = name_batcher + + +@functools.wraps(checkpoint) +def checkpoint_wrapper( + fun: Callable, + *, + concrete: bool = False, + prevent_cse: bool = True, + static_argnums: Union[int, Tuple[int, ...]] = (), + policy: Optional[Callable[..., bool]] = None, +) -> Callable: + if concrete: + msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; " + "in its place, you can use its `static_argnums` option, and if " + "necessary the `jax.ensure_compile_time_eval()` context manager.\n" + "\n" + "For example, if using `concrete=True` for an `is_training` flag:\n" + "\n" + " from functools import partial\n" + "\n" + " @partial(jax.checkpoint, concrete=True)\n" + " def foo(x, is_training):\n" + " if is_training:\n" + " return f(x)\n" + " else:\n" + " return g(x)\n" + "\n" + "replace it with a use of `static_argnums`:\n" + "\n" + " @partial(jax.checkpoint, static_argnums=(1,))\n" + " def foo(x, is_training):\n" + " ...\n" + "\n" + "If jax.numpy operations need to be performed on static arguments, " + "we can use the `jax.ensure_compile_time_eval()` context manager. " + "For example, we can replace this use of `concrete=True`\n:" + "\n" + " @partial(jax.checkpoint, concrete=True)\n" + " def foo(x, y):\n" + " if y > 0:\n" + " return f(x)\n" + " else:\n" + " return g(x)\n" + "\n" + "with this combination of `static_argnums` and " + "`jax.ensure_compile_time_eval()`:\n" + "\n" + " @partial(jax.checkpoint, static_argnums=(1,))\n" + " def foo(x, y):\n" + " with jax.ensure_compile_time_eval():\n" + " y_pos = y > 0\n" + " if y_pos:\n" + " return f(x)\n" + " else:\n" + " return g(x)\n" + "\n" + "See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n") + raise NotImplementedError(msg) + return checkpoint(fun, prevent_cse=prevent_cse, policy=policy, + static_argnums=static_argnums) diff --git a/jax/_src/api.py b/jax/_src/api.py index cb812ea14..7b9cf8be4 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -23,7 +23,6 @@ arrays. from __future__ import annotations import collections -import functools from functools import partial import inspect import math @@ -35,14 +34,12 @@ from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal, import numpy as np from contextlib import contextmanager, ExitStack -import jax from jax._src import linear_util as lu -from jax import stages +from jax._src import stages from jax._src.tree_util import ( tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose, tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix, - prefix_errors, generate_key_paths, _replace_nones) -from jax._src import callback as jcb + prefix_errors, generate_key_paths) from jax._src import core from jax._src import dispatch from jax._src import effects @@ -57,26 +54,18 @@ from jax._src.api_util import ( flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, argnums_partial_except, flatten_axes, donation_vector, rebase_donate_argnums, _ensure_index, _ensure_index_tuple, - shaped_abstractify, _ensure_str_tuple, argnames_partial_except, - validate_argnames, validate_argnums, check_callable, resolve_argnums, - debug_info, result_paths, flat_out_axes, debug_info_final, FLAGS) + shaped_abstractify, _ensure_str_tuple, + check_callable, debug_info, result_paths, flat_out_axes, debug_info_final, FLAGS) from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc from jax._src.lib import pmap_lib +from jax._src.sharding import Sharding from jax._src.sharding_impls import PmapSharding from jax._src.traceback_util import api_boundary -from jax._src.util import (unzip2, curry, safe_map, safe_zip, split_list, - wrap_name, cache, wraps, HashableFunction, - weakref_lru_cache) +from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps) -# Unused imports to be exported -from jax.ad_checkpoint import checkpoint as new_checkpoint -from jax.custom_batching import custom_vmap -from jax.custom_derivatives import (custom_gradient, custom_jvp, - custom_vjp, linear_call) -from jax.custom_transpose import custom_transpose from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src.interpreters import xla @@ -89,7 +78,7 @@ from jax._src.config import ( _thread_local_state as config_thread_local_state, explicit_device_put_scope as config_explicit_device_put_scope, explicit_device_get_scope as config_explicit_device_get_scope) -from jax._src.core import ShapedArray, raise_to_shaped +from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import pxla @@ -1022,10 +1011,11 @@ def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0, argnums, has_aux=has_aux, holomorphic=holomorphic) def _std_basis(pytree): + import jax.numpy as jnp leaves, _ = tree_flatten(pytree) ndim = sum(map(np.size, leaves)) dtype = dtypes.result_type(*leaves) - flat_basis = jax.numpy.eye(ndim, dtype=dtype) + flat_basis = jnp.eye(ndim, dtype=dtype) return _unravel_array_into_pytree(pytree, 1, None, flat_basis) def _jacfwd_unravel(input_pytree, output_pytree_leaf, arr): @@ -2487,8 +2477,8 @@ def _infer_src_sharding(src, x): def device_put( x, - device: Union[None, xc.Device, jax.sharding.Sharding, Any] = None, - *, src: Union[None, xc.Device, jax.sharding.Sharding, Any] = None): + device: Union[None, xc.Device, Sharding, Any] = None, + *, src: Union[None, xc.Device, Sharding, Any] = None): """Transfers ``x`` to ``device``. Args: @@ -2512,8 +2502,8 @@ def device_put( blocking the calling Python thread until any transfers are completed. """ with config_explicit_device_put_scope(): - if ((device is None or isinstance(device, (xc.Device, jax.sharding.Sharding))) and - (src is None or isinstance(src, (xc.Device, jax.sharding.Sharding)))): + if ((device is None or isinstance(device, (xc.Device, Sharding))) and + (src is None or isinstance(src, (xc.Device, Sharding)))): return tree_map( lambda y: dispatch.device_put_p.bind( y, device=device, src=_infer_src_sharding(src, y)), x) @@ -2641,7 +2631,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811 assert (isinstance(aval, ShapedArray) and len(xla.aval_to_xla_shapes(aval)) == 1) sharding_spec = pxla._create_pmap_sharding_spec(aval) - buf = jax.device_put(x, devices[0]) + buf = device_put(x, devices[0]) return pxla.batched_device_put( aval, PmapSharding(np.array(devices), sharding_spec), [buf] * len(devices), devices) @@ -2719,7 +2709,7 @@ class ShapeDtypeStruct: raise ValueError("ShapeDtypeStruct: dtype must be specified.") self.dtype = dtype if core.is_opaque_dtype(dtype) else np.dtype(dtype) if sharding is not None: - if not isinstance(sharding, jax.sharding.Sharding): + if not isinstance(sharding, Sharding): raise ValueError( "sharding should be an instance of `jax.sharding.Sharding`. " f"Got {sharding} of type {type(sharding)}.") @@ -2821,65 +2811,6 @@ def eval_shape(fun: Callable, *args, **kwargs): return tree_unflatten(out_tree(), out) -@functools.wraps(new_checkpoint) # config.jax_new_checkpoint is True by default -def checkpoint(fun: Callable, *, - concrete: bool = False, - prevent_cse: bool = True, - static_argnums: Union[int, Tuple[int, ...]] = (), - policy: Optional[Callable[..., bool]] = None, - ) -> Callable: - if concrete: - msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; " - "in its place, you can use its `static_argnums` option, and if " - "necessary the `jax.ensure_compile_time_eval()` context manager.\n" - "\n" - "For example, if using `concrete=True` for an `is_training` flag:\n" - "\n" - " from functools import partial\n" - "\n" - " @partial(jax.checkpoint, concrete=True)\n" - " def foo(x, is_training):\n" - " if is_training:\n" - " return f(x)\n" - " else:\n" - " return g(x)\n" - "\n" - "replace it with a use of `static_argnums`:\n" - "\n" - " @partial(jax.checkpoint, static_argnums=(1,))\n" - " def foo(x, is_training):\n" - " ...\n" - "\n" - "If jax.numpy operations need to be performed on static arguments, " - "we can use the `jax.ensure_compile_time_eval()` context manager. " - "For example, we can replace this use of `concrete=True`\n:" - "\n" - " @partial(jax.checkpoint, concrete=True)\n" - " def foo(x, y):\n" - " if y > 0:\n" - " return f(x)\n" - " else:\n" - " return g(x)\n" - "\n" - "with this combination of `static_argnums` and " - "`jax.ensure_compile_time_eval()`:\n" - "\n" - " @partial(jax.checkpoint, static_argnums=(1,))\n" - " def foo(x, y):\n" - " with jax.ensure_compile_time_eval():\n" - " y_pos = y > 0\n" - " if y_pos:\n" - " return f(x)\n" - " else:\n" - " return g(x)\n" - "\n" - "See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n") - raise NotImplementedError(msg) - return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy, - static_argnums=static_argnums) -remat = checkpoint # type: ignore - - def named_call( fun: Callable[..., Any], *, @@ -2986,68 +2917,15 @@ def block_until_ready(x): return x.block_until_ready() except AttributeError: return x - return jax.tree_util.tree_map(try_to_block, x) + return tree_map(try_to_block, x) -def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any, - *args: Any, vectorized: bool = False, **kwargs: Any): - """Applies a functionally pure Python callable. Works under :func:`jit`/:func:`~pmap`/etc. - - ``pure_callback`` enables calling a Python function in JIT-ed JAX functions. - The input ``callback`` will be passed NumPy arrays in place of JAX arrays and - should also return NumPy arrays. Execution takes place on CPU, like any - Python+NumPy function. - - The callback is treated as functionally pure, meaning it has no side-effects - and its output value depends only on its argument values. As a consequence, it - is safe to be called multiple times (e.g. when transformed by :func:`~vmap` or - :func:`~pmap`), or not to be called at all when e.g. the output of a - `jit`-decorated function has no data dependence on its value. Pure callbacks - may also be reordered if data-dependence allows. - - When :func:`~pmap`-ed, the pure callback will be called several times (one on each - axis of the map). When `vmap`-ed the behavior will depend on the value of the - ``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback - is assumed to obey - ``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``. - Therefore, the callback will be called directly on batched inputs (where the - batch axes are the leading dimensions). Additionally, the callbacks should - return outputs that have corresponding leading batch axes. If not vectorized - ``callback`` will be mapped sequentially across the batched axis. - For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free - to set ``vectorized=True`` because the ``np.matmul`` function handles - arbitrary leading batch dimensions. - - Args: - callback: A Python callable. The callable will be passed PyTrees of NumPy - arrays as arguments, and should return a PyTree of NumPy arrays that - matches ``result_shape_dtypes``. - result_shape_dtypes: A PyTree with leaves that are objects with ``shape`` - and ``dtype`` attributes which represent to the shapes and dtypes of the - value of ``callback`` applied to ``args`` and ``kwargs``. - *args: The positional arguments to the callback. Must be PyTrees of JAX - types. - vectorized: A boolean that indicates whether or not ``callback`` is - vectorized, meaning it can handle arrays with additional leading - dimensions. If ``vectorized`` is `True`, when the callback is mapped - via `jax.vmap`, it will be called directly on inputs with leading batch - dimensions instead of executing ``callback`` on each mapped input - individually. The callback should also return outputs batched across the - leading axis. By default, ``vectorized`` is ``False``. - **kwargs: The keyword arguments to the callback. Must be PyTrees of JAX - types. - - Returns: - The value of ``callback(*args, **kwargs)``. - """ - return jcb.pure_callback(callback, result_shape_dtypes, *args, - vectorized=vectorized, **kwargs) def clear_backends(): """ Clear all backend clients so that new backend clients can be created later. """ xb._clear_backends() - jax.lib.xla_bridge._backends = {} + xb._backends = {} dispatch.xla_primitive_callable.cache_clear() pjit._pjit_lower_cached.cache_clear() pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error diff --git a/jax/_src/array.py b/jax/_src/array.py index 7ccb4de07..097a8d36a 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -21,19 +21,17 @@ import functools from typing import (Sequence, Tuple, Callable, Optional, List, cast, Set, TYPE_CHECKING) -import jax from jax._src import abstract_arrays +from jax._src import api from jax._src import api_util from jax._src import basearray from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import profiler +from jax._src import xla_bridge from jax._src.config import config -from jax._src.util import use_cpp_class, use_cpp_method from jax._src.lib import xla_client as xc -from jax._src import api -from jax._src.typing import ArrayLike from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.interpreters import xla @@ -41,6 +39,8 @@ from jax._src.sharding import Sharding from jax._src.sharding_impls import ( SingleDeviceSharding, XLACompatibleSharding, PmapSharding, device_replica_id_map, hashed_index) +from jax._src.typing import ArrayLike +from jax._src.util import use_cpp_class, use_cpp_method Shape = Tuple[int, ...] Device = xc.Device @@ -133,7 +133,7 @@ def _create_copy_plan(arrays, s: Sharding, shape: Shape): @functools.lru_cache(maxsize=4096) def _process_has_full_value_in_mcjax(s, shape): # Return False for single host as a fast path. - if jax.process_count() == 1: + if xla_bridge.process_count() == 1: return False num_unique_indices = len( @@ -359,7 +359,7 @@ class ArrayImpl(basearray.Array): return np.asarray(self._value, dtype=dtype) def __dlpack__(self): - from jax.dlpack import to_dlpack # pylint: disable=g-import-not-at-top + from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top return to_dlpack(self) def __reduce__(self): diff --git a/jax/_src/callback.py b/jax/_src/callback.py index ce3e68a4e..9dcff38a8 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -19,17 +19,17 @@ from typing import Any, Callable, Sequence import numpy as np -from jax import tree_util - from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src import effects +from jax._src import tree_util from jax._src import util -from jax._src import dispatch from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc +from jax._src.lax.control_flow.loops import map as lax_map # `pure_callback_p` is the main primitive for staging out Python pure callbacks. pure_callback_p = core.Primitive("pure_callback") @@ -93,7 +93,6 @@ def pure_callback_batching_rule(args, dims, *, callback, vectorized: bool, return pure_callback_p.bind( *merged_args, callback=callback, result_avals=result_avals, vectorized=vectorized) - from jax._src.lax.control_flow import map as lax_map outvals = lax_map(_batch_fun, batched_args) return tuple(outvals), (0,) * len(outvals) @@ -154,6 +153,62 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any, return tree_util.tree_unflatten(out_tree, out_flat) + +def pure_callback_api(callback: Callable[..., Any], result_shape_dtypes: Any, + *args: Any, vectorized: bool = False, **kwargs: Any): + """Applies a functionally pure Python callable. Works under :func:`jit`/:func:`~pmap`/etc. + + ``pure_callback`` enables calling a Python function in JIT-ed JAX functions. + The input ``callback`` will be passed NumPy arrays in place of JAX arrays and + should also return NumPy arrays. Execution takes place on CPU, like any + Python+NumPy function. + + The callback is treated as functionally pure, meaning it has no side-effects + and its output value depends only on its argument values. As a consequence, it + is safe to be called multiple times (e.g. when transformed by :func:`~vmap` or + :func:`~pmap`), or not to be called at all when e.g. the output of a + `jit`-decorated function has no data dependence on its value. Pure callbacks + may also be reordered if data-dependence allows. + + When :func:`~pmap`-ed, the pure callback will be called several times (one on each + axis of the map). When `vmap`-ed the behavior will depend on the value of the + ``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback + is assumed to obey + ``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``. + Therefore, the callback will be called directly on batched inputs (where the + batch axes are the leading dimensions). Additionally, the callbacks should + return outputs that have corresponding leading batch axes. If not vectorized + ``callback`` will be mapped sequentially across the batched axis. + For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free + to set ``vectorized=True`` because the ``np.matmul`` function handles + arbitrary leading batch dimensions. + + Args: + callback: A Python callable. The callable will be passed PyTrees of NumPy + arrays as arguments, and should return a PyTree of NumPy arrays that + matches ``result_shape_dtypes``. + result_shape_dtypes: A PyTree with leaves that are objects with ``shape`` + and ``dtype`` attributes which represent to the shapes and dtypes of the + value of ``callback`` applied to ``args`` and ``kwargs``. + *args: The positional arguments to the callback. Must be PyTrees of JAX + types. + vectorized: A boolean that indicates whether or not ``callback`` is + vectorized, meaning it can handle arrays with additional leading + dimensions. If ``vectorized`` is `True`, when the callback is mapped + via `jax.vmap`, it will be called directly on inputs with leading batch + dimensions instead of executing ``callback`` on each mapped input + individually. The callback should also return outputs batched across the + leading axis. By default, ``vectorized`` is ``False``. + **kwargs: The keyword arguments to the callback. Must be PyTrees of JAX + types. + + Returns: + The value of ``callback(*args, **kwargs)``. + """ + return pure_callback(callback, result_shape_dtypes, *args, + vectorized=vectorized, **kwargs) + + # IO Callback io_callback_p = core.Primitive("io_callback") diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 23267f6ce..167110b95 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import dataclasses import functools @@ -20,28 +21,28 @@ from typing import (Union, Optional, Callable, Dict, Tuple, TypeVar, import numpy as np -import jax import jax.numpy as jnp -import jax.tree_util as jtu from jax import lax -from jax.api_util import flatten_fun -from jax.experimental import pjit -from jax.tree_util import tree_flatten -from jax.tree_util import tree_map -from jax.tree_util import tree_unflatten +from jax._src import api from jax._src import linear_util as lu from jax._src import core from jax._src import custom_derivatives from jax._src import effects +from jax._src import pjit from jax._src import prng from jax._src import source_info_util from jax._src import traceback_util +from jax._src import tree_util as jtu +from jax._src.api_util import flatten_fun from jax._src.config import config from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe +from jax._src.tree_util import tree_flatten +from jax._src.tree_util import tree_map +from jax._src.tree_util import tree_unflatten from jax._src.typing import Array from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, unzip3, weakref_lru_cache) @@ -92,7 +93,7 @@ class JaxException(Exception): del payload return cls(metadata) - def get_effect_type(self) -> core.Effect: + def get_effect_type(self) -> ErrorEffect: raise NotImplementedError @@ -100,7 +101,7 @@ class JaxException(Exception): @dataclasses.dataclass(eq=True, frozen=True) class ErrorEffect(effects.Effect): error_type: Type[JaxException] - shape_dtypes: Tuple[jax.ShapeDtypeStruct, ...] + shape_dtypes: Tuple[api.ShapeDtypeStruct, ...] def __lt__(self, other: 'ErrorEffect'): shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable @@ -161,7 +162,7 @@ class OOBError(JaxException): f'Failed at {self.traceback_info}') def get_effect_type(self): - return ErrorEffect(OOBError, (jax.ShapeDtypeStruct((3,), jnp.int32),)) + return ErrorEffect(OOBError, (api.ShapeDtypeStruct((3,), jnp.int32),)) class FailedCheckError(JaxException): @@ -188,7 +189,7 @@ class FailedCheckError(JaxException): vals = jtu.tree_leaves((self.args, self.kwargs)) return ErrorEffect( FailedCheckError, - tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in vals)) + tuple(api.ShapeDtypeStruct(x.shape, x.dtype) for x in vals)) @dataclasses.dataclass class BatchedError(JaxException): @@ -1112,7 +1113,7 @@ def _check(pred, msg, debug, *fmt_args, **fmt_kwargs): prim_name = 'debug_check' if debug else 'check' raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}') for arg in jtu.tree_leaves((fmt_args, fmt_kwargs)): - if not isinstance(arg, (jax.Array, np.ndarray)): + if not isinstance(arg, (Array, np.ndarray)): raise TypeError('Formatting arguments to checkify.check need to be ' 'PyTrees of arrays, but got ' f'{repr(arg)} of type {type(arg)}.') @@ -1130,7 +1131,7 @@ def _check_error(error, *, debug=False): def is_scalar_pred(pred) -> bool: return (isinstance(pred, bool) or - isinstance(pred, jax.Array) and pred.shape == () and + isinstance(pred, Array) and pred.shape == () and pred.dtype == jnp.dtype('bool')) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 2bea7562c..64fd57187 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -16,15 +16,14 @@ import functools import operator from typing import Callable, Optional -import jax -from jax import tree_util -from jax.tree_util import (tree_flatten, tree_map, tree_structure, - tree_unflatten, treedef_tuple) +from jax import lax +from jax._src import api from jax._src import core from jax._src import custom_api_util from jax._src import linear_util as lu from jax._src import source_info_util from jax._src import traceback_util +from jax._src import tree_util from jax._src import util from jax._src.api_util import flatten_fun_nokwargs from jax._src.interpreters import ad @@ -33,6 +32,8 @@ from jax._src.interpreters.batching import not_mapped from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla +from jax._src.tree_util import (tree_flatten, tree_map, tree_structure, + tree_unflatten, treedef_tuple) source_info_util.register_exclusion(__file__) @@ -194,7 +195,7 @@ def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree): return out def to_vmap_over_extra_batched_dims(primals, tangents): - return jax.jvp(to_jvp, primals, tangents) + return api.jvp(to_jvp, primals, tangents) to_vmap_over_extra_batched_dims_flat, out_tree2 = flatten_fun_nokwargs( lu.wrap_init(to_vmap_over_extra_batched_dims), @@ -274,7 +275,7 @@ def sequential_vmap(f): return f(*args) mapped_args, bcast_args = tree_split(in_batched, list(args)) - out = jax.lax.map(to_map, mapped_args) + out = lax.map(to_map, mapped_args) out_batched = tree_map(lambda _: True, out) return out, out_batched diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 8b0f827ea..78f64a798 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -16,15 +16,9 @@ from functools import update_wrapper, reduce, partial import inspect from typing import (Callable, Generic, List, Optional, Sequence, Tuple, TypeVar, Any) -from jax.custom_transpose import custom_transpose -from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, - treedef_is_leaf, treedef_tuple, - register_pytree_node_class, tree_leaves) -from jax.errors import UnexpectedTracerError -from jax.config import config - from jax._src import core from jax._src import custom_api_util +from jax._src.custom_transpose import custom_transpose from jax._src import dtypes from jax._src import effects from jax._src import linear_util as lu @@ -32,7 +26,9 @@ from jax._src import traceback_util from jax._src.ad_util import (Zero, SymbolicZero, zeros_like_aval, stop_gradient_p) from jax._src.api_util import argnums_partial, flatten_fun_nokwargs +from jax._src.config import config from jax._src.core import raise_to_shaped +from jax._src.errors import UnexpectedTracerError from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -40,6 +36,9 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla from jax._src.interpreters.batching import not_mapped from jax._src.lax import lax +from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_map, + treedef_is_leaf, treedef_tuple, + register_pytree_node_class, tree_leaves) from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 1218bdb3e..4c248f460 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -15,8 +15,6 @@ import functools from typing import Any, Callable, Optional, Tuple -from jax.tree_util import (tree_flatten, tree_leaves, tree_map, - tree_structure, treedef_tuple, tree_unflatten) from jax._src import ad_util from jax._src import api_util from jax._src import core @@ -29,6 +27,8 @@ from jax._src.interpreters import ad from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla +from jax._src.tree_util import (tree_flatten, tree_leaves, tree_map, + tree_structure, treedef_tuple, tree_unflatten) source_info_util.register_exclusion(__file__) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 8da760105..e78bdade5 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -22,23 +22,24 @@ import weakref import numpy as np import jax.numpy as jnp -from jax import tree_util from jax import lax + from jax._src import core from jax._src import effects from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import pjit +from jax._src import tree_util from jax._src import util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.sharding import Sharding from jax._src.sharding_impls import GSPMDSharding, NamedSharding -from jax._src.interpreters import partial_eval as pe # pytype: disable=import-error try: diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 519304e29..7916f78f5 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -30,10 +30,6 @@ import warnings import numpy as np -import jax -from jax.monitoring import record_event_duration_secs - -from jax._src import array from jax._src import core from jax._src import dtypes from jax._src import linear_util as lu @@ -53,6 +49,7 @@ from jax._src.interpreters import xla from jax._src.interpreters import pxla from jax._src.lib.mlir import ir from jax._src.lib import xla_client as xc +from jax._src.monitoring import record_event_duration_secs from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, NamedSharding, @@ -145,7 +142,7 @@ class RuntimeTokenSet(threading.local): self.output_runtime_tokens = {} def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken: - s = jax.sharding.SingleDeviceSharding(device) + s = SingleDeviceSharding(device) if eff not in self.tokens: inp = np.zeros(0, np.bool_) indices = tuple( @@ -302,7 +299,7 @@ class SourceInfo(NamedTuple): def jaxpr_shardings( - jaxpr) -> Iterator[Tuple[jax.sharding.XLACompatibleSharding, SourceInfo]]: + jaxpr) -> Iterator[Tuple[XLACompatibleSharding, SourceInfo]]: from jax._src import pjit from jax.experimental import shard_map @@ -570,8 +567,9 @@ def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool): def _device_put_impl( x, - device: Optional[Union[Device, jax.sharding.Sharding]] = None, - src: Optional[Union[Device, jax.sharding.Sharding]] = None): + device: Optional[Union[Device, Sharding]] = None, + src: Optional[Union[Device, Sharding]] = None): + from jax._src import array try: aval = xla.abstractify(x) except TypeError as err: diff --git a/jax/_src/flatten_util.py b/jax/_src/flatten_util.py index ca99a0819..d7c124af2 100644 --- a/jax/_src/flatten_util.py +++ b/jax/_src/flatten_util.py @@ -16,13 +16,13 @@ import warnings import numpy as np +from jax import lax +import jax.numpy as jnp + +from jax._src import dtypes from jax._src.tree_util import tree_flatten, tree_unflatten from jax._src.util import safe_zip, unzip2, HashablePartial -import jax.numpy as jnp -from jax._src import dtypes -from jax import lax - zip = safe_zip diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 77f9c3b2b..ed673687e 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -17,7 +17,6 @@ from functools import partial import operator import jax -from jax import lax from jax.tree_util import (tree_flatten, treedef_children, tree_leaves, tree_unflatten, treedef_tuple) from jax._src import ad_util @@ -28,6 +27,7 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import xla +from jax._src.lax import lax from jax._src.traceback_util import api_boundary from jax._src.util import split_list, safe_map import numpy as np diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 42b532bf2..2f13de53a 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -20,35 +20,42 @@ from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set, NamedTuple, Union, Sequence) from functools import wraps, partial, partialmethod, lru_cache +from jax import lax from jax import numpy as jnp + from jax._src import core -from jax._src import mesh -from jax._src import linear_util as lu -from jax import stages from jax._src import dispatch from jax._src import effects -from jax.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map, - treedef_tuple) +from jax._src import mesh +from jax._src import linear_util as lu +from jax._src import source_info_util +from jax._src import stages +from jax._src import traceback_util from jax._src.api_util import (flatten_fun_nokwargs, flatten_axes, _ensure_index_tuple, donation_vector, shaped_abstractify, check_callable) -from jax._src import source_info_util -from jax._src import traceback_util -from jax._src.config import config -from jax.errors import JAXTypeError from jax._src.array import ArrayImpl -from jax._src.sharding_impls import NamedSharding +from jax._src.config import config +from jax._src.errors import JAXTypeError from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters.partial_eval import ( + trace_to_subjaxpr_dynamic, DynamicJaxprTracer, + convert_constvars_jaxpr, new_jaxpr_eqn) from jax._src.interpreters import pxla from jax._src.interpreters import xla +from jax._src.pjit import ( + sharding_constraint_p, ParsedPartitionSpec, get_unconstrained_dims, + GSPMDSharding) +from jax._src.sharding_impls import NamedSharding +from jax._src.tree_util import (tree_flatten, tree_unflatten, all_leaves, + tree_map, treedef_tuple) from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3, as_hashable_function, distributed_debug_log, tuple_insert, moveaxis, split_list, wrap_name, merge_lists, partition_list) -from jax import lax source_info_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__) @@ -965,9 +972,6 @@ pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap # This is DynamicJaxprTrace.process_map with some very minor modifications def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): - from jax._src.interpreters.partial_eval import ( - trace_to_subjaxpr_dynamic, DynamicJaxprTracer, - convert_constvars_jaxpr, new_jaxpr_eqn) assert primitive is xmap_p in_avals = [t.aval for t in tracers] global_axis_sizes = params['global_axis_sizes'] @@ -1775,10 +1779,6 @@ def _check_no_loop_collectives(jaxpr, loop_axis_resources): def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None): - from jax._src.pjit import ( - sharding_constraint_p, ParsedPartitionSpec, get_unconstrained_dims, - GSPMDSharding) - rec = lambda jaxpr: _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name) if isinstance(jaxpr, core.ClosedJaxpr): return jaxpr.map_jaxpr(rec) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index cc1322b94..7addd4594 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -25,7 +25,8 @@ import numpy as np from jax._src import core from jax._src import dtypes -from jax._src.api import jit, custom_jvp +from jax._src.api import jit +from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax from jax._src.typing import Array, ArrayLike from jax._src.numpy.util import ( diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 1accfe316..bca511200 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -24,23 +24,10 @@ from functools import partial, lru_cache import threading import warnings -import jax from jax._src import core -from jax import stages -from jax.errors import JAXTypeError -from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters.pxla import PartitionSpec -from jax._src.interpreters import xla -from jax._src.tree_util import ( - tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, - treedef_tuple, broadcast_prefix, all_leaves) - -from jax._src.sharding import Sharding -from jax._src.sharding_impls import ( - NamedSharding, XLACompatibleSharding, GSPMDSharding, - XLADeviceAssignment, SingleDeviceSharding, PmapSharding) +from jax._src import stages from jax._src import dispatch -from jax._src import mesh +from jax._src import mesh as mesh_lib from jax._src import linear_util as lu from jax._src import source_info_util from jax._src import traceback_util @@ -50,6 +37,11 @@ from jax._src.api_util import ( argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, donation_vector, shaped_abstractify, check_callable, resolve_argnums, argnames_partial_except, debug_info, result_paths, jaxpr_debug_info, FLAGS) +from jax._src.errors import JAXTypeError +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters.pxla import PartitionSpec +from jax._src.interpreters import xla + from jax._src.config import config from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -58,8 +50,15 @@ from jax._src.interpreters import pxla from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import xla_client as xc +from jax._src.sharding import Sharding +from jax._src.sharding_impls import ( + NamedSharding, XLACompatibleSharding, GSPMDSharding, + XLADeviceAssignment, SingleDeviceSharding, PmapSharding) from jax._src.traceback_util import api_boundary -from jax._src.tree_util import prefix_errors, generate_key_paths +from jax._src.tree_util import ( + tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, + treedef_tuple, broadcast_prefix, all_leaves, + prefix_errors, generate_key_paths) from jax._src.util import ( HashableFunction, safe_map, safe_zip, wraps, distributed_debug_log, split_list, tuple_insert, weakref_lru_cache, @@ -311,7 +310,7 @@ def _resolve_axis_resources_and_shardings_arg( def pre_infer_params(fun, in_shardings, out_shardings, donate_argnums, static_argnums, static_argnames, device, backend, abstracted_axes): - if abstracted_axes and not jax.config.jax_dynamic_shapes: + if abstracted_axes and not config.jax_dynamic_shapes: raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes") check_callable(fun) @@ -455,7 +454,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs): dyn_kwargs = () del kwargs - if donate_argnums and not jax.config.jax_debug_nans: + if donate_argnums and not config.jax_debug_nans: donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs) else: donated_invars = (False,) * len(explicit_args) @@ -724,7 +723,7 @@ def pjit( def infer_params(*args, **kwargs): # Putting this outside of wrapped would make resources lexically scoped - resource_env = mesh.thread_resources.env + resource_env = mesh_lib.thread_resources.env pjit_info_args = PjitInfo( fun=fun, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, @@ -1125,7 +1124,7 @@ def _check_unique_resources(axis_resources, arg_name): if multiple_uses: raise ValueError(f"A single {arg_name} specification can map every mesh axis " f"to at most one positional dimension, but {arg_axis_resources.user_spec} " - f"has duplicate entries for {mesh.show_axes(multiple_uses)}") + f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}") # -------------------- pjit rules -------------------- @@ -1812,7 +1811,7 @@ def _check_resources_against_named_axes(what, aval, pos_axis_resources, named_ax f"{pos_axis_resources.unsynced_user_spec(SpecSync.DIM_PERMUTE)} " f"that uses one or more mesh axes already used by xmap to partition " f"a named axis appearing in its named_shape (both use mesh axes " - f"{mesh.show_axes(overlap)})") + f"{mesh_lib.show_axes(overlap)})") def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_resources): jaxpr = params["jaxpr"] @@ -1918,7 +1917,7 @@ def with_sharding_constraint(x, shardings=_UNSPECIFIED, flatten_axes("with_sharding_constraint shardings", tree, user_shardings)) del user_shardings - resource_env = jax._src.mesh.thread_resources.env + resource_env = mesh_lib.thread_resources.env mesh = resource_env.physical_mesh shardings_flat = [_create_sharding_for_array(mesh, a) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 5be3bc4a0..9ecf50234 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -21,18 +21,20 @@ from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence, Unio import numpy as np -import jax from jax import lax from jax import numpy as jnp -from jax.config import config -from jax.dtypes import float0 +from jax._src import api from jax._src import basearray +from jax._src import config as config_lib from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import pretty_printer as pp +from jax._src import typing from jax._src.api import jit, vmap +from jax._src.config import config +from jax._src.dtypes import float0 from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -97,7 +99,7 @@ class PRNGImpl(NamedTuple): # -- PRNG key arrays -def _check_prng_key_data(impl, key_data: jax.Array): +def _check_prng_key_data(impl, key_data: typing.Array): ndim = len(impl.key_shape) if not all(hasattr(key_data, attr) for attr in ['ndim', 'shape', 'dtype']): raise TypeError("JAX encountered invalid PRNG key data: expected key_data " @@ -139,7 +141,7 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta): """ impl: PRNGImpl - _base_array: jax.Array + _base_array: typing.Array def __init__(self, impl, key_data: Any): assert not isinstance(key_data, core.Tracer) @@ -512,7 +514,7 @@ mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler) def iterated_vmap_unary(n, f): for _ in range(n): - f = jax.vmap(f) + f = api.vmap(f) return f # TODO(frostig): Revise the following two functions? These basically @@ -528,7 +530,7 @@ def squeeze_vmap(f, left): else: y = jnp.squeeze(y, axis=0) axes = (0, None) - return jax.vmap(f, in_axes=axes, out_axes=0)(x, y) + return api.vmap(f, in_axes=axes, out_axes=0)(x, y) return squeeze_vmap_f def iterated_vmap_binary_bcast(shape1, shape2, f): @@ -543,7 +545,7 @@ def iterated_vmap_binary_bcast(shape1, shape2, f): assert len(shape1) == len(shape2) for sz1, sz2 in reversed(zip(shape1, shape2)): if sz1 == sz2: - f = jax.vmap(f, out_axes=0) + f = api.vmap(f, out_axes=0) else: assert sz1 == 1 or sz2 == 1, (sz1, sz2) f = squeeze_vmap(f, sz1 == 1) @@ -785,14 +787,14 @@ mlir.register_lowering(random_unwrap_p, random_unwrap_lowering) # -- threefry2x32 PRNG implementation -def _is_threefry_prng_key(key: jax.Array) -> bool: +def _is_threefry_prng_key(key: typing.Array) -> bool: try: return key.shape == (2,) and key.dtype == np.uint32 except AttributeError: return False -def threefry_seed(seed: jax.Array) -> jax.Array: +def threefry_seed(seed: typing.Array) -> typing.Array: """Create a single raw threefry PRNG key from an integer seed. Args: @@ -811,7 +813,7 @@ def threefry_seed(seed: jax.Array) -> jax.Array: convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1]) k1 = convert( lax.shift_right_logical(seed, lax_internal._const(seed, 32))) - with jax.numpy_dtype_promotion('standard'): + with config_lib.numpy_dtype_promotion('standard'): # TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit # inputs. We should avoid this. k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF))) @@ -1090,26 +1092,26 @@ def threefry_2x32(keypair, count): return lax.reshape(out[:-1] if odd_size else out, count.shape) -def threefry_split(key: jax.Array, num: int) -> jax.Array: +def threefry_split(key: typing.Array, num: int) -> typing.Array: if config.jax_threefry_partitionable: return _threefry_split_foldlike(key, int(num)) # type: ignore else: return _threefry_split_original(key, int(num)) # type: ignore @partial(jit, static_argnums=(1,), inline=True) -def _threefry_split_original(key, num) -> jax.Array: +def _threefry_split_original(key, num) -> typing.Array: counts = lax.iota(np.uint32, num * 2) return lax.reshape(threefry_2x32(key, counts), (num, 2)) @partial(jit, static_argnums=(1,), inline=True) -def _threefry_split_foldlike(key, num) -> jax.Array: +def _threefry_split_foldlike(key, num) -> typing.Array: k1, k2 = key counts1, counts2 = iota_2x32_shape((num,)) bits1, bits2 = threefry2x32_p.bind(k1, k2, counts1, counts2) return jnp.stack([bits1, bits2], axis=1) -def threefry_fold_in(key: jax.Array, data: jax.Array) -> jax.Array: +def threefry_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: assert not data.shape return _threefry_fold_in(key, jnp.uint32(data)) @@ -1118,7 +1120,7 @@ def _threefry_fold_in(key, data): return threefry_2x32(key, threefry_seed(data)) -def threefry_random_bits(key: jax.Array, bit_width, shape): +def threefry_random_bits(key: typing.Array, bit_width, shape): """Sample uniform random bits of given width and shape using PRNG key.""" if not _is_threefry_prng_key(key): raise TypeError("threefry_random_bits got invalid prng key.") @@ -1131,7 +1133,7 @@ def threefry_random_bits(key: jax.Array, bit_width, shape): else: return _threefry_random_bits_original(key, bit_width, shape) -def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape): +def _threefry_random_bits_partitionable(key: typing.Array, bit_width, shape): if all(core.is_constant_dim(d) for d in shape) and math.prod(shape) > 2 ** 64: raise NotImplementedError('random bits array of size exceeding 2 ** 64') @@ -1150,7 +1152,7 @@ def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape): return lax.convert_element_type(bits1 ^ bits2, dtype) @partial(jit, static_argnums=(1, 2), inline=True) -def _threefry_random_bits_original(key: jax.Array, bit_width, shape): +def _threefry_random_bits_original(key: typing.Array, bit_width, shape): size = math.prod(shape) # Compute ceil(bit_width * size / 32) in a way that is friendly to shape # polymorphism @@ -1210,12 +1212,12 @@ threefry_prng_impl = PRNGImpl( # stable/deterministic across backends or compiler versions. Correspondingly, we # reserve the right to change any of these implementations at any time! -def _rbg_seed(seed: jax.Array) -> jax.Array: +def _rbg_seed(seed: typing.Array) -> typing.Array: assert not seed.shape halfkey = threefry_seed(seed) return jnp.concatenate([halfkey, halfkey]) -def _rbg_split(key: jax.Array, num: int) -> jax.Array: +def _rbg_split(key: typing.Array, num: int) -> typing.Array: if config.jax_threefry_partitionable: _threefry_split = _threefry_split_foldlike else: @@ -1223,12 +1225,12 @@ def _rbg_split(key: jax.Array, num: int) -> jax.Array: return vmap( _threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4) -def _rbg_fold_in(key: jax.Array, data: jax.Array) -> jax.Array: +def _rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: assert not data.shape return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4) -def _rbg_random_bits(key: jax.Array, bit_width: int, shape: Sequence[int] - ) -> jax.Array: +def _rbg_random_bits(key: typing.Array, bit_width: int, shape: Sequence[int] + ) -> typing.Array: if not key.shape == (4,) and key.dtype == jnp.dtype('uint32'): raise TypeError("_rbg_random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): @@ -1244,12 +1246,12 @@ rbg_prng_impl = PRNGImpl( fold_in=_rbg_fold_in, tag='rbg') -def _unsafe_rbg_split(key: jax.Array, num: int) -> jax.Array: +def _unsafe_rbg_split(key: typing.Array, num: int) -> typing.Array: # treat 10 iterations of random bits as a 'hash function' _, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32') return keys[::10] -def _unsafe_rbg_fold_in(key: jax.Array, data: jax.Array) -> jax.Array: +def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: assert not data.shape _, random_bits = lax.rng_bit_generator(_rbg_seed(data), (10, 4), dtype='uint32') return key ^ random_bits[-1] diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 59399b6e3..497de2698 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -15,13 +15,13 @@ from functools import partial import operator -from jax import config -from jax.tree_util import tree_map, tree_reduce from jax._src import api from jax._src import dtypes as _dtypes from jax._src import xla_bridge -from jax._src.config import flags +from jax._src.config import config, flags from jax._src.lib import xla_client +from jax._src.tree_util import tree_map, tree_reduce + import numpy as np diff --git a/jax/_src/random.py b/jax/_src/random.py index 449461013..42191bb7e 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -21,17 +21,17 @@ import warnings import numpy as np -import jax import jax.numpy as jnp from jax import lax -from jax.config import config from jax.numpy.linalg import cholesky, svd, eigh +from jax._src import config as config_lib from jax._src import core from jax._src import dtypes from jax._src import prng from jax._src import xla_bridge from jax._src.api import jit, vmap +from jax._src.config import config from jax._src.core import NamedShape from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -677,7 +677,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array: else: # 'cholesky' factor = cholesky(cov) normal_samples = normal(key, shape + mean.shape[-1:], dtype) - with jax.numpy_rank_promotion('allow'): + with config_lib.numpy_rank_promotion('allow'): result = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples) return result diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index de5f88087..f8f6d3f4b 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -24,8 +24,8 @@ from jax import jit from jax import vmap from jax import lax -from jax._src import api from jax._src import core +from jax._src import custom_derivatives from jax._src import dtypes from jax._src.interpreters import ad from jax._src.lax.lax import _const as _lax_const @@ -96,7 +96,7 @@ def erfinv(x: ArrayLike) -> Array: return lax.erf_inv(x) -@api.custom_jvp +@custom_derivatives.custom_jvp @_wraps(osp_special.logit, module='scipy.special', update_doc=False) def logit(x: ArrayLike) -> Array: x, = promote_args_inexact("logit", x) @@ -214,7 +214,7 @@ def polygamma(n: ArrayLike, x: ArrayLike) -> Array: return _polygamma(jnp.broadcast_to(n_arr, shape), jnp.broadcast_to(x_arr, shape)) -@api.custom_jvp +@custom_derivatives.custom_jvp def _polygamma(n: ArrayLike, x: ArrayLike) -> Array: dtype = lax.dtype(n).type n_plus = n + dtype(1) @@ -481,7 +481,7 @@ def _ndtri(p: ArrayLike) -> Array: return x_nan_replaced -@partial(api.custom_jvp, nondiff_argnums=(1,)) +@partial(custom_derivatives.custom_jvp, nondiff_argnums=(1,)) def log_ndtr(x: ArrayLike, series_order: int = 3) -> Array: r"""Log Normal distribution function. @@ -1361,7 +1361,7 @@ def _expi_neg(x: Array) -> Array: # x < 0 return -exp1(-x) -@api.custom_jvp +@custom_derivatives.custom_jvp @jit @_wraps(osp_special.expi, module='scipy.special') def expi(x: ArrayLike) -> Array: @@ -1479,7 +1479,7 @@ def _expn3(n: int, x: Array) -> Array: return (ans + one) * jnp.exp(-x) / xk -@partial(api.custom_jvp, nondiff_argnums=(0,)) +@partial(custom_derivatives.custom_jvp, nondiff_argnums=(0,)) @jnp.vectorize @_wraps(osp_special.expn, module='scipy.special') @jit diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 6b4f7480b..191332195 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -20,10 +20,10 @@ import operator as op from typing import (Any, Sequence, List, Tuple, Optional, Mapping, Dict, Set, FrozenSet, Union, cast) -import jax from jax._src import core from jax._src import mesh as mesh_lib from jax._src import sharding +from jax._src import xla_bridge from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method from jax._src.lib import xla_client as xc from jax._src.interpreters import mlir @@ -238,7 +238,7 @@ class NamedSharding(XLACompatibleSharding): # TODO(yaskatariya): Remove this and replace this with a normalized # representation of Parsed Pspec if self._parsed_pspec is None: - from jax.experimental import pjit + from jax._src import pjit self._parsed_pspec, _, _ = pjit._prepare_axis_resources( self.spec, "NamedSharding spec") @@ -287,7 +287,7 @@ class NamedSharding(XLACompatibleSharding): num_dimensions: int, axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None ) -> xc.OpSharding: - from jax.experimental.pjit import get_array_mapping + from jax._src.pjit import get_array_mapping assert self._parsed_pspec is not None array_mapping = get_array_mapping(self._parsed_pspec) # TODO(yashkatariya): Move away from sharding spec in NamedSharding @@ -429,7 +429,8 @@ class PmapSharding(XLACompatibleSharding): '`None` to sharded_dim is not supported. Please file a jax ' 'issue if you need this feature.') - pmap_devices: np.ndarray = np.array(jax.local_devices()[:num_ways_sharded]) + pmap_devices: np.ndarray = np.array( + xla_bridge.local_devices()[:num_ways_sharded]) return cls(pmap_devices, sharding_spec) @functools.cached_property diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 6480e1ecb..d641b9223 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -21,7 +21,6 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Tupl import numpy as np -from jax import lax from jax._src import api_util from jax._src import ad_util @@ -32,6 +31,8 @@ from jax._src import source_info_util from jax._src import tree_util from jax._src.config import config from jax._src.interpreters import ad +from jax._src.lax import lax +from jax._src.lax import slicing as lax_slicing from jax._src.state.types import AbstractRef, RefEffect from jax._src.state.primitives import get_p, swap_p, addupdate_p from jax._src.state.utils import hoist_consts_to_refs @@ -222,7 +223,7 @@ def _dynamic_index(x, idx, indexed_dims): starts = [next(idx_) if b else np.int32(0) for b in indexed_dims] assert next(idx_, None) is None sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)] - out = lax.dynamic_slice(x, starts, sizes) + out = lax_slicing.dynamic_slice(x, starts, sizes) return lax.squeeze(out, [i for i, b in enumerate(indexed_dims) if b]) def _dynamic_update_index(x, idx, val, indexed_dims): @@ -231,7 +232,7 @@ def _dynamic_update_index(x, idx, val, indexed_dims): starts = [next(idx_) if b else np.int32(0) for b in indexed_dims] assert next(idx_, None) is None sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)] - return lax.dynamic_update_slice(x, val.reshape(sizes), starts) + return lax_slicing.dynamic_update_slice(x, val.reshape(sizes), starts) @register_discharge_rule(core.closed_call_p) def _closed_call_discharge_rule( diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index edf1fcfac..d86d1c89c 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -18,13 +18,14 @@ from typing import Any, List, Tuple, Union import numpy as np -import jax + from jax._src import ad_util from jax._src import core from jax._src import pretty_printer as pp from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe +from jax._src.lax import lax from jax._src.typing import Array from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect, AccumEffect) @@ -420,7 +421,7 @@ def _get_vmap(batched_args, batched_dims, *, indexed_dims): # `idxs` doesn't include the non indexed dims. idx_place = [i for i, i_dim in enumerate(indexed_dims) if i_dim].index(ref_dim) - iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) + iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) idxs = tuple_insert(idxs, idx_place, iota) else: bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape) @@ -453,7 +454,7 @@ def _swap_vmap(batched_args, batched_dims, *, indexed_dims): indexed_dims = tuple_insert(indexed_dims, ref_dim, True) idx_place = [i for i, i_dim in enumerate(indexed_dims) if i_dim].index(ref_dim) - iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) + iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) idxs = tuple_insert(idxs, idx_place, iota) val = batching.moveaxis(val, val_dim, 0) bdim_out = 0 @@ -486,7 +487,7 @@ def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims): idx_place = [i for i, i_dim in enumerate(indexed_dims) if i_dim].index(ref_dim) idxs_shape, = {i.shape for i in idxs} or [()] - iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) + iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0) idxs = tuple_insert(idxs, idx_place, iota) val = batching.moveaxis(val, val_dim, 0) return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), [] diff --git a/tests/api_test.py b/tests/api_test.py index bd71cb6c5..ee736c4b1 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -44,11 +44,13 @@ import numpy as np import concurrent.futures import jax +import jax.custom_batching +import jax.custom_derivatives +import jax.custom_transpose import jax.numpy as jnp from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian from jax._src import core from jax import lax -from jax import custom_batching from jax._src import api, dtypes, lib, api_util from jax.errors import UnexpectedTracerError from jax.interpreters import ad @@ -3687,7 +3689,7 @@ class APITest(jtu.JaxTestCase): def test_leak_checker_avoids_false_positive_custom_jvp(self): # see https://github.com/google/jax/issues/5636 with jax.checking_leaks(): - @api.custom_jvp + @jax.custom_jvp def t(y): return y @@ -3926,7 +3928,7 @@ class APITest(jtu.JaxTestCase): def test_backward_pass_ref_dropping(self): refs = [] - @api.custom_vjp + @jax.custom_vjp def f(x): return x def f_fwd(x): @@ -4218,8 +4220,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_basic(self, remat): @@ -4260,8 +4262,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_freevars(self, remat): @@ -4284,7 +4286,7 @@ class RematTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_remat_concrete_error(self): - @api.remat # no static_argnums or concrete + @jax.remat # no static_argnums or concrete def g(x): if x > 0: return lax.sin(x) @@ -4294,7 +4296,7 @@ class RematTest(jtu.JaxTestCase): with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"): g(3.) - @partial(api.remat, static_argnums=(0,)) # using static_argnums but... + @partial(jax.remat, static_argnums=(0,)) # using static_argnums but... def g(x): if x > 0: # jnp operations still get staged! return lax.sin(x) @@ -4305,7 +4307,7 @@ class RematTest(jtu.JaxTestCase): g(jnp.array(3.)) # But don't raise an error mentioning static_argnums here: - @api.remat + @jax.remat def g(x): jax.jit(lambda: 0 if jnp.add(1, 1) else 0)() return lax.sin(x) @@ -4317,7 +4319,7 @@ class RematTest(jtu.JaxTestCase): self.assertNotIn('static_argnums', msg) def test_remat_grad_python_control_flow_static_argnums(self): - @partial(api.remat, static_argnums=(0,)) + @partial(jax.remat, static_argnums=(0,)) def g(x): with jax.ensure_compile_time_eval(): x_pos = x > 0 @@ -4339,7 +4341,7 @@ class RematTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_remat_grad_python_control_flow_unhashable_static_argnums(self): - @partial(api.remat, static_argnums=(0,)) + @partial(jax.remat, static_argnums=(0,)) def g(x): x = x.val with jax.ensure_compile_time_eval(): @@ -4375,7 +4377,7 @@ class RematTest(jtu.JaxTestCase): # the remat-decorated function. count = 0 - @api.remat + @jax.remat def g(x): nonlocal count count += 1 @@ -4399,7 +4401,7 @@ class RematTest(jtu.JaxTestCase): # above test, which doesn't check for static_argnums. count = 0 - @partial(api.remat, static_argnums=(0,)) + @partial(jax.remat, static_argnums=(0,)) def g(x): nonlocal count count += 1 @@ -4422,8 +4424,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_jit(self, remat): @@ -4450,8 +4452,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_vmap(self, remat): @@ -4486,8 +4488,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_vmap_not_leading_dim(self, remat): @@ -4504,8 +4506,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_higher_order_autodiff(self, remat): @@ -4520,7 +4522,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), + ('', jax.remat), ('_new', new_checkpoint), ]) def test_remat_scan(self, remat): @@ -4553,8 +4555,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_no_redundant_flops(self, remat): @@ -4582,8 +4584,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_binomial_checkpointing(self, remat): @@ -4604,7 +4606,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), + ('', jax.remat), ('_new', new_checkpoint), ]) def test_remat_symbolic_zeros(self, remat): @@ -4637,8 +4639,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_jit2(self, remat): @@ -4657,7 +4659,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), + ('', jax.remat), ('_new', new_checkpoint), ]) def test_remat_nontrivial_env(self, remat): @@ -4690,8 +4692,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_jit3(self, remat): @@ -4724,7 +4726,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), + ('', jax.remat), ('_new', new_checkpoint), ]) def test_remat_scan2(self, remat): @@ -4760,8 +4762,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_eval_counter(self, remat): @@ -4821,8 +4823,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_escaped_tracer_remat(self, remat): @@ -4842,8 +4844,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_no_cse_widget_on_primals(self, remat): @@ -4868,7 +4870,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), + ('', jax.remat), ('_new', new_checkpoint), ]) def test_no_cse_widget_with_prevent_cse_false(self, remat): @@ -4892,7 +4894,7 @@ class RematTest(jtu.JaxTestCase): {"testcase_name": f"_{policy_name}_{remat_name}", "remat": remat, "policy": policy, "in_jaxpr2": in_jaxpr2, "not_in_jaxpr2": not_in_jaxpr2} for remat_name, remat in [ - ('old_remat', api.remat), + ('old_remat', jax.remat), ('new_remat', new_checkpoint), ] for policy_name, policy, in_jaxpr2, not_in_jaxpr2 in [ @@ -4919,7 +4921,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"_{remat_name}", "remat": remat} for remat_name, remat in [ - ('old_remat', api.remat), + ('old_remat', jax.remat), ('new_remat', new_checkpoint), ]) def test_remat_custom_policy_save_cos(self, remat): @@ -4935,7 +4937,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"_{remat_name}", "remat": remat} for remat_name, remat in [ - ('old_remat', api.remat), + ('old_remat', jax.remat), ('new_remat', new_checkpoint), ]) def test_remat_checkpoint_dots(self, remat): @@ -4958,7 +4960,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"_{remat_name}", "remat": remat} for remat_name, remat in [ - ('old_remat', api.remat), + ('old_remat', jax.remat), ('new_remat', new_checkpoint), ]) def test_remat_checkpoint_dots_with_no_batch_dims(self, remat): @@ -4981,7 +4983,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"_{remat_name}", "remat": remat} for remat_name, remat in [ - ('old_remat', api.remat), + ('old_remat', jax.remat), ('new_remat', new_checkpoint), ]) def test_remat_checkpoint_dots_with_no_batch_dims2(self, remat): @@ -5004,7 +5006,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"_{remat_name}", "remat": remat} for remat_name, remat in [ - ('old_remat', api.remat), + ('old_remat', jax.remat), ('new_remat', new_checkpoint), ]) def test_remat_checkpoint_dots_jit(self, remat): @@ -5029,7 +5031,7 @@ class RematTest(jtu.JaxTestCase): x = jnp.ones((5,)) def f(W): - @partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots) + @partial(jax.remat, policy=jax.checkpoint_policies.checkpoint_dots) def f(x): x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST)) x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST)) @@ -5055,7 +5057,7 @@ class RematTest(jtu.JaxTestCase): modes=['fwd', 'rev']) def test_remat_custom_jvp_policy(self): - @api.custom_jvp + @jax.custom_jvp def sin(x): return jnp.sin(x) def sin_jvp(primals, tangents): @@ -5064,7 +5066,7 @@ class RematTest(jtu.JaxTestCase): return sin(x), jnp.cos(x) * g sin.defjvp(sin_jvp) - @partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots) + @partial(jax.remat, policy=jax.checkpoint_policies.checkpoint_dots) def f(x): x = jnp.dot(x, x, precision=lax.Precision.HIGHEST) x = sin(x * 1e-3) @@ -5081,7 +5083,7 @@ class RematTest(jtu.JaxTestCase): jtu.check_grads(g, (3.,), order=2, modes=['fwd', 'rev']) def test_remat_custom_vjp_policy(self): - @api.custom_vjp + @jax.custom_vjp def sin(x): return jnp.sin(x) def sin_fwd(x): @@ -5090,7 +5092,7 @@ class RematTest(jtu.JaxTestCase): return (jnp.cos(x) * y_bar,) sin.defvjp(sin_fwd, sin_bwd) - @partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots) + @partial(jax.remat, policy=jax.checkpoint_policies.checkpoint_dots) def f(x): @partial(api.named_call, name="dot") def dot2(y, z): @@ -5114,7 +5116,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"_{remat_name}", "remat": remat} for remat_name, remat in [ - ('old_remat', api.remat), + ('old_remat', jax.remat), ('new_remat', new_checkpoint), ]) def test_remat_dropvar_policy(self, remat): @@ -5129,7 +5131,7 @@ class RematTest(jtu.JaxTestCase): api.grad(g)(3.) def test_remat_custom_jvp_linear_policy(self): - @api.custom_jvp + @jax.custom_jvp def sum(x): return jnp.sum(x, axis=0) @sum.defjvp @@ -5137,7 +5139,7 @@ class RematTest(jtu.JaxTestCase): (x,), (xdot,) = primals, tangents return sum(x), sum(xdot) - @partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots) + @partial(jax.remat, policy=jax.checkpoint_policies.checkpoint_dots) def f(x): return sum(x) jtu.check_grads(f, (jnp.ones(3),), order=2, modes=['fwd', 'rev']) @@ -5244,8 +5246,8 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), - ('_policy', partial(api.remat, policy=lambda *_, **__: False)), + ('', jax.remat), + ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_checkpoint_dropvars(self, remat): @@ -5312,7 +5314,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), + ('', jax.remat), ('_new', new_checkpoint), ]) def test_remat_of_scan(self, remat): @@ -5327,11 +5329,11 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), + ('', jax.remat), ('_new', new_checkpoint), ]) def test_const_in_jvp_scan(self, remat): - @api.custom_jvp + @jax.custom_jvp def f(x): return x * jnp.arange(3.) @f.defjvp @@ -5393,7 +5395,7 @@ class RematTest(jtu.JaxTestCase): y, _ = lax.scan(lambda x, _: (f(x), None), x, None, length=1) return y - @api.custom_jvp + @jax.custom_jvp def sin(x): return jnp.sin(x) def sin_jvp(primals, tangents): @@ -5451,7 +5453,7 @@ class RematTest(jtu.JaxTestCase): y, _ = lax.scan(lambda x, _: (f(x), None), x, None, length=1) return y - @api.custom_jvp + @jax.custom_jvp def sin(x): return jnp.sin(x) def sin_jvp(primals, tangents): @@ -5505,7 +5507,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), + ('', jax.remat), ('_new', new_checkpoint), ]) def test_remat_of_cond(self, remat): @@ -5530,11 +5532,11 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), + ('', jax.remat), ('_new', new_checkpoint), ]) def test_const_in_jvp_cond(self, remat): - @api.custom_jvp + @jax.custom_jvp def f(x): return x * jnp.arange(3.) @f.defjvp @@ -5553,7 +5555,7 @@ class RematTest(jtu.JaxTestCase): x = jnp.ones((5,)) def f(W): - @partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots) + @partial(jax.remat, policy=jax.checkpoint_policies.checkpoint_dots) def f(x): x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST)) x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST)) @@ -5618,7 +5620,7 @@ class RematTest(jtu.JaxTestCase): def cond_apply(f, x): return lax.cond(x.sum() > -jnp.inf, f, lambda x: x, x) - @api.custom_jvp + @jax.custom_jvp def sin(x): return jnp.sin(x) def sin_jvp(primals, tangents): @@ -5675,7 +5677,7 @@ class RematTest(jtu.JaxTestCase): def cond_apply(f, x): return lax.cond(True, f, lambda x: x, x) - @api.custom_jvp + @jax.custom_jvp def sin(x): return jnp.sin(x) def sin_jvp(primals, tangents): @@ -5729,7 +5731,7 @@ class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ - ('', api.remat), + ('', jax.remat), ('_new', new_checkpoint), ]) def test_remat_of_while_loop(self, remat): @@ -6142,7 +6144,7 @@ class DCETest(jtu.JaxTestCase): def test_dce_jaxpr_scan_overpruning(self): # This is a regression test for a specific issue. - @api.remat + @jax.remat def scanned_f(c, x): out = jnp.tanh(c * x) return out, out @@ -6159,7 +6161,7 @@ class DCETest(jtu.JaxTestCase): def test_dce_jaxpr_scan_const_in_jvp(self): # The main point of this test is to check for a crash. - @api.custom_jvp + @jax.custom_jvp def f(x): return x * np.arange(3.) @f.defjvp @@ -6260,7 +6262,7 @@ class DCETest(jtu.JaxTestCase): class CustomJVPTest(jtu.JaxTestCase): def test_basic(self): - @api.custom_jvp + @jax.custom_jvp def f(x): return jnp.sin(x) def f_jvp(primals, tangents): @@ -6276,7 +6278,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) def test_invariance(self): - @api.custom_jvp + @jax.custom_jvp def f(x): return jnp.cos(2 * x) / 2. def f_jvp(primals, tangents): @@ -6299,7 +6301,7 @@ class CustomJVPTest(jtu.JaxTestCase): check_dtypes=False) def test_python_control_flow(self): - @api.custom_jvp + @jax.custom_jvp def f(x): if x > 0: return jnp.sin(x) @@ -6326,7 +6328,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertAllClose(api.grad(f)(-x), 3., check_dtypes=False) def test_vmap(self): - @api.custom_jvp + @jax.custom_jvp def f(x): assert jnp.ndim(x) == 0 return jnp.sin(x) @@ -6361,7 +6363,7 @@ class CustomJVPTest(jtu.JaxTestCase): (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) def test_jit(self): - @api.custom_jvp + @jax.custom_jvp def f(x): return jnp.sin(x) def f_jvp(primals, tangents): @@ -6387,7 +6389,7 @@ class CustomJVPTest(jtu.JaxTestCase): check_dtypes=False) def test_pytrees(self): - @api.custom_jvp + @jax.custom_jvp def f(x): return {'b': jnp.sin(x['a'])} def f_jvp(primals, tangents): @@ -6404,7 +6406,7 @@ class CustomJVPTest(jtu.JaxTestCase): def test_kwargs(self): # from https://github.com/google/jax/issues/1938 - @api.custom_jvp + @jax.custom_jvp def my_fun(x, y, c=1.): return c * (x + y) def my_jvp(primals, tangents): @@ -6417,7 +6419,7 @@ class CustomJVPTest(jtu.JaxTestCase): api.jvp(f, (10., 5.), (1., 1.)) # doesn't crash def test_initial_style(self): - @api.custom_jvp + @jax.custom_jvp def f(x): return 3 * x def f_jvp(primals, tangents): @@ -6459,7 +6461,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_initial_style_vmap(self): - @api.custom_jvp + @jax.custom_jvp def f(x): assert jnp.ndim(x) == 0 return 3 * x @@ -6507,7 +6509,7 @@ class CustomJVPTest(jtu.JaxTestCase): def test_initial_style_vmap_with_collective(self): - @api.custom_jvp + @jax.custom_jvp def f(x): return lax.psum(x, 'foo') @@ -6527,7 +6529,7 @@ class CustomJVPTest(jtu.JaxTestCase): def test_closed_over_tracers_error_message(self): def f(x): - @api.custom_jvp + @jax.custom_jvp def g(y): return x + y def g_jvp(primals, tangents): @@ -6539,7 +6541,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertRaises(ad.CustomJVPException, lambda: api.grad(f)(3.)) def test_nondiff_arg(self): - @partial(api.custom_jvp, nondiff_argnums=(0,)) + @partial(jax.custom_jvp, nondiff_argnums=(0,)) def app(f, x): return f(x) def app_jvp(f, primals, tangents): @@ -6566,7 +6568,7 @@ class CustomJVPTest(jtu.JaxTestCase): # rule) or (2) static data (e.g. integers which parameterize shapes). raise unittest.SkipTest("behavior no longer supported") - @partial(api.custom_jvp, nondiff_argnums=(0,)) + @partial(jax.custom_jvp, nondiff_argnums=(0,)) def f(x, y): return x * y def f_jvp(x, primals, tangents): @@ -6583,7 +6585,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_nondiff_arg_vmap_tracer(self): - @partial(api.custom_jvp, nondiff_argnums=(0,)) + @partial(jax.custom_jvp, nondiff_argnums=(0,)) def f(x, y): return x * y def f_jvp(x, primals, tangents): @@ -6600,7 +6602,7 @@ class CustomJVPTest(jtu.JaxTestCase): def test_nondiff_arg_hiding_jvp_tracer(self): def f(x): - @partial(api.custom_jvp, nondiff_argnums=(0,)) + @partial(jax.custom_jvp, nondiff_argnums=(0,)) def g(h, x): return h(x) @g.defjvp @@ -6621,7 +6623,7 @@ class CustomJVPTest(jtu.JaxTestCase): raise unittest.SkipTest("TODO") # TODO(mattjj): write test def test_missing_jvp_rule_error_message(self): - @api.custom_jvp + @jax.custom_jvp def foo(x): return x ** 2 @@ -6639,7 +6641,7 @@ class CustomJVPTest(jtu.JaxTestCase): lambda: api.grad(foo)(2.)) def test_jvp_rule_inconsistent_pytree_structures_error_message(self): - @api.custom_jvp + @jax.custom_jvp def f(x): return (x**2,) @@ -6663,7 +6665,7 @@ class CustomJVPTest(jtu.JaxTestCase): lambda: api.jvp(f, (2.,), (1.,))) def test_primal_tangent_aval_disagreement_error_message(self): - @api.custom_jvp + @jax.custom_jvp def f(x): return x ** 2 @@ -6685,7 +6687,7 @@ class CustomJVPTest(jtu.JaxTestCase): def test_jvp_rule_doesnt_return_pair_error_message(self): # https://github.com/google/jax/issues/2516 - @api.custom_jvp + @jax.custom_jvp def f(x): return x ** 2 @@ -6849,7 +6851,7 @@ class CustomJVPTest(jtu.JaxTestCase): def test_jaxpr_zeros(self): # from https://github.com/google/jax/issues/2657 - @api.custom_jvp + @jax.custom_jvp def f(A, b): return A @ b @@ -6875,7 +6877,7 @@ class CustomJVPTest(jtu.JaxTestCase): grad(experiment)(1.) # doesn't crash def test_linear_in_scan(self): - @api.custom_jvp + @jax.custom_jvp def f(x): return -x @@ -6895,7 +6897,7 @@ class CustomJVPTest(jtu.JaxTestCase): def test_custom_jvps_first_rule_is_none(self): # https://github.com/google/jax/issues/3389 - @api.custom_jvp + @jax.custom_jvp def f(x, y): return x ** 2 * y @@ -6948,7 +6950,7 @@ class CustomJVPTest(jtu.JaxTestCase): def test_fun_with_nested_calls_2(self): def call(f, *args): - f = api.custom_jvp(f) + f = jax.custom_jvp(f) f.defjvp(lambda primals, tangents: (f(*primals), sum(tangents))) return f(*args) @@ -6973,7 +6975,7 @@ class CustomJVPTest(jtu.JaxTestCase): alpha = np.float32(2.) def sample(seed): - @api.custom_jvp + @jax.custom_jvp def f(alpha): return jax.random.gamma(seed, alpha, shape=[]) @@ -7013,7 +7015,7 @@ class CustomJVPTest(jtu.JaxTestCase): @unittest.skipIf(numpy_version == (1, 21, 0), "https://github.com/numpy/numpy/issues/19305") def test_float0(self): - @api.custom_jvp + @jax.custom_jvp def f(x, y): return x, y def f_jvp(primals, _): @@ -7030,7 +7032,7 @@ class CustomJVPTest(jtu.JaxTestCase): @unittest.skipIf(numpy_version == (1, 21, 0), "https://github.com/numpy/numpy/issues/19305") def test_float0_initial_style(self): - @api.custom_jvp + @jax.custom_jvp def f(x, y): return x, y def f_jvp(primals, _): @@ -7049,7 +7051,7 @@ class CustomJVPTest(jtu.JaxTestCase): (primals, expected_tangents)) def test_remat(self): - @api.custom_jvp + @jax.custom_jvp def f(x): return jnp.sin(x) def f_jvp(primals, tangents): @@ -7058,7 +7060,7 @@ class CustomJVPTest(jtu.JaxTestCase): return f(x), 2 * jnp.cos(x) * g f.defjvp(f_jvp) - @api.remat + @jax.remat def g(x): return f(f(x)) @@ -7071,7 +7073,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_remat_higher_order(self): - @api.custom_jvp + @jax.custom_jvp def f(x): return jnp.sin(x) def f_jvp(primals, tangents): @@ -7100,7 +7102,7 @@ class CustomJVPTest(jtu.JaxTestCase): # over an array constant. y = jnp.arange(1., 4.) - @api.custom_jvp + @jax.custom_jvp def f(x): assert jnp.ndim(x) == 0 return 3 * x * jnp.sum(y) @@ -7154,7 +7156,7 @@ class CustomJVPTest(jtu.JaxTestCase): def test_custom_jvp_vmap_broadcasting_interaction_2(self): # https://github.com/google/jax/issues/5849 - @api.custom_jvp + @jax.custom_jvp def transform(box, R): if jnp.isscalar(box) or box.size == 1: return R * box @@ -7383,7 +7385,7 @@ class CustomJVPTest(jtu.JaxTestCase): return out return _fun - f = api.custom_jvp(f) + f = jax.custom_jvp(f) @partial(f.defjvp, symbolic_zeros=True) def f_jvp(primals, tangents): @@ -7424,7 +7426,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertArraysAllClose(tangent_out2, jnp.array([99., 100.])) def test_symbolic_zero_custom_jvp_vmap_output(self): - @api.custom_jvp + @jax.custom_jvp def f(x, y): return x * y @@ -7441,7 +7443,7 @@ class CustomJVPTest(jtu.JaxTestCase): class CustomVJPTest(jtu.JaxTestCase): def test_basic(self): - @api.custom_vjp + @jax.custom_vjp def f(x): return jnp.sin(x) def f_fwd(x): @@ -7457,7 +7459,7 @@ class CustomVJPTest(jtu.JaxTestCase): (jnp.sin(x), 2 * jnp.cos(x))) def test_invariance(self): - @api.custom_vjp + @jax.custom_vjp def f(x): return jnp.cos(2 * x) / 2. def f_fwd(x): @@ -7480,7 +7482,7 @@ class CustomVJPTest(jtu.JaxTestCase): check_dtypes=False) def test_python_control_flow(self): - @api.custom_vjp + @jax.custom_vjp def f(x): if x > 0: return jnp.sin(x) @@ -7506,7 +7508,7 @@ class CustomVJPTest(jtu.JaxTestCase): check_dtypes=False) def test_vmap(self): - @api.custom_vjp + @jax.custom_vjp def f(x): assert jnp.ndim(x) == 0 return jnp.sin(x) @@ -7543,7 +7545,7 @@ class CustomVJPTest(jtu.JaxTestCase): 2 * jnp.cos(xx)) def test_jit(self): - @api.custom_vjp + @jax.custom_vjp def f(x): return jnp.sin(x) def f_fwd(x): @@ -7567,7 +7569,7 @@ class CustomVJPTest(jtu.JaxTestCase): check_dtypes=False) def test_pytrees(self): - @api.custom_vjp + @jax.custom_vjp def f(x): return {'b': jnp.sin(x['a'])} def f_fwd(x): @@ -7582,7 +7584,7 @@ class CustomVJPTest(jtu.JaxTestCase): {'a': 2 * jnp.cos(x['a'])}) def test_jvp_error(self): - @api.custom_vjp + @jax.custom_vjp def f(x): return jnp.sin(x) def f_fwd(x): @@ -7606,7 +7608,7 @@ class CustomVJPTest(jtu.JaxTestCase): def test_kwargs(self): # from https://github.com/google/jax/issues/1938 - @api.custom_vjp + @jax.custom_vjp def my_fun(x, y, c=1.): return c * (x + y) my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None), @@ -7616,7 +7618,7 @@ class CustomVJPTest(jtu.JaxTestCase): api.grad(f)(10., 5.) # doesn't crash def test_initial_style(self): - @api.custom_vjp + @jax.custom_vjp def f(x): return jnp.sin(x) def f_fwd(x): @@ -7638,7 +7640,7 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected) def test_initial_style_vmap(self): - @api.custom_vjp + @jax.custom_vjp def f(x): assert jnp.ndim(x) == 0 return 3 * x @@ -7661,7 +7663,7 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_nondiff_arg(self): - @partial(api.custom_vjp, nondiff_argnums=(0,)) + @partial(jax.custom_vjp, nondiff_argnums=(0,)) def app(f, x): return f(x) def app_fwd(f, x): @@ -7687,7 +7689,7 @@ class CustomVJPTest(jtu.JaxTestCase): # tracers in nondiff_argnums to greatly simplify bookkeeping while still # supporting the cases for which it is necessary. def outer(x): - @api.custom_vjp + @jax.custom_vjp def f(y): return x * y def f_fwd(y): @@ -7711,7 +7713,7 @@ class CustomVJPTest(jtu.JaxTestCase): def test_closed_over_vmap_tracer(self): def outer(x): - @api.custom_vjp + @jax.custom_vjp def f(y): return x * y def f_fwd(y): @@ -7731,7 +7733,7 @@ class CustomVJPTest(jtu.JaxTestCase): def test_closed_over_tracer3(self): def outer(x): - @api.custom_vjp + @jax.custom_vjp def f(y): return x * y def f_fwd(y): @@ -7754,7 +7756,7 @@ class CustomVJPTest(jtu.JaxTestCase): # This is similar to the old (now skipped) test_nondiff_arg_tracer, except # we're testing for the error message that that usage pattern now raises. - @partial(api.custom_vjp, nondiff_argnums=(0,)) + @partial(jax.custom_vjp, nondiff_argnums=(0,)) def f(x, y): return x * y def f_fwd(x, y): @@ -7779,7 +7781,7 @@ class CustomVJPTest(jtu.JaxTestCase): raise unittest.SkipTest("TODO") # TODO(mattjj): write test def test_missing_vjp_rule_error(self): - @api.custom_vjp + @jax.custom_vjp def foo(x): return x ** 2 @@ -7793,7 +7795,7 @@ class CustomVJPTest(jtu.JaxTestCase): lambda: api.grad(foo)(2.)) def test_vjp_rule_inconsistent_pytree_structures_error(self): - @api.custom_vjp + @jax.custom_vjp def f(x): return x @@ -7820,7 +7822,7 @@ class CustomVJPTest(jtu.JaxTestCase): lambda: api.grad(f)(2.)) def test_vjp_bwd_returns_non_tuple_error(self): - @api.custom_vjp + @jax.custom_vjp def f(x): return x @@ -7915,7 +7917,7 @@ class CustomVJPTest(jtu.JaxTestCase): def test_clip_gradient(self): # https://github.com/google/jax/issues/2784 - @api.custom_vjp + @jax.custom_vjp def _clip_gradient(lo, hi, x): return x # identity function when not differentiating @@ -7941,7 +7943,7 @@ class CustomVJPTest(jtu.JaxTestCase): def f(x): return x ** 2 - @api.custom_vjp + @jax.custom_vjp def g(x): return f(x) @@ -7974,7 +7976,7 @@ class CustomVJPTest(jtu.JaxTestCase): x = jnp.ones((10, 3)) # Create the custom function - @api.custom_vjp + @jax.custom_vjp def custom_fun(x): return x.sum() @@ -8005,7 +8007,7 @@ class CustomVJPTest(jtu.JaxTestCase): # over an array constant. y = jnp.arange(1., 4.) - @api.custom_vjp + @jax.custom_vjp def f(x): assert jnp.ndim(x) == 0 return 3 * x * jnp.sum(y) @@ -8029,7 +8031,7 @@ class CustomVJPTest(jtu.JaxTestCase): def test_initial_style_vmap_with_collective(self): - @api.custom_vjp + @jax.custom_vjp def f(x): return lax.psum(x, 'foo') @@ -8113,7 +8115,7 @@ class CustomVJPTest(jtu.JaxTestCase): @unittest.skipIf(numpy_version == (1, 21, 0), "https://github.com/numpy/numpy/issues/19305") def test_float0(self): - @api.custom_vjp + @jax.custom_vjp def f(x, _): return x def f_fwd(x, _): @@ -8131,7 +8133,7 @@ class CustomVJPTest(jtu.JaxTestCase): @unittest.skipIf(numpy_version == (1, 21, 0), "https://github.com/numpy/numpy/issues/19305") def test_float0_initial_style(self): - @api.custom_vjp + @jax.custom_vjp def f(x): return x def f_fwd(x): @@ -8150,7 +8152,7 @@ class CustomVJPTest(jtu.JaxTestCase): (2., np.zeros(shape=(), dtype=float0))) def test_remat(self): - @api.custom_vjp + @jax.custom_vjp def f(x): return jnp.sin(x) def f_fwd(x): @@ -8159,7 +8161,7 @@ class CustomVJPTest(jtu.JaxTestCase): return (2 * cos_x * g,) f.defvjp(f_fwd, f_rev) - @api.remat + @jax.remat def g(x): return f(f(x)) @@ -8172,7 +8174,7 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_remat_higher_order(self): - @api.custom_vjp + @jax.custom_vjp def f(x): return jnp.sin(x) def f_fwd(x): @@ -8184,20 +8186,20 @@ class CustomVJPTest(jtu.JaxTestCase): def g(x): return f(f(x)) - ans = api.grad(api.grad(api.remat(g)))(2.) + ans = api.grad(api.grad(jax.remat(g)))(2.) expected = api.grad(api.grad(g))(2.) self.assertAllClose(ans, expected, check_dtypes=False) - ans = api.grad(api.remat(api.grad(g)))(2.) + ans = api.grad(jax.remat(api.grad(g)))(2.) expected = api.grad(api.grad(g))(2.) self.assertAllClose(ans, expected, check_dtypes=False) - ans = api.grad(api.grad(api.grad(api.remat(g))))(2.) + ans = api.grad(api.grad(api.grad(jax.remat(g))))(2.) expected = api.grad(api.grad(api.grad(g)))(2.) self.assertAllClose(ans, expected, check_dtypes=False) def test_bwd_nones(self): - @api.custom_vjp + @jax.custom_vjp def f(x, y): return x * jnp.sin(y) def f_fwd(x, y): @@ -8211,7 +8213,7 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_bwd_nones_vmap(self): - @api.custom_vjp + @jax.custom_vjp def f(x, y): return x * jnp.sin(y) def f_fwd(x, y): @@ -8225,7 +8227,7 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_bwd_nones_pytree(self): - @api.custom_vjp + @jax.custom_vjp def f(xs, y): x1, x2 = xs return x1 * x2 * jnp.sin(y) @@ -8241,7 +8243,7 @@ class CustomVJPTest(jtu.JaxTestCase): def test_custom_vjp_closure_4521(self): # https://github.com/google/jax/issues/4521 - @api.custom_vjp + @jax.custom_vjp def g(x, y): return None def g_fwd(x, y): @@ -8264,7 +8266,7 @@ class CustomVJPTest(jtu.JaxTestCase): lax.scan(scan_body, jnp.ones(5), None, 100) # doesn't crash def test_float0_bwd_none(self): - @api.custom_vjp + @jax.custom_vjp def f(i, x): return jnp.sin(x) def f_fwd(i, x): @@ -8278,7 +8280,7 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_custom_gradient(self): - @api.custom_gradient + @jax.custom_gradient def f(x): return x ** 2, lambda g: (g * x,) @@ -8287,7 +8289,7 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False) def test_custom_gradient_2(self): - @api.custom_gradient + @jax.custom_gradient def f(x, y): return x * y, lambda g: (y, x) @@ -8296,7 +8298,7 @@ class CustomVJPTest(jtu.JaxTestCase): check_dtypes=False) def test_custom_gradient_3(self): - @api.custom_gradient + @jax.custom_gradient def f(x): vjp = lambda g: (jnp.cos(x) * jnp.arange(3., 6.),) return jnp.sum(jnp.sin(x)), vjp @@ -8309,7 +8311,7 @@ class CustomVJPTest(jtu.JaxTestCase): check_dtypes=False) def test_custom_gradient_can_return_singleton_value_in_vjp(self): - @api.custom_gradient + @jax.custom_gradient def f(x): return x ** 2, lambda g: g * x @@ -8323,7 +8325,7 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertLessEqual(len(aux_args), 1) return _cos_after(converted_fn, x, *aux_args) - @partial(api.custom_vjp, nondiff_argnums=(0,)) + @partial(jax.custom_vjp, nondiff_argnums=(0,)) def _cos_after(fn, x, *args): return jnp.cos(fn(x, *args)) @@ -8364,7 +8366,7 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertLessEqual(len(aux_args), 1) return _cos_after(converted_fn, x, *aux_args) - @partial(api.custom_vjp, nondiff_argnums=(0,)) + @partial(jax.custom_vjp, nondiff_argnums=(0,)) def _cos_after(fn, x, *args): return jnp.cos(fn(x, *args)) @@ -8502,13 +8504,13 @@ def transpose_unary(f, x_example): return transposed -# This class wraps api.custom_transpose in order to pass in a +# This class wraps jax.custom_transpose.custom_transpose in order to pass in a # particular tree of output type on each call. Otherwise it forwards # all attribute access. class _custom_transpose: def __init__(self, out_types, fun): self.out_types = out_types - self.fun = api.custom_transpose(fun) + self.fun = jax.custom_transpose.custom_transpose(fun) def __getattr__(self, name): return getattr(self.fun, name) @@ -8540,7 +8542,7 @@ class CustomTransposeTest(jtu.JaxTestCase): def f(x, y): def fn(r, x): return x / r def tp(r, t): return t / r - return x + api.linear_call(fn, tp, y, x) + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) def f_ref(x, y): return x + x / y @@ -8558,7 +8560,7 @@ class CustomTransposeTest(jtu.JaxTestCase): def f(x, y): def fn(r, x): return x / r def tp(r, t): return t / (2. * r) # nb: not the true transpose - return x + api.linear_call(fn, tp, y, x) + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) def f_ref(x, y): return x + x / y @@ -8576,7 +8578,7 @@ class CustomTransposeTest(jtu.JaxTestCase): def fn(r, x): return x / r def tp(r, t): return t / (2. * r) # nb: untrue transpose def f_(x, y): - return x + api.linear_call(fn, tp, y, x) + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) x = jnp.ones(2) * 6. y = jnp.ones(2) * 3. @@ -8597,7 +8599,7 @@ class CustomTransposeTest(jtu.JaxTestCase): t1, t2 = t return t1 + t2 - return api.linear_call(fn, tp, (), c * x) + return jax.custom_derivatives.linear_call(fn, tp, (), c * x) def f_ref(c, x): return [c * x, c * x] @@ -8613,7 +8615,7 @@ class CustomTransposeTest(jtu.JaxTestCase): def id_(x): def f(_, x): return x def t(_, t): return 0. - return api.linear_call(f, t, (), x) + return jax.custom_derivatives.linear_call(f, t, (), x) # identity function with an untrue transpose of 7, and where both # forward and transpose have custom transpositions that should @@ -8621,7 +8623,7 @@ class CustomTransposeTest(jtu.JaxTestCase): def f(x): def f_(_, x): return id_(x) def t_(_, t): return id_(7.) - return api.linear_call(f_, t_, (), x) + return jax.custom_derivatives.linear_call(f_, t_, (), x) x = 5. id_t = transpose_unary(id_, x) @@ -8643,7 +8645,7 @@ class CustomTransposeTest(jtu.JaxTestCase): def f(x, y): def fn(r, x): return x / r def tp(r, t): return t / r - return x + api.linear_call(fn, tp, y, x) + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) x = jnp.ones(2) * 6. y = jnp.ones(2) * 3. @@ -9009,7 +9011,7 @@ class CustomTransposeTest(jtu.JaxTestCase): class CustomVmapTest(jtu.JaxTestCase): def test_basic(self): - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x): return jnp.sin(x) @f.def_vmap @@ -9029,7 +9031,7 @@ class CustomVmapTest(jtu.JaxTestCase): def test_closure(self): z = jnp.array([2., 1., 3.]) - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x): return z + jnp.sin(x) @f.def_vmap @@ -9049,7 +9051,7 @@ class CustomVmapTest(jtu.JaxTestCase): self.assertAllClose(ys, z + jnp.cos(xs)) def test_rule_multi_output(self): - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x): return jnp.sin(x), jnp.cos(x) @f.def_vmap @@ -9065,7 +9067,7 @@ class CustomVmapTest(jtu.JaxTestCase): self.assertAllClose(ys2, jnp.sin(xs)) def test_nary(self): - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x, y): return jnp.sin(x) + y ** 2. @f.def_vmap @@ -9081,7 +9083,7 @@ class CustomVmapTest(jtu.JaxTestCase): self.assertAllClose(zs, jnp.cos(xs) + ys ** 2.) def test_nary_mixed_batching(self): - @api.custom_vmap + @jax.custom_batching.custom_vmap def vector_dot(u, v): self.assertEqual(u.ndim, 1) self.assertEqual(v.ndim, 1) @@ -9131,7 +9133,7 @@ class CustomVmapTest(jtu.JaxTestCase): self.assertEqual(in_batched_log[3], [True, True]) def test_rule_input_signature(self): - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x): return jnp.sin(x) rule_args = [] @@ -9149,7 +9151,7 @@ class CustomVmapTest(jtu.JaxTestCase): self.assertEqual(len(in_batched), 1) def test_rule_output_vs_batching_output_mismatch(self): - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x): return jnp.sin(x) @f.def_vmap @@ -9164,7 +9166,7 @@ class CustomVmapTest(jtu.JaxTestCase): lambda: api.vmap(f)(xs)) def test_rule_vs_call_output_mismatch(self): - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x): return jnp.sin(x) @f.def_vmap @@ -9179,7 +9181,7 @@ class CustomVmapTest(jtu.JaxTestCase): lambda: api.vmap(f)(xs)) def test_jvp_basic(self): - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x): return jnp.sin(x) @f.def_vmap @@ -9210,7 +9212,7 @@ class CustomVmapTest(jtu.JaxTestCase): z = jnp.array([2., 1., 3.]) def bcast(x): return z + x - z - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x): return z + jnp.sin(x) @f.def_vmap @@ -9237,7 +9239,7 @@ class CustomVmapTest(jtu.JaxTestCase): self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs) def test_jvp_nary(self): - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x, y): return jnp.sin(x) + y @f.def_vmap @@ -9260,7 +9262,7 @@ class CustomVmapTest(jtu.JaxTestCase): self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys) def test_jvp_extra_batched_tangents(self): - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x): return jnp.sin(x) @f.def_vmap @@ -9280,7 +9282,7 @@ class CustomVmapTest(jtu.JaxTestCase): def test_jacfwd(self): # jacfwd is another way to exercise extra-batched tangents - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x): return jnp.sin(x) @f.def_vmap @@ -9294,7 +9296,7 @@ class CustomVmapTest(jtu.JaxTestCase): self.assertAllClose(j, -jnp.diag(jnp.sin(x))) def test_jvp_extra_batched_primals(self): - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x): return jnp.sin(x) @f.def_vmap @@ -9320,14 +9322,14 @@ class CustomVmapTest(jtu.JaxTestCase): # this test checks that vmapped JVPs continue to behave this way # when custom_vmap is involved and the custom vmap rule is linear. - @api.custom_vmap + @jax.custom_batching.custom_vmap def f_linear(x): return 7. * x @f_linear.def_vmap def linear_rule(axis_size, in_batched, xs): return 11. * xs, in_batched[0] - @api.custom_vmap + @jax.custom_batching.custom_vmap def f_nonlinear(x): return jnp.sin(x) @f_nonlinear.def_vmap @@ -9357,7 +9359,7 @@ class CustomVmapTest(jtu.JaxTestCase): # depend on input tangents, extra-batched input tangents can # create batched output primals, as this test checks. - @api.custom_jvp + @jax.custom_jvp def cos_with_invalid_dataflow_jvp(x): return jnp.cos(x) @cos_with_invalid_dataflow_jvp.defjvp @@ -9365,7 +9367,7 @@ class CustomVmapTest(jtu.JaxTestCase): [x], [tx] = x, tx return jnp.cos(x * tx), tx - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x): return jnp.sin(x) @f.def_vmap @@ -9396,7 +9398,7 @@ class CustomVmapTest(jtu.JaxTestCase): xs = (xs, [xs + 1, xs + 2], [xs + 3], xs + 4) in_batched_ref = tree_util.tree_map(lambda _: True, x) - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(xs): return tree_sin(xs) @f.def_vmap @@ -9420,7 +9422,7 @@ class CustomVmapTest(jtu.JaxTestCase): xs = (xs, [xs + 1, None], [xs + 3], None) in_batched_ref = tree_util.tree_map(lambda _: True, x) - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(xs): return tree_sin(xs) @f.def_vmap @@ -9436,7 +9438,7 @@ class CustomVmapTest(jtu.JaxTestCase): self.assertAllClose(ys, tree_cos(xs)) def test_jit(self): - @api.custom_vmap + @jax.custom_batching.custom_vmap def f(x): return jnp.sin(x) @f.def_vmap @@ -9451,7 +9453,7 @@ class CustomVmapTest(jtu.JaxTestCase): self.assertAllClose(api.vmap(jit(f))(xs), api.vmap(f)(xs)) def test_sequential_vmap_basic(self): - @custom_batching.sequential_vmap + @jax.custom_batching.sequential_vmap def f(x): return x + 1. @@ -9465,7 +9467,7 @@ class CustomVmapTest(jtu.JaxTestCase): self.assertEqual(str(jaxpr), str(jaxpr_ref)) def test_sequential_vmap_nary_same_batching(self): - @custom_batching.sequential_vmap + @jax.custom_batching.sequential_vmap def f(x, y): return x + y @@ -9479,7 +9481,7 @@ class CustomVmapTest(jtu.JaxTestCase): self.assertEqual(str(jaxpr), str(jaxpr_ref)) def test_sequential_vmap_nary_mixed_batching(self): - @custom_batching.sequential_vmap + @jax.custom_batching.sequential_vmap def f(x, y): return x + y @@ -9497,9 +9499,9 @@ class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" def test_method_forwarding(self): - @api.custom_vmap - @api.custom_jvp - @api.custom_transpose + @jax.custom_batching.custom_vmap + @jax.custom_jvp + @jax.custom_transpose.custom_transpose def f(x): return 2. * x # none of these err: @@ -9512,7 +9514,7 @@ class CustomApiTest(jtu.JaxTestCase): def test_def_method_forwarding_all_permutations(self): for wraps in it.permutations([ - api.custom_jvp, api.custom_transpose, api.custom_vmap]): + jax.custom_jvp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): f = lambda x: x + 1. for wrap in wraps: f = wrap(f) @@ -9521,7 +9523,7 @@ class CustomApiTest(jtu.JaxTestCase): self.assertIsInstance(getattr(f, method), Callable) for decorators in it.permutations([ - api.custom_vjp, api.custom_transpose, api.custom_vmap]): + jax.custom_vjp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): f = lambda x: x + 1. for decorator in decorators: f = decorator(f)