mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph. PiperOrigin-RevId: 521822756
This commit is contained in:
parent
3c1f3abba2
commit
c1f65fc8b2
@ -76,7 +76,7 @@ del _xc
|
||||
|
||||
from jax._src.api import effects_barrier as effects_barrier
|
||||
from jax._src.api import block_until_ready as block_until_ready
|
||||
from jax._src.api import checkpoint as checkpoint
|
||||
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
|
||||
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
|
||||
from jax._src.api import clear_backends as clear_backends
|
||||
from jax._src.custom_derivatives import closure_convert as closure_convert
|
||||
@ -116,8 +116,8 @@ from jax._src.api import named_scope as named_scope
|
||||
from jax._src.api import pmap as pmap
|
||||
from jax._src.xla_bridge import process_count as process_count
|
||||
from jax._src.xla_bridge import process_index as process_index
|
||||
from jax._src.api import pure_callback as pure_callback
|
||||
from jax._src.api import remat as remat
|
||||
from jax._src.callback import pure_callback_api as pure_callback
|
||||
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
|
||||
from jax._src.core import ShapedArray as _deprecated_ShapedArray
|
||||
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
|
||||
from jax._src.api import value_and_grad as value_and_grad
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import (Any, Callable, FrozenSet, List, Optional, Sequence, Tuple,
|
||||
@ -20,9 +21,8 @@ import types
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import linear_util as lu
|
||||
@ -31,6 +31,7 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.api_util import flatten_fun, shaped_abstractify
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -39,6 +40,7 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import convolution as lax_convolution
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
|
||||
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
|
||||
safe_zip, merge_lists, weakref_lru_cache)
|
||||
|
||||
@ -389,7 +391,7 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
|
||||
args, kwargs = tree_unflatten(in_tree, args)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
out = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1],
|
||||
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1],
|
||||
return_shape=True)(*in_leaves)
|
||||
assert isinstance(out, tuple)
|
||||
jaxpr_, out_shape = out
|
||||
@ -522,7 +524,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
||||
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
|
||||
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
|
||||
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
|
||||
logger.log(logging.WARNING if jax.config.jax_log_checkpoint_residuals
|
||||
logger.log(logging.WARNING if config.jax_log_checkpoint_residuals
|
||||
else logging.DEBUG,
|
||||
'remat-decorated function ' +
|
||||
'saving inputs with shapes:\n' * bool(res_invars) +
|
||||
@ -652,7 +654,7 @@ def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
|
||||
assert not jaxpr.constvars
|
||||
|
||||
if differentiated and prevent_cse:
|
||||
if jax.config.jax_remat_opt_barrier:
|
||||
if config.jax_remat_opt_barrier:
|
||||
translation_rule = _remat_translation_using_opt_barrier
|
||||
elif is_gpu_platform:
|
||||
translation_rule = _remat_translation_using_while
|
||||
@ -661,7 +663,7 @@ def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
|
||||
else:
|
||||
translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args)
|
||||
|
||||
return jax.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr)
|
||||
return api.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr)
|
||||
|
||||
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
|
||||
args = _optimization_barrier(args)
|
||||
@ -670,9 +672,9 @@ def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
|
||||
# TODO(mattjj): add core utility for 'create dummy value for this type'?
|
||||
def _dummy_like(aval: core.AbstractValue) -> Any:
|
||||
if aval is core.abstract_token:
|
||||
return jax.lax.create_token()
|
||||
return lax_internal.create_token()
|
||||
elif isinstance(aval, (core.ShapedArray, core.DShapedArray)):
|
||||
return jax.lax.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore
|
||||
return lax_internal.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore
|
||||
else:
|
||||
raise ValueError(aval)
|
||||
|
||||
@ -682,11 +684,13 @@ def _remat_translation_using_while(*args, jaxpr: core.Jaxpr):
|
||||
# result = eval_jaxpr(*args)
|
||||
# }
|
||||
# The loop carry is a tuple: (counter, result, args)
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
|
||||
avals_out = tuple(v.aval for v in jaxpr.outvars)
|
||||
carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args)
|
||||
def cond(carry):
|
||||
counter, _, _ = carry
|
||||
unif = jax.lax.rng_uniform(np.int32(1), np.int32(2), shape=())
|
||||
unif = lax_internal.rng_uniform(np.int32(1), np.int32(2), shape=())
|
||||
return counter < unif
|
||||
|
||||
def body(carry):
|
||||
@ -694,7 +698,7 @@ def _remat_translation_using_while(*args, jaxpr: core.Jaxpr):
|
||||
results = core.eval_jaxpr(jaxpr, (), *args)
|
||||
return (counter + 1, tuple(results), args)
|
||||
|
||||
carry_res = jax.lax.while_loop(cond, body, carry_init)
|
||||
carry_res = lax_control_flow.while_loop(cond, body, carry_init)
|
||||
return carry_res[1]
|
||||
|
||||
def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
|
||||
@ -703,6 +707,8 @@ def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
|
||||
# return eval_jaxpr(*args)
|
||||
# else:
|
||||
# return 0
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
|
||||
avals_out = tuple(v.aval for v in jaxpr.outvars)
|
||||
|
||||
def remat_comp(*args):
|
||||
@ -710,8 +716,8 @@ def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
|
||||
def dummy_comp(*args):
|
||||
return tuple(map(_dummy_like, avals_out))
|
||||
|
||||
unif = jax.lax.rng_uniform(np.float32(0), np.float32(1), shape=())
|
||||
return jax.lax.cond(unif < np.float32(2), remat_comp, dummy_comp, *args)
|
||||
unif = lax_internal.rng_uniform(np.float32(0), np.float32(1), shape=())
|
||||
return lax_control_flow.cond(unif < np.float32(2), remat_comp, dummy_comp, *args)
|
||||
|
||||
mlir.register_lowering(
|
||||
remat_p, mlir.lower_fun(remat_lowering, multiple_results=True))
|
||||
@ -760,3 +766,63 @@ def name_batcher(args, dims, *, name):
|
||||
(x,), (d,) = args, dims
|
||||
return name_p.bind(x, name=name), d
|
||||
batching.primitive_batchers[name_p] = name_batcher
|
||||
|
||||
|
||||
@functools.wraps(checkpoint)
|
||||
def checkpoint_wrapper(
|
||||
fun: Callable,
|
||||
*,
|
||||
concrete: bool = False,
|
||||
prevent_cse: bool = True,
|
||||
static_argnums: Union[int, Tuple[int, ...]] = (),
|
||||
policy: Optional[Callable[..., bool]] = None,
|
||||
) -> Callable:
|
||||
if concrete:
|
||||
msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
|
||||
"in its place, you can use its `static_argnums` option, and if "
|
||||
"necessary the `jax.ensure_compile_time_eval()` context manager.\n"
|
||||
"\n"
|
||||
"For example, if using `concrete=True` for an `is_training` flag:\n"
|
||||
"\n"
|
||||
" from functools import partial\n"
|
||||
"\n"
|
||||
" @partial(jax.checkpoint, concrete=True)\n"
|
||||
" def foo(x, is_training):\n"
|
||||
" if is_training:\n"
|
||||
" return f(x)\n"
|
||||
" else:\n"
|
||||
" return g(x)\n"
|
||||
"\n"
|
||||
"replace it with a use of `static_argnums`:\n"
|
||||
"\n"
|
||||
" @partial(jax.checkpoint, static_argnums=(1,))\n"
|
||||
" def foo(x, is_training):\n"
|
||||
" ...\n"
|
||||
"\n"
|
||||
"If jax.numpy operations need to be performed on static arguments, "
|
||||
"we can use the `jax.ensure_compile_time_eval()` context manager. "
|
||||
"For example, we can replace this use of `concrete=True`\n:"
|
||||
"\n"
|
||||
" @partial(jax.checkpoint, concrete=True)\n"
|
||||
" def foo(x, y):\n"
|
||||
" if y > 0:\n"
|
||||
" return f(x)\n"
|
||||
" else:\n"
|
||||
" return g(x)\n"
|
||||
"\n"
|
||||
"with this combination of `static_argnums` and "
|
||||
"`jax.ensure_compile_time_eval()`:\n"
|
||||
"\n"
|
||||
" @partial(jax.checkpoint, static_argnums=(1,))\n"
|
||||
" def foo(x, y):\n"
|
||||
" with jax.ensure_compile_time_eval():\n"
|
||||
" y_pos = y > 0\n"
|
||||
" if y_pos:\n"
|
||||
" return f(x)\n"
|
||||
" else:\n"
|
||||
" return g(x)\n"
|
||||
"\n"
|
||||
"See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n")
|
||||
raise NotImplementedError(msg)
|
||||
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
|
||||
static_argnums=static_argnums)
|
||||
|
156
jax/_src/api.py
156
jax/_src/api.py
@ -23,7 +23,6 @@ arrays.
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import math
|
||||
@ -35,14 +34,12 @@ from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
|
||||
import numpy as np
|
||||
from contextlib import contextmanager, ExitStack
|
||||
|
||||
import jax
|
||||
from jax._src import linear_util as lu
|
||||
from jax import stages
|
||||
from jax._src import stages
|
||||
from jax._src.tree_util import (
|
||||
tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose,
|
||||
tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix,
|
||||
prefix_errors, generate_key_paths, _replace_nones)
|
||||
from jax._src import callback as jcb
|
||||
prefix_errors, generate_key_paths)
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
@ -57,26 +54,18 @@ from jax._src.api_util import (
|
||||
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
|
||||
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
shaped_abstractify, _ensure_str_tuple, argnames_partial_except,
|
||||
validate_argnames, validate_argnums, check_callable, resolve_argnums,
|
||||
debug_info, result_paths, flat_out_axes, debug_info_final, FLAGS)
|
||||
shaped_abstractify, _ensure_str_tuple,
|
||||
check_callable, debug_info, result_paths, flat_out_axes, debug_info_final, FLAGS)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import PmapSharding
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (unzip2, curry, safe_map, safe_zip, split_list,
|
||||
wrap_name, cache, wraps, HashableFunction,
|
||||
weakref_lru_cache)
|
||||
from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps)
|
||||
|
||||
|
||||
# Unused imports to be exported
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint
|
||||
from jax.custom_batching import custom_vmap
|
||||
from jax.custom_derivatives import (custom_gradient, custom_jvp,
|
||||
custom_vjp, linear_call)
|
||||
from jax.custom_transpose import custom_transpose
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
@ -89,7 +78,7 @@ from jax._src.config import (
|
||||
_thread_local_state as config_thread_local_state,
|
||||
explicit_device_put_scope as config_explicit_device_put_scope,
|
||||
explicit_device_get_scope as config_explicit_device_get_scope)
|
||||
from jax._src.core import ShapedArray, raise_to_shaped
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import pxla
|
||||
@ -1022,10 +1011,11 @@ def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
argnums, has_aux=has_aux, holomorphic=holomorphic)
|
||||
|
||||
def _std_basis(pytree):
|
||||
import jax.numpy as jnp
|
||||
leaves, _ = tree_flatten(pytree)
|
||||
ndim = sum(map(np.size, leaves))
|
||||
dtype = dtypes.result_type(*leaves)
|
||||
flat_basis = jax.numpy.eye(ndim, dtype=dtype)
|
||||
flat_basis = jnp.eye(ndim, dtype=dtype)
|
||||
return _unravel_array_into_pytree(pytree, 1, None, flat_basis)
|
||||
|
||||
def _jacfwd_unravel(input_pytree, output_pytree_leaf, arr):
|
||||
@ -2487,8 +2477,8 @@ def _infer_src_sharding(src, x):
|
||||
|
||||
def device_put(
|
||||
x,
|
||||
device: Union[None, xc.Device, jax.sharding.Sharding, Any] = None,
|
||||
*, src: Union[None, xc.Device, jax.sharding.Sharding, Any] = None):
|
||||
device: Union[None, xc.Device, Sharding, Any] = None,
|
||||
*, src: Union[None, xc.Device, Sharding, Any] = None):
|
||||
"""Transfers ``x`` to ``device``.
|
||||
|
||||
Args:
|
||||
@ -2512,8 +2502,8 @@ def device_put(
|
||||
blocking the calling Python thread until any transfers are completed.
|
||||
"""
|
||||
with config_explicit_device_put_scope():
|
||||
if ((device is None or isinstance(device, (xc.Device, jax.sharding.Sharding))) and
|
||||
(src is None or isinstance(src, (xc.Device, jax.sharding.Sharding)))):
|
||||
if ((device is None or isinstance(device, (xc.Device, Sharding))) and
|
||||
(src is None or isinstance(src, (xc.Device, Sharding)))):
|
||||
return tree_map(
|
||||
lambda y: dispatch.device_put_p.bind(
|
||||
y, device=device, src=_infer_src_sharding(src, y)), x)
|
||||
@ -2641,7 +2631,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
assert (isinstance(aval, ShapedArray) and
|
||||
len(xla.aval_to_xla_shapes(aval)) == 1)
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval)
|
||||
buf = jax.device_put(x, devices[0])
|
||||
buf = device_put(x, devices[0])
|
||||
return pxla.batched_device_put(
|
||||
aval, PmapSharding(np.array(devices), sharding_spec),
|
||||
[buf] * len(devices), devices)
|
||||
@ -2719,7 +2709,7 @@ class ShapeDtypeStruct:
|
||||
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
|
||||
self.dtype = dtype if core.is_opaque_dtype(dtype) else np.dtype(dtype)
|
||||
if sharding is not None:
|
||||
if not isinstance(sharding, jax.sharding.Sharding):
|
||||
if not isinstance(sharding, Sharding):
|
||||
raise ValueError(
|
||||
"sharding should be an instance of `jax.sharding.Sharding`. "
|
||||
f"Got {sharding} of type {type(sharding)}.")
|
||||
@ -2821,65 +2811,6 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
return tree_unflatten(out_tree(), out)
|
||||
|
||||
|
||||
@functools.wraps(new_checkpoint) # config.jax_new_checkpoint is True by default
|
||||
def checkpoint(fun: Callable, *,
|
||||
concrete: bool = False,
|
||||
prevent_cse: bool = True,
|
||||
static_argnums: Union[int, Tuple[int, ...]] = (),
|
||||
policy: Optional[Callable[..., bool]] = None,
|
||||
) -> Callable:
|
||||
if concrete:
|
||||
msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
|
||||
"in its place, you can use its `static_argnums` option, and if "
|
||||
"necessary the `jax.ensure_compile_time_eval()` context manager.\n"
|
||||
"\n"
|
||||
"For example, if using `concrete=True` for an `is_training` flag:\n"
|
||||
"\n"
|
||||
" from functools import partial\n"
|
||||
"\n"
|
||||
" @partial(jax.checkpoint, concrete=True)\n"
|
||||
" def foo(x, is_training):\n"
|
||||
" if is_training:\n"
|
||||
" return f(x)\n"
|
||||
" else:\n"
|
||||
" return g(x)\n"
|
||||
"\n"
|
||||
"replace it with a use of `static_argnums`:\n"
|
||||
"\n"
|
||||
" @partial(jax.checkpoint, static_argnums=(1,))\n"
|
||||
" def foo(x, is_training):\n"
|
||||
" ...\n"
|
||||
"\n"
|
||||
"If jax.numpy operations need to be performed on static arguments, "
|
||||
"we can use the `jax.ensure_compile_time_eval()` context manager. "
|
||||
"For example, we can replace this use of `concrete=True`\n:"
|
||||
"\n"
|
||||
" @partial(jax.checkpoint, concrete=True)\n"
|
||||
" def foo(x, y):\n"
|
||||
" if y > 0:\n"
|
||||
" return f(x)\n"
|
||||
" else:\n"
|
||||
" return g(x)\n"
|
||||
"\n"
|
||||
"with this combination of `static_argnums` and "
|
||||
"`jax.ensure_compile_time_eval()`:\n"
|
||||
"\n"
|
||||
" @partial(jax.checkpoint, static_argnums=(1,))\n"
|
||||
" def foo(x, y):\n"
|
||||
" with jax.ensure_compile_time_eval():\n"
|
||||
" y_pos = y > 0\n"
|
||||
" if y_pos:\n"
|
||||
" return f(x)\n"
|
||||
" else:\n"
|
||||
" return g(x)\n"
|
||||
"\n"
|
||||
"See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n")
|
||||
raise NotImplementedError(msg)
|
||||
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
|
||||
static_argnums=static_argnums)
|
||||
remat = checkpoint # type: ignore
|
||||
|
||||
|
||||
def named_call(
|
||||
fun: Callable[..., Any],
|
||||
*,
|
||||
@ -2986,68 +2917,15 @@ def block_until_ready(x):
|
||||
return x.block_until_ready()
|
||||
except AttributeError:
|
||||
return x
|
||||
return jax.tree_util.tree_map(try_to_block, x)
|
||||
return tree_map(try_to_block, x)
|
||||
|
||||
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||
*args: Any, vectorized: bool = False, **kwargs: Any):
|
||||
"""Applies a functionally pure Python callable. Works under :func:`jit`/:func:`~pmap`/etc.
|
||||
|
||||
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
|
||||
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
|
||||
should also return NumPy arrays. Execution takes place on CPU, like any
|
||||
Python+NumPy function.
|
||||
|
||||
The callback is treated as functionally pure, meaning it has no side-effects
|
||||
and its output value depends only on its argument values. As a consequence, it
|
||||
is safe to be called multiple times (e.g. when transformed by :func:`~vmap` or
|
||||
:func:`~pmap`), or not to be called at all when e.g. the output of a
|
||||
`jit`-decorated function has no data dependence on its value. Pure callbacks
|
||||
may also be reordered if data-dependence allows.
|
||||
|
||||
When :func:`~pmap`-ed, the pure callback will be called several times (one on each
|
||||
axis of the map). When `vmap`-ed the behavior will depend on the value of the
|
||||
``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
|
||||
is assumed to obey
|
||||
``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``.
|
||||
Therefore, the callback will be called directly on batched inputs (where the
|
||||
batch axes are the leading dimensions). Additionally, the callbacks should
|
||||
return outputs that have corresponding leading batch axes. If not vectorized
|
||||
``callback`` will be mapped sequentially across the batched axis.
|
||||
For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free
|
||||
to set ``vectorized=True`` because the ``np.matmul`` function handles
|
||||
arbitrary leading batch dimensions.
|
||||
|
||||
Args:
|
||||
callback: A Python callable. The callable will be passed PyTrees of NumPy
|
||||
arrays as arguments, and should return a PyTree of NumPy arrays that
|
||||
matches ``result_shape_dtypes``.
|
||||
result_shape_dtypes: A PyTree with leaves that are objects with ``shape``
|
||||
and ``dtype`` attributes which represent to the shapes and dtypes of the
|
||||
value of ``callback`` applied to ``args`` and ``kwargs``.
|
||||
*args: The positional arguments to the callback. Must be PyTrees of JAX
|
||||
types.
|
||||
vectorized: A boolean that indicates whether or not ``callback`` is
|
||||
vectorized, meaning it can handle arrays with additional leading
|
||||
dimensions. If ``vectorized`` is `True`, when the callback is mapped
|
||||
via `jax.vmap`, it will be called directly on inputs with leading batch
|
||||
dimensions instead of executing ``callback`` on each mapped input
|
||||
individually. The callback should also return outputs batched across the
|
||||
leading axis. By default, ``vectorized`` is ``False``.
|
||||
**kwargs: The keyword arguments to the callback. Must be PyTrees of JAX
|
||||
types.
|
||||
|
||||
Returns:
|
||||
The value of ``callback(*args, **kwargs)``.
|
||||
"""
|
||||
return jcb.pure_callback(callback, result_shape_dtypes, *args,
|
||||
vectorized=vectorized, **kwargs)
|
||||
|
||||
def clear_backends():
|
||||
"""
|
||||
Clear all backend clients so that new backend clients can be created later.
|
||||
"""
|
||||
xb._clear_backends()
|
||||
jax.lib.xla_bridge._backends = {}
|
||||
xb._backends = {}
|
||||
dispatch.xla_primitive_callable.cache_clear()
|
||||
pjit._pjit_lower_cached.cache_clear()
|
||||
pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error
|
||||
|
@ -21,19 +21,17 @@ import functools
|
||||
from typing import (Sequence, Tuple, Callable, Optional, List, cast, Set,
|
||||
TYPE_CHECKING)
|
||||
|
||||
import jax
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import basearray
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import profiler
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.config import config
|
||||
from jax._src.util import use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src import api
|
||||
from jax._src.typing import ArrayLike
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
@ -41,6 +39,8 @@ from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
|
||||
device_replica_id_map, hashed_index)
|
||||
from jax._src.typing import ArrayLike
|
||||
from jax._src.util import use_cpp_class, use_cpp_method
|
||||
|
||||
Shape = Tuple[int, ...]
|
||||
Device = xc.Device
|
||||
@ -133,7 +133,7 @@ def _create_copy_plan(arrays, s: Sharding, shape: Shape):
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def _process_has_full_value_in_mcjax(s, shape):
|
||||
# Return False for single host as a fast path.
|
||||
if jax.process_count() == 1:
|
||||
if xla_bridge.process_count() == 1:
|
||||
return False
|
||||
|
||||
num_unique_indices = len(
|
||||
@ -359,7 +359,7 @@ class ArrayImpl(basearray.Array):
|
||||
return np.asarray(self._value, dtype=dtype)
|
||||
|
||||
def __dlpack__(self):
|
||||
from jax.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
|
||||
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
|
||||
return to_dlpack(self)
|
||||
|
||||
def __reduce__(self):
|
||||
|
@ -19,17 +19,17 @@ from typing import Any, Callable, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import tree_util
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import dispatch
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lax.control_flow.loops import map as lax_map
|
||||
|
||||
# `pure_callback_p` is the main primitive for staging out Python pure callbacks.
|
||||
pure_callback_p = core.Primitive("pure_callback")
|
||||
@ -93,7 +93,6 @@ def pure_callback_batching_rule(args, dims, *, callback, vectorized: bool,
|
||||
return pure_callback_p.bind(
|
||||
*merged_args, callback=callback, result_avals=result_avals,
|
||||
vectorized=vectorized)
|
||||
from jax._src.lax.control_flow import map as lax_map
|
||||
outvals = lax_map(_batch_fun, batched_args)
|
||||
return tuple(outvals), (0,) * len(outvals)
|
||||
|
||||
@ -154,6 +153,62 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
|
||||
|
||||
def pure_callback_api(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||
*args: Any, vectorized: bool = False, **kwargs: Any):
|
||||
"""Applies a functionally pure Python callable. Works under :func:`jit`/:func:`~pmap`/etc.
|
||||
|
||||
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
|
||||
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
|
||||
should also return NumPy arrays. Execution takes place on CPU, like any
|
||||
Python+NumPy function.
|
||||
|
||||
The callback is treated as functionally pure, meaning it has no side-effects
|
||||
and its output value depends only on its argument values. As a consequence, it
|
||||
is safe to be called multiple times (e.g. when transformed by :func:`~vmap` or
|
||||
:func:`~pmap`), or not to be called at all when e.g. the output of a
|
||||
`jit`-decorated function has no data dependence on its value. Pure callbacks
|
||||
may also be reordered if data-dependence allows.
|
||||
|
||||
When :func:`~pmap`-ed, the pure callback will be called several times (one on each
|
||||
axis of the map). When `vmap`-ed the behavior will depend on the value of the
|
||||
``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
|
||||
is assumed to obey
|
||||
``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``.
|
||||
Therefore, the callback will be called directly on batched inputs (where the
|
||||
batch axes are the leading dimensions). Additionally, the callbacks should
|
||||
return outputs that have corresponding leading batch axes. If not vectorized
|
||||
``callback`` will be mapped sequentially across the batched axis.
|
||||
For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free
|
||||
to set ``vectorized=True`` because the ``np.matmul`` function handles
|
||||
arbitrary leading batch dimensions.
|
||||
|
||||
Args:
|
||||
callback: A Python callable. The callable will be passed PyTrees of NumPy
|
||||
arrays as arguments, and should return a PyTree of NumPy arrays that
|
||||
matches ``result_shape_dtypes``.
|
||||
result_shape_dtypes: A PyTree with leaves that are objects with ``shape``
|
||||
and ``dtype`` attributes which represent to the shapes and dtypes of the
|
||||
value of ``callback`` applied to ``args`` and ``kwargs``.
|
||||
*args: The positional arguments to the callback. Must be PyTrees of JAX
|
||||
types.
|
||||
vectorized: A boolean that indicates whether or not ``callback`` is
|
||||
vectorized, meaning it can handle arrays with additional leading
|
||||
dimensions. If ``vectorized`` is `True`, when the callback is mapped
|
||||
via `jax.vmap`, it will be called directly on inputs with leading batch
|
||||
dimensions instead of executing ``callback`` on each mapped input
|
||||
individually. The callback should also return outputs batched across the
|
||||
leading axis. By default, ``vectorized`` is ``False``.
|
||||
**kwargs: The keyword arguments to the callback. Must be PyTrees of JAX
|
||||
types.
|
||||
|
||||
Returns:
|
||||
The value of ``callback(*args, **kwargs)``.
|
||||
"""
|
||||
return pure_callback(callback, result_shape_dtypes, *args,
|
||||
vectorized=vectorized, **kwargs)
|
||||
|
||||
|
||||
# IO Callback
|
||||
|
||||
io_callback_p = core.Primitive("io_callback")
|
||||
|
@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
@ -20,28 +21,28 @@ from typing import (Union, Optional, Callable, Dict, Tuple, TypeVar,
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.tree_util as jtu
|
||||
from jax import lax
|
||||
from jax.api_util import flatten_fun
|
||||
from jax.experimental import pjit
|
||||
from jax.tree_util import tree_flatten
|
||||
from jax.tree_util import tree_map
|
||||
from jax.tree_util import tree_unflatten
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import effects
|
||||
from jax._src import pjit
|
||||
from jax._src import prng
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util as jtu
|
||||
from jax._src.api_util import flatten_fun
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.tree_util import tree_flatten
|
||||
from jax._src.tree_util import tree_map
|
||||
from jax._src.tree_util import tree_unflatten
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
|
||||
unzip3, weakref_lru_cache)
|
||||
@ -92,7 +93,7 @@ class JaxException(Exception):
|
||||
del payload
|
||||
return cls(metadata)
|
||||
|
||||
def get_effect_type(self) -> core.Effect:
|
||||
def get_effect_type(self) -> ErrorEffect:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -100,7 +101,7 @@ class JaxException(Exception):
|
||||
@dataclasses.dataclass(eq=True, frozen=True)
|
||||
class ErrorEffect(effects.Effect):
|
||||
error_type: Type[JaxException]
|
||||
shape_dtypes: Tuple[jax.ShapeDtypeStruct, ...]
|
||||
shape_dtypes: Tuple[api.ShapeDtypeStruct, ...]
|
||||
|
||||
def __lt__(self, other: 'ErrorEffect'):
|
||||
shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable
|
||||
@ -161,7 +162,7 @@ class OOBError(JaxException):
|
||||
f'Failed at {self.traceback_info}')
|
||||
|
||||
def get_effect_type(self):
|
||||
return ErrorEffect(OOBError, (jax.ShapeDtypeStruct((3,), jnp.int32),))
|
||||
return ErrorEffect(OOBError, (api.ShapeDtypeStruct((3,), jnp.int32),))
|
||||
|
||||
class FailedCheckError(JaxException):
|
||||
|
||||
@ -188,7 +189,7 @@ class FailedCheckError(JaxException):
|
||||
vals = jtu.tree_leaves((self.args, self.kwargs))
|
||||
return ErrorEffect(
|
||||
FailedCheckError,
|
||||
tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in vals))
|
||||
tuple(api.ShapeDtypeStruct(x.shape, x.dtype) for x in vals))
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BatchedError(JaxException):
|
||||
@ -1112,7 +1113,7 @@ def _check(pred, msg, debug, *fmt_args, **fmt_kwargs):
|
||||
prim_name = 'debug_check' if debug else 'check'
|
||||
raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}')
|
||||
for arg in jtu.tree_leaves((fmt_args, fmt_kwargs)):
|
||||
if not isinstance(arg, (jax.Array, np.ndarray)):
|
||||
if not isinstance(arg, (Array, np.ndarray)):
|
||||
raise TypeError('Formatting arguments to checkify.check need to be '
|
||||
'PyTrees of arrays, but got '
|
||||
f'{repr(arg)} of type {type(arg)}.')
|
||||
@ -1130,7 +1131,7 @@ def _check_error(error, *, debug=False):
|
||||
|
||||
def is_scalar_pred(pred) -> bool:
|
||||
return (isinstance(pred, bool) or
|
||||
isinstance(pred, jax.Array) and pred.shape == () and
|
||||
isinstance(pred, Array) and pred.shape == () and
|
||||
pred.dtype == jnp.dtype('bool'))
|
||||
|
||||
|
||||
|
@ -16,15 +16,14 @@ import functools
|
||||
import operator
|
||||
from typing import Callable, Optional
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax.tree_util import (tree_flatten, tree_map, tree_structure,
|
||||
tree_unflatten, treedef_tuple)
|
||||
from jax import lax
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import custom_api_util
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.api_util import flatten_fun_nokwargs
|
||||
from jax._src.interpreters import ad
|
||||
@ -33,6 +32,8 @@ from jax._src.interpreters.batching import not_mapped
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.tree_util import (tree_flatten, tree_map, tree_structure,
|
||||
tree_unflatten, treedef_tuple)
|
||||
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
@ -194,7 +195,7 @@ def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree):
|
||||
return out
|
||||
|
||||
def to_vmap_over_extra_batched_dims(primals, tangents):
|
||||
return jax.jvp(to_jvp, primals, tangents)
|
||||
return api.jvp(to_jvp, primals, tangents)
|
||||
|
||||
to_vmap_over_extra_batched_dims_flat, out_tree2 = flatten_fun_nokwargs(
|
||||
lu.wrap_init(to_vmap_over_extra_batched_dims),
|
||||
@ -274,7 +275,7 @@ def sequential_vmap(f):
|
||||
return f(*args)
|
||||
|
||||
mapped_args, bcast_args = tree_split(in_batched, list(args))
|
||||
out = jax.lax.map(to_map, mapped_args)
|
||||
out = lax.map(to_map, mapped_args)
|
||||
out_batched = tree_map(lambda _: True, out)
|
||||
return out, out_batched
|
||||
|
||||
|
@ -16,15 +16,9 @@ from functools import update_wrapper, reduce, partial
|
||||
import inspect
|
||||
from typing import (Callable, Generic, List, Optional, Sequence, Tuple, TypeVar, Any)
|
||||
|
||||
from jax.custom_transpose import custom_transpose
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
treedef_is_leaf, treedef_tuple,
|
||||
register_pytree_node_class, tree_leaves)
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.config import config
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import custom_api_util
|
||||
from jax._src.custom_transpose import custom_transpose
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
@ -32,7 +26,9 @@ from jax._src import traceback_util
|
||||
from jax._src.ad_util import (Zero, SymbolicZero, zeros_like_aval,
|
||||
stop_gradient_p)
|
||||
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
|
||||
from jax._src.config import config
|
||||
from jax._src.core import raise_to_shaped
|
||||
from jax._src.errors import UnexpectedTracerError
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -40,6 +36,9 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters.batching import not_mapped
|
||||
from jax._src.lax import lax
|
||||
from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
treedef_is_leaf, treedef_tuple,
|
||||
register_pytree_node_class, tree_leaves)
|
||||
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
|
||||
|
||||
|
||||
|
@ -15,8 +15,6 @@
|
||||
import functools
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
|
||||
tree_structure, treedef_tuple, tree_unflatten)
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
@ -29,6 +27,8 @@ from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.tree_util import (tree_flatten, tree_leaves, tree_map,
|
||||
tree_structure, treedef_tuple, tree_unflatten)
|
||||
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
|
@ -22,23 +22,24 @@ import weakref
|
||||
import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import pjit
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import GSPMDSharding, NamedSharding
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
|
||||
# pytype: disable=import-error
|
||||
try:
|
||||
|
@ -30,10 +30,6 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.monitoring import record_event_duration_secs
|
||||
|
||||
from jax._src import array
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
@ -53,6 +49,7 @@ from jax._src.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.monitoring import record_event_duration_secs
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding, NamedSharding,
|
||||
@ -145,7 +142,7 @@ class RuntimeTokenSet(threading.local):
|
||||
self.output_runtime_tokens = {}
|
||||
|
||||
def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken:
|
||||
s = jax.sharding.SingleDeviceSharding(device)
|
||||
s = SingleDeviceSharding(device)
|
||||
if eff not in self.tokens:
|
||||
inp = np.zeros(0, np.bool_)
|
||||
indices = tuple(
|
||||
@ -302,7 +299,7 @@ class SourceInfo(NamedTuple):
|
||||
|
||||
|
||||
def jaxpr_shardings(
|
||||
jaxpr) -> Iterator[Tuple[jax.sharding.XLACompatibleSharding, SourceInfo]]:
|
||||
jaxpr) -> Iterator[Tuple[XLACompatibleSharding, SourceInfo]]:
|
||||
from jax._src import pjit
|
||||
from jax.experimental import shard_map
|
||||
|
||||
@ -570,8 +567,9 @@ def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool):
|
||||
|
||||
def _device_put_impl(
|
||||
x,
|
||||
device: Optional[Union[Device, jax.sharding.Sharding]] = None,
|
||||
src: Optional[Union[Device, jax.sharding.Sharding]] = None):
|
||||
device: Optional[Union[Device, Sharding]] = None,
|
||||
src: Optional[Union[Device, Sharding]] = None):
|
||||
from jax._src import array
|
||||
try:
|
||||
aval = xla.abstractify(x)
|
||||
except TypeError as err:
|
||||
|
@ -16,13 +16,13 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src.util import safe_zip, unzip2, HashablePartial
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
|
||||
zip = safe_zip
|
||||
|
||||
|
||||
|
@ -17,7 +17,6 @@ from functools import partial
|
||||
import operator
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.tree_util import (tree_flatten, treedef_children, tree_leaves,
|
||||
tree_unflatten, treedef_tuple)
|
||||
from jax._src import ad_util
|
||||
@ -28,6 +27,7 @@ from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lax import lax
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import split_list, safe_map
|
||||
import numpy as np
|
||||
|
@ -20,35 +20,42 @@ from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
|
||||
NamedTuple, Union, Sequence)
|
||||
from functools import wraps, partial, partialmethod, lru_cache
|
||||
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import mesh
|
||||
from jax._src import linear_util as lu
|
||||
from jax import stages
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
|
||||
treedef_tuple)
|
||||
from jax._src import mesh
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import stages
|
||||
from jax._src import traceback_util
|
||||
from jax._src.api_util import (flatten_fun_nokwargs, flatten_axes,
|
||||
_ensure_index_tuple, donation_vector,
|
||||
shaped_abstractify, check_callable)
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src.config import config
|
||||
from jax.errors import JAXTypeError
|
||||
from jax._src.array import ArrayImpl
|
||||
from jax._src.sharding_impls import NamedSharding
|
||||
from jax._src.config import config
|
||||
from jax._src.errors import JAXTypeError
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters.partial_eval import (
|
||||
trace_to_subjaxpr_dynamic, DynamicJaxprTracer,
|
||||
convert_constvars_jaxpr, new_jaxpr_eqn)
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.pjit import (
|
||||
sharding_constraint_p, ParsedPartitionSpec, get_unconstrained_dims,
|
||||
GSPMDSharding)
|
||||
from jax._src.sharding_impls import NamedSharding
|
||||
from jax._src.tree_util import (tree_flatten, tree_unflatten, all_leaves,
|
||||
tree_map, treedef_tuple)
|
||||
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
|
||||
as_hashable_function, distributed_debug_log,
|
||||
tuple_insert, moveaxis, split_list, wrap_name,
|
||||
merge_lists, partition_list)
|
||||
from jax import lax
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
traceback_util.register_exclusion(__file__)
|
||||
@ -965,9 +972,6 @@ pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap
|
||||
|
||||
# This is DynamicJaxprTrace.process_map with some very minor modifications
|
||||
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
from jax._src.interpreters.partial_eval import (
|
||||
trace_to_subjaxpr_dynamic, DynamicJaxprTracer,
|
||||
convert_constvars_jaxpr, new_jaxpr_eqn)
|
||||
assert primitive is xmap_p
|
||||
in_avals = [t.aval for t in tracers]
|
||||
global_axis_sizes = params['global_axis_sizes']
|
||||
@ -1775,10 +1779,6 @@ def _check_no_loop_collectives(jaxpr, loop_axis_resources):
|
||||
|
||||
|
||||
def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
|
||||
from jax._src.pjit import (
|
||||
sharding_constraint_p, ParsedPartitionSpec, get_unconstrained_dims,
|
||||
GSPMDSharding)
|
||||
|
||||
rec = lambda jaxpr: _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name)
|
||||
if isinstance(jaxpr, core.ClosedJaxpr):
|
||||
return jaxpr.map_jaxpr(rec)
|
||||
|
@ -25,7 +25,8 @@ import numpy as np
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.api import jit, custom_jvp
|
||||
from jax._src.api import jit
|
||||
from jax._src.custom_derivatives import custom_jvp
|
||||
from jax._src.lax import lax
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.numpy.util import (
|
||||
|
@ -24,23 +24,10 @@ from functools import partial, lru_cache
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax import stages
|
||||
from jax.errors import JAXTypeError
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters.pxla import PartitionSpec
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.tree_util import (
|
||||
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
|
||||
treedef_tuple, broadcast_prefix, all_leaves)
|
||||
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, XLACompatibleSharding, GSPMDSharding,
|
||||
XLADeviceAssignment, SingleDeviceSharding, PmapSharding)
|
||||
from jax._src import stages
|
||||
from jax._src import dispatch
|
||||
from jax._src import mesh
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
@ -50,6 +37,11 @@ from jax._src.api_util import (
|
||||
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
||||
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
|
||||
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info, FLAGS)
|
||||
from jax._src.errors import JAXTypeError
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters.pxla import PartitionSpec
|
||||
from jax._src.interpreters import xla
|
||||
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -58,8 +50,15 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, XLACompatibleSharding, GSPMDSharding,
|
||||
XLADeviceAssignment, SingleDeviceSharding, PmapSharding)
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import prefix_errors, generate_key_paths
|
||||
from jax._src.tree_util import (
|
||||
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
|
||||
treedef_tuple, broadcast_prefix, all_leaves,
|
||||
prefix_errors, generate_key_paths)
|
||||
from jax._src.util import (
|
||||
HashableFunction, safe_map, safe_zip, wraps,
|
||||
distributed_debug_log, split_list, tuple_insert, weakref_lru_cache,
|
||||
@ -311,7 +310,7 @@ def _resolve_axis_resources_and_shardings_arg(
|
||||
def pre_infer_params(fun, in_shardings, out_shardings,
|
||||
donate_argnums, static_argnums, static_argnames, device,
|
||||
backend, abstracted_axes):
|
||||
if abstracted_axes and not jax.config.jax_dynamic_shapes:
|
||||
if abstracted_axes and not config.jax_dynamic_shapes:
|
||||
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
|
||||
|
||||
check_callable(fun)
|
||||
@ -455,7 +454,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
dyn_kwargs = ()
|
||||
del kwargs
|
||||
|
||||
if donate_argnums and not jax.config.jax_debug_nans:
|
||||
if donate_argnums and not config.jax_debug_nans:
|
||||
donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs)
|
||||
else:
|
||||
donated_invars = (False,) * len(explicit_args)
|
||||
@ -724,7 +723,7 @@ def pjit(
|
||||
|
||||
def infer_params(*args, **kwargs):
|
||||
# Putting this outside of wrapped would make resources lexically scoped
|
||||
resource_env = mesh.thread_resources.env
|
||||
resource_env = mesh_lib.thread_resources.env
|
||||
pjit_info_args = PjitInfo(
|
||||
fun=fun, in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, static_argnums=static_argnums,
|
||||
@ -1125,7 +1124,7 @@ def _check_unique_resources(axis_resources, arg_name):
|
||||
if multiple_uses:
|
||||
raise ValueError(f"A single {arg_name} specification can map every mesh axis "
|
||||
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
|
||||
f"has duplicate entries for {mesh.show_axes(multiple_uses)}")
|
||||
f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}")
|
||||
|
||||
# -------------------- pjit rules --------------------
|
||||
|
||||
@ -1812,7 +1811,7 @@ def _check_resources_against_named_axes(what, aval, pos_axis_resources, named_ax
|
||||
f"{pos_axis_resources.unsynced_user_spec(SpecSync.DIM_PERMUTE)} "
|
||||
f"that uses one or more mesh axes already used by xmap to partition "
|
||||
f"a named axis appearing in its named_shape (both use mesh axes "
|
||||
f"{mesh.show_axes(overlap)})")
|
||||
f"{mesh_lib.show_axes(overlap)})")
|
||||
|
||||
def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_resources):
|
||||
jaxpr = params["jaxpr"]
|
||||
@ -1918,7 +1917,7 @@ def with_sharding_constraint(x, shardings=_UNSPECIFIED,
|
||||
flatten_axes("with_sharding_constraint shardings", tree, user_shardings))
|
||||
del user_shardings
|
||||
|
||||
resource_env = jax._src.mesh.thread_resources.env
|
||||
resource_env = mesh_lib.thread_resources.env
|
||||
mesh = resource_env.physical_mesh
|
||||
|
||||
shardings_flat = [_create_sharding_for_array(mesh, a)
|
||||
|
@ -21,18 +21,20 @@ from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence, Unio
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax.config import config
|
||||
from jax.dtypes import float0
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import basearray
|
||||
from jax._src import config as config_lib
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import typing
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.config import config
|
||||
from jax._src.dtypes import float0
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -97,7 +99,7 @@ class PRNGImpl(NamedTuple):
|
||||
|
||||
# -- PRNG key arrays
|
||||
|
||||
def _check_prng_key_data(impl, key_data: jax.Array):
|
||||
def _check_prng_key_data(impl, key_data: typing.Array):
|
||||
ndim = len(impl.key_shape)
|
||||
if not all(hasattr(key_data, attr) for attr in ['ndim', 'shape', 'dtype']):
|
||||
raise TypeError("JAX encountered invalid PRNG key data: expected key_data "
|
||||
@ -139,7 +141,7 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
"""
|
||||
|
||||
impl: PRNGImpl
|
||||
_base_array: jax.Array
|
||||
_base_array: typing.Array
|
||||
|
||||
def __init__(self, impl, key_data: Any):
|
||||
assert not isinstance(key_data, core.Tracer)
|
||||
@ -512,7 +514,7 @@ mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler)
|
||||
|
||||
def iterated_vmap_unary(n, f):
|
||||
for _ in range(n):
|
||||
f = jax.vmap(f)
|
||||
f = api.vmap(f)
|
||||
return f
|
||||
|
||||
# TODO(frostig): Revise the following two functions? These basically
|
||||
@ -528,7 +530,7 @@ def squeeze_vmap(f, left):
|
||||
else:
|
||||
y = jnp.squeeze(y, axis=0)
|
||||
axes = (0, None)
|
||||
return jax.vmap(f, in_axes=axes, out_axes=0)(x, y)
|
||||
return api.vmap(f, in_axes=axes, out_axes=0)(x, y)
|
||||
return squeeze_vmap_f
|
||||
|
||||
def iterated_vmap_binary_bcast(shape1, shape2, f):
|
||||
@ -543,7 +545,7 @@ def iterated_vmap_binary_bcast(shape1, shape2, f):
|
||||
assert len(shape1) == len(shape2)
|
||||
for sz1, sz2 in reversed(zip(shape1, shape2)):
|
||||
if sz1 == sz2:
|
||||
f = jax.vmap(f, out_axes=0)
|
||||
f = api.vmap(f, out_axes=0)
|
||||
else:
|
||||
assert sz1 == 1 or sz2 == 1, (sz1, sz2)
|
||||
f = squeeze_vmap(f, sz1 == 1)
|
||||
@ -785,14 +787,14 @@ mlir.register_lowering(random_unwrap_p, random_unwrap_lowering)
|
||||
# -- threefry2x32 PRNG implementation
|
||||
|
||||
|
||||
def _is_threefry_prng_key(key: jax.Array) -> bool:
|
||||
def _is_threefry_prng_key(key: typing.Array) -> bool:
|
||||
try:
|
||||
return key.shape == (2,) and key.dtype == np.uint32
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
def threefry_seed(seed: jax.Array) -> jax.Array:
|
||||
def threefry_seed(seed: typing.Array) -> typing.Array:
|
||||
"""Create a single raw threefry PRNG key from an integer seed.
|
||||
|
||||
Args:
|
||||
@ -811,7 +813,7 @@ def threefry_seed(seed: jax.Array) -> jax.Array:
|
||||
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
|
||||
k1 = convert(
|
||||
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
|
||||
with jax.numpy_dtype_promotion('standard'):
|
||||
with config_lib.numpy_dtype_promotion('standard'):
|
||||
# TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit
|
||||
# inputs. We should avoid this.
|
||||
k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))
|
||||
@ -1090,26 +1092,26 @@ def threefry_2x32(keypair, count):
|
||||
return lax.reshape(out[:-1] if odd_size else out, count.shape)
|
||||
|
||||
|
||||
def threefry_split(key: jax.Array, num: int) -> jax.Array:
|
||||
def threefry_split(key: typing.Array, num: int) -> typing.Array:
|
||||
if config.jax_threefry_partitionable:
|
||||
return _threefry_split_foldlike(key, int(num)) # type: ignore
|
||||
else:
|
||||
return _threefry_split_original(key, int(num)) # type: ignore
|
||||
|
||||
@partial(jit, static_argnums=(1,), inline=True)
|
||||
def _threefry_split_original(key, num) -> jax.Array:
|
||||
def _threefry_split_original(key, num) -> typing.Array:
|
||||
counts = lax.iota(np.uint32, num * 2)
|
||||
return lax.reshape(threefry_2x32(key, counts), (num, 2))
|
||||
|
||||
@partial(jit, static_argnums=(1,), inline=True)
|
||||
def _threefry_split_foldlike(key, num) -> jax.Array:
|
||||
def _threefry_split_foldlike(key, num) -> typing.Array:
|
||||
k1, k2 = key
|
||||
counts1, counts2 = iota_2x32_shape((num,))
|
||||
bits1, bits2 = threefry2x32_p.bind(k1, k2, counts1, counts2)
|
||||
return jnp.stack([bits1, bits2], axis=1)
|
||||
|
||||
|
||||
def threefry_fold_in(key: jax.Array, data: jax.Array) -> jax.Array:
|
||||
def threefry_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
|
||||
assert not data.shape
|
||||
return _threefry_fold_in(key, jnp.uint32(data))
|
||||
|
||||
@ -1118,7 +1120,7 @@ def _threefry_fold_in(key, data):
|
||||
return threefry_2x32(key, threefry_seed(data))
|
||||
|
||||
|
||||
def threefry_random_bits(key: jax.Array, bit_width, shape):
|
||||
def threefry_random_bits(key: typing.Array, bit_width, shape):
|
||||
"""Sample uniform random bits of given width and shape using PRNG key."""
|
||||
if not _is_threefry_prng_key(key):
|
||||
raise TypeError("threefry_random_bits got invalid prng key.")
|
||||
@ -1131,7 +1133,7 @@ def threefry_random_bits(key: jax.Array, bit_width, shape):
|
||||
else:
|
||||
return _threefry_random_bits_original(key, bit_width, shape)
|
||||
|
||||
def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape):
|
||||
def _threefry_random_bits_partitionable(key: typing.Array, bit_width, shape):
|
||||
if all(core.is_constant_dim(d) for d in shape) and math.prod(shape) > 2 ** 64:
|
||||
raise NotImplementedError('random bits array of size exceeding 2 ** 64')
|
||||
|
||||
@ -1150,7 +1152,7 @@ def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape):
|
||||
return lax.convert_element_type(bits1 ^ bits2, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2), inline=True)
|
||||
def _threefry_random_bits_original(key: jax.Array, bit_width, shape):
|
||||
def _threefry_random_bits_original(key: typing.Array, bit_width, shape):
|
||||
size = math.prod(shape)
|
||||
# Compute ceil(bit_width * size / 32) in a way that is friendly to shape
|
||||
# polymorphism
|
||||
@ -1210,12 +1212,12 @@ threefry_prng_impl = PRNGImpl(
|
||||
# stable/deterministic across backends or compiler versions. Correspondingly, we
|
||||
# reserve the right to change any of these implementations at any time!
|
||||
|
||||
def _rbg_seed(seed: jax.Array) -> jax.Array:
|
||||
def _rbg_seed(seed: typing.Array) -> typing.Array:
|
||||
assert not seed.shape
|
||||
halfkey = threefry_seed(seed)
|
||||
return jnp.concatenate([halfkey, halfkey])
|
||||
|
||||
def _rbg_split(key: jax.Array, num: int) -> jax.Array:
|
||||
def _rbg_split(key: typing.Array, num: int) -> typing.Array:
|
||||
if config.jax_threefry_partitionable:
|
||||
_threefry_split = _threefry_split_foldlike
|
||||
else:
|
||||
@ -1223,12 +1225,12 @@ def _rbg_split(key: jax.Array, num: int) -> jax.Array:
|
||||
return vmap(
|
||||
_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)
|
||||
|
||||
def _rbg_fold_in(key: jax.Array, data: jax.Array) -> jax.Array:
|
||||
def _rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
|
||||
assert not data.shape
|
||||
return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4)
|
||||
|
||||
def _rbg_random_bits(key: jax.Array, bit_width: int, shape: Sequence[int]
|
||||
) -> jax.Array:
|
||||
def _rbg_random_bits(key: typing.Array, bit_width: int, shape: Sequence[int]
|
||||
) -> typing.Array:
|
||||
if not key.shape == (4,) and key.dtype == jnp.dtype('uint32'):
|
||||
raise TypeError("_rbg_random_bits got invalid prng key.")
|
||||
if bit_width not in (8, 16, 32, 64):
|
||||
@ -1244,12 +1246,12 @@ rbg_prng_impl = PRNGImpl(
|
||||
fold_in=_rbg_fold_in,
|
||||
tag='rbg')
|
||||
|
||||
def _unsafe_rbg_split(key: jax.Array, num: int) -> jax.Array:
|
||||
def _unsafe_rbg_split(key: typing.Array, num: int) -> typing.Array:
|
||||
# treat 10 iterations of random bits as a 'hash function'
|
||||
_, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32')
|
||||
return keys[::10]
|
||||
|
||||
def _unsafe_rbg_fold_in(key: jax.Array, data: jax.Array) -> jax.Array:
|
||||
def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
|
||||
assert not data.shape
|
||||
_, random_bits = lax.rng_bit_generator(_rbg_seed(data), (10, 4), dtype='uint32')
|
||||
return key ^ random_bits[-1]
|
||||
|
@ -15,13 +15,13 @@
|
||||
from functools import partial
|
||||
import operator
|
||||
|
||||
from jax import config
|
||||
from jax.tree_util import tree_map, tree_reduce
|
||||
from jax._src import api
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.config import flags
|
||||
from jax._src.config import config, flags
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.tree_util import tree_map, tree_reduce
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
@ -21,17 +21,17 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax.config import config
|
||||
from jax.numpy.linalg import cholesky, svd, eigh
|
||||
|
||||
from jax._src import config as config_lib
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import prng
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.config import config
|
||||
from jax._src.core import NamedShape
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -677,7 +677,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
|
||||
else: # 'cholesky'
|
||||
factor = cholesky(cov)
|
||||
normal_samples = normal(key, shape + mean.shape[-1:], dtype)
|
||||
with jax.numpy_rank_promotion('allow'):
|
||||
with config_lib.numpy_rank_promotion('allow'):
|
||||
result = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
|
||||
return result
|
||||
|
||||
|
@ -24,8 +24,8 @@ from jax import jit
|
||||
from jax import vmap
|
||||
from jax import lax
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
@ -96,7 +96,7 @@ def erfinv(x: ArrayLike) -> Array:
|
||||
return lax.erf_inv(x)
|
||||
|
||||
|
||||
@api.custom_jvp
|
||||
@custom_derivatives.custom_jvp
|
||||
@_wraps(osp_special.logit, module='scipy.special', update_doc=False)
|
||||
def logit(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("logit", x)
|
||||
@ -214,7 +214,7 @@ def polygamma(n: ArrayLike, x: ArrayLike) -> Array:
|
||||
return _polygamma(jnp.broadcast_to(n_arr, shape), jnp.broadcast_to(x_arr, shape))
|
||||
|
||||
|
||||
@api.custom_jvp
|
||||
@custom_derivatives.custom_jvp
|
||||
def _polygamma(n: ArrayLike, x: ArrayLike) -> Array:
|
||||
dtype = lax.dtype(n).type
|
||||
n_plus = n + dtype(1)
|
||||
@ -481,7 +481,7 @@ def _ndtri(p: ArrayLike) -> Array:
|
||||
return x_nan_replaced
|
||||
|
||||
|
||||
@partial(api.custom_jvp, nondiff_argnums=(1,))
|
||||
@partial(custom_derivatives.custom_jvp, nondiff_argnums=(1,))
|
||||
def log_ndtr(x: ArrayLike, series_order: int = 3) -> Array:
|
||||
r"""Log Normal distribution function.
|
||||
|
||||
@ -1361,7 +1361,7 @@ def _expi_neg(x: Array) -> Array:
|
||||
# x < 0
|
||||
return -exp1(-x)
|
||||
|
||||
@api.custom_jvp
|
||||
@custom_derivatives.custom_jvp
|
||||
@jit
|
||||
@_wraps(osp_special.expi, module='scipy.special')
|
||||
def expi(x: ArrayLike) -> Array:
|
||||
@ -1479,7 +1479,7 @@ def _expn3(n: int, x: Array) -> Array:
|
||||
return (ans + one) * jnp.exp(-x) / xk
|
||||
|
||||
|
||||
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
||||
@partial(custom_derivatives.custom_jvp, nondiff_argnums=(0,))
|
||||
@jnp.vectorize
|
||||
@_wraps(osp_special.expn, module='scipy.special')
|
||||
@jit
|
||||
|
@ -20,10 +20,10 @@ import operator as op
|
||||
from typing import (Any, Sequence, List, Tuple, Optional, Mapping, Dict, Set,
|
||||
FrozenSet, Union, cast)
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import sharding
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.interpreters import mlir
|
||||
@ -238,7 +238,7 @@ class NamedSharding(XLACompatibleSharding):
|
||||
# TODO(yaskatariya): Remove this and replace this with a normalized
|
||||
# representation of Parsed Pspec
|
||||
if self._parsed_pspec is None:
|
||||
from jax.experimental import pjit
|
||||
from jax._src import pjit
|
||||
self._parsed_pspec, _, _ = pjit._prepare_axis_resources(
|
||||
self.spec, "NamedSharding spec")
|
||||
|
||||
@ -287,7 +287,7 @@ class NamedSharding(XLACompatibleSharding):
|
||||
num_dimensions: int,
|
||||
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
|
||||
) -> xc.OpSharding:
|
||||
from jax.experimental.pjit import get_array_mapping
|
||||
from jax._src.pjit import get_array_mapping
|
||||
assert self._parsed_pspec is not None
|
||||
array_mapping = get_array_mapping(self._parsed_pspec)
|
||||
# TODO(yashkatariya): Move away from sharding spec in NamedSharding
|
||||
@ -429,7 +429,8 @@ class PmapSharding(XLACompatibleSharding):
|
||||
'`None` to sharded_dim is not supported. Please file a jax '
|
||||
'issue if you need this feature.')
|
||||
|
||||
pmap_devices: np.ndarray = np.array(jax.local_devices()[:num_ways_sharded])
|
||||
pmap_devices: np.ndarray = np.array(
|
||||
xla_bridge.local_devices()[:num_ways_sharded])
|
||||
return cls(pmap_devices, sharding_spec)
|
||||
|
||||
@functools.cached_property
|
||||
|
@ -21,7 +21,6 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Tupl
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import ad_util
|
||||
@ -32,6 +31,8 @@ from jax._src import source_info_util
|
||||
from jax._src import tree_util
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax import slicing as lax_slicing
|
||||
from jax._src.state.types import AbstractRef, RefEffect
|
||||
from jax._src.state.primitives import get_p, swap_p, addupdate_p
|
||||
from jax._src.state.utils import hoist_consts_to_refs
|
||||
@ -222,7 +223,7 @@ def _dynamic_index(x, idx, indexed_dims):
|
||||
starts = [next(idx_) if b else np.int32(0) for b in indexed_dims]
|
||||
assert next(idx_, None) is None
|
||||
sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)]
|
||||
out = lax.dynamic_slice(x, starts, sizes)
|
||||
out = lax_slicing.dynamic_slice(x, starts, sizes)
|
||||
return lax.squeeze(out, [i for i, b in enumerate(indexed_dims) if b])
|
||||
|
||||
def _dynamic_update_index(x, idx, val, indexed_dims):
|
||||
@ -231,7 +232,7 @@ def _dynamic_update_index(x, idx, val, indexed_dims):
|
||||
starts = [next(idx_) if b else np.int32(0) for b in indexed_dims]
|
||||
assert next(idx_, None) is None
|
||||
sizes = [1 if b else size for b, size in zip(indexed_dims, x.shape)]
|
||||
return lax.dynamic_update_slice(x, val.reshape(sizes), starts)
|
||||
return lax_slicing.dynamic_update_slice(x, val.reshape(sizes), starts)
|
||||
|
||||
@register_discharge_rule(core.closed_call_p)
|
||||
def _closed_call_discharge_rule(
|
||||
|
@ -18,13 +18,14 @@ from typing import Any, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax
|
||||
from jax._src.typing import Array
|
||||
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
|
||||
AccumEffect)
|
||||
@ -420,7 +421,7 @@ def _get_vmap(batched_args, batched_dims, *, indexed_dims):
|
||||
# `idxs` doesn't include the non indexed dims.
|
||||
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
||||
if i_dim].index(ref_dim)
|
||||
iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
idxs = tuple_insert(idxs, idx_place, iota)
|
||||
else:
|
||||
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
|
||||
@ -453,7 +454,7 @@ def _swap_vmap(batched_args, batched_dims, *, indexed_dims):
|
||||
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
|
||||
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
||||
if i_dim].index(ref_dim)
|
||||
iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
idxs = tuple_insert(idxs, idx_place, iota)
|
||||
val = batching.moveaxis(val, val_dim, 0)
|
||||
bdim_out = 0
|
||||
@ -486,7 +487,7 @@ def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims):
|
||||
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
||||
if i_dim].index(ref_dim)
|
||||
idxs_shape, = {i.shape for i in idxs} or [()]
|
||||
iota = jax.lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||||
idxs = tuple_insert(idxs, idx_place, iota)
|
||||
val = batching.moveaxis(val, val_dim, 0)
|
||||
return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), []
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user