Split jax.interpreters.xla up into three pieces:

* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
This commit is contained in:
Peter Hawkins 2021-11-22 08:22:10 -08:00 committed by jax authors
parent 34855def13
commit d262bae88b
29 changed files with 1229 additions and 1074 deletions

View File

@ -56,6 +56,8 @@ from ..tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
treedef_is_leaf, treedef_children, Partial, PyTreeDef)
from .util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
extend_name_stack, wrap_name, cache, wraps, HashableFunction)
from jax._src import device_array
from jax._src import dispatch
from jax._src.lib import jax_jit
from jax._src.lib import version
from jax._src.lib import xla_bridge as xb
@ -83,6 +85,7 @@ from .._src.config import (flags, config, bool_env, disable_jit as _disable_jit,
debug_infs as config_debug_infs,
_thread_local_state as config_thread_local_state)
traceback_util.register_exclusion(__file__)
AxisName = Any
@ -127,7 +130,7 @@ def _nan_check_posthook(fun, args, kwargs, output):
buffers.extend(da_or_sda.device_buffers)
try:
xla.check_special(xla.xla_call_p, buffers)
dispatch.check_special(xla.xla_call_p, buffers)
except FloatingPointError:
# compiled_fun can only raise in this case
assert config.jax_debug_nans or config.jax_debug_infs
@ -426,14 +429,14 @@ def _cpp_jit(
# inspect the argument x, we actually do need to execute it and look at the
# outputs that could be tracers (if f is capturing `Tracer` by closure).
execute: Optional[functools.partial] = (
xla._xla_callable.most_recent_entry())
dispatch._xla_callable.most_recent_entry())
use_fastpath = (
# This is if we have already executed this code-path (most-recent entry
# has been reset to None). Thus, we do not support the fast-path.
execute is not None and
execute.func is xla._execute_compiled and # not trivial, not pmap
execute.func is dispatch._execute_compiled and # not trivial, not pmap
# Not supported: ShardedDeviceArray
all(xla.type_is_device_array(x) for x in out_flat))
all(device_array.type_is_device_array(x) for x in out_flat))
### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
_, xla_executable, _, result_handlers, kept_var_idx = execute.args
@ -493,7 +496,7 @@ class Lowered:
in_tree: PyTreeDef
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_lowering: Union[xla.XlaComputation, pxla.MeshComputation]
_lowering: Union[dispatch.XlaComputation, pxla.MeshComputation]
_no_kwargs: bool
def __init__(self, lowering, in_tree, out_tree, donate_argnums,
@ -529,7 +532,7 @@ class Compiled:
in_tree: PyTreeDef
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_executable: Union[xla.XlaCompiledComputation, pxla.MeshExecutable]
_executable: Union[dispatch.XlaCompiledComputation, pxla.MeshExecutable]
_no_kwargs: bool
def __init__(self, executable, in_tree, out_tree, donate_argnums,
@ -595,7 +598,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
name = flat_fun.__name__
arg_specs = unsafe_map(arg_spec, args_flat)
computation = xla.lower_xla_callable(
computation = dispatch.lower_xla_callable(
flat_fun, device, backend, name, donated_invars, *arg_specs)
return Lowered(computation, in_tree, out_tree(), donate_argnums)
@ -830,8 +833,8 @@ def xla_computation(fun: Callable,
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
if out_parts is None:
out_parts_flat = None
else:
@ -1990,7 +1993,7 @@ def _cpp_pmap(
getattr(execute[0], "func", None) is pxla.execute_replicated and
# No tracers in the outputs. Checking for ShardedDeviceArray should be
# sufficient, but we use the more general `DeviceArray`.
all(isinstance(x, xla.DeviceArray) for x in out_flat))
all(isinstance(x, device_array.DeviceArray) for x in out_flat))
### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
xla_executable, backend_, in_handler, out_handler = execute[0].args
@ -2552,7 +2555,7 @@ def device_put(x, device: Optional[xc.Device] = None):
Returns:
A copy of ``x`` that resides on ``device``.
"""
return tree_map(lambda y: xla.device_put_p.bind(y, device=device), x)
return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]):
@ -2617,7 +2620,8 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]):
raise ValueError("the shards passed to device_put_sharded must have "
f"consistent shape and dtype, but got {a1} and {a2}.")
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
buffers = [buf for x, d in zip(xs, devices) for buf in xla.device_put(x, d)]
buffers = [buf for x, d in zip(xs, devices)
for buf in dispatch.device_put(x, d)]
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
return tree_multimap(_device_put_sharded, *shards)
@ -2660,7 +2664,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]):
core.raise_to_shaped(core.get_aval(x)))
assert (isinstance(aval, core.ShapedArray) and
len(xla.aval_to_xla_shapes(aval)) == 1)
buf, = xla.device_put(x, devices[0])
buf, = dispatch.device_put(x, devices[0])
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
return pxla.make_sharded_device_array(aval, None, [buf, *rest_bufs])
return tree_map(_device_put_replicated, x)

307
jax/_src/device_array.py Normal file
View File

@ -0,0 +1,307 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# On-device arrays.
from functools import partial, partialmethod
import operator
from typing import (Any, Optional, Union)
import weakref
import numpy as np
from jax import core
from jax._src.config import config
from jax._src import dtypes
from jax._src.lib import xla_client as xc
import jax._src.util as util
### device-persistent data
xe = xc._xla
Device = xc.Device
Buffer = xe.Buffer
def _forward_method(attrname, self, fun, *args):
return fun(getattr(self, attrname), *args)
_forward_to_value = partial(_forward_method, "_value")
# The following is used for the type xc.Buffer or _DeviceArray.
DeviceArrayProtocol = Any
DeviceArray = xc.DeviceArrayBase
def make_device_array(
aval: core.ShapedArray,
device: Optional[Device],
device_buffer: Buffer,
) -> Union[Buffer, "_DeviceArray"]:
"""Returns a DeviceArray implementation based on arguments.
This is to be used only within JAX. It will return either a PythonDeviceArray
or a C++ equivalent implementation.
"""
if isinstance(device_buffer, xc.Buffer):
if device_buffer.aval == aval and device_buffer._device == device:
return device_buffer
device_buffer = device_buffer.clone()
device_buffer._device = device
device_buffer.aval = aval
device_buffer.weak_type = aval.weak_type
return device_buffer
return _DeviceArray(aval, device, device_buffer)
def type_is_device_array(x):
"""Returns `True` if `x` is a non-sharded DeviceArray.
Use this function instead of `type(x) is Devicearray`.
"""
type_x = type(x)
return type_x is _DeviceArray or type_x is xc.Buffer
def device_array_supports_weakrefs():
try:
weakref.ref(DeviceArray())
return True
except TypeError:
return False
class _DeviceArray(DeviceArray): # type: ignore
"""A DeviceArray is an ndarray backed by a single device memory buffer."""
# We don't subclass ndarray because that would open up a host of issues,
# but lax_numpy.py overrides isinstance behavior and attaches ndarray methods.
__slots__ = [
"aval", "device_buffer", "_npy_value", "_device", "__weakref__"
]
__array_priority__ = 100
# DeviceArray has methods that are dynamically populated in lax_numpy.py,
# and this annotation is needed to make pytype happy.
_HAS_DYNAMIC_ATTRIBUTES = True
def __init__(self, aval: core.ShapedArray, device: Optional[Device],
device_buffer: Buffer):
"""Initializer.
Args:
aval: The abstract value associated to this array (shape+dtype+weak_type).
device: The optional sticky device. See
https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
device_buffer: The underlying buffer owning the on-device data.
"""
DeviceArray.__init__(self)
self.aval = aval
self.device_buffer = device_buffer
self._device = device
self._npy_value = None
if config.jax_enable_checks:
assert type(aval) is core.ShapedArray
npy_value = self._value
assert npy_value.dtype == aval.dtype and npy_value.shape == aval.shape, (
aval, npy_value.shape, npy_value.dtype)
assert (device is None) or device is device_buffer.device()
def _check_if_deleted(self):
if self.device_buffer is deleted_buffer:
raise RuntimeError("DeviceArray has been deleted.")
def block_until_ready(self):
"""Blocks the caller until the buffer's value has been computed on device.
This method is mostly useful for timing microbenchmarks that wish to
time how long a computation takes, without transferring the result back
to the host.
Returns the buffer object (`self`).
"""
self._check_if_deleted()
self.device_buffer.block_host_until_ready() # pytype: disable=attribute-error
return self
@property
def _value(self):
self._check_if_deleted()
if self._npy_value is None:
self._npy_value = self.device_buffer.to_py() # pytype: disable=attribute-error # bind-properties
self._npy_value.flags.writeable = False
return self._npy_value
@property
def shape(self):
return self.aval.shape
@property
def dtype(self):
return self.aval.dtype
@property
def size(self):
return util.prod(self.aval.shape)
@property
def ndim(self):
return len(self.aval.shape)
def copy_to_host_async(self):
"""Requests a copy of the buffer to the host."""
self._check_if_deleted()
if self._npy_value is None:
self.device_buffer.copy_to_host_async() # pytype: disable=attribute-error
def delete(self):
"""Deletes the device array and any cached copy on the host.
It is an error to access the contents of a `DeviceArray` after it has
been deleted.
Use of this method is optional; device buffers will be reclaimed
automatically by Python when a DeviceArray object is garbage collected.
However, it is sometimes useful to have more explicit control over the
time of deletion.
"""
self.device_buffer.delete() # pytype: disable=attribute-error
self.device_buffer = deleted_buffer
self._npy_value = None
@property
def __cuda_array_interface__(self):
return self.device_buffer.__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties
# Adding methods dynamically to both _DeviceArray and xc.Buffer
# pylint: disable=protected-access
for device_array in [DeviceArray]:
def copy(self):
"""Returns an ndarray (backed by host memory, not device memory)."""
return np.asarray(self)
setattr(device_array, "copy", copy)
def __repr__(self):
line_width = np.get_printoptions()["linewidth"]
prefix = '{}('.format(self.__class__.__name__.lstrip('_'))
s = np.array2string(self._value, prefix=prefix, suffix=',',
separator=', ', max_line_width=line_width)
if self.aval is not None and self.aval.weak_type:
dtype_str = f'dtype={self.dtype.name}, weak_type=True)'
else:
dtype_str = f'dtype={self.dtype.name})'
last_line_len = len(s) - s.rfind('\n') + 1
sep = ' '
if last_line_len + len(dtype_str) + 1 > line_width:
sep = ' ' * len(prefix)
return "{}{},{}{}".format(prefix, s, sep, dtype_str)
setattr(device_array, "__repr__", __repr__)
def item(self):
if dtypes.issubdtype(self.dtype, np.complexfloating):
return complex(self)
elif dtypes.issubdtype(self.dtype, np.floating):
return float(self)
elif dtypes.issubdtype(self.dtype, np.integer):
return int(self)
elif dtypes.issubdtype(self.dtype, np.bool_):
return bool(self)
else:
raise TypeError(self.dtype)
setattr(device_array, "item", item)
def __len__(self):
try:
return self.aval.shape[0]
except IndexError as err:
raise TypeError("len() of unsized object") from err # same as numpy error
setattr(device_array, "__len__", __len__)
def __iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())
setattr(device_array, "__iter__", __iter__)
def __reversed__(self):
return iter(self[::-1])
setattr(device_array, "__reversed__", __reversed__)
def __format__(self, format_spec):
# Simulates behavior of https://github.com/numpy/numpy/pull/9883
if self.ndim == 0:
return format(self._value[()], format_spec)
else:
return format(self._value, format_spec)
setattr(device_array, "__format__", __format__)
def __array__(self, dtype=None, context=None):
return np.asarray(self._value, dtype=dtype)
setattr(device_array, "__array__", __array__)
setattr(device_array, "__str__", partialmethod(_forward_to_value, str))
setattr(device_array, "__bool__", partialmethod(_forward_to_value, bool))
setattr(device_array, "__nonzero__", partialmethod(_forward_to_value, bool))
setattr(device_array, "__float__", lambda self: self._value.__float__())
setattr(device_array, "__int__", lambda self: self._value.__int__())
setattr(device_array, "__complex__", lambda self: self._value.__complex__())
setattr(device_array, "__hex__", partialmethod(_forward_to_value, hex))
setattr(device_array, "__oct__", partialmethod(_forward_to_value, oct))
setattr(device_array, "__index__", partialmethod(_forward_to_value,
operator.index))
to_bytes = lambda self, order="C": self._value.tobytes(order)
setattr(device_array, "tobytes", to_bytes)
del to_bytes
setattr(device_array, "tolist", lambda self: self._value.tolist())
# pickle saves and loads just like an ndarray
setattr(device_array, "__reduce__",
partialmethod(_forward_to_value, operator.methodcaller("__reduce__")))
# explicitly set to be unhashable.
setattr(device_array, "__hash__", None)
# clobbered when jax.numpy is imported, but useful in tests
setattr(device_array, "__eq__", lambda self, other: self._value == other)
# The following methods are dynamically overridden in lax_numpy.py.
def raise_not_implemented():
raise NotImplementedError
setattr(device_array, "__getitem__", lambda self, i: raise_not_implemented())
# pylint: enable=protected-access
class DeletedBuffer(object): pass
deleted_buffer = DeletedBuffer()
device_array_types = [xc.Buffer, _DeviceArray]
for _device_array in device_array_types:
core.literalable_types.add(_device_array)
core.pytype_aval_mappings[device_array] = core.ConcreteArray

671
jax/_src/dispatch.py Normal file
View File

@ -0,0 +1,671 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Primitive dispatch and jit dispatch.
from functools import partial
import itertools
from typing import (
Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union)
import warnings
from absl import logging
import numpy as np
from jax import core
from jax import linear_util as lu
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.masking as masking
import jax.interpreters.xla as xla
import jax.interpreters.partial_eval as pe
from jax.errors import UnexpectedTracerError
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src import device_array
from jax._src import dtypes
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
import jax._src.util as util
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
xe = xc._xla
Backend = xe.Client
Device = xc.Device
Buffer = xe.Buffer
XlaExecutable = xc.Executable
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
# This flag is set on exit; no logging should be attempted
_on_exit = False
### op-by-op execution
ArgSpec = Tuple[core.AbstractValue, Optional[Device]]
def arg_spec(x: Any) -> ArgSpec:
aval = xla.abstractify(x)
try:
return aval, x._device
except:
return aval, None
def apply_primitive(prim, *args, **params):
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
if config.jax_enable_mlir:
import jax.interpreters.mlir
return jax.interpreters.mlir.apply_primitive(prim, *args, **params)
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
**params)
return compiled_fun(*args)
# TODO(phawkins): update code referring to xla.apply_primitive to point here.
xla.apply_primitive = apply_primitive
@util.cache()
def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
avals, arg_devices = util.unzip2(arg_specs)
donated_invars = (False,) * len(arg_specs)
device = _device_from_arg_devices(arg_devices)
def prim_fun(*args):
out = prim.bind(*args, **params)
if prim.multiple_results:
return out
else:
return out,
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
prim.name, donated_invars, *arg_specs)
if not prim.multiple_results:
return lambda *args, **kw: compiled(*args, **kw)[0]
else:
return compiled
def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[Device]:
"""Given devices of inputs, determine where to perform a computation.
Args:
devices: list where each element is a either a `Device` instance or `None`.
Returns:
A `Device` instance or None.
Raises:
ValueError if input devices are inconsistent.
"""
try:
device, = {d for d in devices if d is not None} or (None,)
return device
except ValueError as err:
msg = "primitive arguments must be colocated on the same device, got {}"
raise ValueError(msg.format(", ".join(map(str, devices)))) from err
# JIT execution
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
if config.jax_enable_mlir:
import jax.interpreters.mlir
return jax.interpreters.mlir._xla_call_impl_mlir(
fun, *args, device=device, backend=backend, name=name,
donated_invars=donated_invars, inline=inline)
del inline # Only used at tracing time
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
try:
out = compiled_fun(*args)
except FloatingPointError:
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
print("Invalid value encountered in the output of a jit/pmap-ed function. "
"Calling the de-optimized version.")
# We want to run the wrapped function again (after _xla_callable already ran
# it), but linear_util.WrappedFun instances are meant to be run only once.
# In addition to re-executing the Python code, which is usually undesirable
# but which config.jax_debug_nans is meant to opt into, we'll be re-executing
# any linear_util.py-style side effects, i.e. re-populating Stores created
# by any transformation_with_aux's applied to fun. Since this is
# intentional here, to avoid "Store occupied" errors we clone the WrappedFun
# with empty stores.
stores = [lu.Store() for _ in fun.stores]
clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params)
with core.new_sublevel():
_ = clone.call_wrapped(*args) # probably won't return
return out
xla.xla_call_p.def_impl(_xla_call_impl)
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
_xla_callable = lu.cache(_xla_callable_uncached)
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
if any(isinstance(c, core.Tracer) for c in consts):
raise UnexpectedTracerError("Encountered an unexpected tracer.")
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx)
abstract_args, arg_devices = util.unzip2(pruned_arg_specs)
donated_invars = [
x for i, x in enumerate(donated_invars) if i in kept_var_idx
]
map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
jaxpr = apply_outfeed_rewriter(jaxpr)
nreps = jaxpr_replicas(jaxpr)
device = _xla_callable_device(nreps, backend, device, arg_devices)
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if not jaxpr.eqns:
return XlaComputation(
name, None, True, None, jaxpr, consts, device, abstract_args, out_avals,
kept_var_idx)
if not _on_exit:
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logging.log(log_priority, "Compiling %s (%s) for args %s.",
fun.__name__, id(fun), abstract_args)
if nreps > 1:
warnings.warn(
f"The jitted function {name} includes a pmap. Using "
"jit-of-pmap can lead to inefficient data movement, as the outer jit "
"does not preserve sharded data representations and instead collects "
"input and output arrays onto a single device. "
"Consider removing the outer jit unless you know what you're doing. "
"See https://github.com/google/jax/issues/2926.")
if nreps > xb.device_count(backend):
raise ValueError(
f"compiling computation `{name}` that requires {nreps} replicas, but "
f"only {xb.device_count(backend)} XLA devices are available.")
if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
raise NotImplementedError(
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).")
tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU
c = xc.XlaBuilder(f"jit_{fun.__name__}")
xla_consts = xla._xla_consts(c, consts)
xla_args, donated_invars = xla._xla_callable_args(c, abstract_args, tuple_args,
donated_invars=donated_invars)
platform = backend.platform
ctx = xla.TranslationContext(c, platform, xla.AxisEnv(nreps, (), ()),
xla.extend_name_stack(xla.wrap_name(name, 'jit')))
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
# There is a non-zero cost to building an output tuple, particularly on TPU.
# Avoid it if the output arity is 1.
output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes)
if platform in ("gpu", "tpu"):
donated_invars = xla.set_up_aliases(
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
if any(donated_invars):
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
unused_donations = [str(c.GetShape(a))
for a, d in zip(xla_args, donated_invars) if d]
warnings.warn("Some donated buffers were not usable: {}".format(
", ".join(unused_donations)))
built = c.build(output)
return XlaComputation(
name, built, False, donated_invars, nreps, device, backend, tuple_args,
abstract_args, out_avals, kept_var_idx)
def prefetch(x):
if isinstance(x, device_array.DeviceArray):
x.copy_to_host_async()
return x
def jaxpr_literals(jaxpr):
"""Generates all the literals inside a jaxpr, including nested subjaxprs."""
for eqn in jaxpr.eqns:
for v in eqn.invars:
if type(v) is core.Literal:
yield v.val
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_literals(subjaxpr)
def jaxpr_has_pmap(jaxpr):
"""Whether there is an xla_pmap primitive anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
if 'xla_pmap' in eqn.primitive.name:
return True
for subjaxpr in core.subjaxprs(jaxpr):
if jaxpr_has_pmap(subjaxpr):
return True
return False
def _prune_unused_inputs(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
# TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive
# applications that do not produce used outputs. Must handle side-effecting
# primitives and nested jaxpr.
used.update(
v for eqn in jaxpr.eqns for v in eqn.invars if isinstance(v, core.Var))
kept_const_idx, new_constvars = util.unzip2(
(i, v) for i, v in enumerate(jaxpr.constvars) if v in used)
kept_var_idx, new_invars = util.unzip2(
(i, v) for i, v in enumerate(jaxpr.invars) if v in used)
new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns)
return new_jaxpr, set(kept_const_idx), set(kept_var_idx)
# We can optionally set a Jaxpr rewriter that can be applied just before
# compilation. This mechanism is used for compiling id_tap, we can
# remove it once we bring the id_tap implementation into the core.
outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
if outfeed_rewriter is not None:
return outfeed_rewriter(jaxpr)
else:
return jaxpr
outfeed_primitives: Set[core.Primitive] = set()
def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool:
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
for eqn in jaxpr.eqns)
def _param_uses_outfeed(param):
if type(param) is core.Jaxpr:
if jaxpr_uses_outfeed(param):
return True
elif type(param) is core.ClosedJaxpr:
if jaxpr_uses_outfeed(param.jaxpr):
return True
return False
def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool:
if prim in outfeed_primitives:
return True
for param in params.values():
if isinstance(param, tuple):
if any(unsafe_map(_param_uses_outfeed, param)):
return True
elif _param_uses_outfeed(param):
return True
return False
def jaxpr_replicas(jaxpr) -> int:
"""The number of replicas needed for a jaxpr.
For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the
subjaxprs. For a list of eqns, take the maximum number of replicas.
"""
if isinstance(jaxpr, core.ClosedJaxpr):
jaxpr = jaxpr.jaxpr
return max(unsafe_map(eqn_replicas, jaxpr.eqns), default=1)
# TODO(mattjj): this function assumes that only pmap has a parameter named
# axis_size, and that it corresponds to cross-replica mapping
def eqn_replicas(eqn):
call_jaxpr = eqn.params.get("call_jaxpr")
if call_jaxpr:
return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr)
elif eqn.primitive in xla._initial_style_primitives:
return initial_style_primitive_replicas(eqn.params)
else:
return 1
def initial_style_primitive_replicas(params):
return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(), default=1)
def _xla_callable_device(nreps, backend, device, arg_devices):
if nreps > 1:
if device is not None or backend is not None:
raise ValueError(f"can't specify device or backend for jit-of-pmap, "
f"got device={device} and backend={backend}")
return None
else:
if device is None and backend is None:
return _device_from_arg_devices(arg_devices)
elif device is not None and backend is None:
return device
elif device is None and backend is not None:
return xb.get_backend(backend).get_default_device_assignment(1)[0]
else:
assert False # Unreachable given the error check in _xla_callable
# Result handlers
def aval_to_result_handler(device: Optional[Device],
aval: core.AbstractValue) -> Callable:
try:
return xla_result_handlers[type(aval)](device, aval)
except KeyError as err:
raise TypeError(f"No xla_result_handler for type: {type(aval)}") from err
def array_result_handler(device: Optional[Device], aval: core.ShapedArray):
if aval.dtype is dtypes.float0:
return lambda _: np.zeros(aval.shape, dtypes.float0)
return partial(device_array.make_device_array, core.raise_to_shaped(aval), device)
xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Callable]] = {
core.AbstractUnit: lambda _, __: lambda _: core.unit,
core.ShapedArray: array_result_handler,
core.ConcreteArray: array_result_handler,
}
xla_result_handlers[core.AbstractToken] = lambda _, __: lambda _: core.token
def needs_check_special():
return config.jax_debug_infs or config.jax_debug_nans
def check_special(name, bufs):
if needs_check_special():
for buf in bufs:
_check_special(name, buf.xla_shape(), buf)
def _check_special(name, xla_shape, buf):
assert not xla_shape.is_tuple()
if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
if config.jax_debug_nans and np.any(np.isnan(buf.to_py())):
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
if config.jax_debug_infs and np.any(np.isinf(buf.to_py())):
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
def _execute_compiled(name: str, compiled: XlaExecutable,
output_buffer_counts: Optional[Sequence[int]], handlers,
kept_var_idx, *args):
device, = compiled.local_devices()
input_bufs = list(
itertools.chain.from_iterable(
device_put(x, device)
for i, x in enumerate(args)
if x is not core.token and i in kept_var_idx))
out_bufs = compiled.execute(input_bufs)
check_special(name, out_bufs)
if output_buffer_counts is None:
return (handlers[0](*out_bufs),)
return tuple(
handler(*bs) for handler, bs in
unsafe_zip(handlers, xla._partition_outputs(output_buffer_counts, out_bufs)))
def _execute_replicated(name: str, compiled: XlaExecutable,
output_buffer_counts: Optional[Sequence[int]], handlers,
kept_var_idx, *args):
input_bufs = [
list(
itertools.chain.from_iterable(
device_put(x, device)
for i, x in enumerate(args)
if x is not core.token and i in kept_var_idx))
for device in compiled.local_devices()
]
out_bufs = [
buf[0] for buf in compiled.execute_sharded_on_local_devices(
list(zip(*input_bufs)))
]
check_special(name, out_bufs)
if output_buffer_counts is None:
return (handlers[0](*out_bufs),)
return tuple(
handler(*bs) for handler, bs in
unsafe_zip(handlers, xla._partition_outputs(output_buffer_counts, out_bufs)))
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
kept_var_idx, *args):
env = {core.unitvar: core.unit}
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
map(env.setdefault, jaxpr.invars, pruned_args)
map(env.setdefault, jaxpr.constvars, consts)
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
for v in jaxpr.outvars]
return [_copy_device_array_to_device(x, device) if device_array.type_is_device_array(x)
else h(*device_put(x, device)) for h, x in zip(handlers, outs)]
class XlaComputation:
name: str
_is_trivial: bool
_executable: Optional['XlaCompiledComputation']
_donated_invars: Optional[Sequence[bool]]
def __init__(self, name: str, hlo, is_trivial: bool,
donated_invars: Optional[Sequence[bool]], *compile_args):
self.name = name
self._hlo = hlo
self._is_trivial = is_trivial
self._donated_invars = donated_invars
self._executable = None
self.compile_args = compile_args
def is_trivial(self):
return self._is_trivial
def hlo(self):
if self.is_trivial():
raise ValueError("A trivial computation has no HLO")
return self._hlo
def compile(self) -> 'XlaCompiledComputation':
if self._executable is None:
if self.is_trivial():
self._executable = XlaCompiledComputation.from_trivial_jaxpr(
*self.compile_args)
else:
self._executable = XlaCompiledComputation.from_xla_computation(
self.name, self.hlo(), *self.compile_args)
return self._executable
def backend_compile(backend, built_c, options):
# we use a separate function call to ensure that XLA compilation appears
# separately in Python profiling results
return backend.compile(built_c, compile_options=options)
# TODO(phawkins): update users.
xla.backend_compile = backend_compile
def compile_or_get_cached(backend, computation, compile_options):
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc
# Persistent compilation cache only implemented on TPU.
# TODO(skye): add warning when initializing cache on unsupported default platform
if cc.is_initialized() and backend.platform == 'tpu':
cached_executable = cc.get_executable(computation, compile_options, backend)
if cached_executable is not None:
logging.info('Persistent compilation cache hit')
return cached_executable
else:
compiled = backend_compile(backend, computation, compile_options)
cc.put_executable(computation, compile_options, compiled, backend)
return compiled
return backend_compile(backend, computation, compile_options)
class XlaCompiledComputation:
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call):
self._xla_executable = xla_executable
self.in_avals = in_avals
self._kept_var_idx = kept_var_idx
self.unsafe_call = unsafe_call
@staticmethod
def from_xla_computation(
name: str,
xla_computation,
nreps: int,
device,
backend,
tuple_args: bool,
in_avals,
out_avals,
kept_var_idx) -> 'XlaCompiledComputation':
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
options = xb.get_compile_options(
num_replicas=nreps,
num_partitions=1,
device_assignment=(device.id,) if device else None)
options.parameter_is_tupled_arguments = tuple_args
compiled = compile_or_get_cached(backend, xla_computation, options)
buffer_counts = (None if len(out_avals) == 1 else
[len(xla.aval_to_xla_shapes(aval)) for aval in out_avals])
execute = _execute_compiled if nreps == 1 else _execute_replicated
unsafe_call = partial(execute, name, compiled, buffer_counts,
result_handlers, kept_var_idx)
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call)
def is_trivial(self):
return self._xla_executable == None
def xla_executable(self):
if self.is_trivial():
raise ValueError("A trivial compiled computation has no XLA executable")
return self._xla_executable
@staticmethod
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals,
kept_var_idx) -> 'XlaCompiledComputation':
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
unsafe_call = partial(_execute_trivial, jaxpr, device, consts,
out_avals, result_handlers, kept_var_idx)
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call)
def call(self, *args):
arg_specs = unsafe_map(arg_spec, args)
arg_avals = [spec[0] for i, spec in enumerate(arg_specs)
if i in self._kept_var_idx]
check_arg_avals_for_call(self.in_avals, arg_avals)
return self.unsafe_call(*args)
def check_arg_avals_for_call(ref_avals, arg_avals):
if len(ref_avals) != len(arg_avals):
raise TypeError(
f"Computation compiled for {len(ref_avals)} inputs "
f"but called with {len(arg_avals)}")
for ref_aval, arg_aval in zip(ref_avals, arg_avals):
if not core.typematch(ref_aval, arg_aval):
ref_avals_fmt = ', '.join(str(a) for a in ref_avals)
arg_avals_fmt = ', '.join(str(a) for a in arg_avals)
raise TypeError(
f"Computation compiled for input types:\n {ref_avals_fmt}\n"
f"called with:\n {arg_avals_fmt}")
def device_put(x, device: Optional[Device] = None) -> Tuple[Any]:
x = xla.canonicalize_dtype(x)
try:
return device_put_handlers[type(x)](x, device)
except KeyError as err:
raise TypeError(f"No device_put handler for type: {type(x)}") from err
# TODO(phawkins): update users.
xla.device_put = device_put
def _device_put_array(x, device: Optional[Device]):
backend = xb.get_device_backend(device)
if x.dtype is dtypes.float0:
x = np.zeros(x.shape, dtype=np.dtype(bool))
return (backend.buffer_from_pyval(x, device),)
def _device_put_scalar(x, device):
return _device_put_array(dtypes.coerce_to_array(x), device)
def _device_put_unit(_, device):
backend = xb.get_device_backend(device)
return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')),
device),)
_scalar_types = dtypes.python_scalar_dtypes.keys()
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = {
core.Unit: _device_put_unit
}
device_put_handlers.update((t, _device_put_array) for t in array_types)
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
device_put_handlers[core.Token] = lambda x, _: (x,)
def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[Device]):
x = _copy_device_array_to_device(x, device)
return (x.device_buffer,)
for t in device_array.device_array_types:
device_put_handlers[t] = _device_put_device_array
def _copy_device_array_to_device(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[xc.Device]) -> Union[device_array.DeviceArrayProtocol, device_array._DeviceArray]:
if device is None:
# no copying to be done because there's no target specified
return x
elif xb.get_device_backend(device).platform == x.device_buffer.platform():
# source and target platforms are the same
if x.device_buffer.device() == device:
# no copying to be done because source equals target
if x._device == device:
return x
else:
moved_buf = x.device_buffer # We need to change stickyness
else:
# move the buffer with a device-to-device copy
moved_buf = x.device_buffer.copy_to_device(device)
else:
# buffers from different XLA backends are passed through the host.
backend = xb.get_device_backend(device)
moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device)
return device_array.make_device_array(x.aval, device, moved_buf)
def _device_put_impl(x, device: Optional[Device] = None):
if device_array.type_is_device_array(x):
return _copy_device_array_to_device(x, device)
try:
a = xla.abstractify(x)
except TypeError as err:
raise TypeError(
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
return aval_to_result_handler(device, a)(*device_put(x, device))
device_put_p = core.Primitive('device_put')
device_put_p.def_impl(_device_put_impl)
device_put_p.def_abstract_eval(lambda x, device=None: x)
xla.translations[device_put_p] = lambda c, x, device=None: x
ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
masking.defvectorized(device_put_p)
batching.defvectorized(device_put_p)

View File

@ -14,7 +14,7 @@
from jax import core
from jax import numpy as jnp
from jax.interpreters import xla
from jax._src import device_array
from jax._src.lib import xla_client
from jax._src.lib import xla_bridge
@ -23,7 +23,7 @@ SUPPORTED_DTYPES = set([jnp.int8, jnp.int16, jnp.int32, jnp.int64,
jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64])
def to_dlpack(x: xla.DeviceArrayProtocol, take_ownership: bool = False):
def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False):
"""Returns a DLPack tensor that encapsulates a DeviceArray `x`.
Takes ownership of the contents of `x`; leaves `x` in an invalid/deleted
@ -37,7 +37,7 @@ def to_dlpack(x: xla.DeviceArrayProtocol, take_ownership: bool = False):
undefined behavior if the DLPack consumer writes to a buffer that JAX
owns.
"""
if not isinstance(x, xla.DeviceArray):
if not isinstance(x, device_array.DeviceArray):
raise TypeError("Argument to to_dlpack must be a DeviceArray, got {}"
.format(type(x)))
return xla_client._xla.buffer_to_dlpack_managed_tensor(
@ -62,4 +62,4 @@ def from_dlpack(dlpack):
xla_shape = buf.xla_shape()
assert not xla_shape.is_tuple()
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
return xla.make_device_array(aval, buf.device(), buf) # pytype: disable=attribute-error
return device_array.make_device_array(aval, buf.device(), buf) # pytype: disable=attribute-error

View File

@ -33,6 +33,8 @@ from jax import core
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import device_array
from jax._src import dispatch
from jax import linear_util as lu
from jax._src import dtypes
from jax import tree_util
@ -473,7 +475,7 @@ def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
operand = np.asarray(operand, new_dtype)
if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
return operand
else:
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
@ -813,7 +815,7 @@ def broadcast_in_dim(operand: Array, shape: Shape,
shape = _broadcast_in_dim_shape_rule(
operand, shape=shape, broadcast_dimensions=broadcast_dimensions)
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions)
and isinstance(operand, (xla.DeviceArray, core.Tracer))):
and isinstance(operand, (device_array.DeviceArray, core.Tracer))):
return operand
return broadcast_in_dim_p.bind(
operand, shape=tuple(shape),
@ -872,7 +874,7 @@ def reshape(operand: Array, new_sizes: Shape,
dims = api_util._ensure_index_tuple(dimensions)
same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
if (np.shape(operand) and same_shape and same_dims
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
return operand
else:
return reshape_p.bind(
@ -1405,7 +1407,7 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array:
"""
permutation = tuple(operator.index(d) for d in permutation)
if (permutation == tuple(range(np.ndim(operand)))
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
return operand
else:
return transpose_p.bind(operand, permutation=permutation)
@ -1607,11 +1609,11 @@ def full(shape: Shape, fill_value: Array, dtype: Optional[DType] = None) -> Arra
return broadcast(fill_value, shape)
def _device_put_raw(x, weak_type=None):
if isinstance(x, xla.DeviceArray):
if isinstance(x, device_array.DeviceArray):
return x
else:
aval = raise_to_shaped(core.get_aval(x), weak_type=weak_type)
return xla.array_result_handler(None, aval)(*xla.device_put(x))
return dispatch.array_result_handler(None, aval)(*dispatch.device_put(x))
def zeros_like_shaped_array(aval):
assert isinstance(aval, ShapedArray)
@ -2107,10 +2109,11 @@ def zeros_like_array(x):
for t in itertools.chain(
dtypes.python_scalar_dtypes.keys(), array_types,
[xla._CppDeviceArray, xla._DeviceArray, pxla.ShardedDeviceArray, pxla.pmap_lib.ShardedDeviceArray]):
device_array.device_array_types,
[pxla.ShardedDeviceArray, pxla.pmap_lib.ShardedDeviceArray]):
ad_util.jaxval_adders[t] = add
ad_util.jaxval_zeros_likers[xla._DeviceArray] = zeros_like_array
ad_util.jaxval_zeros_likers[xla._CppDeviceArray] = zeros_like_array
ad_util.jaxval_zeros_likers[device_array._DeviceArray] = zeros_like_array
ad_util.jaxval_zeros_likers[device_array.Buffer] = zeros_like_array
ad_util.jaxval_zeros_likers[pxla.ShardedDeviceArray] = zeros_like_array
ad_util.jaxval_zeros_likers[pxla.pmap_lib.ShardedDeviceArray] = zeros_like_array

View File

@ -47,9 +47,9 @@ from jax._src.api_util import _ensure_index_tuple
from jax import errors
from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
from jax.config import config
from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray, make_device_array
from jax.interpreters import pxla
from jax import lax
from jax._src import device_array
from jax._src.lax.lax import _array_copy
from jax._src.ops import scatter
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
@ -362,8 +362,9 @@ class ndarray(metaclass=ArrayMeta):
def weak_type(self) -> bool: ...
ndarray.register(DeviceArray)
ndarray.register(_CppDeviceArray)
ndarray.register(device_array.DeviceArray)
for t in device_array.device_array_types:
ndarray.register(t)
ndarray.register(pxla._SDA_BASE_CLASS)
@ -3593,12 +3594,12 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0, *, device=None):
lax._check_user_dtype_supported(_inferred_dtype, "array")
out = _np_array(object, copy=copy, dtype=dtype)
if dtype: assert _dtype(out) == dtype
elif isinstance(object, (DeviceArray, core.Tracer)):
elif isinstance(object, (device_array.DeviceArray, core.Tracer)):
if object.aval is None:
# object is a raw buffer; convert to device array on its current device.
aval = ShapedArray(object.xla_shape().dimensions(), object.dtype,
weak_type=bool(getattr(object, "weak_type", False)))
object = make_device_array(aval, object.device(), object)
object = device_array.make_device_array(aval, object.device(), object)
out = _array_copy(object) if copy else object
elif isinstance(object, (list, tuple)):
if object:
@ -3630,7 +3631,7 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0, *, device=None):
return out
def _can_call_numpy_array(x):
return _all(not isinstance(l, (core.Tracer, DeviceArray))
return _all(not isinstance(l, (core.Tracer, device_array.DeviceArray))
for l in tree_leaves(x))
@ -6715,7 +6716,7 @@ _NOT_IMPLEMENTED = ['argpartition']
# Experimental support for NumPy's module dispatch with NEP-37.
# Currently requires https://github.com/seberg/numpy-dispatch
_JAX_ARRAY_TYPES = (DeviceArray, core.Tracer)
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer)
_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,)
def __array_module__(self, types):
@ -6752,7 +6753,7 @@ def _multi_slice(arr,
@jit
def _unstack(x):
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
setattr(DeviceArray, "_unstack", _unstack)
setattr(device_array.DeviceArray, "_unstack", _unstack)
def _chunk_iter(x, size):
if size > x.shape[0]:
yield x
@ -6762,7 +6763,7 @@ def _chunk_iter(x, size):
yield lax.dynamic_slice_in_dim(x, i * size, size)
if tail:
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
setattr(DeviceArray, "_chunk_iter", _chunk_iter)
setattr(device_array.DeviceArray, "_chunk_iter", _chunk_iter)
# Syntactic sugar for scatter operations.
class _IndexUpdateHelper:
@ -7050,7 +7051,7 @@ def _set_device_array_base_attributes(device_array):
setattr(device_array, "nbytes", property(_nbytes))
setattr(device_array, "clip", _clip)
_set_device_array_base_attributes(DeviceArray)
_set_device_array_base_attributes(device_array.DeviceArray)
def _set_device_array_attributes(device_array):
@ -7063,7 +7064,7 @@ def _set_device_array_attributes(device_array):
setattr(device_array, "_multi_slice", _multi_slice)
setattr(device_array, "at", property(_IndexUpdateHelper))
_set_device_array_attributes(_DeviceArray)
_set_device_array_attributes(_CppDeviceArray)
for t in device_array.device_array_types:
_set_device_array_attributes(t)
_set_device_array_attributes(pxla._ShardedDeviceArray)
_set_device_array_attributes(pxla.pmap_lib.ShardedDeviceArray)

View File

@ -38,6 +38,7 @@ from jax._src.config import flags, bool_env, config
from jax._src.util import prod, unzip2
from jax.tree_util import tree_multimap, tree_all, tree_map, tree_reduce
from jax._src.lib import xla_bridge
from jax._src import dispatch
from jax.interpreters import xla
from jax.experimental.maps import mesh
@ -337,29 +338,29 @@ def check_grads(f, args, order,
@contextmanager
def count_device_put():
device_put = xla.device_put
device_put = dispatch.device_put
count = [0]
def device_put_and_count(*args, **kwargs):
count[0] += 1
return device_put(*args, **kwargs)
xla.device_put = device_put_and_count
dispatch.device_put = device_put_and_count
try:
yield count
finally:
xla.device_put = device_put
dispatch.device_put = device_put
@contextmanager
def count_primitive_compiles():
xla.xla_primitive_callable.cache_clear()
dispatch.xla_primitive_callable.cache_clear()
count = [-1]
try:
yield count
finally:
count[0] = xla.xla_primitive_callable.cache_info().misses
count[0] = dispatch.xla_primitive_callable.cache_info().misses
@contextmanager
@ -986,11 +987,11 @@ class JaxTestCase(parameterized.TestCase):
np_shapes = tree_map(lambda x: np.shape(np.asarray(x)), python_ans)
self.assertEqual(python_shapes, np_shapes)
cache_misses = xla.xla_primitive_callable.cache_info().misses
cache_misses = dispatch.xla_primitive_callable.cache_info().misses
python_ans = fun(*args)
if check_cache_misses:
self.assertEqual(
cache_misses, xla.xla_primitive_callable.cache_info().misses,
cache_misses, dispatch.xla_primitive_callable.cache_info().misses,
"Compilation detected during second call of {} in op-by-op "
"mode.".format(fun))

View File

@ -1245,6 +1245,11 @@ class AbstractToken(AbstractValue):
abstract_token: AbstractToken = AbstractToken()
# Concrete token object
class Token(object): pass
token = Token()
pytype_aval_mappings[Token] = lambda _: abstract_token
def raise_to_shaped(aval: AbstractValue, weak_type=None):
if weak_type is None:

View File

@ -19,6 +19,7 @@ from typing import (Tuple, List, Sequence, Set, Dict, Any, Callable, Union,
Optional)
from jax import core
from jax._src import dispatch
from jax._src import source_info_util
from jax.core import Var, Literal, Atom, Tracer
from jax._src.util import (safe_zip, safe_map, curry, unzip2, split_list,
@ -625,12 +626,12 @@ xla.xla_shape_handlers[AbsArray] = _array_xla_shape
xla.canonicalize_dtype_handlers[Array] = identity
def _array_device_put(x, device):
return xla._device_put_array(x._data, device)
xla.device_put_handlers[Array] = _array_device_put
return dispatch._device_put_array(x._data, device)
dispatch.device_put_handlers[Array] = _array_device_put
def _bdint_device_put(x, device):
return xla._device_put_scalar(x._val, device)
xla.device_put_handlers[BoundedInt] = _bdint_device_put
return dispatch._device_put_scalar(x._val, device)
dispatch.device_put_handlers[BoundedInt] = _bdint_device_put
def _bdint_canoncalize_dtype(x):
return BoundedInt(xla.canonicalize_dtype(x._val), x._bound)
@ -677,8 +678,8 @@ def djaxpr_subcomp(c, jaxpr, dim_args, args):
def execute_compiled(compiled, partitioner, handlers, dim_vals, args):
input_bufs = list(it.chain(
(buf for x in dim_vals for buf in xla.device_put(x, None)),
(buf for x in args for buf in xla.device_put(x, None))))
(buf for x in dim_vals for buf in dispatch.device_put(x, None)),
(buf for x in args for buf in dispatch.device_put(x, None))))
out_bufs = compiled.execute(input_bufs)
dims_dict, grouped_bufs = partitioner(out_bufs)
return [handler(dims_dict, bs) for handler, bs in zip(handlers, grouped_bufs)]
@ -705,7 +706,7 @@ def result_handler(aval):
if isinstance(aval, AbsArray):
return array_result_handler(aval)
else:
handler = xla.aval_to_result_handler(None, aval)
handler = dispatch.aval_to_result_handler(None, aval)
return lambda _, bufs: handler(*bufs)
def array_result_handler(aval):
@ -721,7 +722,7 @@ def array_result_handler(aval):
else:
raise NotImplementedError # TODO
padded_aval = core.ShapedArray(tuple(padded_shape), aval._eltTy._dtype)
array_handler = xla.array_result_handler(None, padded_aval)
array_handler = dispatch.array_result_handler(None, padded_aval)
def handler(dims_dict, bufs):
shape = tuple(dims_dict[d] if isinstance(d, Var) else
DimIndexer(dims_dict[d.name], d.indices) if isinstance(d, DimIndexingExpr) else

View File

@ -458,6 +458,7 @@ from jax import lax
from jax.experimental import pjit
from jax.interpreters import ad, xla, batching, masking, pxla
from jax.interpreters import partial_eval as pe
from jax._src import dispatch
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src import util
@ -880,7 +881,7 @@ It takes the following parameters:
"""
outside_call_p = core.Primitive("outside_call")
outside_call_p.multiple_results = True
xla.outfeed_primitives.add(outside_call_p)
dispatch.outfeed_primitives.add(outside_call_p)
def _outside_call_abstract_eval(*args_a: pe.AbstractValue,
@ -918,7 +919,7 @@ def _outside_call_impl(*args, **params):
# even in eager execution some primitives, such as while, are compiled.
# It would be confusing to process a sequence "id_tap; while" in two
# different threads.
return xla.apply_primitive(outside_call_p, *args, **params)
return dispatch.apply_primitive(outside_call_p, *args, **params)
outside_call_p.def_impl(_outside_call_impl)
@ -1355,7 +1356,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
"""Rewrite a Jaxpr to thread the token, if needed."""
assert has_input_token or not has_output_token
if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr):
if not has_input_token and not dispatch.jaxpr_uses_outfeed(jaxpr):
return jaxpr
mk_new_var = core.gensym([jaxpr])
@ -1377,7 +1378,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
lax.create_token_p, {}, source_info_util.current()))
for eqn in jaxpr.eqns:
if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params):
if not dispatch.primitive_uses_outfeed(eqn.primitive, eqn.params):
eqns.append(eqn)
else:
output_token_var = mk_new_var(last_token_var.aval)
@ -1415,7 +1416,7 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
cond_jaxpr, _, body_jaxpr, _ = util.split_dict(
eqn.params,
["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
if dispatch.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
_rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var,
input_itoken_var, output_itoken_var,
mk_new_var)
@ -1692,7 +1693,7 @@ id_p.def_impl(lambda *args: args)
id_p.def_abstract_eval(lambda *args: args)
xla.register_translation(id_p, lambda ctx, avals_in, avals_out, *args: args)
xla.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)
dispatch.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)
class CallbackException(Exception):
@ -1820,7 +1821,7 @@ def _initialize_outfeed_receiver(
def exit_handler():
# Prevent logging usage during compilation, gives errors under pytest
xla._on_exit = True # type: ignore[protected-access]
dispatch._on_exit = True # type: ignore[protected-access]
if not _callback_handler_data.on_exit:
_callback_handler_data.on_exit = True
barrier_wait("at_exit")

View File

@ -26,6 +26,7 @@ from jax._src import api_util
from jax import config
from jax._src import api
from jax import core, custom_derivatives
from jax._src import dispatch
from jax._src import dtypes
from jax import linear_util as lu
from jax import random, tree_util
@ -968,7 +969,7 @@ def _add(x: TfVal, y: TfVal) -> TfVal:
tf_impl[ad_util.add_jaxvals_p] = _add
tf_impl[xla.device_put_p] = lambda x, device=None: x
tf_impl[dispatch.device_put_p] = lambda x, device=None: x
def _neg(x: TfVal) -> TfVal:
if x.dtype.is_unsigned:

View File

@ -52,7 +52,7 @@ from jax._src import test_util as jtu
from jax import lax
from jax import numpy as jnp
from jax._src.lax import control_flow as lax_control_flow
from jax.interpreters import xla
from jax._src import dispatch
from jax._src.lib import xla_client
@ -639,7 +639,7 @@ def _make_device_put_harness(name,
define(
"device_put",
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_device={device}",
lambda x: xla.device_put_p.bind(x, device=_device_fn()),
lambda x: dispatch.device_put_p.bind(x, device=_device_fn()),
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype,

View File

@ -23,6 +23,7 @@ import jax.numpy as jnp
from jax import core
from jax._src.util import unzip2
from jax._src import ad_util
from jax._src import dispatch
from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten)
import jax.linear_util as lu
@ -240,7 +241,7 @@ deflinear(lax.slice_p)
deflinear(lax.reduce_sum_p)
deflinear(lax.reduce_window_sum_p)
deflinear(lax.fft_p)
deflinear(xla.device_put_p)
deflinear(dispatch.device_put_p)
def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool,
combine_fn: Callable):

View File

@ -27,6 +27,7 @@ from .. import numpy as jnp
from .. import core
from .. import linear_util as lu
from .._src.api import _check_callable, _check_arg
from jax._src import dispatch
from ..tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
tree_leaves)
from .._src.tree_util import _replace_nones
@ -735,7 +736,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
use_spmd_lowering, in_avals,
tile_by_mesh_axes=True)
else:
return xla.lower_xla_callable(
return dispatch.lower_xla_callable(
f, None, backend, name, donated_invars, *((a, None) for a in in_avals))
class EvaluationPlan(NamedTuple):

View File

@ -29,7 +29,6 @@ from .. import linear_util as lu
from .._src.util import (unzip2, safe_map, safe_zip, wrap_name, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize)
from . import xla
from . import partial_eval as pe
map = safe_map
@ -629,5 +628,3 @@ def zeros_like_batched(batched_args, batch_dims):
bdim, = batch_dims
return zeros_like_jaxval(val), bdim
primitive_batchers[zeros_like_p] = zeros_like_batched
defvectorized(xla.device_put_p)

View File

@ -31,6 +31,8 @@ from jax import linear_util as lu
from jax._src.config import config
from jax._src import ad_util
from jax._src import custom_derivatives
from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.lax import control_flow
@ -234,8 +236,8 @@ for ptype, dtype in dtypes.python_scalar_dtypes.items():
def _device_array_constant_handler(val, canonicalize_types):
return _ndarray_constant_handler(val.device_buffer.to_py(),
canonicalize_types)
register_constant_handler(xla._DeviceArray, _device_array_constant_handler)
register_constant_handler(xla._CppDeviceArray, _device_array_constant_handler)
for t in device_array.device_array_types:
register_constant_handler(t, _device_array_constant_handler)
# Source locations
@ -523,7 +525,7 @@ translations[core.call_p] = partial(_named_call_lowering, name="core_call")
def _device_put_lowering(ctx, avals_in, avals_out, x, *, device):
return [x]
translations[xla.device_put_p] = _device_put_lowering
translations[dispatch.device_put_p] = _device_put_lowering
def _full_like_aval(value, aval: core.ShapedArray) -> ir.Value:
@ -629,7 +631,7 @@ def _execute_compiled(name: str, compiled: xla.XlaExecutable,
unsafe_zip(arg_handlers,
(x for i, x in enumerate(args) if i in kept_var_idx)))
out_bufs = compiled.execute(input_bufs)
xla.check_special(name, out_bufs)
dispatch.check_special(name, out_bufs)
return [handler(device, *bs) for handler, bs in
zip(result_handlers, xla._partition_outputs(buffer_counts, out_bufs))]
@ -649,7 +651,7 @@ def _execute_replicated(name: str, compiled: xla.XlaExecutable,
buf[0] for buf in compiled.execute_sharded_on_local_devices(
list(zip(*input_bufs)))
]
xla.check_special(name, out_bufs)
dispatch.check_special(name, out_bufs)
return [handler(device, *bs) for handler, bs in
zip(result_handlers, xla._partition_outputs(buffer_counts, out_bufs))]
@ -662,7 +664,7 @@ def _execute_trivial(jaxpr, device: Optional[xla.Device], consts, buffer_counts,
map(env.setdefault, jaxpr.constvars, consts)
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
for v in jaxpr.outvars]
return [xla.device_put_p.bind(x, device=device) for x in outs]
return [dispatch.device_put_p.bind(x, device=device) for x in outs]
class XlaCompiledComputation:
@ -688,7 +690,7 @@ class XlaCompiledComputation:
num_partitions=1,
device_assignment=(device.id,) if device else None)
options.parameter_is_tupled_arguments = tuple_args
compiled = xla.compile_or_get_cached(backend, xla_computation, options)
compiled = dispatch.compile_or_get_cached(backend, xla_computation, options)
buffer_counts = [aval_to_num_buffers(aval) for aval in avals_out]
if nreps == 1:
return XlaCompiledComputation(compiled, partial(
@ -771,18 +773,18 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars
if any(isinstance(c, core.Tracer) for c in consts):
raise UnexpectedTracerError("Encountered an unexpected tracer.")
jaxpr, kept_const_idx, kept_var_idx = xla._prune_unused_inputs(jaxpr)
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx)
avals_in, arg_devices = util.unzip2(pruned_arg_specs)
donated_invars = [
x for i, x in enumerate(donated_invars) if i in kept_var_idx
]
map(xla.prefetch, itertools.chain(consts, xla.jaxpr_literals(jaxpr)))
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
map(dispatch.prefetch, itertools.chain(consts, dispatch.jaxpr_literals(jaxpr)))
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
nreps = xla.jaxpr_replicas(jaxpr)
device = xla._xla_callable_device(nreps, backend, device, arg_devices)
nreps = dispatch.jaxpr_replicas(jaxpr)
device = dispatch._xla_callable_device(nreps, backend, device, arg_devices)
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
# Computations that only produce constants and/or only rearrange their inputs,
@ -792,7 +794,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars
return XlaComputation(name, None, True, jaxpr, consts, device, avals_in,
avals_out, kept_var_idx)
if not xla._on_exit:
if not dispatch._on_exit:
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logging.log(log_priority, "Compiling %s (%s) for args %s.",
fun.__name__, id(fun), avals_in)
@ -810,7 +812,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars
f"compiling computation that requires {nreps} replicas, but only "
f"{xb.device_count(backend)} XLA devices are available")
if xb.process_count() > 1 and (nreps > 1 or xla.jaxpr_has_pmap(jaxpr)):
if xb.process_count() > 1 and (nreps > 1 or dispatch.jaxpr_has_pmap(jaxpr)):
raise NotImplementedError(
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).")
@ -855,15 +857,15 @@ def _xla_call_impl_mlir(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(xla.arg_spec, args))
*unsafe_map(dispatch.arg_spec, args))
return compiled_fun(*args)
@util.cache()
def _xla_primitive_callable(prim, *arg_specs: xla.ArgSpec, **params):
def _xla_primitive_callable(prim, *arg_specs: dispatch.ArgSpec, **params):
avals, arg_devices = util.unzip2(arg_specs)
donated_invars = (False,) * len(arg_specs)
device = xla._device_from_arg_devices(arg_devices)
device = dispatch._device_from_arg_devices(arg_devices)
def prim_fun(*args):
out = prim.bind(*args, **params)
if prim.multiple_results:
@ -879,7 +881,7 @@ def _xla_primitive_callable(prim, *arg_specs: xla.ArgSpec, **params):
def apply_primitive(prim, *args, **params):
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
compiled_fun = _xla_primitive_callable(prim, *unsafe_map(xla.arg_spec, args),
compiled_fun = _xla_primitive_callable(prim, *unsafe_map(dispatch.arg_spec, args),
**params)
return compiled_fun(*args)

View File

@ -46,11 +46,13 @@ from .. import core
from .. import linear_util as lu
from jax._src.abstract_arrays import array_types
from ..core import ConcreteArray, ShapedArray
from jax._src import device_array
from .._src import source_info_util
from .._src.util import (unzip3, prod, safe_map, safe_zip,
extend_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log)
from ..errors import JAXTypeError
from jax._src import dispatch
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
@ -321,15 +323,15 @@ def _shard_device_array(x, devices, indices):
_as_slice_indices(x, idx) for idx in indices)
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
return device_put(shards, devices)
shard_arg_handlers[xla._DeviceArray] = _shard_device_array
shard_arg_handlers[xla._CppDeviceArray] = _shard_device_array
for t in device_array.device_array_types:
shard_arg_handlers[t] = _shard_device_array
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
# from the input ShardingSpec, rather than the indices. However, this would
# require duplicating the ordering logic of spec_to_indices, which is more
# subtle and more likely to change than the index logic we have to support here.
def _as_slice_indices(arr: xla.DeviceArrayProtocol, idx: Index) -> Tuple[
def _as_slice_indices(arr: device_array.DeviceArrayProtocol, idx: Index) -> Tuple[
Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]:
"""Returns start_indices, limit_indices, removed_dims"""
start_indices = [0] * arr.ndim
@ -509,11 +511,11 @@ if _USE_CPP_SDA:
ShardedDeviceArrayBase = pmap_lib.ShardedDeviceArrayBase # type: ignore
# We want the C++ SDA to extend the DeviceArrayBase. We want this both to
# benefit from its methods, and to have isinstance(x, DeviceArray) return true
ShardedDeviceArrayBase.__bases__ = ((xla.DeviceArray,) + # type: ignore
ShardedDeviceArrayBase.__bases__ = ((device_array.DeviceArray,) + # type: ignore
ShardedDeviceArrayBase.__bases__)
_SDA_BASE_CLASS = pmap_lib.ShardedDeviceArrayBase # type: ignore
else:
_SDA_BASE_CLASS: Type[xla.DeviceArray] = xla.DeviceArray # type: ignore
_SDA_BASE_CLASS: Type[device_array.DeviceArray] = device_array.DeviceArray # type: ignore
class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore
@ -646,7 +648,7 @@ def _sda__getitem__(self, idx):
if buf_idx is not None:
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return xla.make_device_array(aval, None, buf)
return device_array.make_device_array(aval, None, buf)
return super(self.__class__, self).__getitem__(idx)
@ -726,7 +728,7 @@ def _register_handlers_for_sharded_device_array(sda):
xla.register_constant_handler(sda, _sharded_device_array_constant_handler)
core.pytype_aval_mappings[sda] = ConcreteArray
xla.device_put_handlers[sda] = xla._device_put_array
dispatch.device_put_handlers[sda] = dispatch._device_put_array
xla.pytype_aval_mappings[sda] = op.attrgetter("aval")
xla.canonicalize_dtype_handlers[sda] = identity
@ -829,7 +831,7 @@ def parallel_callable(fun: lu.WrappedFun,
with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
out_axes = out_axes_thunk()
assert len(out_sharded_avals) == len(out_axes), (len(out_sharded_avals), len(out_axes))
@ -844,7 +846,7 @@ def parallel_callable(fun: lu.WrappedFun,
check_multihost_collective_allowlist(jaxpr)
# TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
jaxpr_replicas = xla.jaxpr_replicas(jaxpr)
jaxpr_replicas = dispatch.jaxpr_replicas(jaxpr)
num_local_replicas = axis_size * jaxpr_replicas
num_global_replicas = global_axis_size * jaxpr_replicas
@ -1024,7 +1026,7 @@ def parallel_callable(fun: lu.WrappedFun,
handle_outs)
return WeakRefList([execute_fun, None])
compiled = xla.compile_or_get_cached(backend, built, compile_options)
compiled = dispatch.compile_or_get_cached(backend, built, compile_options)
handle_args = InputsHandler(compiled.local_devices(), input_sharding_specs,
input_indices)
execute_fun = partial(execute_replicated, compiled, backend, handle_args, handle_outs)
@ -1314,9 +1316,9 @@ def partitioned_sharding_spec(num_partitions: int,
def execute_replicated(compiled, backend, in_handler, out_handler, *args):
input_bufs = in_handler(args)
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
if xla.needs_check_special():
if dispatch.needs_check_special():
for bufs in out_bufs:
xla.check_special("parallel computation", bufs)
dispatch.check_special("parallel computation", bufs)
return out_handler(out_bufs)
@ -1634,7 +1636,7 @@ def lower_mesh_computation(
_sanitize_mesh_jaxpr(jaxpr)
if local_mesh.shape != mesh.shape:
check_multihost_collective_allowlist(jaxpr)
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
# 3. Build up the HLO
c = xc.XlaBuilder(f"xmap_{fun.__name__}")
@ -1766,7 +1768,7 @@ class MeshExecutable:
input_indices, local_input_specs,
handle_outs)
else:
compiled = xla.compile_or_get_cached(backend, computation, compile_options)
compiled = dispatch.compile_or_get_cached(backend, computation, compile_options)
handle_args = InputsHandler(compiled.local_devices(), local_input_specs,
input_indices)
self.unsafe_call = partial(execute_replicated, compiled, backend, handle_args, handle_outs)
@ -1777,7 +1779,7 @@ class MeshExecutable:
def call(self, *args):
arg_avals = map(xla.abstractify, args)
ref_avals = self._local_in_untiled_avals
xla.check_arg_avals_for_call(ref_avals, arg_avals)
dispatch.check_arg_avals_for_call(ref_avals, arg_avals)
return self.unsafe_call(*args)
@ -1908,6 +1910,6 @@ _thread_local_state = _ThreadLocalState()
def device_put(x, devices: Sequence[xb.xla_client.Device], replicate: bool=False) -> List[xb.xla_client.Buffer]:
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
if replicate:
return list(it.chain.from_iterable(xla.device_put(x, device) for device in devices))
return list(it.chain.from_iterable(dispatch.device_put(x, device) for device in devices))
else:
return list(it.chain.from_iterable(xla.device_put(val, device) for val, device in safe_zip(x, devices)))
return list(it.chain.from_iterable(dispatch.device_put(val, device) for val, device in safe_zip(x, devices)))

View File

@ -25,6 +25,7 @@ from . import partial_eval as pe
from . import pxla
from . import xla
from .. import linear_util as lu
from jax._src import dispatch
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from .._src.api_util import argnums_partial, flatten_axes, flatten_fun, _ensure_index_tuple
@ -157,7 +158,7 @@ def _sharded_callable(
device_assignment = np.reshape(device_assignment, (-1, nparts))
# device_assignment = None # TODO(skye): replace with default device assignment?
compiled = xla.backend_compile(
compiled = dispatch.backend_compile(
xb.get_backend(), built,
xb.get_compile_options(nrep, nparts, device_assignment))

File diff suppressed because it is too large Load Diff

View File

@ -19,7 +19,7 @@
from . import fft as fft
from . import linalg as linalg
from jax.interpreters.xla import DeviceArray as DeviceArray
from jax._src.device_array import DeviceArray as DeviceArray
from jax._src.numpy.lax_numpy import (
ComplexWarning as ComplexWarning,

View File

@ -48,6 +48,7 @@ from jax.interpreters import ad
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters.sharded_jit import PartitionSpec as P
from jax._src import device_array
import jax._src.lib
from jax._src.lib import xla_client
from jax._src import test_util as jtu
@ -225,7 +226,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
def test_jit_device(self):
device = jax.devices()[-1]
x = self.jit(lambda x: x, device=device)(3.)
self.assertIsInstance(x, xla.DeviceArray)
self.assertIsInstance(x, jnp.DeviceArray)
self.assertEqual(x.device_buffer.device(), device)
def test_complex_support(self):
@ -492,7 +493,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
jitted_f = self.jit(lambda a: a + 1)
jitted_f(1)
self.assertIsInstance(jitted_f(2), xla._CppDeviceArray)
self.assertIsInstance(jitted_f(2), device_array.Buffer)
@jtu.skip_on_devices("cpu")
def test_explicit_backend(self):
@ -986,13 +987,13 @@ class APITest(jtu.JaxTestCase):
"Transpose rule (for reverse-mode differentiation) for 'foo' not implemented")
def test_is_subclass(self):
self.assertTrue(issubclass(xla.DeviceArray, jnp.ndarray))
self.assertTrue(issubclass(xla._CppDeviceArray, jnp.ndarray))
self.assertTrue(issubclass(device_array.DeviceArray, jnp.ndarray))
self.assertTrue(issubclass(device_array.Buffer, jnp.ndarray))
self.assertTrue(issubclass(pxla.ShardedDeviceArray, jnp.ndarray))
self.assertTrue(issubclass(pxla._ShardedDeviceArray, jnp.ndarray))
self.assertFalse(issubclass(np.ndarray, jnp.ndarray))
self.assertFalse(issubclass(xla.DeviceArray, np.ndarray))
self.assertFalse(issubclass(xla._CppDeviceArray, np.ndarray))
self.assertFalse(issubclass(device_array.DeviceArray, np.ndarray))
self.assertFalse(issubclass(device_array.Buffer, np.ndarray))
self.assertFalse(issubclass(pxla.ShardedDeviceArray, np.ndarray))
self.assertFalse(issubclass(pxla._ShardedDeviceArray, np.ndarray))
@ -1007,7 +1008,7 @@ class APITest(jtu.JaxTestCase):
def test_device_put_and_get(self):
x = np.arange(12.).reshape((3, 4)).astype("float32")
dx = api.device_put(x)
self.assertIsInstance(dx, xla.DeviceArray)
self.assertIsInstance(dx, device_array.DeviceArray)
self.assertIsInstance(dx, jnp.ndarray)
self.assertNotIsInstance(dx, np.ndarray)
x2 = api.device_get(dx)
@ -1030,7 +1031,7 @@ class APITest(jtu.JaxTestCase):
def test_device_get_scalar(self):
x = np.arange(12.).reshape((3, 4)).astype("float32")
x = api.device_put(x)
self.assertIsInstance(x, xla.DeviceArray)
self.assertIsInstance(x, device_array.DeviceArray)
y = [x, 2]
y2 = api.device_get(y)
self.assertIsInstance(y2, list)
@ -1465,11 +1466,11 @@ class APITest(jtu.JaxTestCase):
def test_devicearray_repr(self):
x = device_put(jnp.zeros(3))
self.assertIsInstance(x, xla.DeviceArray)
self.assertIsInstance(x, device_array.DeviceArray)
repr(x) # doesn't crash
x = device_put(jnp.ones(3) + 1j * jnp.ones(3))
self.assertIsInstance(x, xla.DeviceArray)
self.assertIsInstance(x, device_array.DeviceArray)
repr(x) # doesn't crash
def test_devicearray_delete(self):
@ -2304,7 +2305,7 @@ class APITest(jtu.JaxTestCase):
def test_device_array_hash(self):
rep = jnp.ones((1,)) + 1.
self.assertIsInstance(rep, jax.interpreters.xla.DeviceArray)
self.assertIsInstance(rep, device_array.DeviceArray)
self.assertNotIsInstance(rep, collections.abc.Hashable)
with self.assertRaisesRegex(TypeError, 'unhashable type'):
hash(rep)
@ -2733,7 +2734,7 @@ class APITest(jtu.JaxTestCase):
def test_jit_returning_token(self):
x = jax.jit(jax.lax.create_token)(1.0)
self.assertIsInstance(x, jax.interpreters.xla.Token)
self.assertIsInstance(x, jax.core.Token)
def test_leak_checker_catches_a_jit_leak(self):
with jax.checking_leaks():

View File

@ -19,6 +19,8 @@ import numpy as np
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax import core, jit, lax, make_jaxpr
from jax._src import device_array
from jax._src import dispatch
from jax.interpreters import xla
from jax._src.lib import xla_bridge, xla_client
xops = xla_client.ops
@ -102,8 +104,8 @@ class ConcreteSparseArray(AbstractSparseArray):
def sparse_array_result_handler(device, aval):
def build_sparse_array(data_buf, indices_buf):
data = xla.make_device_array(aval.data_aval, device, data_buf)
indices = xla.make_device_array(aval.indices_aval, device, indices_buf)
data = device_array.make_device_array(aval.data_aval, device, data_buf)
indices = device_array.make_device_array(aval.indices_aval, device, indices_buf)
return SparseArray(aval, data, indices)
return build_sparse_array
@ -129,8 +131,8 @@ core.pytype_aval_mappings[SparseArray] = lambda x: x.aval
core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval
xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
xla.device_put_handlers[SparseArray] = sparse_array_device_put_handler
xla.xla_result_handlers[AbstractSparseArray] = sparse_array_result_handler
dispatch.device_put_handlers[SparseArray] = sparse_array_device_put_handler
dispatch.xla_result_handlers[AbstractSparseArray] = sparse_array_result_handler
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
xla.register_constant_handler(SparseArray, sparse_array_constant_handler)
@ -257,8 +259,8 @@ core.pytype_aval_mappings[Empty] = lambda x: ConcreteEmpty()
core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval
xla.pytype_aval_mappings[Empty] = lambda x: AbstractEmpty()
xla.canonicalize_dtype_handlers[Empty] = lambda x: x
xla.device_put_handlers[Empty] = lambda _, __: ()
xla.xla_result_handlers[AbstractEmpty] = lambda _, __: lambda: Empty(AbstractEmpty())
dispatch.device_put_handlers[Empty] = lambda _, __: ()
dispatch.xla_result_handlers[AbstractEmpty] = lambda _, __: lambda: Empty(AbstractEmpty())
xla.xla_shape_handlers[AbstractEmpty] = lambda _: ()

View File

@ -27,7 +27,6 @@ from jax._src import dtypes
from jax import lax
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax.interpreters import xla
from jax.config import config
config.parse_flags_with_absl()
@ -166,7 +165,7 @@ class DtypesTest(jtu.JaxTestCase):
def testScalarInstantiation(self, scalar_type):
a = scalar_type(1)
self.assertEqual(a.dtype, jnp.dtype(scalar_type))
self.assertIsInstance(a, xla.DeviceArray)
self.assertIsInstance(a, jnp.DeviceArray)
self.assertEqual(0, jnp.ndim(a))
self.assertIsInstance(np.dtype(scalar_type).type(1), scalar_type)

View File

@ -35,7 +35,6 @@ from jax._src import test_util as jtu
from jax import tree_util
from jax._src.util import unzip2
from jax.experimental import maps
from jax.interpreters import xla
import jax.numpy as jnp # scan tests use numpy
import jax.scipy as jsp
@ -2759,7 +2758,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not DeviceArray
_, vjp_fun = jax.vjp(cumprod, x)
*_, ext_res = vjp_fun.args[0].args[0]
self.assertIsInstance(ext_res, xla.DeviceArray)
self.assertIsInstance(ext_res, jnp.DeviceArray)
def test_scan_vmap_collectives(self):
def scan_f(state, x):

View File

@ -39,9 +39,9 @@ import jax.ops
from jax import lax
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax._src import device_array
from jax._src import dtypes
from jax import tree_util
from jax.interpreters import xla
from jax.test_util import check_grads
from jax._src.util import prod, safe_zip
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc
@ -2483,7 +2483,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
expected_np_input_after_call = np.ones((1))
expected_jnp_input_after_call = jnp.ones((1))
self.assertTrue(xla.type_is_device_array(jnp.concatenate([np_input])))
self.assertTrue(device_array.type_is_device_array(jnp.concatenate([np_input])))
attempt_sideeffect(np_input)
attempt_sideeffect(jnp_input)
@ -3678,13 +3678,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
assert not np.isscalar(jnp.array(3))
def testArrayOutputsDeviceArrays(self):
assert xla.type_is_device_array(jnp.array([]))
assert xla.type_is_device_array(jnp.array(np.array([])))
assert device_array.type_is_device_array(jnp.array([]))
assert device_array.type_is_device_array(jnp.array(np.array([])))
class NDArrayLike:
def __array__(self, dtype=None):
return np.array([], dtype=dtype)
assert xla.type_is_device_array(jnp.array(NDArrayLike()))
assert device_array.type_is_device_array(jnp.array(NDArrayLike()))
# NOTE(mattjj): disabled b/c __array__ must produce ndarrays
# class DeviceArrayLike:

View File

@ -36,7 +36,6 @@ from jax._src import lax_reference
from jax.test_util import check_grads
import jax.util
from jax._src.util import prod
from jax import xla
from jax._src.lax.lax import _device_put_raw
@ -2566,7 +2565,7 @@ class LazyConstantTest(jtu.JaxTestCase):
if jit:
op = jax.jit(op)
result = op(input_type(value))
assert isinstance(result, xla.DeviceArray)
assert isinstance(result, jnp.DeviceArray)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype_in={}_dtype_out={}".format(

View File

@ -23,7 +23,6 @@ import jax.numpy as jnp
from jax import lax
from jax._src import test_util as jtu
from jax._src.lib import xla_bridge
from jax.interpreters import xla
from jax.config import config
config.parse_flags_with_absl()
@ -182,7 +181,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
devices = self.get_devices()
def f(): return lax.add(3., 4.)
self.assertIsInstance(f(), xla.DeviceArray)
self.assertIsInstance(f(), jnp.DeviceArray)
self.assert_uncommitted_to_device(f(), devices[0])
self.assert_uncommitted_to_device(jax.jit(f)(), devices[0])
self.assert_committed_to_device(jax.jit(f, device=devices[1])(),

View File

@ -40,6 +40,7 @@ from jax import random
from jax.core import ShapedArray
from jax import (pmap, soft_pmap, jit, vmap, jvp, grad, make_jaxpr,
linearize, device_put)
from jax._src import device_array
import jax._src.lib
from jax._src.lib import xla_bridge
from jax._src.util import prod, safe_map
@ -540,12 +541,12 @@ class PythonPmapTest(jtu.JaxTestCase):
y = f(x)
self.assertIsInstance(y, jnp.ndarray)
self.assertIsInstance(y, pxla.ShardedDeviceArray)
self.assertIsInstance(y, jax.interpreters.xla.DeviceArray)
self.assertIsInstance(y, device_array.DeviceArray)
self.assertNotIsInstance(y, np.ndarray)
self.assertAllClose(y, 2 * x, check_dtypes=False)
z = f(y)
self.assertIsInstance(z, pxla.ShardedDeviceArray)
self.assertIsInstance(z, jax.interpreters.xla.DeviceArray)
self.assertIsInstance(z, device_array.DeviceArray)
self.assertNotIsInstance(z, np.ndarray)
self.assertAllClose(z, 2 * 2 * x, check_dtypes=False)
@ -2325,7 +2326,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
sharded_x = pmap(lambda x: x)(x)
self.assertIsNone(sharded_x._npy_value)
for i in range(8):
self.assertIsInstance(sharded_x[i], jax.interpreters.xla.DeviceArray)
self.assertIsInstance(sharded_x[i], device_array.DeviceArray)
self.assertIsNone(sharded_x._npy_value)
def test_device_put_sharded_array(self):

View File

@ -16,7 +16,7 @@ from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax.interpreters import xla
from jax._src import dispatch
class XlaInterpreterTest(jtu.JaxTestCase):
@ -26,7 +26,7 @@ class XlaInterpreterTest(jtu.JaxTestCase):
return args[0]
closed_jaxpr = jax.make_jaxpr(f)(*range(10))
pruned_jaxpr, kept_const_idx, kept_var_idx = xla._prune_unused_inputs(
pruned_jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(
closed_jaxpr.jaxpr)
assert len(pruned_jaxpr.invars) == 1
assert kept_const_idx == set()