Avoid imports from the public jax.* namespace in more places internally.

This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
This commit is contained in:
Peter Hawkins 2023-04-04 11:41:00 -07:00 committed by jax authors
parent 3c1f3abba2
commit c1f65fc8b2
24 changed files with 486 additions and 480 deletions

View File

@ -76,7 +76,7 @@ del _xc
from jax._src.api import effects_barrier as effects_barrier
from jax._src.api import block_until_ready as block_until_ready
from jax._src.api import checkpoint as checkpoint
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_backends as clear_backends
from jax._src.custom_derivatives import closure_convert as closure_convert
@ -116,8 +116,8 @@ from jax._src.api import named_scope as named_scope
from jax._src.api import pmap as pmap
from jax._src.xla_bridge import process_count as process_count
from jax._src.xla_bridge import process_index as process_index
from jax._src.api import pure_callback as pure_callback
from jax._src.api import remat as remat
from jax._src.callback import pure_callback_api as pure_callback
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
from jax._src.core import ShapedArray as _deprecated_ShapedArray
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.api import value_and_grad as value_and_grad

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from functools import partial
import logging
from typing import (Any, Callable, FrozenSet, List, Optional, Sequence, Tuple,
@ -20,9 +21,8 @@ import types
import numpy as np
import jax
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
from jax._src import ad_util
from jax._src import api
from jax._src import core
from jax._src import dispatch
from jax._src import linear_util as lu
@ -31,6 +31,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -39,6 +40,7 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax import convolution as lax_convolution
from jax._src.lib.mlir.dialects import hlo
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
safe_zip, merge_lists, weakref_lru_cache)
@ -389,7 +391,7 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)
out = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1],
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1],
return_shape=True)(*in_leaves)
assert isinstance(out, tuple)
jaxpr_, out_shape = out
@ -522,7 +524,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
logger.log(logging.WARNING if jax.config.jax_log_checkpoint_residuals
logger.log(logging.WARNING if config.jax_log_checkpoint_residuals
else logging.DEBUG,
'remat-decorated function ' +
'saving inputs with shapes:\n' * bool(res_invars) +
@ -652,7 +654,7 @@ def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
assert not jaxpr.constvars
if differentiated and prevent_cse:
if jax.config.jax_remat_opt_barrier:
if config.jax_remat_opt_barrier:
translation_rule = _remat_translation_using_opt_barrier
elif is_gpu_platform:
translation_rule = _remat_translation_using_while
@ -661,7 +663,7 @@ def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
else:
translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args)
return jax.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr)
return api.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr)
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
args = _optimization_barrier(args)
@ -670,9 +672,9 @@ def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
# TODO(mattjj): add core utility for 'create dummy value for this type'?
def _dummy_like(aval: core.AbstractValue) -> Any:
if aval is core.abstract_token:
return jax.lax.create_token()
return lax_internal.create_token()
elif isinstance(aval, (core.ShapedArray, core.DShapedArray)):
return jax.lax.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore
return lax_internal.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore
else:
raise ValueError(aval)
@ -682,11 +684,13 @@ def _remat_translation_using_while(*args, jaxpr: core.Jaxpr):
# result = eval_jaxpr(*args)
# }
# The loop carry is a tuple: (counter, result, args)
from jax._src.lax import control_flow as lax_control_flow
avals_out = tuple(v.aval for v in jaxpr.outvars)
carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args)
def cond(carry):
counter, _, _ = carry
unif = jax.lax.rng_uniform(np.int32(1), np.int32(2), shape=())
unif = lax_internal.rng_uniform(np.int32(1), np.int32(2), shape=())
return counter < unif
def body(carry):
@ -694,7 +698,7 @@ def _remat_translation_using_while(*args, jaxpr: core.Jaxpr):
results = core.eval_jaxpr(jaxpr, (), *args)
return (counter + 1, tuple(results), args)
carry_res = jax.lax.while_loop(cond, body, carry_init)
carry_res = lax_control_flow.while_loop(cond, body, carry_init)
return carry_res[1]
def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
@ -703,6 +707,8 @@ def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
# return eval_jaxpr(*args)
# else:
# return 0
from jax._src.lax import control_flow as lax_control_flow
avals_out = tuple(v.aval for v in jaxpr.outvars)
def remat_comp(*args):
@ -710,8 +716,8 @@ def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
def dummy_comp(*args):
return tuple(map(_dummy_like, avals_out))
unif = jax.lax.rng_uniform(np.float32(0), np.float32(1), shape=())
return jax.lax.cond(unif < np.float32(2), remat_comp, dummy_comp, *args)
unif = lax_internal.rng_uniform(np.float32(0), np.float32(1), shape=())
return lax_control_flow.cond(unif < np.float32(2), remat_comp, dummy_comp, *args)
mlir.register_lowering(
remat_p, mlir.lower_fun(remat_lowering, multiple_results=True))
@ -760,3 +766,63 @@ def name_batcher(args, dims, *, name):
(x,), (d,) = args, dims
return name_p.bind(x, name=name), d
batching.primitive_batchers[name_p] = name_batcher
@functools.wraps(checkpoint)
def checkpoint_wrapper(
fun: Callable,
*,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: Union[int, Tuple[int, ...]] = (),
policy: Optional[Callable[..., bool]] = None,
) -> Callable:
if concrete:
msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
"in its place, you can use its `static_argnums` option, and if "
"necessary the `jax.ensure_compile_time_eval()` context manager.\n"
"\n"
"For example, if using `concrete=True` for an `is_training` flag:\n"
"\n"
" from functools import partial\n"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, is_training):\n"
" if is_training:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"replace it with a use of `static_argnums`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, is_training):\n"
" ...\n"
"\n"
"If jax.numpy operations need to be performed on static arguments, "
"we can use the `jax.ensure_compile_time_eval()` context manager. "
"For example, we can replace this use of `concrete=True`\n:"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, y):\n"
" if y > 0:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"with this combination of `static_argnums` and "
"`jax.ensure_compile_time_eval()`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, y):\n"
" with jax.ensure_compile_time_eval():\n"
" y_pos = y > 0\n"
" if y_pos:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n")
raise NotImplementedError(msg)
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)

View File

@ -23,7 +23,6 @@ arrays.
from __future__ import annotations
import collections
import functools
from functools import partial
import inspect
import math
@ -35,14 +34,12 @@ from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
import numpy as np
from contextlib import contextmanager, ExitStack
import jax
from jax._src import linear_util as lu
from jax import stages
from jax._src import stages
from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose,
tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix,
prefix_errors, generate_key_paths, _replace_nones)
from jax._src import callback as jcb
prefix_errors, generate_key_paths)
from jax._src import core
from jax._src import dispatch
from jax._src import effects
@ -57,26 +54,18 @@ from jax._src.api_util import (
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple, argnames_partial_except,
validate_argnames, validate_argnums, check_callable, resolve_argnums,
debug_info, result_paths, flat_out_axes, debug_info_final, FLAGS)
shaped_abstractify, _ensure_str_tuple,
check_callable, debug_info, result_paths, flat_out_axes, debug_info_final, FLAGS)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import PmapSharding
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, curry, safe_map, safe_zip, split_list,
wrap_name, cache, wraps, HashableFunction,
weakref_lru_cache)
from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps)
# Unused imports to be exported
from jax.ad_checkpoint import checkpoint as new_checkpoint
from jax.custom_batching import custom_vmap
from jax.custom_derivatives import (custom_gradient, custom_jvp,
custom_vjp, linear_call)
from jax.custom_transpose import custom_transpose
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
@ -89,7 +78,7 @@ from jax._src.config import (
_thread_local_state as config_thread_local_state,
explicit_device_put_scope as config_explicit_device_put_scope,
explicit_device_get_scope as config_explicit_device_get_scope)
from jax._src.core import ShapedArray, raise_to_shaped
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import pxla
@ -1022,10 +1011,11 @@ def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
argnums, has_aux=has_aux, holomorphic=holomorphic)
def _std_basis(pytree):
import jax.numpy as jnp
leaves, _ = tree_flatten(pytree)
ndim = sum(map(np.size, leaves))
dtype = dtypes.result_type(*leaves)
flat_basis = jax.numpy.eye(ndim, dtype=dtype)
flat_basis = jnp.eye(ndim, dtype=dtype)
return _unravel_array_into_pytree(pytree, 1, None, flat_basis)
def _jacfwd_unravel(input_pytree, output_pytree_leaf, arr):
@ -2487,8 +2477,8 @@ def _infer_src_sharding(src, x):
def device_put(
x,
device: Union[None, xc.Device, jax.sharding.Sharding, Any] = None,
*, src: Union[None, xc.Device, jax.sharding.Sharding, Any] = None):
device: Union[None, xc.Device, Sharding, Any] = None,
*, src: Union[None, xc.Device, Sharding, Any] = None):
"""Transfers ``x`` to ``device``.
Args:
@ -2512,8 +2502,8 @@ def device_put(
blocking the calling Python thread until any transfers are completed.
"""
with config_explicit_device_put_scope():
if ((device is None or isinstance(device, (xc.Device, jax.sharding.Sharding))) and
(src is None or isinstance(src, (xc.Device, jax.sharding.Sharding)))):
if ((device is None or isinstance(device, (xc.Device, Sharding))) and
(src is None or isinstance(src, (xc.Device, Sharding)))):
return tree_map(
lambda y: dispatch.device_put_p.bind(
y, device=device, src=_infer_src_sharding(src, y)), x)
@ -2641,7 +2631,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
assert (isinstance(aval, ShapedArray) and
len(xla.aval_to_xla_shapes(aval)) == 1)
sharding_spec = pxla._create_pmap_sharding_spec(aval)
buf = jax.device_put(x, devices[0])
buf = device_put(x, devices[0])
return pxla.batched_device_put(
aval, PmapSharding(np.array(devices), sharding_spec),
[buf] * len(devices), devices)
@ -2719,7 +2709,7 @@ class ShapeDtypeStruct:
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
self.dtype = dtype if core.is_opaque_dtype(dtype) else np.dtype(dtype)
if sharding is not None:
if not isinstance(sharding, jax.sharding.Sharding):
if not isinstance(sharding, Sharding):
raise ValueError(
"sharding should be an instance of `jax.sharding.Sharding`. "
f"Got {sharding} of type {type(sharding)}.")
@ -2821,65 +2811,6 @@ def eval_shape(fun: Callable, *args, **kwargs):
return tree_unflatten(out_tree(), out)
@functools.wraps(new_checkpoint) # config.jax_new_checkpoint is True by default
def checkpoint(fun: Callable, *,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: Union[int, Tuple[int, ...]] = (),
policy: Optional[Callable[..., bool]] = None,
) -> Callable:
if concrete:
msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
"in its place, you can use its `static_argnums` option, and if "
"necessary the `jax.ensure_compile_time_eval()` context manager.\n"
"\n"
"For example, if using `concrete=True` for an `is_training` flag:\n"
"\n"
" from functools import partial\n"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, is_training):\n"
" if is_training:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"replace it with a use of `static_argnums`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, is_training):\n"
" ...\n"
"\n"
"If jax.numpy operations need to be performed on static arguments, "
"we can use the `jax.ensure_compile_time_eval()` context manager. "
"For example, we can replace this use of `concrete=True`\n:"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, y):\n"
" if y > 0:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"with this combination of `static_argnums` and "
"`jax.ensure_compile_time_eval()`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, y):\n"
" with jax.ensure_compile_time_eval():\n"
" y_pos = y > 0\n"
" if y_pos:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n")
raise NotImplementedError(msg)
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)
remat = checkpoint # type: ignore
def named_call(
fun: Callable[..., Any],
*,
@ -2986,68 +2917,15 @@ def block_until_ready(x):
return x.block_until_ready()
except AttributeError:
return x
return jax.tree_util.tree_map(try_to_block, x)
return tree_map(try_to_block, x)
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
*args: Any, vectorized: bool = False, **kwargs: Any):
"""Applies a functionally pure Python callable. Works under :func:`jit`/:func:`~pmap`/etc.
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
should also return NumPy arrays. Execution takes place on CPU, like any
Python+NumPy function.
The callback is treated as functionally pure, meaning it has no side-effects
and its output value depends only on its argument values. As a consequence, it
is safe to be called multiple times (e.g. when transformed by :func:`~vmap` or
:func:`~pmap`), or not to be called at all when e.g. the output of a
`jit`-decorated function has no data dependence on its value. Pure callbacks
may also be reordered if data-dependence allows.
When :func:`~pmap`-ed, the pure callback will be called several times (one on each
axis of the map). When `vmap`-ed the behavior will depend on the value of the
``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
is assumed to obey
``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``.
Therefore, the callback will be called directly on batched inputs (where the
batch axes are the leading dimensions). Additionally, the callbacks should
return outputs that have corresponding leading batch axes. If not vectorized
``callback`` will be mapped sequentially across the batched axis.
For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free
to set ``vectorized=True`` because the ``np.matmul`` function handles
arbitrary leading batch dimensions.
Args:
callback: A Python callable. The callable will be passed PyTrees of NumPy
arrays as arguments, and should return a PyTree of NumPy arrays that
matches ``result_shape_dtypes``.
result_shape_dtypes: A PyTree with leaves that are objects with ``shape``
and ``dtype`` attributes which represent to the shapes and dtypes of the
value of ``callback`` applied to ``args`` and ``kwargs``.
*args: The positional arguments to the callback. Must be PyTrees of JAX
types.
vectorized: A boolean that indicates whether or not ``callback`` is
vectorized, meaning it can handle arrays with additional leading
dimensions. If ``vectorized`` is `True`, when the callback is mapped
via `jax.vmap`, it will be called directly on inputs with leading batch
dimensions instead of executing ``callback`` on each mapped input
individually. The callback should also return outputs batched across the
leading axis. By default, ``vectorized`` is ``False``.
**kwargs: The keyword arguments to the callback. Must be PyTrees of JAX
types.
Returns:
The value of ``callback(*args, **kwargs)``.
"""
return jcb.pure_callback(callback, result_shape_dtypes, *args,
vectorized=vectorized, **kwargs)
def clear_backends():
"""
Clear all backend clients so that new backend clients can be created later.
"""
xb._clear_backends()
jax.lib.xla_bridge._backends = {}
xb._backends = {}
dispatch.xla_primitive_callable.cache_clear()
pjit._pjit_lower_cached.cache_clear()
pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error

View File

@ -21,19 +21,17 @@ import functools
from typing import (Sequence, Tuple, Callable, Optional, List, cast, Set,
TYPE_CHECKING)
import jax
from jax._src import abstract_arrays
from jax._src import api
from jax._src import api_util
from jax._src import basearray
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import profiler
from jax._src import xla_bridge
from jax._src.config import config
from jax._src.util import use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src import api
from jax._src.typing import ArrayLike
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
@ -41,6 +39,8 @@ from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
device_replica_id_map, hashed_index)
from jax._src.typing import ArrayLike
from jax._src.util import use_cpp_class, use_cpp_method
Shape = Tuple[int, ...]
Device = xc.Device
@ -133,7 +133,7 @@ def _create_copy_plan(arrays, s: Sharding, shape: Shape):
@functools.lru_cache(maxsize=4096)
def _process_has_full_value_in_mcjax(s, shape):
# Return False for single host as a fast path.
if jax.process_count() == 1:
if xla_bridge.process_count() == 1:
return False
num_unique_indices = len(
@ -359,7 +359,7 @@ class ArrayImpl(basearray.Array):
return np.asarray(self._value, dtype=dtype)
def __dlpack__(self):
from jax.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
return to_dlpack(self)
def __reduce__(self):

View File

@ -19,17 +19,17 @@ from typing import Any, Callable, Sequence
import numpy as np
from jax import tree_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import tree_util
from jax._src import util
from jax._src import dispatch
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc
from jax._src.lax.control_flow.loops import map as lax_map
# `pure_callback_p` is the main primitive for staging out Python pure callbacks.
pure_callback_p = core.Primitive("pure_callback")
@ -93,7 +93,6 @@ def pure_callback_batching_rule(args, dims, *, callback, vectorized: bool,
return pure_callback_p.bind(
*merged_args, callback=callback, result_avals=result_avals,
vectorized=vectorized)
from jax._src.lax.control_flow import map as lax_map
outvals = lax_map(_batch_fun, batched_args)
return tuple(outvals), (0,) * len(outvals)
@ -154,6 +153,62 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
return tree_util.tree_unflatten(out_tree, out_flat)
def pure_callback_api(callback: Callable[..., Any], result_shape_dtypes: Any,
*args: Any, vectorized: bool = False, **kwargs: Any):
"""Applies a functionally pure Python callable. Works under :func:`jit`/:func:`~pmap`/etc.
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
should also return NumPy arrays. Execution takes place on CPU, like any
Python+NumPy function.
The callback is treated as functionally pure, meaning it has no side-effects
and its output value depends only on its argument values. As a consequence, it
is safe to be called multiple times (e.g. when transformed by :func:`~vmap` or
:func:`~pmap`), or not to be called at all when e.g. the output of a
`jit`-decorated function has no data dependence on its value. Pure callbacks
may also be reordered if data-dependence allows.
When :func:`~pmap`-ed, the pure callback will be called several times (one on each
axis of the map). When `vmap`-ed the behavior will depend on the value of the
``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
is assumed to obey
``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``.
Therefore, the callback will be called directly on batched inputs (where the
batch axes are the leading dimensions). Additionally, the callbacks should
return outputs that have corresponding leading batch axes. If not vectorized
``callback`` will be mapped sequentially across the batched axis.
For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free
to set ``vectorized=True`` because the ``np.matmul`` function handles
arbitrary leading batch dimensions.
Args:
callback: A Python callable. The callable will be passed PyTrees of NumPy
arrays as arguments, and should return a PyTree of NumPy arrays that
matches ``result_shape_dtypes``.
result_shape_dtypes: A PyTree with leaves that are objects with ``shape``
and ``dtype`` attributes which represent to the shapes and dtypes of the
value of ``callback`` applied to ``args`` and ``kwargs``.
*args: The positional arguments to the callback. Must be PyTrees of JAX
types.
vectorized: A boolean that indicates whether or not ``callback`` is
vectorized, meaning it can handle arrays with additional leading
dimensions. If ``vectorized`` is `True`, when the callback is mapped
via `jax.vmap`, it will be called directly on inputs with leading batch
dimensions instead of executing ``callback`` on each mapped input
individually. The callback should also return outputs batched across the
leading axis. By default, ``vectorized`` is ``False``.
**kwargs: The keyword arguments to the callback. Must be PyTrees of JAX
types.
Returns:
The value of ``callback(*args, **kwargs)``.
"""
return pure_callback(callback, result_shape_dtypes, *args,
vectorized=vectorized, **kwargs)
# IO Callback
io_callback_p = core.Primitive("io_callback")

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
import functools
@ -20,28 +21,28 @@ from typing import (Union, Optional, Callable, Dict, Tuple, TypeVar,
import numpy as np
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import lax
from jax.api_util import flatten_fun
from jax.experimental import pjit
from jax.tree_util import tree_flatten
from jax.tree_util import tree_map
from jax.tree_util import tree_unflatten
from jax._src import api
from jax._src import linear_util as lu
from jax._src import core
from jax._src import custom_derivatives
from jax._src import effects
from jax._src import pjit
from jax._src import prng
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util as jtu
from jax._src.api_util import flatten_fun
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import tree_flatten
from jax._src.tree_util import tree_map
from jax._src.tree_util import tree_unflatten
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
unzip3, weakref_lru_cache)
@ -92,7 +93,7 @@ class JaxException(Exception):
del payload
return cls(metadata)
def get_effect_type(self) -> core.Effect:
def get_effect_type(self) -> ErrorEffect:
raise NotImplementedError
@ -100,7 +101,7 @@ class JaxException(Exception):
@dataclasses.dataclass(eq=True, frozen=True)
class ErrorEffect(effects.Effect):
error_type: Type[JaxException]
shape_dtypes: Tuple[jax.ShapeDtypeStruct, ...]
shape_dtypes: Tuple[api.ShapeDtypeStruct, ...]
def __lt__(self, other: 'ErrorEffect'):
shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable
@ -161,7 +162,7 @@ class OOBError(JaxException):
f'Failed at {self.traceback_info}')
def get_effect_type(self):
return ErrorEffect(OOBError, (jax.ShapeDtypeStruct((3,), jnp.int32),))
return ErrorEffect(OOBError, (api.ShapeDtypeStruct((3,), jnp.int32),))
class FailedCheckError(JaxException):
@ -188,7 +189,7 @@ class FailedCheckError(JaxException):
vals = jtu.tree_leaves((self.args, self.kwargs))
return ErrorEffect(
FailedCheckError,
tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in vals))
tuple(api.ShapeDtypeStruct(x.shape, x.dtype) for x in vals))
@dataclasses.dataclass
class BatchedError(JaxException):
@ -1112,7 +1113,7 @@ def _check(pred, msg, debug, *fmt_args, **fmt_kwargs):
prim_name = 'debug_check' if debug else 'check'
raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}')
for arg in jtu.tree_leaves((fmt_args, fmt_kwargs)):
if not isinstance(arg, (jax.Array, np.ndarray)):
if not isinstance(arg, (Array, np.ndarray)):
raise TypeError('Formatting arguments to checkify.check need to be '
'PyTrees of arrays, but got '
f'{repr(arg)} of type {type(arg)}.')
@ -1130,7 +1131,7 @@ def _check_error(error, *, debug=False):
def is_scalar_pred(pred) -> bool:
return (isinstance(pred, bool) or
isinstance(pred, jax.Array) and pred.shape == () and
isinstance(pred, Array) and pred.shape == () and
pred.dtype == jnp.dtype('bool'))

View File

@ -16,15 +16,14 @@ import functools
import operator
from typing import Callable, Optional
import jax
from jax import tree_util
from jax.tree_util import (tree_flatten, tree_map, tree_structure,
tree_unflatten, treedef_tuple)
from jax import lax
from jax._src import api
from jax._src import core
from jax._src import custom_api_util
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util
from jax._src import util
from jax._src.api_util import flatten_fun_nokwargs
from jax._src.interpreters import ad
@ -33,6 +32,8 @@ from jax._src.interpreters.batching import not_mapped
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.tree_util import (tree_flatten, tree_map, tree_structure,
tree_unflatten, treedef_tuple)
source_info_util.register_exclusion(__file__)
@ -194,7 +195,7 @@ def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree):
return out
def to_vmap_over_extra_batched_dims(primals, tangents):
return jax.jvp(to_jvp, primals, tangents)
return api.jvp(to_jvp, primals, tangents)
to_vmap_over_extra_batched_dims_flat, out_tree2 = flatten_fun_nokwargs(
lu.wrap_init(to_vmap_over_extra_batched_dims),
@ -274,7 +275,7 @@ def sequential_vmap(f):
return f(*args)
mapped_args, bcast_args = tree_split(in_batched, list(args))
out = jax.lax.map(to_map, mapped_args)
out = lax.map(to_map, mapped_args)
out_batched = tree_map(lambda _: True, out)
return out, out_batched

View File

@ -16,15 +16,9 @@ from functools import update_wrapper, reduce, partial
import inspect
from typing import (Callable, Generic, List, Optional, Sequence, Tuple, TypeVar, Any)
from jax.custom_transpose import custom_transpose
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
treedef_is_leaf, treedef_tuple,
register_pytree_node_class, tree_leaves)
from jax.errors import UnexpectedTracerError
from jax.config import config
from jax._src import core
from jax._src import custom_api_util
from jax._src.custom_transpose import custom_transpose
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
@ -32,7 +26,9 @@ from jax._src import traceback_util
from jax._src.ad_util import (Zero, SymbolicZero, zeros_like_aval,
stop_gradient_p)
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
from jax._src.config import config
from jax._src.core import raise_to_shaped
from jax._src.errors import UnexpectedTracerError
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -40,6 +36,9 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.interpreters.batching import not_mapped
from jax._src.lax import lax
from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_map,
treedef_is_leaf, treedef_tuple,
register_pytree_node_class, tree_leaves)
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable

View File

@ -15,8 +15,6 @@
import functools
from typing import Any, Callable, Optional, Tuple
from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
tree_structure, treedef_tuple, tree_unflatten)
from jax._src import ad_util
from jax._src import api_util
from jax._src import core
@ -29,6 +27,8 @@ from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.tree_util import (tree_flatten, tree_leaves, tree_map,
tree_structure, treedef_tuple, tree_unflatten)
source_info_util.register_exclusion(__file__)

View File

@ -22,23 +22,24 @@ import weakref
import numpy as np
import jax.numpy as jnp
from jax import tree_util
from jax import lax
from jax._src import core
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import pjit
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding
from jax._src.sharding_impls import GSPMDSharding, NamedSharding
from jax._src.interpreters import partial_eval as pe
# pytype: disable=import-error
try:

View File

@ -30,10 +30,6 @@ import warnings
import numpy as np
import jax
from jax.monitoring import record_event_duration_secs
from jax._src import array
from jax._src import core
from jax._src import dtypes
from jax._src import linear_util as lu
@ -53,6 +49,7 @@ from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir
from jax._src.lib import xla_client as xc
from jax._src.monitoring import record_event_duration_secs
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, NamedSharding,
@ -145,7 +142,7 @@ class RuntimeTokenSet(threading.local):
self.output_runtime_tokens = {}
def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken:
s = jax.sharding.SingleDeviceSharding(device)
s = SingleDeviceSharding(device)
if eff not in self.tokens:
inp = np.zeros(0, np.bool_)
indices = tuple(
@ -302,7 +299,7 @@ class SourceInfo(NamedTuple):
def jaxpr_shardings(
jaxpr) -> Iterator[Tuple[jax.sharding.XLACompatibleSharding, SourceInfo]]:
jaxpr) -> Iterator[Tuple[XLACompatibleSharding, SourceInfo]]:
from jax._src import pjit
from jax.experimental import shard_map
@ -570,8 +567,9 @@ def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool):
def _device_put_impl(
x,
device: Optional[Union[Device, jax.sharding.Sharding]] = None,
src: Optional[Union[Device, jax.sharding.Sharding]] = None):
device: Optional[Union[Device, Sharding]] = None,
src: Optional[Union[Device, Sharding]] = None):
from jax._src import array
try:
aval = xla.abstractify(x)
except TypeError as err:

View File

@ -16,13 +16,13 @@ import warnings
import numpy as np
from jax import lax
import jax.numpy as jnp
from jax._src import dtypes
from jax._src.tree_util import tree_flatten, tree_unflatten
from jax._src.util import safe_zip, unzip2, HashablePartial
import jax.numpy as jnp
from jax._src import dtypes
from jax import lax
zip = safe_zip

View File

@ -17,7 +17,6 @@ from functools import partial
import operator
import jax
from jax import lax
from jax.tree_util import (tree_flatten, treedef_children, tree_leaves,
tree_unflatten, treedef_tuple)
from jax._src import ad_util
@ -28,6 +27,7 @@ from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lax import lax
from jax._src.traceback_util import api_boundary
from jax._src.util import split_list, safe_map
import numpy as np

View File

@ -20,35 +20,42 @@ from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
NamedTuple, Union, Sequence)
from functools import wraps, partial, partialmethod, lru_cache
from jax import lax
from jax import numpy as jnp
from jax._src import core
from jax._src import mesh
from jax._src import linear_util as lu
from jax import stages
from jax._src import dispatch
from jax._src import effects
from jax.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
treedef_tuple)
from jax._src import mesh
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import stages
from jax._src import traceback_util
from jax._src.api_util import (flatten_fun_nokwargs, flatten_axes,
_ensure_index_tuple, donation_vector,
shaped_abstractify, check_callable)
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.config import config
from jax.errors import JAXTypeError
from jax._src.array import ArrayImpl
from jax._src.sharding_impls import NamedSharding
from jax._src.config import config
from jax._src.errors import JAXTypeError
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters.partial_eval import (
trace_to_subjaxpr_dynamic, DynamicJaxprTracer,
convert_constvars_jaxpr, new_jaxpr_eqn)
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.pjit import (
sharding_constraint_p, ParsedPartitionSpec, get_unconstrained_dims,
GSPMDSharding)
from jax._src.sharding_impls import NamedSharding
from jax._src.tree_util import (tree_flatten, tree_unflatten, all_leaves,
tree_map, treedef_tuple)
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
as_hashable_function, distributed_debug_log,
tuple_insert, moveaxis, split_list, wrap_name,
merge_lists, partition_list)
from jax import lax
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
@ -965,9 +972,6 @@ pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap
# This is DynamicJaxprTrace.process_map with some very minor modifications
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
from jax._src.interpreters.partial_eval import (
trace_to_subjaxpr_dynamic, DynamicJaxprTracer,
convert_constvars_jaxpr, new_jaxpr_eqn)
assert primitive is xmap_p
in_avals = [t.aval for t in tracers]
global_axis_sizes = params['global_axis_sizes']
@ -1775,10 +1779,6 @@ def _check_no_loop_collectives(jaxpr, loop_axis_resources):
def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
from jax._src.pjit import (
sharding_constraint_p, ParsedPartitionSpec, get_unconstrained_dims,
GSPMDSharding)
rec = lambda jaxpr: _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name)
if isinstance(jaxpr, core.ClosedJaxpr):
return jaxpr.map_jaxpr(rec)

View File

@ -25,7 +25,8 @@ import numpy as np
from jax._src import core
from jax._src import dtypes
from jax._src.api import jit, custom_jvp
from jax._src.api import jit
from jax._src.custom_derivatives import custom_jvp
from jax._src.lax import lax
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import (

View File

@ -24,23 +24,10 @@ from functools import partial, lru_cache
import threading
import warnings
import jax
from jax._src import core
from jax import stages
from jax.errors import JAXTypeError
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters.pxla import PartitionSpec
from jax._src.interpreters import xla
from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
treedef_tuple, broadcast_prefix, all_leaves)
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
NamedSharding, XLACompatibleSharding, GSPMDSharding,
XLADeviceAssignment, SingleDeviceSharding, PmapSharding)
from jax._src import stages
from jax._src import dispatch
from jax._src import mesh
from jax._src import mesh as mesh_lib
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import traceback_util
@ -50,6 +37,11 @@ from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info, FLAGS)
from jax._src.errors import JAXTypeError
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters.pxla import PartitionSpec
from jax._src.interpreters import xla
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -58,8 +50,15 @@ from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_client as xc
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
NamedSharding, XLACompatibleSharding, GSPMDSharding,
XLADeviceAssignment, SingleDeviceSharding, PmapSharding)
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import prefix_errors, generate_key_paths
from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
treedef_tuple, broadcast_prefix, all_leaves,
prefix_errors, generate_key_paths)
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps,
distributed_debug_log, split_list, tuple_insert, weakref_lru_cache,
@ -311,7 +310,7 @@ def _resolve_axis_resources_and_shardings_arg(
def pre_infer_params(fun, in_shardings, out_shardings,
donate_argnums, static_argnums, static_argnames, device,
backend, abstracted_axes):
if abstracted_axes and not jax.config.jax_dynamic_shapes:
if abstracted_axes and not config.jax_dynamic_shapes:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
check_callable(fun)
@ -455,7 +454,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
dyn_kwargs = ()
del kwargs
if donate_argnums and not jax.config.jax_debug_nans:
if donate_argnums and not config.jax_debug_nans:
donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs)
else:
donated_invars = (False,) * len(explicit_args)
@ -724,7 +723,7 @@ def pjit(
def infer_params(*args, **kwargs):
# Putting this outside of wrapped would make resources lexically scoped
resource_env = mesh.thread_resources.env
resource_env = mesh_lib.thread_resources.env
pjit_info_args = PjitInfo(
fun=fun, in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
@ -1125,7 +1124,7 @@ def _check_unique_resources(axis_resources, arg_name):
if multiple_uses:
raise ValueError(f"A single {arg_name} specification can map every mesh axis "
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
f"has duplicate entries for {mesh.show_axes(multiple_uses)}")
f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}")
# -------------------- pjit rules --------------------
@ -1812,7 +1811,7 @@ def _check_resources_against_named_axes(what, aval, pos_axis_resources, named_ax
f"{pos_axis_resources.unsynced_user_spec(SpecSync.DIM_PERMUTE)} "
f"that uses one or more mesh axes already used by xmap to partition "
f"a named axis appearing in its named_shape (both use mesh axes "
f"{mesh.show_axes(overlap)})")
f"{mesh_lib.show_axes(overlap)})")
def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_resources):
jaxpr = params["jaxpr"]
@ -1918,7 +1917,7 @@ def with_sharding_constraint(x, shardings=_UNSPECIFIED,
flatten_axes("with_sharding_constraint shardings", tree, user_shardings))
del user_shardings
resource_env = jax._src.mesh.thread_resources.env
resource_env = mesh_lib.thread_resources.env
mesh = resource_env.physical_mesh
shardings_flat = [_create_sharding_for_array(mesh, a)

View File

@ -21,18 +21,20 @@ from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence, Unio
import numpy as np
import jax
from jax import lax
from jax import numpy as jnp
from jax.config import config
from jax.dtypes import float0
from jax._src import api
from jax._src import basearray
from jax._src import config as config_lib
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import typing
from jax._src.api import jit, vmap
from jax._src.config import config
from jax._src.dtypes import float0
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -97,7 +99,7 @@ class PRNGImpl(NamedTuple):
# -- PRNG key arrays
def _check_prng_key_data(impl, key_data: jax.Array):
def _check_prng_key_data(impl, key_data: typing.Array):
ndim = len(impl.key_shape)
if not all(hasattr(key_data, attr) for attr in ['ndim', 'shape', 'dtype']):
raise TypeError("JAX encountered invalid PRNG key data: expected key_data "
@ -139,7 +141,7 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
"""
impl: PRNGImpl
_base_array: jax.Array
_base_array: typing.Array
def __init__(self, impl, key_data: Any):
assert not isinstance(key_data, core.Tracer)
@ -512,7 +514,7 @@ mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler)
def iterated_vmap_unary(n, f):
for _ in range(n):
f = jax.vmap(f)
f = api.vmap(f)
return f
# TODO(frostig): Revise the following two functions? These basically
@ -528,7 +530,7 @@ def squeeze_vmap(f, left):
else:
y = jnp.squeeze(y, axis=0)
axes = (0, None)
return jax.vmap(f, in_axes=axes, out_axes=0)(x, y)
return api.vmap(f, in_axes=axes, out_axes=0)(x, y)
return squeeze_vmap_f
def iterated_vmap_binary_bcast(shape1, shape2, f):
@ -543,7 +545,7 @@ def iterated_vmap_binary_bcast(shape1, shape2, f):
assert len(shape1) == len(shape2)
for sz1, sz2 in reversed(zip(shape1, shape2)):
if sz1 == sz2:
f = jax.vmap(f, out_axes=0)
f = api.vmap(f, out_axes=0)
else:
assert sz1 == 1 or sz2 == 1, (sz1, sz2)
f = squeeze_vmap(f, sz1 == 1)
@ -785,14 +787,14 @@ mlir.register_lowering(random_unwrap_p, random_unwrap_lowering)
# -- threefry2x32 PRNG implementation
def _is_threefry_prng_key(key: jax.Array) -> bool:
def _is_threefry_prng_key(key: typing.Array) -> bool:
try:
return key.shape == (2,) and key.dtype == np.uint32
except AttributeError:
return False
def threefry_seed(seed: jax.Array) -> jax.Array:
def threefry_seed(seed: typing.Array) -> typing.Array:
"""Create a single raw threefry PRNG key from an integer seed.
Args:
@ -811,7 +813,7 @@ def threefry_seed(seed: jax.Array) -> jax.Array:
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
k1 = convert(
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
with jax.numpy_dtype_promotion('standard'):
with config_lib.numpy_dtype_promotion('standard'):
# TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit
# inputs. We should avoid this.
k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))
@ -1090,26 +1092,26 @@ def threefry_2x32(keypair, count):
return lax.reshape(out[:-1] if odd_size else out, count.shape)
def threefry_split(key: jax.Array, num: int) -> jax.Array:
def threefry_split(key: typing.Array, num: int) -> typing.Array:
if config.jax_threefry_partitionable:
return _threefry_split_foldlike(key, int(num)) # type: ignore
else:
return _threefry_split_original(key, int(num)) # type: ignore
@partial(jit, static_argnums=(1,), inline=True)
def _threefry_split_original(key, num) -> jax.Array:
def _threefry_split_original(key, num) -> typing.Array:
counts = lax.iota(np.uint32, num * 2)
return lax.reshape(threefry_2x32(key, counts), (num, 2))
@partial(jit, static_argnums=(1,), inline=True)
def _threefry_split_foldlike(key, num) -> jax.Array:
def _threefry_split_foldlike(key, num) -> typing.Array:
k1, k2 = key
counts1, counts2 = iota_2x32_shape((num,))
bits1, bits2 = threefry2x32_p.bind(k1, k2, counts1, counts2)
return jnp.stack([bits1, bits2], axis=1)
def threefry_fold_in(key: jax.Array, data: jax.Array) -> jax.Array:
def threefry_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
assert not data.shape
return _threefry_fold_in(key, jnp.uint32(data))
@ -1118,7 +1120,7 @@ def _threefry_fold_in(key, data):
return threefry_2x32(key, threefry_seed(data))
def threefry_random_bits(key: jax.Array, bit_width, shape):
def threefry_random_bits(key: typing.Array, bit_width, shape):
"""Sample uniform random bits of given width and shape using PRNG key."""
if not _is_threefry_prng_key(key):
raise TypeError("threefry_random_bits got invalid prng key.")
@ -1131,7 +1133,7 @@ def threefry_random_bits(key: jax.Array, bit_width, shape):
else:
return _threefry_random_bits_original(key, bit_width, shape)
def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape):
def _threefry_random_bits_partitionable(key: typing.Array, bit_width, shape):
if all(core.is_constant_dim(d) for d in shape) and math.prod(shape) > 2 ** 64:
raise NotImplementedError('random bits array of size exceeding 2 ** 64')
@ -1150,7 +1152,7 @@ def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape):
return lax.convert_element_type(bits1 ^ bits2, dtype)
@partial(jit, static_argnums=(1, 2), inline=True)
def _threefry_random_bits_original(key: jax.Array, bit_width, shape):
def _threefry_random_bits_original(key: typing.Array, bit_width, shape):
size = math.prod(shape)
# Compute ceil(bit_width * size / 32) in a way that is friendly to shape
# polymorphism
@ -1210,12 +1212,12 @@ threefry_prng_impl = PRNGImpl(
# stable/deterministic across backends or compiler versions. Correspondingly, we
# reserve the right to change any of these implementations at any time!
def _rbg_seed(seed: jax.Array) -> jax.Array:
def _rbg_seed(seed: typing.Array) -> typing.Array:
assert not seed.shape
halfkey = threefry_seed(seed)
return jnp.concatenate([halfkey, halfkey])
def _rbg_split(key: jax.Array, num: int) -> jax.Array:
def _rbg_split(key: typing.Array, num: int) -> typing.Array:
if config.jax_threefry_partitionable:
_threefry_split = _threefry_split_foldlike
else:
@ -1223,12 +1225,12 @@ def _rbg_split(key: jax.Array, num: int) -> jax.Array:
return vmap(
_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)
def _rbg_fold_in(key: jax.Array, data: jax.Array) -> jax.Array:
def _rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
assert not data.shape
return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4)
def _rbg_random_bits(key: jax.Array, bit_width: int, shape: Sequence[int]
) -> jax.Array:
def _rbg_random_bits(key: typing.Array, bit_width: int, shape: Sequence[int]
) -> typing.Array:
if not key.shape == (4,) and key.dtype == jnp.dtype('uint32'):
raise TypeError("_rbg_random_bits got invalid prng key.")
if bit_width not in (8, 16, 32, 64):
@ -1244,12 +1246,12 @@ rbg_prng_impl = PRNGImpl(
fold_in=_rbg_fold_in,
tag='rbg')
def _unsafe_rbg_split(key: jax.Array, num: int) -> jax.Array:
def _unsafe_rbg_split(key: typing.Array, num: int) -> typing.Array:
# treat 10 iterations of random bits as a 'hash function'
_, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32')
return keys[::10]
def _unsafe_rbg_fold_in(key: jax.Array, data: jax.Array) -> jax.Array:
def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
assert not data.shape
_, random_bits = lax.rng_bit_generator(_rbg_seed(data), (10, 4), dtype='uint32')
return key ^ random_bits[-1]

View File

@ -15,13 +15,13 @@
from functools import partial
import operator
from jax import config
from jax.tree_util import tree_map, tree_reduce
from jax._src import api
from jax._src import dtypes as _dtypes
from jax._src import xla_bridge
from jax._src.config import flags
from jax._src.config import config, flags
from jax._src.lib import xla_client
from jax._src.tree_util import tree_map, tree_reduce
import numpy as np

View File

@ -21,17 +21,17 @@ import warnings
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
from jax.config import config
from jax.numpy.linalg import cholesky, svd, eigh
from jax._src import config as config_lib
from jax._src import core
from jax._src import dtypes
from jax._src import prng
from jax._src import xla_bridge
from jax._src.api import jit, vmap
from jax._src.config import config
from jax._src.core import NamedShape
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -677,7 +677,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
else: # 'cholesky'
factor = cholesky(cov)
normal_samples = normal(key, shape + mean.shape[-1:], dtype)
with jax.numpy_rank_promotion('allow'):
with config_lib.numpy_rank_promotion('allow'):
result = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
return result

View File

@ -24,8 +24,8 @@ from jax import jit
from jax import vmap
from jax import lax
from jax._src import api
from jax._src import core
from jax._src import custom_derivatives
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.lax.lax import _const as _lax_const
@ -96,7 +96,7 @@ def erfinv(x: ArrayLike) -> Array:
return lax.erf_inv(x)
@api.custom_jvp
@custom_derivatives.custom_jvp
@_wraps(osp_special.logit, module='scipy.special', update_doc=False)
def logit(x: ArrayLike) -> Array:
x, = promote_args_inexact("logit", x)
@ -214,7 +214,7 @@ def polygamma(n: ArrayLike, x: ArrayLike) -> Array:
return _polygamma(jnp.broadcast_to(n_arr, shape), jnp.broadcast_to(x_arr, shape))
@api.custom_jvp
@custom_derivatives.custom_jvp
def _polygamma(n: ArrayLike, x: ArrayLike) -> Array:
dtype = lax.dtype(n).type
n_plus = n + dtype(1)
@ -481,7 +481,7 @@ def _ndtri(p: ArrayLike) -> Array:
return x_nan_replaced
@partial(api.custom_jvp, nondiff_argnums=(1,))
@partial(custom_derivatives.custom_jvp, nondiff_argnums=(1,))
def log_ndtr(x: ArrayLike, series_order: int = 3) -> Array:
r"""Log Normal distribution function.
@ -1361,7 +1361,7 @@ def _expi_neg(x: Array) -> Array:
# x < 0
return -exp1(-x)
@api.custom_jvp
@custom_derivatives.custom_jvp
@jit
@_wraps(osp_special.expi, module='scipy.special')
def expi(x: ArrayLike) -> Array:
@ -1479,7 +1479,7 @@ def _expn3(n: int, x: Array) -> Array:
return (ans + one) * jnp.exp(-x) / xk
@partial(api.custom_jvp, nondiff_argnums=(0,))
@partial(custom_derivatives.custom_jvp, nondiff_argnums=(0,))
@jnp.vectorize
@_wraps(osp_special.expn, module='scipy.special')
@jit

View File

@ -20,10 +20,10 @@ import operator as op
from typing import (Any, Sequence, List, Tuple, Optional, Mapping, Dict, Set,
FrozenSet, Union, cast)
import jax
from jax._src import core
from jax._src import mesh as mesh_lib
from jax._src import sharding
from jax._src import xla_bridge
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.interpreters import mlir
@ -238,7 +238,7 @@ class NamedSharding(XLACompatibleSharding):
# TODO(yaskatariya): Remove this and replace this with a normalized
# representation of Parsed Pspec
if self._parsed_pspec is None:
from jax.experimental import pjit
from jax._src import pjit
self._parsed_pspec, _, _ = pjit._prepare_axis_resources(
self.spec, "NamedSharding spec")
@ -287,7 +287,7 @@ class NamedSharding(XLACompatibleSharding):
num_dimensions: int,
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
) -> xc.OpSharding:
from jax.experimental.pjit import get_array_mapping
from jax._src.pjit import get_array_mapping
assert self._parsed_pspec is not None
array_mapping = get_array_mapping(self._parsed_pspec)
# TODO(yashkatariya): Move away from sharding spec in NamedSharding
@ -429,7 +429,8 @@ class PmapSharding(XLACompatibleSharding):
'`None` to sharded_dim is not supported. Please file a jax '
'issue if you need this feature.')
pmap_devices: np.ndarray = np.array(jax.local_devices()[:num_ways_sharded])
pmap_devices: np.ndarray = np.array(
xla_bridge.local_devices()[:num_ways_sharded])
return cls(pmap_devices, sharding_spec)
@functools.cached_property

View File

@ -21,7 +21,6 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Tupl
import numpy as np
from jax import lax
from jax._src import api_util
from jax._src import ad_util
@ -32,6 +31,8 @@ from jax._src import source_info_util
from jax._src import tree_util
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.lax import lax
from jax._src.lax import slicing as lax_slicing
from jax._src.state.types import AbstractRef, RefEffect
from jax._src.state.primitives import get_p, swap_p, addupdate_p
from jax._src.state.utils import hoist_consts_to_refs
@ -222,7 +223,7 @@ def _dynamic_index(x, idx, indexed_dims):
starts = [next(idx_) if b else np.int32(0) for b in indexed_dims]
assert next(idx_, None) is None
sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)]
out = lax.dynamic_slice(x, starts, sizes)
out = lax_slicing.dynamic_slice(x, starts, sizes)
return lax.squeeze(out, [i for i, b in enumerate(indexed_dims) if b])
def _dynamic_update_index(x, idx, val, indexed_dims):
@ -231,7 +232,7 @@ def _dynamic_update_index(x, idx, val, indexed_dims):
starts = [next(idx_) if b else np.int32(0) for b in indexed_dims]
assert next(idx_, None) is None
sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)]
return lax.dynamic_update_slice(x, val.reshape(sizes), starts)
return lax_slicing.dynamic_update_slice(x, val.reshape(sizes), starts)
@register_discharge_rule(core.closed_call_p)
def _closed_call_discharge_rule(

View File

@ -18,13 +18,14 @@ from typing import Any, List, Tuple, Union
import numpy as np
import jax
from jax._src import ad_util
from jax._src import core
from jax._src import pretty_printer as pp
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax
from jax._src.typing import Array
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
AccumEffect)
@ -420,7 +421,7 @@ def _get_vmap(batched_args, batched_dims, *, indexed_dims):
# `idxs` doesn't include the non indexed dims.
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
else:
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
@ -453,7 +454,7 @@ def _swap_vmap(batched_args, batched_dims, *, indexed_dims):
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
val = batching.moveaxis(val, val_dim, 0)
bdim_out = 0
@ -486,7 +487,7 @@ def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims):
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
idxs_shape, = {i.shape for i in idxs} or [()]
iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
val = batching.moveaxis(val, val_dim, 0)
return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), []

File diff suppressed because it is too large Load Diff