mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray. * jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations. * jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic). The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.) PiperOrigin-RevId: 411565432
This commit is contained in:
parent
34855def13
commit
d262bae88b
@ -56,6 +56,8 @@ from ..tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
|
||||
treedef_is_leaf, treedef_children, Partial, PyTreeDef)
|
||||
from .util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
extend_name_stack, wrap_name, cache, wraps, HashableFunction)
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import version
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
@ -83,6 +85,7 @@ from .._src.config import (flags, config, bool_env, disable_jit as _disable_jit,
|
||||
debug_infs as config_debug_infs,
|
||||
_thread_local_state as config_thread_local_state)
|
||||
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
AxisName = Any
|
||||
@ -127,7 +130,7 @@ def _nan_check_posthook(fun, args, kwargs, output):
|
||||
buffers.extend(da_or_sda.device_buffers)
|
||||
|
||||
try:
|
||||
xla.check_special(xla.xla_call_p, buffers)
|
||||
dispatch.check_special(xla.xla_call_p, buffers)
|
||||
except FloatingPointError:
|
||||
# compiled_fun can only raise in this case
|
||||
assert config.jax_debug_nans or config.jax_debug_infs
|
||||
@ -426,14 +429,14 @@ def _cpp_jit(
|
||||
# inspect the argument x, we actually do need to execute it and look at the
|
||||
# outputs that could be tracers (if f is capturing `Tracer` by closure).
|
||||
execute: Optional[functools.partial] = (
|
||||
xla._xla_callable.most_recent_entry())
|
||||
dispatch._xla_callable.most_recent_entry())
|
||||
use_fastpath = (
|
||||
# This is if we have already executed this code-path (most-recent entry
|
||||
# has been reset to None). Thus, we do not support the fast-path.
|
||||
execute is not None and
|
||||
execute.func is xla._execute_compiled and # not trivial, not pmap
|
||||
execute.func is dispatch._execute_compiled and # not trivial, not pmap
|
||||
# Not supported: ShardedDeviceArray
|
||||
all(xla.type_is_device_array(x) for x in out_flat))
|
||||
all(device_array.type_is_device_array(x) for x in out_flat))
|
||||
### If we can use the fastpath, we return required info to the caller.
|
||||
if use_fastpath:
|
||||
_, xla_executable, _, result_handlers, kept_var_idx = execute.args
|
||||
@ -493,7 +496,7 @@ class Lowered:
|
||||
in_tree: PyTreeDef
|
||||
out_tree: PyTreeDef
|
||||
donate_argnums: Tuple[int]
|
||||
_lowering: Union[xla.XlaComputation, pxla.MeshComputation]
|
||||
_lowering: Union[dispatch.XlaComputation, pxla.MeshComputation]
|
||||
_no_kwargs: bool
|
||||
|
||||
def __init__(self, lowering, in_tree, out_tree, donate_argnums,
|
||||
@ -529,7 +532,7 @@ class Compiled:
|
||||
in_tree: PyTreeDef
|
||||
out_tree: PyTreeDef
|
||||
donate_argnums: Tuple[int]
|
||||
_executable: Union[xla.XlaCompiledComputation, pxla.MeshExecutable]
|
||||
_executable: Union[dispatch.XlaCompiledComputation, pxla.MeshExecutable]
|
||||
_no_kwargs: bool
|
||||
|
||||
def __init__(self, executable, in_tree, out_tree, donate_argnums,
|
||||
@ -595,7 +598,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
|
||||
name = flat_fun.__name__
|
||||
arg_specs = unsafe_map(arg_spec, args_flat)
|
||||
computation = xla.lower_xla_callable(
|
||||
computation = dispatch.lower_xla_callable(
|
||||
flat_fun, device, backend, name, donated_invars, *arg_specs)
|
||||
return Lowered(computation, in_tree, out_tree(), donate_argnums)
|
||||
|
||||
@ -830,8 +833,8 @@ def xla_computation(fun: Callable,
|
||||
for axis_name, size in axis_env or []:
|
||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
|
||||
if out_parts is None:
|
||||
out_parts_flat = None
|
||||
else:
|
||||
@ -1990,7 +1993,7 @@ def _cpp_pmap(
|
||||
getattr(execute[0], "func", None) is pxla.execute_replicated and
|
||||
# No tracers in the outputs. Checking for ShardedDeviceArray should be
|
||||
# sufficient, but we use the more general `DeviceArray`.
|
||||
all(isinstance(x, xla.DeviceArray) for x in out_flat))
|
||||
all(isinstance(x, device_array.DeviceArray) for x in out_flat))
|
||||
### If we can use the fastpath, we return required info to the caller.
|
||||
if use_fastpath:
|
||||
xla_executable, backend_, in_handler, out_handler = execute[0].args
|
||||
@ -2552,7 +2555,7 @@ def device_put(x, device: Optional[xc.Device] = None):
|
||||
Returns:
|
||||
A copy of ``x`` that resides on ``device``.
|
||||
"""
|
||||
return tree_map(lambda y: xla.device_put_p.bind(y, device=device), x)
|
||||
return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
|
||||
|
||||
|
||||
def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]):
|
||||
@ -2617,7 +2620,8 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]):
|
||||
raise ValueError("the shards passed to device_put_sharded must have "
|
||||
f"consistent shape and dtype, but got {a1} and {a2}.")
|
||||
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
|
||||
buffers = [buf for x, d in zip(xs, devices) for buf in xla.device_put(x, d)]
|
||||
buffers = [buf for x, d in zip(xs, devices)
|
||||
for buf in dispatch.device_put(x, d)]
|
||||
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
|
||||
|
||||
return tree_multimap(_device_put_sharded, *shards)
|
||||
@ -2660,7 +2664,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]):
|
||||
core.raise_to_shaped(core.get_aval(x)))
|
||||
assert (isinstance(aval, core.ShapedArray) and
|
||||
len(xla.aval_to_xla_shapes(aval)) == 1)
|
||||
buf, = xla.device_put(x, devices[0])
|
||||
buf, = dispatch.device_put(x, devices[0])
|
||||
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
|
||||
return pxla.make_sharded_device_array(aval, None, [buf, *rest_bufs])
|
||||
return tree_map(_device_put_replicated, x)
|
||||
|
307
jax/_src/device_array.py
Normal file
307
jax/_src/device_array.py
Normal file
@ -0,0 +1,307 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# On-device arrays.
|
||||
|
||||
from functools import partial, partialmethod
|
||||
import operator
|
||||
from typing import (Any, Optional, Union)
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax._src.config import config
|
||||
from jax._src import dtypes
|
||||
from jax._src.lib import xla_client as xc
|
||||
import jax._src.util as util
|
||||
|
||||
### device-persistent data
|
||||
|
||||
xe = xc._xla
|
||||
|
||||
Device = xc.Device
|
||||
Buffer = xe.Buffer
|
||||
|
||||
|
||||
def _forward_method(attrname, self, fun, *args):
|
||||
return fun(getattr(self, attrname), *args)
|
||||
_forward_to_value = partial(_forward_method, "_value")
|
||||
|
||||
|
||||
# The following is used for the type xc.Buffer or _DeviceArray.
|
||||
DeviceArrayProtocol = Any
|
||||
DeviceArray = xc.DeviceArrayBase
|
||||
|
||||
|
||||
def make_device_array(
|
||||
aval: core.ShapedArray,
|
||||
device: Optional[Device],
|
||||
device_buffer: Buffer,
|
||||
) -> Union[Buffer, "_DeviceArray"]:
|
||||
"""Returns a DeviceArray implementation based on arguments.
|
||||
|
||||
This is to be used only within JAX. It will return either a PythonDeviceArray
|
||||
or a C++ equivalent implementation.
|
||||
"""
|
||||
if isinstance(device_buffer, xc.Buffer):
|
||||
|
||||
if device_buffer.aval == aval and device_buffer._device == device:
|
||||
return device_buffer
|
||||
device_buffer = device_buffer.clone()
|
||||
device_buffer._device = device
|
||||
device_buffer.aval = aval
|
||||
device_buffer.weak_type = aval.weak_type
|
||||
return device_buffer
|
||||
|
||||
return _DeviceArray(aval, device, device_buffer)
|
||||
|
||||
|
||||
def type_is_device_array(x):
|
||||
"""Returns `True` if `x` is a non-sharded DeviceArray.
|
||||
|
||||
Use this function instead of `type(x) is Devicearray`.
|
||||
"""
|
||||
type_x = type(x)
|
||||
return type_x is _DeviceArray or type_x is xc.Buffer
|
||||
|
||||
|
||||
def device_array_supports_weakrefs():
|
||||
try:
|
||||
weakref.ref(DeviceArray())
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
class _DeviceArray(DeviceArray): # type: ignore
|
||||
"""A DeviceArray is an ndarray backed by a single device memory buffer."""
|
||||
# We don't subclass ndarray because that would open up a host of issues,
|
||||
# but lax_numpy.py overrides isinstance behavior and attaches ndarray methods.
|
||||
__slots__ = [
|
||||
"aval", "device_buffer", "_npy_value", "_device", "__weakref__"
|
||||
]
|
||||
__array_priority__ = 100
|
||||
|
||||
# DeviceArray has methods that are dynamically populated in lax_numpy.py,
|
||||
# and this annotation is needed to make pytype happy.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
|
||||
def __init__(self, aval: core.ShapedArray, device: Optional[Device],
|
||||
device_buffer: Buffer):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
aval: The abstract value associated to this array (shape+dtype+weak_type).
|
||||
device: The optional sticky device. See
|
||||
https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
|
||||
device_buffer: The underlying buffer owning the on-device data.
|
||||
"""
|
||||
DeviceArray.__init__(self)
|
||||
self.aval = aval
|
||||
self.device_buffer = device_buffer
|
||||
self._device = device
|
||||
|
||||
self._npy_value = None
|
||||
if config.jax_enable_checks:
|
||||
assert type(aval) is core.ShapedArray
|
||||
npy_value = self._value
|
||||
assert npy_value.dtype == aval.dtype and npy_value.shape == aval.shape, (
|
||||
aval, npy_value.shape, npy_value.dtype)
|
||||
assert (device is None) or device is device_buffer.device()
|
||||
|
||||
def _check_if_deleted(self):
|
||||
if self.device_buffer is deleted_buffer:
|
||||
raise RuntimeError("DeviceArray has been deleted.")
|
||||
|
||||
def block_until_ready(self):
|
||||
"""Blocks the caller until the buffer's value has been computed on device.
|
||||
|
||||
This method is mostly useful for timing microbenchmarks that wish to
|
||||
time how long a computation takes, without transferring the result back
|
||||
to the host.
|
||||
|
||||
Returns the buffer object (`self`).
|
||||
"""
|
||||
self._check_if_deleted()
|
||||
self.device_buffer.block_host_until_ready() # pytype: disable=attribute-error
|
||||
return self
|
||||
|
||||
@property
|
||||
def _value(self):
|
||||
self._check_if_deleted()
|
||||
if self._npy_value is None:
|
||||
self._npy_value = self.device_buffer.to_py() # pytype: disable=attribute-error # bind-properties
|
||||
self._npy_value.flags.writeable = False
|
||||
return self._npy_value
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.aval.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.aval.dtype
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return util.prod(self.aval.shape)
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return len(self.aval.shape)
|
||||
|
||||
def copy_to_host_async(self):
|
||||
"""Requests a copy of the buffer to the host."""
|
||||
self._check_if_deleted()
|
||||
if self._npy_value is None:
|
||||
self.device_buffer.copy_to_host_async() # pytype: disable=attribute-error
|
||||
|
||||
def delete(self):
|
||||
"""Deletes the device array and any cached copy on the host.
|
||||
|
||||
It is an error to access the contents of a `DeviceArray` after it has
|
||||
been deleted.
|
||||
|
||||
Use of this method is optional; device buffers will be reclaimed
|
||||
automatically by Python when a DeviceArray object is garbage collected.
|
||||
However, it is sometimes useful to have more explicit control over the
|
||||
time of deletion.
|
||||
"""
|
||||
self.device_buffer.delete() # pytype: disable=attribute-error
|
||||
self.device_buffer = deleted_buffer
|
||||
self._npy_value = None
|
||||
|
||||
@property
|
||||
def __cuda_array_interface__(self):
|
||||
return self.device_buffer.__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties
|
||||
|
||||
|
||||
# Adding methods dynamically to both _DeviceArray and xc.Buffer
|
||||
# pylint: disable=protected-access
|
||||
for device_array in [DeviceArray]:
|
||||
|
||||
|
||||
def copy(self):
|
||||
"""Returns an ndarray (backed by host memory, not device memory)."""
|
||||
return np.asarray(self)
|
||||
setattr(device_array, "copy", copy)
|
||||
|
||||
def __repr__(self):
|
||||
line_width = np.get_printoptions()["linewidth"]
|
||||
prefix = '{}('.format(self.__class__.__name__.lstrip('_'))
|
||||
s = np.array2string(self._value, prefix=prefix, suffix=',',
|
||||
separator=', ', max_line_width=line_width)
|
||||
if self.aval is not None and self.aval.weak_type:
|
||||
dtype_str = f'dtype={self.dtype.name}, weak_type=True)'
|
||||
else:
|
||||
dtype_str = f'dtype={self.dtype.name})'
|
||||
last_line_len = len(s) - s.rfind('\n') + 1
|
||||
sep = ' '
|
||||
if last_line_len + len(dtype_str) + 1 > line_width:
|
||||
sep = ' ' * len(prefix)
|
||||
return "{}{},{}{}".format(prefix, s, sep, dtype_str)
|
||||
|
||||
setattr(device_array, "__repr__", __repr__)
|
||||
|
||||
def item(self):
|
||||
if dtypes.issubdtype(self.dtype, np.complexfloating):
|
||||
return complex(self)
|
||||
elif dtypes.issubdtype(self.dtype, np.floating):
|
||||
return float(self)
|
||||
elif dtypes.issubdtype(self.dtype, np.integer):
|
||||
return int(self)
|
||||
elif dtypes.issubdtype(self.dtype, np.bool_):
|
||||
return bool(self)
|
||||
else:
|
||||
raise TypeError(self.dtype)
|
||||
|
||||
setattr(device_array, "item", item)
|
||||
|
||||
def __len__(self):
|
||||
try:
|
||||
return self.aval.shape[0]
|
||||
except IndexError as err:
|
||||
raise TypeError("len() of unsized object") from err # same as numpy error
|
||||
|
||||
setattr(device_array, "__len__", __len__)
|
||||
|
||||
def __iter__(self):
|
||||
if self.ndim == 0:
|
||||
raise TypeError("iteration over a 0-d array") # same as numpy error
|
||||
else:
|
||||
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())
|
||||
|
||||
setattr(device_array, "__iter__", __iter__)
|
||||
|
||||
def __reversed__(self):
|
||||
return iter(self[::-1])
|
||||
|
||||
setattr(device_array, "__reversed__", __reversed__)
|
||||
|
||||
def __format__(self, format_spec):
|
||||
# Simulates behavior of https://github.com/numpy/numpy/pull/9883
|
||||
if self.ndim == 0:
|
||||
return format(self._value[()], format_spec)
|
||||
else:
|
||||
return format(self._value, format_spec)
|
||||
|
||||
setattr(device_array, "__format__", __format__)
|
||||
|
||||
def __array__(self, dtype=None, context=None):
|
||||
return np.asarray(self._value, dtype=dtype)
|
||||
|
||||
setattr(device_array, "__array__", __array__)
|
||||
|
||||
setattr(device_array, "__str__", partialmethod(_forward_to_value, str))
|
||||
setattr(device_array, "__bool__", partialmethod(_forward_to_value, bool))
|
||||
setattr(device_array, "__nonzero__", partialmethod(_forward_to_value, bool))
|
||||
setattr(device_array, "__float__", lambda self: self._value.__float__())
|
||||
setattr(device_array, "__int__", lambda self: self._value.__int__())
|
||||
setattr(device_array, "__complex__", lambda self: self._value.__complex__())
|
||||
setattr(device_array, "__hex__", partialmethod(_forward_to_value, hex))
|
||||
setattr(device_array, "__oct__", partialmethod(_forward_to_value, oct))
|
||||
setattr(device_array, "__index__", partialmethod(_forward_to_value,
|
||||
operator.index))
|
||||
to_bytes = lambda self, order="C": self._value.tobytes(order)
|
||||
setattr(device_array, "tobytes", to_bytes)
|
||||
del to_bytes
|
||||
setattr(device_array, "tolist", lambda self: self._value.tolist())
|
||||
|
||||
# pickle saves and loads just like an ndarray
|
||||
setattr(device_array, "__reduce__",
|
||||
partialmethod(_forward_to_value, operator.methodcaller("__reduce__")))
|
||||
|
||||
# explicitly set to be unhashable.
|
||||
setattr(device_array, "__hash__", None)
|
||||
|
||||
# clobbered when jax.numpy is imported, but useful in tests
|
||||
setattr(device_array, "__eq__", lambda self, other: self._value == other)
|
||||
|
||||
# The following methods are dynamically overridden in lax_numpy.py.
|
||||
def raise_not_implemented():
|
||||
raise NotImplementedError
|
||||
|
||||
setattr(device_array, "__getitem__", lambda self, i: raise_not_implemented())
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class DeletedBuffer(object): pass
|
||||
deleted_buffer = DeletedBuffer()
|
||||
|
||||
|
||||
device_array_types = [xc.Buffer, _DeviceArray]
|
||||
for _device_array in device_array_types:
|
||||
core.literalable_types.add(_device_array)
|
||||
core.pytype_aval_mappings[device_array] = core.ConcreteArray
|
671
jax/_src/dispatch.py
Normal file
671
jax/_src/dispatch.py
Normal file
@ -0,0 +1,671 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Primitive dispatch and jit dispatch.
|
||||
|
||||
from functools import partial
|
||||
import itertools
|
||||
from typing import (
|
||||
Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union)
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
import jax.interpreters.ad as ad
|
||||
import jax.interpreters.batching as batching
|
||||
import jax.interpreters.masking as masking
|
||||
import jax.interpreters.xla as xla
|
||||
import jax.interpreters.partial_eval as pe
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
import jax._src.util as util
|
||||
from jax._src import traceback_util
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
xe = xc._xla
|
||||
|
||||
Backend = xe.Client
|
||||
Device = xc.Device
|
||||
Buffer = xe.Buffer
|
||||
|
||||
XlaExecutable = xc.Executable
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
# This flag is set on exit; no logging should be attempted
|
||||
_on_exit = False
|
||||
|
||||
### op-by-op execution
|
||||
|
||||
ArgSpec = Tuple[core.AbstractValue, Optional[Device]]
|
||||
|
||||
def arg_spec(x: Any) -> ArgSpec:
|
||||
aval = xla.abstractify(x)
|
||||
try:
|
||||
return aval, x._device
|
||||
except:
|
||||
return aval, None
|
||||
|
||||
|
||||
def apply_primitive(prim, *args, **params):
|
||||
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
|
||||
if config.jax_enable_mlir:
|
||||
import jax.interpreters.mlir
|
||||
return jax.interpreters.mlir.apply_primitive(prim, *args, **params)
|
||||
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
|
||||
**params)
|
||||
return compiled_fun(*args)
|
||||
|
||||
# TODO(phawkins): update code referring to xla.apply_primitive to point here.
|
||||
xla.apply_primitive = apply_primitive
|
||||
|
||||
|
||||
@util.cache()
|
||||
def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
|
||||
avals, arg_devices = util.unzip2(arg_specs)
|
||||
donated_invars = (False,) * len(arg_specs)
|
||||
device = _device_from_arg_devices(arg_devices)
|
||||
def prim_fun(*args):
|
||||
out = prim.bind(*args, **params)
|
||||
if prim.multiple_results:
|
||||
return out
|
||||
else:
|
||||
return out,
|
||||
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
|
||||
prim.name, donated_invars, *arg_specs)
|
||||
if not prim.multiple_results:
|
||||
return lambda *args, **kw: compiled(*args, **kw)[0]
|
||||
else:
|
||||
return compiled
|
||||
|
||||
|
||||
def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[Device]:
|
||||
"""Given devices of inputs, determine where to perform a computation.
|
||||
|
||||
Args:
|
||||
devices: list where each element is a either a `Device` instance or `None`.
|
||||
Returns:
|
||||
A `Device` instance or None.
|
||||
Raises:
|
||||
ValueError if input devices are inconsistent.
|
||||
"""
|
||||
try:
|
||||
device, = {d for d in devices if d is not None} or (None,)
|
||||
return device
|
||||
except ValueError as err:
|
||||
msg = "primitive arguments must be colocated on the same device, got {}"
|
||||
raise ValueError(msg.format(", ".join(map(str, devices)))) from err
|
||||
|
||||
|
||||
# JIT execution
|
||||
|
||||
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
|
||||
donated_invars, inline):
|
||||
if config.jax_enable_mlir:
|
||||
import jax.interpreters.mlir
|
||||
return jax.interpreters.mlir._xla_call_impl_mlir(
|
||||
fun, *args, device=device, backend=backend, name=name,
|
||||
donated_invars=donated_invars, inline=inline)
|
||||
|
||||
del inline # Only used at tracing time
|
||||
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
|
||||
*unsafe_map(arg_spec, args))
|
||||
try:
|
||||
out = compiled_fun(*args)
|
||||
except FloatingPointError:
|
||||
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
|
||||
print("Invalid value encountered in the output of a jit/pmap-ed function. "
|
||||
"Calling the de-optimized version.")
|
||||
# We want to run the wrapped function again (after _xla_callable already ran
|
||||
# it), but linear_util.WrappedFun instances are meant to be run only once.
|
||||
# In addition to re-executing the Python code, which is usually undesirable
|
||||
# but which config.jax_debug_nans is meant to opt into, we'll be re-executing
|
||||
# any linear_util.py-style side effects, i.e. re-populating Stores created
|
||||
# by any transformation_with_aux's applied to fun. Since this is
|
||||
# intentional here, to avoid "Store occupied" errors we clone the WrappedFun
|
||||
# with empty stores.
|
||||
stores = [lu.Store() for _ in fun.stores]
|
||||
clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params)
|
||||
with core.new_sublevel():
|
||||
_ = clone.call_wrapped(*args) # probably won't return
|
||||
return out
|
||||
|
||||
xla.xla_call_p.def_impl(_xla_call_impl)
|
||||
|
||||
|
||||
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
|
||||
donated_invars, *arg_specs):
|
||||
return lower_xla_callable(fun, device, backend, name, donated_invars,
|
||||
*arg_specs).compile().unsafe_call
|
||||
|
||||
_xla_callable = lu.cache(_xla_callable_uncached)
|
||||
|
||||
|
||||
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
donated_invars, *arg_specs):
|
||||
if device is not None and backend is not None:
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
"got device={} and backend={}".format(device, backend))
|
||||
|
||||
abstract_args, arg_devices = util.unzip2(arg_specs)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, abstract_args, pe.debug_info_final(fun, "jit"))
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise UnexpectedTracerError("Encountered an unexpected tracer.")
|
||||
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
|
||||
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
||||
pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx)
|
||||
abstract_args, arg_devices = util.unzip2(pruned_arg_specs)
|
||||
donated_invars = [
|
||||
x for i, x in enumerate(donated_invars) if i in kept_var_idx
|
||||
]
|
||||
map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
|
||||
jaxpr = apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
nreps = jaxpr_replicas(jaxpr)
|
||||
device = _xla_callable_device(nreps, backend, device, arg_devices)
|
||||
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
|
||||
|
||||
# Computations that only produce constants and/or only rearrange their inputs,
|
||||
# which are often produced from partial evaluation, don't need compilation,
|
||||
# and don't need to evaluate their arguments.
|
||||
if not jaxpr.eqns:
|
||||
return XlaComputation(
|
||||
name, None, True, None, jaxpr, consts, device, abstract_args, out_avals,
|
||||
kept_var_idx)
|
||||
|
||||
if not _on_exit:
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority, "Compiling %s (%s) for args %s.",
|
||||
fun.__name__, id(fun), abstract_args)
|
||||
|
||||
if nreps > 1:
|
||||
warnings.warn(
|
||||
f"The jitted function {name} includes a pmap. Using "
|
||||
"jit-of-pmap can lead to inefficient data movement, as the outer jit "
|
||||
"does not preserve sharded data representations and instead collects "
|
||||
"input and output arrays onto a single device. "
|
||||
"Consider removing the outer jit unless you know what you're doing. "
|
||||
"See https://github.com/google/jax/issues/2926.")
|
||||
|
||||
if nreps > xb.device_count(backend):
|
||||
raise ValueError(
|
||||
f"compiling computation `{name}` that requires {nreps} replicas, but "
|
||||
f"only {xb.device_count(backend)} XLA devices are available.")
|
||||
|
||||
if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
|
||||
raise NotImplementedError(
|
||||
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
|
||||
"extra data movement anyway, so maybe you don't want it after all).")
|
||||
|
||||
tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU
|
||||
|
||||
c = xc.XlaBuilder(f"jit_{fun.__name__}")
|
||||
xla_consts = xla._xla_consts(c, consts)
|
||||
xla_args, donated_invars = xla._xla_callable_args(c, abstract_args, tuple_args,
|
||||
donated_invars=donated_invars)
|
||||
platform = backend.platform
|
||||
ctx = xla.TranslationContext(c, platform, xla.AxisEnv(nreps, (), ()),
|
||||
xla.extend_name_stack(xla.wrap_name(name, 'jit')))
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
# There is a non-zero cost to building an output tuple, particularly on TPU.
|
||||
# Avoid it if the output arity is 1.
|
||||
output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes)
|
||||
if platform in ("gpu", "tpu"):
|
||||
donated_invars = xla.set_up_aliases(
|
||||
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
|
||||
if any(donated_invars):
|
||||
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
|
||||
unused_donations = [str(c.GetShape(a))
|
||||
for a, d in zip(xla_args, donated_invars) if d]
|
||||
warnings.warn("Some donated buffers were not usable: {}".format(
|
||||
", ".join(unused_donations)))
|
||||
built = c.build(output)
|
||||
return XlaComputation(
|
||||
name, built, False, donated_invars, nreps, device, backend, tuple_args,
|
||||
abstract_args, out_avals, kept_var_idx)
|
||||
|
||||
|
||||
def prefetch(x):
|
||||
if isinstance(x, device_array.DeviceArray):
|
||||
x.copy_to_host_async()
|
||||
return x
|
||||
|
||||
|
||||
def jaxpr_literals(jaxpr):
|
||||
"""Generates all the literals inside a jaxpr, including nested subjaxprs."""
|
||||
for eqn in jaxpr.eqns:
|
||||
for v in eqn.invars:
|
||||
if type(v) is core.Literal:
|
||||
yield v.val
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
yield from jaxpr_literals(subjaxpr)
|
||||
|
||||
|
||||
def jaxpr_has_pmap(jaxpr):
|
||||
"""Whether there is an xla_pmap primitive anywhere inside a Jaxpr."""
|
||||
for eqn in jaxpr.eqns:
|
||||
if 'xla_pmap' in eqn.primitive.name:
|
||||
return True
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
if jaxpr_has_pmap(subjaxpr):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _prune_unused_inputs(
|
||||
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
|
||||
used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
|
||||
# TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive
|
||||
# applications that do not produce used outputs. Must handle side-effecting
|
||||
# primitives and nested jaxpr.
|
||||
used.update(
|
||||
v for eqn in jaxpr.eqns for v in eqn.invars if isinstance(v, core.Var))
|
||||
kept_const_idx, new_constvars = util.unzip2(
|
||||
(i, v) for i, v in enumerate(jaxpr.constvars) if v in used)
|
||||
kept_var_idx, new_invars = util.unzip2(
|
||||
(i, v) for i, v in enumerate(jaxpr.invars) if v in used)
|
||||
new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns)
|
||||
return new_jaxpr, set(kept_const_idx), set(kept_var_idx)
|
||||
|
||||
|
||||
# We can optionally set a Jaxpr rewriter that can be applied just before
|
||||
# compilation. This mechanism is used for compiling id_tap, we can
|
||||
# remove it once we bring the id_tap implementation into the core.
|
||||
outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None
|
||||
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
if outfeed_rewriter is not None:
|
||||
return outfeed_rewriter(jaxpr)
|
||||
else:
|
||||
return jaxpr
|
||||
|
||||
outfeed_primitives: Set[core.Primitive] = set()
|
||||
def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool:
|
||||
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
|
||||
return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
|
||||
for eqn in jaxpr.eqns)
|
||||
|
||||
def _param_uses_outfeed(param):
|
||||
if type(param) is core.Jaxpr:
|
||||
if jaxpr_uses_outfeed(param):
|
||||
return True
|
||||
elif type(param) is core.ClosedJaxpr:
|
||||
if jaxpr_uses_outfeed(param.jaxpr):
|
||||
return True
|
||||
return False
|
||||
|
||||
def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool:
|
||||
if prim in outfeed_primitives:
|
||||
return True
|
||||
for param in params.values():
|
||||
if isinstance(param, tuple):
|
||||
if any(unsafe_map(_param_uses_outfeed, param)):
|
||||
return True
|
||||
elif _param_uses_outfeed(param):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def jaxpr_replicas(jaxpr) -> int:
|
||||
"""The number of replicas needed for a jaxpr.
|
||||
|
||||
For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the
|
||||
subjaxprs. For a list of eqns, take the maximum number of replicas.
|
||||
"""
|
||||
if isinstance(jaxpr, core.ClosedJaxpr):
|
||||
jaxpr = jaxpr.jaxpr
|
||||
return max(unsafe_map(eqn_replicas, jaxpr.eqns), default=1)
|
||||
|
||||
# TODO(mattjj): this function assumes that only pmap has a parameter named
|
||||
# axis_size, and that it corresponds to cross-replica mapping
|
||||
def eqn_replicas(eqn):
|
||||
call_jaxpr = eqn.params.get("call_jaxpr")
|
||||
if call_jaxpr:
|
||||
return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr)
|
||||
elif eqn.primitive in xla._initial_style_primitives:
|
||||
return initial_style_primitive_replicas(eqn.params)
|
||||
else:
|
||||
return 1
|
||||
|
||||
def initial_style_primitive_replicas(params):
|
||||
return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(), default=1)
|
||||
|
||||
|
||||
def _xla_callable_device(nreps, backend, device, arg_devices):
|
||||
if nreps > 1:
|
||||
if device is not None or backend is not None:
|
||||
raise ValueError(f"can't specify device or backend for jit-of-pmap, "
|
||||
f"got device={device} and backend={backend}")
|
||||
return None
|
||||
else:
|
||||
if device is None and backend is None:
|
||||
return _device_from_arg_devices(arg_devices)
|
||||
elif device is not None and backend is None:
|
||||
return device
|
||||
elif device is None and backend is not None:
|
||||
return xb.get_backend(backend).get_default_device_assignment(1)[0]
|
||||
else:
|
||||
assert False # Unreachable given the error check in _xla_callable
|
||||
|
||||
|
||||
# Result handlers
|
||||
|
||||
def aval_to_result_handler(device: Optional[Device],
|
||||
aval: core.AbstractValue) -> Callable:
|
||||
try:
|
||||
return xla_result_handlers[type(aval)](device, aval)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No xla_result_handler for type: {type(aval)}") from err
|
||||
|
||||
def array_result_handler(device: Optional[Device], aval: core.ShapedArray):
|
||||
if aval.dtype is dtypes.float0:
|
||||
return lambda _: np.zeros(aval.shape, dtypes.float0)
|
||||
return partial(device_array.make_device_array, core.raise_to_shaped(aval), device)
|
||||
|
||||
|
||||
xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Callable]] = {
|
||||
core.AbstractUnit: lambda _, __: lambda _: core.unit,
|
||||
core.ShapedArray: array_result_handler,
|
||||
core.ConcreteArray: array_result_handler,
|
||||
}
|
||||
xla_result_handlers[core.AbstractToken] = lambda _, __: lambda _: core.token
|
||||
|
||||
|
||||
def needs_check_special():
|
||||
return config.jax_debug_infs or config.jax_debug_nans
|
||||
|
||||
def check_special(name, bufs):
|
||||
if needs_check_special():
|
||||
for buf in bufs:
|
||||
_check_special(name, buf.xla_shape(), buf)
|
||||
|
||||
def _check_special(name, xla_shape, buf):
|
||||
assert not xla_shape.is_tuple()
|
||||
if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
|
||||
if config.jax_debug_nans and np.any(np.isnan(buf.to_py())):
|
||||
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
|
||||
if config.jax_debug_infs and np.any(np.isinf(buf.to_py())):
|
||||
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
||||
|
||||
|
||||
def _execute_compiled(name: str, compiled: XlaExecutable,
|
||||
output_buffer_counts: Optional[Sequence[int]], handlers,
|
||||
kept_var_idx, *args):
|
||||
device, = compiled.local_devices()
|
||||
input_bufs = list(
|
||||
itertools.chain.from_iterable(
|
||||
device_put(x, device)
|
||||
for i, x in enumerate(args)
|
||||
if x is not core.token and i in kept_var_idx))
|
||||
out_bufs = compiled.execute(input_bufs)
|
||||
check_special(name, out_bufs)
|
||||
if output_buffer_counts is None:
|
||||
return (handlers[0](*out_bufs),)
|
||||
return tuple(
|
||||
handler(*bs) for handler, bs in
|
||||
unsafe_zip(handlers, xla._partition_outputs(output_buffer_counts, out_bufs)))
|
||||
|
||||
|
||||
def _execute_replicated(name: str, compiled: XlaExecutable,
|
||||
output_buffer_counts: Optional[Sequence[int]], handlers,
|
||||
kept_var_idx, *args):
|
||||
input_bufs = [
|
||||
list(
|
||||
itertools.chain.from_iterable(
|
||||
device_put(x, device)
|
||||
for i, x in enumerate(args)
|
||||
if x is not core.token and i in kept_var_idx))
|
||||
for device in compiled.local_devices()
|
||||
]
|
||||
out_bufs = [
|
||||
buf[0] for buf in compiled.execute_sharded_on_local_devices(
|
||||
list(zip(*input_bufs)))
|
||||
]
|
||||
check_special(name, out_bufs)
|
||||
if output_buffer_counts is None:
|
||||
return (handlers[0](*out_bufs),)
|
||||
return tuple(
|
||||
handler(*bs) for handler, bs in
|
||||
unsafe_zip(handlers, xla._partition_outputs(output_buffer_counts, out_bufs)))
|
||||
|
||||
|
||||
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
|
||||
kept_var_idx, *args):
|
||||
env = {core.unitvar: core.unit}
|
||||
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
|
||||
map(env.setdefault, jaxpr.invars, pruned_args)
|
||||
map(env.setdefault, jaxpr.constvars, consts)
|
||||
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
|
||||
for v in jaxpr.outvars]
|
||||
return [_copy_device_array_to_device(x, device) if device_array.type_is_device_array(x)
|
||||
else h(*device_put(x, device)) for h, x in zip(handlers, outs)]
|
||||
|
||||
|
||||
class XlaComputation:
|
||||
name: str
|
||||
_is_trivial: bool
|
||||
_executable: Optional['XlaCompiledComputation']
|
||||
_donated_invars: Optional[Sequence[bool]]
|
||||
|
||||
def __init__(self, name: str, hlo, is_trivial: bool,
|
||||
donated_invars: Optional[Sequence[bool]], *compile_args):
|
||||
self.name = name
|
||||
self._hlo = hlo
|
||||
self._is_trivial = is_trivial
|
||||
self._donated_invars = donated_invars
|
||||
self._executable = None
|
||||
self.compile_args = compile_args
|
||||
|
||||
def is_trivial(self):
|
||||
return self._is_trivial
|
||||
|
||||
def hlo(self):
|
||||
if self.is_trivial():
|
||||
raise ValueError("A trivial computation has no HLO")
|
||||
return self._hlo
|
||||
|
||||
def compile(self) -> 'XlaCompiledComputation':
|
||||
if self._executable is None:
|
||||
if self.is_trivial():
|
||||
self._executable = XlaCompiledComputation.from_trivial_jaxpr(
|
||||
*self.compile_args)
|
||||
else:
|
||||
self._executable = XlaCompiledComputation.from_xla_computation(
|
||||
self.name, self.hlo(), *self.compile_args)
|
||||
return self._executable
|
||||
|
||||
def backend_compile(backend, built_c, options):
|
||||
# we use a separate function call to ensure that XLA compilation appears
|
||||
# separately in Python profiling results
|
||||
return backend.compile(built_c, compile_options=options)
|
||||
|
||||
# TODO(phawkins): update users.
|
||||
xla.backend_compile = backend_compile
|
||||
|
||||
def compile_or_get_cached(backend, computation, compile_options):
|
||||
# Avoid import cycle between jax and jax.experimental
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
# Persistent compilation cache only implemented on TPU.
|
||||
# TODO(skye): add warning when initializing cache on unsupported default platform
|
||||
if cc.is_initialized() and backend.platform == 'tpu':
|
||||
cached_executable = cc.get_executable(computation, compile_options, backend)
|
||||
if cached_executable is not None:
|
||||
logging.info('Persistent compilation cache hit')
|
||||
return cached_executable
|
||||
else:
|
||||
compiled = backend_compile(backend, computation, compile_options)
|
||||
cc.put_executable(computation, compile_options, compiled, backend)
|
||||
return compiled
|
||||
return backend_compile(backend, computation, compile_options)
|
||||
|
||||
|
||||
class XlaCompiledComputation:
|
||||
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call):
|
||||
self._xla_executable = xla_executable
|
||||
self.in_avals = in_avals
|
||||
self._kept_var_idx = kept_var_idx
|
||||
self.unsafe_call = unsafe_call
|
||||
|
||||
@staticmethod
|
||||
def from_xla_computation(
|
||||
name: str,
|
||||
xla_computation,
|
||||
nreps: int,
|
||||
device,
|
||||
backend,
|
||||
tuple_args: bool,
|
||||
in_avals,
|
||||
out_avals,
|
||||
kept_var_idx) -> 'XlaCompiledComputation':
|
||||
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
||||
options = xb.get_compile_options(
|
||||
num_replicas=nreps,
|
||||
num_partitions=1,
|
||||
device_assignment=(device.id,) if device else None)
|
||||
options.parameter_is_tupled_arguments = tuple_args
|
||||
compiled = compile_or_get_cached(backend, xla_computation, options)
|
||||
buffer_counts = (None if len(out_avals) == 1 else
|
||||
[len(xla.aval_to_xla_shapes(aval)) for aval in out_avals])
|
||||
execute = _execute_compiled if nreps == 1 else _execute_replicated
|
||||
unsafe_call = partial(execute, name, compiled, buffer_counts,
|
||||
result_handlers, kept_var_idx)
|
||||
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call)
|
||||
|
||||
def is_trivial(self):
|
||||
return self._xla_executable == None
|
||||
|
||||
def xla_executable(self):
|
||||
if self.is_trivial():
|
||||
raise ValueError("A trivial compiled computation has no XLA executable")
|
||||
return self._xla_executable
|
||||
|
||||
@staticmethod
|
||||
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals,
|
||||
kept_var_idx) -> 'XlaCompiledComputation':
|
||||
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
||||
unsafe_call = partial(_execute_trivial, jaxpr, device, consts,
|
||||
out_avals, result_handlers, kept_var_idx)
|
||||
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call)
|
||||
|
||||
def call(self, *args):
|
||||
arg_specs = unsafe_map(arg_spec, args)
|
||||
arg_avals = [spec[0] for i, spec in enumerate(arg_specs)
|
||||
if i in self._kept_var_idx]
|
||||
check_arg_avals_for_call(self.in_avals, arg_avals)
|
||||
return self.unsafe_call(*args)
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals):
|
||||
if len(ref_avals) != len(arg_avals):
|
||||
raise TypeError(
|
||||
f"Computation compiled for {len(ref_avals)} inputs "
|
||||
f"but called with {len(arg_avals)}")
|
||||
for ref_aval, arg_aval in zip(ref_avals, arg_avals):
|
||||
if not core.typematch(ref_aval, arg_aval):
|
||||
ref_avals_fmt = ', '.join(str(a) for a in ref_avals)
|
||||
arg_avals_fmt = ', '.join(str(a) for a in arg_avals)
|
||||
raise TypeError(
|
||||
f"Computation compiled for input types:\n {ref_avals_fmt}\n"
|
||||
f"called with:\n {arg_avals_fmt}")
|
||||
|
||||
|
||||
def device_put(x, device: Optional[Device] = None) -> Tuple[Any]:
|
||||
x = xla.canonicalize_dtype(x)
|
||||
try:
|
||||
return device_put_handlers[type(x)](x, device)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No device_put handler for type: {type(x)}") from err
|
||||
|
||||
# TODO(phawkins): update users.
|
||||
xla.device_put = device_put
|
||||
|
||||
def _device_put_array(x, device: Optional[Device]):
|
||||
backend = xb.get_device_backend(device)
|
||||
if x.dtype is dtypes.float0:
|
||||
x = np.zeros(x.shape, dtype=np.dtype(bool))
|
||||
return (backend.buffer_from_pyval(x, device),)
|
||||
|
||||
def _device_put_scalar(x, device):
|
||||
return _device_put_array(dtypes.coerce_to_array(x), device)
|
||||
|
||||
def _device_put_unit(_, device):
|
||||
backend = xb.get_device_backend(device)
|
||||
return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')),
|
||||
device),)
|
||||
|
||||
_scalar_types = dtypes.python_scalar_dtypes.keys()
|
||||
|
||||
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = {
|
||||
core.Unit: _device_put_unit
|
||||
}
|
||||
device_put_handlers.update((t, _device_put_array) for t in array_types)
|
||||
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
|
||||
device_put_handlers[core.Token] = lambda x, _: (x,)
|
||||
|
||||
|
||||
def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[Device]):
|
||||
x = _copy_device_array_to_device(x, device)
|
||||
return (x.device_buffer,)
|
||||
for t in device_array.device_array_types:
|
||||
device_put_handlers[t] = _device_put_device_array
|
||||
|
||||
def _copy_device_array_to_device(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[xc.Device]) -> Union[device_array.DeviceArrayProtocol, device_array._DeviceArray]:
|
||||
if device is None:
|
||||
# no copying to be done because there's no target specified
|
||||
return x
|
||||
elif xb.get_device_backend(device).platform == x.device_buffer.platform():
|
||||
# source and target platforms are the same
|
||||
if x.device_buffer.device() == device:
|
||||
# no copying to be done because source equals target
|
||||
if x._device == device:
|
||||
return x
|
||||
else:
|
||||
moved_buf = x.device_buffer # We need to change stickyness
|
||||
else:
|
||||
# move the buffer with a device-to-device copy
|
||||
moved_buf = x.device_buffer.copy_to_device(device)
|
||||
else:
|
||||
# buffers from different XLA backends are passed through the host.
|
||||
backend = xb.get_device_backend(device)
|
||||
moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device)
|
||||
return device_array.make_device_array(x.aval, device, moved_buf)
|
||||
|
||||
|
||||
def _device_put_impl(x, device: Optional[Device] = None):
|
||||
if device_array.type_is_device_array(x):
|
||||
return _copy_device_array_to_device(x, device)
|
||||
|
||||
try:
|
||||
a = xla.abstractify(x)
|
||||
except TypeError as err:
|
||||
raise TypeError(
|
||||
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
|
||||
return aval_to_result_handler(device, a)(*device_put(x, device))
|
||||
|
||||
device_put_p = core.Primitive('device_put')
|
||||
device_put_p.def_impl(_device_put_impl)
|
||||
device_put_p.def_abstract_eval(lambda x, device=None: x)
|
||||
xla.translations[device_put_p] = lambda c, x, device=None: x
|
||||
ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
|
||||
masking.defvectorized(device_put_p)
|
||||
batching.defvectorized(device_put_p)
|
@ -14,7 +14,7 @@
|
||||
|
||||
from jax import core
|
||||
from jax import numpy as jnp
|
||||
from jax.interpreters import xla
|
||||
from jax._src import device_array
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_bridge
|
||||
|
||||
@ -23,7 +23,7 @@ SUPPORTED_DTYPES = set([jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64])
|
||||
|
||||
|
||||
def to_dlpack(x: xla.DeviceArrayProtocol, take_ownership: bool = False):
|
||||
def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False):
|
||||
"""Returns a DLPack tensor that encapsulates a DeviceArray `x`.
|
||||
|
||||
Takes ownership of the contents of `x`; leaves `x` in an invalid/deleted
|
||||
@ -37,7 +37,7 @@ def to_dlpack(x: xla.DeviceArrayProtocol, take_ownership: bool = False):
|
||||
undefined behavior if the DLPack consumer writes to a buffer that JAX
|
||||
owns.
|
||||
"""
|
||||
if not isinstance(x, xla.DeviceArray):
|
||||
if not isinstance(x, device_array.DeviceArray):
|
||||
raise TypeError("Argument to to_dlpack must be a DeviceArray, got {}"
|
||||
.format(type(x)))
|
||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
||||
@ -62,4 +62,4 @@ def from_dlpack(dlpack):
|
||||
xla_shape = buf.xla_shape()
|
||||
assert not xla_shape.is_tuple()
|
||||
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
|
||||
return xla.make_device_array(aval, buf.device(), buf) # pytype: disable=attribute-error
|
||||
return device_array.make_device_array(aval, buf.device(), buf) # pytype: disable=attribute-error
|
||||
|
@ -33,6 +33,8 @@ from jax import core
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax import linear_util as lu
|
||||
from jax._src import dtypes
|
||||
from jax import tree_util
|
||||
@ -473,7 +475,7 @@ def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
|
||||
operand = np.asarray(operand, new_dtype)
|
||||
|
||||
if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
|
||||
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
|
||||
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
|
||||
return operand
|
||||
else:
|
||||
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
|
||||
@ -813,7 +815,7 @@ def broadcast_in_dim(operand: Array, shape: Shape,
|
||||
shape = _broadcast_in_dim_shape_rule(
|
||||
operand, shape=shape, broadcast_dimensions=broadcast_dimensions)
|
||||
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions)
|
||||
and isinstance(operand, (xla.DeviceArray, core.Tracer))):
|
||||
and isinstance(operand, (device_array.DeviceArray, core.Tracer))):
|
||||
return operand
|
||||
return broadcast_in_dim_p.bind(
|
||||
operand, shape=tuple(shape),
|
||||
@ -872,7 +874,7 @@ def reshape(operand: Array, new_sizes: Shape,
|
||||
dims = api_util._ensure_index_tuple(dimensions)
|
||||
same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
|
||||
if (np.shape(operand) and same_shape and same_dims
|
||||
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
|
||||
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
|
||||
return operand
|
||||
else:
|
||||
return reshape_p.bind(
|
||||
@ -1405,7 +1407,7 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array:
|
||||
"""
|
||||
permutation = tuple(operator.index(d) for d in permutation)
|
||||
if (permutation == tuple(range(np.ndim(operand)))
|
||||
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
|
||||
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
|
||||
return operand
|
||||
else:
|
||||
return transpose_p.bind(operand, permutation=permutation)
|
||||
@ -1607,11 +1609,11 @@ def full(shape: Shape, fill_value: Array, dtype: Optional[DType] = None) -> Arra
|
||||
return broadcast(fill_value, shape)
|
||||
|
||||
def _device_put_raw(x, weak_type=None):
|
||||
if isinstance(x, xla.DeviceArray):
|
||||
if isinstance(x, device_array.DeviceArray):
|
||||
return x
|
||||
else:
|
||||
aval = raise_to_shaped(core.get_aval(x), weak_type=weak_type)
|
||||
return xla.array_result_handler(None, aval)(*xla.device_put(x))
|
||||
return dispatch.array_result_handler(None, aval)(*dispatch.device_put(x))
|
||||
|
||||
def zeros_like_shaped_array(aval):
|
||||
assert isinstance(aval, ShapedArray)
|
||||
@ -2107,10 +2109,11 @@ def zeros_like_array(x):
|
||||
|
||||
for t in itertools.chain(
|
||||
dtypes.python_scalar_dtypes.keys(), array_types,
|
||||
[xla._CppDeviceArray, xla._DeviceArray, pxla.ShardedDeviceArray, pxla.pmap_lib.ShardedDeviceArray]):
|
||||
device_array.device_array_types,
|
||||
[pxla.ShardedDeviceArray, pxla.pmap_lib.ShardedDeviceArray]):
|
||||
ad_util.jaxval_adders[t] = add
|
||||
ad_util.jaxval_zeros_likers[xla._DeviceArray] = zeros_like_array
|
||||
ad_util.jaxval_zeros_likers[xla._CppDeviceArray] = zeros_like_array
|
||||
ad_util.jaxval_zeros_likers[device_array._DeviceArray] = zeros_like_array
|
||||
ad_util.jaxval_zeros_likers[device_array.Buffer] = zeros_like_array
|
||||
ad_util.jaxval_zeros_likers[pxla.ShardedDeviceArray] = zeros_like_array
|
||||
ad_util.jaxval_zeros_likers[pxla.pmap_lib.ShardedDeviceArray] = zeros_like_array
|
||||
|
||||
|
@ -47,9 +47,9 @@ from jax._src.api_util import _ensure_index_tuple
|
||||
from jax import errors
|
||||
from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
|
||||
from jax.config import config
|
||||
from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray, make_device_array
|
||||
from jax.interpreters import pxla
|
||||
from jax import lax
|
||||
from jax._src import device_array
|
||||
from jax._src.lax.lax import _array_copy
|
||||
from jax._src.ops import scatter
|
||||
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
|
||||
@ -362,8 +362,9 @@ class ndarray(metaclass=ArrayMeta):
|
||||
def weak_type(self) -> bool: ...
|
||||
|
||||
|
||||
ndarray.register(DeviceArray)
|
||||
ndarray.register(_CppDeviceArray)
|
||||
ndarray.register(device_array.DeviceArray)
|
||||
for t in device_array.device_array_types:
|
||||
ndarray.register(t)
|
||||
ndarray.register(pxla._SDA_BASE_CLASS)
|
||||
|
||||
|
||||
@ -3593,12 +3594,12 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0, *, device=None):
|
||||
lax._check_user_dtype_supported(_inferred_dtype, "array")
|
||||
out = _np_array(object, copy=copy, dtype=dtype)
|
||||
if dtype: assert _dtype(out) == dtype
|
||||
elif isinstance(object, (DeviceArray, core.Tracer)):
|
||||
elif isinstance(object, (device_array.DeviceArray, core.Tracer)):
|
||||
if object.aval is None:
|
||||
# object is a raw buffer; convert to device array on its current device.
|
||||
aval = ShapedArray(object.xla_shape().dimensions(), object.dtype,
|
||||
weak_type=bool(getattr(object, "weak_type", False)))
|
||||
object = make_device_array(aval, object.device(), object)
|
||||
object = device_array.make_device_array(aval, object.device(), object)
|
||||
out = _array_copy(object) if copy else object
|
||||
elif isinstance(object, (list, tuple)):
|
||||
if object:
|
||||
@ -3630,7 +3631,7 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0, *, device=None):
|
||||
return out
|
||||
|
||||
def _can_call_numpy_array(x):
|
||||
return _all(not isinstance(l, (core.Tracer, DeviceArray))
|
||||
return _all(not isinstance(l, (core.Tracer, device_array.DeviceArray))
|
||||
for l in tree_leaves(x))
|
||||
|
||||
|
||||
@ -6715,7 +6716,7 @@ _NOT_IMPLEMENTED = ['argpartition']
|
||||
|
||||
# Experimental support for NumPy's module dispatch with NEP-37.
|
||||
# Currently requires https://github.com/seberg/numpy-dispatch
|
||||
_JAX_ARRAY_TYPES = (DeviceArray, core.Tracer)
|
||||
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer)
|
||||
_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,)
|
||||
|
||||
def __array_module__(self, types):
|
||||
@ -6752,7 +6753,7 @@ def _multi_slice(arr,
|
||||
@jit
|
||||
def _unstack(x):
|
||||
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
|
||||
setattr(DeviceArray, "_unstack", _unstack)
|
||||
setattr(device_array.DeviceArray, "_unstack", _unstack)
|
||||
def _chunk_iter(x, size):
|
||||
if size > x.shape[0]:
|
||||
yield x
|
||||
@ -6762,7 +6763,7 @@ def _chunk_iter(x, size):
|
||||
yield lax.dynamic_slice_in_dim(x, i * size, size)
|
||||
if tail:
|
||||
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
|
||||
setattr(DeviceArray, "_chunk_iter", _chunk_iter)
|
||||
setattr(device_array.DeviceArray, "_chunk_iter", _chunk_iter)
|
||||
|
||||
# Syntactic sugar for scatter operations.
|
||||
class _IndexUpdateHelper:
|
||||
@ -7050,7 +7051,7 @@ def _set_device_array_base_attributes(device_array):
|
||||
setattr(device_array, "nbytes", property(_nbytes))
|
||||
setattr(device_array, "clip", _clip)
|
||||
|
||||
_set_device_array_base_attributes(DeviceArray)
|
||||
_set_device_array_base_attributes(device_array.DeviceArray)
|
||||
|
||||
|
||||
def _set_device_array_attributes(device_array):
|
||||
@ -7063,7 +7064,7 @@ def _set_device_array_attributes(device_array):
|
||||
setattr(device_array, "_multi_slice", _multi_slice)
|
||||
setattr(device_array, "at", property(_IndexUpdateHelper))
|
||||
|
||||
_set_device_array_attributes(_DeviceArray)
|
||||
_set_device_array_attributes(_CppDeviceArray)
|
||||
for t in device_array.device_array_types:
|
||||
_set_device_array_attributes(t)
|
||||
_set_device_array_attributes(pxla._ShardedDeviceArray)
|
||||
_set_device_array_attributes(pxla.pmap_lib.ShardedDeviceArray)
|
||||
|
@ -38,6 +38,7 @@ from jax._src.config import flags, bool_env, config
|
||||
from jax._src.util import prod, unzip2
|
||||
from jax.tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src import dispatch
|
||||
from jax.interpreters import xla
|
||||
from jax.experimental.maps import mesh
|
||||
|
||||
@ -337,29 +338,29 @@ def check_grads(f, args, order,
|
||||
|
||||
@contextmanager
|
||||
def count_device_put():
|
||||
device_put = xla.device_put
|
||||
device_put = dispatch.device_put
|
||||
count = [0]
|
||||
|
||||
def device_put_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return device_put(*args, **kwargs)
|
||||
|
||||
xla.device_put = device_put_and_count
|
||||
dispatch.device_put = device_put_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
xla.device_put = device_put
|
||||
dispatch.device_put = device_put
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_primitive_compiles():
|
||||
xla.xla_primitive_callable.cache_clear()
|
||||
dispatch.xla_primitive_callable.cache_clear()
|
||||
|
||||
count = [-1]
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
count[0] = xla.xla_primitive_callable.cache_info().misses
|
||||
count[0] = dispatch.xla_primitive_callable.cache_info().misses
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -986,11 +987,11 @@ class JaxTestCase(parameterized.TestCase):
|
||||
np_shapes = tree_map(lambda x: np.shape(np.asarray(x)), python_ans)
|
||||
self.assertEqual(python_shapes, np_shapes)
|
||||
|
||||
cache_misses = xla.xla_primitive_callable.cache_info().misses
|
||||
cache_misses = dispatch.xla_primitive_callable.cache_info().misses
|
||||
python_ans = fun(*args)
|
||||
if check_cache_misses:
|
||||
self.assertEqual(
|
||||
cache_misses, xla.xla_primitive_callable.cache_info().misses,
|
||||
cache_misses, dispatch.xla_primitive_callable.cache_info().misses,
|
||||
"Compilation detected during second call of {} in op-by-op "
|
||||
"mode.".format(fun))
|
||||
|
||||
|
@ -1245,6 +1245,11 @@ class AbstractToken(AbstractValue):
|
||||
|
||||
abstract_token: AbstractToken = AbstractToken()
|
||||
|
||||
# Concrete token object
|
||||
class Token(object): pass
|
||||
token = Token()
|
||||
pytype_aval_mappings[Token] = lambda _: abstract_token
|
||||
|
||||
|
||||
def raise_to_shaped(aval: AbstractValue, weak_type=None):
|
||||
if weak_type is None:
|
||||
|
@ -19,6 +19,7 @@ from typing import (Tuple, List, Sequence, Set, Dict, Any, Callable, Union,
|
||||
Optional)
|
||||
|
||||
from jax import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import source_info_util
|
||||
from jax.core import Var, Literal, Atom, Tracer
|
||||
from jax._src.util import (safe_zip, safe_map, curry, unzip2, split_list,
|
||||
@ -625,12 +626,12 @@ xla.xla_shape_handlers[AbsArray] = _array_xla_shape
|
||||
xla.canonicalize_dtype_handlers[Array] = identity
|
||||
|
||||
def _array_device_put(x, device):
|
||||
return xla._device_put_array(x._data, device)
|
||||
xla.device_put_handlers[Array] = _array_device_put
|
||||
return dispatch._device_put_array(x._data, device)
|
||||
dispatch.device_put_handlers[Array] = _array_device_put
|
||||
|
||||
def _bdint_device_put(x, device):
|
||||
return xla._device_put_scalar(x._val, device)
|
||||
xla.device_put_handlers[BoundedInt] = _bdint_device_put
|
||||
return dispatch._device_put_scalar(x._val, device)
|
||||
dispatch.device_put_handlers[BoundedInt] = _bdint_device_put
|
||||
|
||||
def _bdint_canoncalize_dtype(x):
|
||||
return BoundedInt(xla.canonicalize_dtype(x._val), x._bound)
|
||||
@ -677,8 +678,8 @@ def djaxpr_subcomp(c, jaxpr, dim_args, args):
|
||||
|
||||
def execute_compiled(compiled, partitioner, handlers, dim_vals, args):
|
||||
input_bufs = list(it.chain(
|
||||
(buf for x in dim_vals for buf in xla.device_put(x, None)),
|
||||
(buf for x in args for buf in xla.device_put(x, None))))
|
||||
(buf for x in dim_vals for buf in dispatch.device_put(x, None)),
|
||||
(buf for x in args for buf in dispatch.device_put(x, None))))
|
||||
out_bufs = compiled.execute(input_bufs)
|
||||
dims_dict, grouped_bufs = partitioner(out_bufs)
|
||||
return [handler(dims_dict, bs) for handler, bs in zip(handlers, grouped_bufs)]
|
||||
@ -705,7 +706,7 @@ def result_handler(aval):
|
||||
if isinstance(aval, AbsArray):
|
||||
return array_result_handler(aval)
|
||||
else:
|
||||
handler = xla.aval_to_result_handler(None, aval)
|
||||
handler = dispatch.aval_to_result_handler(None, aval)
|
||||
return lambda _, bufs: handler(*bufs)
|
||||
|
||||
def array_result_handler(aval):
|
||||
@ -721,7 +722,7 @@ def array_result_handler(aval):
|
||||
else:
|
||||
raise NotImplementedError # TODO
|
||||
padded_aval = core.ShapedArray(tuple(padded_shape), aval._eltTy._dtype)
|
||||
array_handler = xla.array_result_handler(None, padded_aval)
|
||||
array_handler = dispatch.array_result_handler(None, padded_aval)
|
||||
def handler(dims_dict, bufs):
|
||||
shape = tuple(dims_dict[d] if isinstance(d, Var) else
|
||||
DimIndexer(dims_dict[d.name], d.indices) if isinstance(d, DimIndexingExpr) else
|
||||
|
@ -458,6 +458,7 @@ from jax import lax
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import ad, xla, batching, masking, pxla
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src import dispatch
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
@ -880,7 +881,7 @@ It takes the following parameters:
|
||||
"""
|
||||
outside_call_p = core.Primitive("outside_call")
|
||||
outside_call_p.multiple_results = True
|
||||
xla.outfeed_primitives.add(outside_call_p)
|
||||
dispatch.outfeed_primitives.add(outside_call_p)
|
||||
|
||||
|
||||
def _outside_call_abstract_eval(*args_a: pe.AbstractValue,
|
||||
@ -918,7 +919,7 @@ def _outside_call_impl(*args, **params):
|
||||
# even in eager execution some primitives, such as while, are compiled.
|
||||
# It would be confusing to process a sequence "id_tap; while" in two
|
||||
# different threads.
|
||||
return xla.apply_primitive(outside_call_p, *args, **params)
|
||||
return dispatch.apply_primitive(outside_call_p, *args, **params)
|
||||
|
||||
|
||||
outside_call_p.def_impl(_outside_call_impl)
|
||||
@ -1355,7 +1356,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
|
||||
"""Rewrite a Jaxpr to thread the token, if needed."""
|
||||
assert has_input_token or not has_output_token
|
||||
|
||||
if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr):
|
||||
if not has_input_token and not dispatch.jaxpr_uses_outfeed(jaxpr):
|
||||
return jaxpr
|
||||
|
||||
mk_new_var = core.gensym([jaxpr])
|
||||
@ -1377,7 +1378,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
|
||||
lax.create_token_p, {}, source_info_util.current()))
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params):
|
||||
if not dispatch.primitive_uses_outfeed(eqn.primitive, eqn.params):
|
||||
eqns.append(eqn)
|
||||
else:
|
||||
output_token_var = mk_new_var(last_token_var.aval)
|
||||
@ -1415,7 +1416,7 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
cond_jaxpr, _, body_jaxpr, _ = util.split_dict(
|
||||
eqn.params,
|
||||
["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
|
||||
if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
|
||||
if dispatch.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
|
||||
_rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var,
|
||||
input_itoken_var, output_itoken_var,
|
||||
mk_new_var)
|
||||
@ -1692,7 +1693,7 @@ id_p.def_impl(lambda *args: args)
|
||||
id_p.def_abstract_eval(lambda *args: args)
|
||||
xla.register_translation(id_p, lambda ctx, avals_in, avals_out, *args: args)
|
||||
|
||||
xla.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)
|
||||
dispatch.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)
|
||||
|
||||
|
||||
class CallbackException(Exception):
|
||||
@ -1820,7 +1821,7 @@ def _initialize_outfeed_receiver(
|
||||
|
||||
def exit_handler():
|
||||
# Prevent logging usage during compilation, gives errors under pytest
|
||||
xla._on_exit = True # type: ignore[protected-access]
|
||||
dispatch._on_exit = True # type: ignore[protected-access]
|
||||
if not _callback_handler_data.on_exit:
|
||||
_callback_handler_data.on_exit = True
|
||||
barrier_wait("at_exit")
|
||||
|
@ -26,6 +26,7 @@ from jax._src import api_util
|
||||
from jax import config
|
||||
from jax._src import api
|
||||
from jax import core, custom_derivatives
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax import linear_util as lu
|
||||
from jax import random, tree_util
|
||||
@ -968,7 +969,7 @@ def _add(x: TfVal, y: TfVal) -> TfVal:
|
||||
|
||||
|
||||
tf_impl[ad_util.add_jaxvals_p] = _add
|
||||
tf_impl[xla.device_put_p] = lambda x, device=None: x
|
||||
tf_impl[dispatch.device_put_p] = lambda x, device=None: x
|
||||
|
||||
def _neg(x: TfVal) -> TfVal:
|
||||
if x.dtype.is_unsigned:
|
||||
|
@ -52,7 +52,7 @@ from jax._src import test_util as jtu
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax.interpreters import xla
|
||||
from jax._src import dispatch
|
||||
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
@ -639,7 +639,7 @@ def _make_device_put_harness(name,
|
||||
define(
|
||||
"device_put",
|
||||
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_device={device}",
|
||||
lambda x: xla.device_put_p.bind(x, device=_device_fn()),
|
||||
lambda x: dispatch.device_put_p.bind(x, device=_device_fn()),
|
||||
[RandArg(shape, dtype)],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
|
@ -23,6 +23,7 @@ import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax._src.util import unzip2
|
||||
from jax._src import ad_util
|
||||
from jax._src import dispatch
|
||||
from jax.tree_util import (register_pytree_node, tree_structure,
|
||||
treedef_is_leaf, tree_flatten, tree_unflatten)
|
||||
import jax.linear_util as lu
|
||||
@ -240,7 +241,7 @@ deflinear(lax.slice_p)
|
||||
deflinear(lax.reduce_sum_p)
|
||||
deflinear(lax.reduce_window_sum_p)
|
||||
deflinear(lax.fft_p)
|
||||
deflinear(xla.device_put_p)
|
||||
deflinear(dispatch.device_put_p)
|
||||
|
||||
def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool,
|
||||
combine_fn: Callable):
|
||||
|
@ -27,6 +27,7 @@ from .. import numpy as jnp
|
||||
from .. import core
|
||||
from .. import linear_util as lu
|
||||
from .._src.api import _check_callable, _check_arg
|
||||
from jax._src import dispatch
|
||||
from ..tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
|
||||
tree_leaves)
|
||||
from .._src.tree_util import _replace_nones
|
||||
@ -735,7 +736,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
use_spmd_lowering, in_avals,
|
||||
tile_by_mesh_axes=True)
|
||||
else:
|
||||
return xla.lower_xla_callable(
|
||||
return dispatch.lower_xla_callable(
|
||||
f, None, backend, name, donated_invars, *((a, None) for a in in_avals))
|
||||
|
||||
class EvaluationPlan(NamedTuple):
|
||||
|
@ -29,7 +29,6 @@ from .. import linear_util as lu
|
||||
from .._src.util import (unzip2, safe_map, safe_zip, wrap_name, split_list,
|
||||
canonicalize_axis, moveaxis, as_hashable_function,
|
||||
curry, memoize)
|
||||
from . import xla
|
||||
from . import partial_eval as pe
|
||||
|
||||
map = safe_map
|
||||
@ -629,5 +628,3 @@ def zeros_like_batched(batched_args, batch_dims):
|
||||
bdim, = batch_dims
|
||||
return zeros_like_jaxval(val), bdim
|
||||
primitive_batchers[zeros_like_p] = zeros_like_batched
|
||||
|
||||
defvectorized(xla.device_put_p)
|
||||
|
@ -31,6 +31,8 @@ from jax import linear_util as lu
|
||||
from jax._src.config import config
|
||||
from jax._src import ad_util
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax import control_flow
|
||||
@ -234,8 +236,8 @@ for ptype, dtype in dtypes.python_scalar_dtypes.items():
|
||||
def _device_array_constant_handler(val, canonicalize_types):
|
||||
return _ndarray_constant_handler(val.device_buffer.to_py(),
|
||||
canonicalize_types)
|
||||
register_constant_handler(xla._DeviceArray, _device_array_constant_handler)
|
||||
register_constant_handler(xla._CppDeviceArray, _device_array_constant_handler)
|
||||
for t in device_array.device_array_types:
|
||||
register_constant_handler(t, _device_array_constant_handler)
|
||||
|
||||
|
||||
# Source locations
|
||||
@ -523,7 +525,7 @@ translations[core.call_p] = partial(_named_call_lowering, name="core_call")
|
||||
def _device_put_lowering(ctx, avals_in, avals_out, x, *, device):
|
||||
return [x]
|
||||
|
||||
translations[xla.device_put_p] = _device_put_lowering
|
||||
translations[dispatch.device_put_p] = _device_put_lowering
|
||||
|
||||
|
||||
def _full_like_aval(value, aval: core.ShapedArray) -> ir.Value:
|
||||
@ -629,7 +631,7 @@ def _execute_compiled(name: str, compiled: xla.XlaExecutable,
|
||||
unsafe_zip(arg_handlers,
|
||||
(x for i, x in enumerate(args) if i in kept_var_idx)))
|
||||
out_bufs = compiled.execute(input_bufs)
|
||||
xla.check_special(name, out_bufs)
|
||||
dispatch.check_special(name, out_bufs)
|
||||
return [handler(device, *bs) for handler, bs in
|
||||
zip(result_handlers, xla._partition_outputs(buffer_counts, out_bufs))]
|
||||
|
||||
@ -649,7 +651,7 @@ def _execute_replicated(name: str, compiled: xla.XlaExecutable,
|
||||
buf[0] for buf in compiled.execute_sharded_on_local_devices(
|
||||
list(zip(*input_bufs)))
|
||||
]
|
||||
xla.check_special(name, out_bufs)
|
||||
dispatch.check_special(name, out_bufs)
|
||||
return [handler(device, *bs) for handler, bs in
|
||||
zip(result_handlers, xla._partition_outputs(buffer_counts, out_bufs))]
|
||||
|
||||
@ -662,7 +664,7 @@ def _execute_trivial(jaxpr, device: Optional[xla.Device], consts, buffer_counts,
|
||||
map(env.setdefault, jaxpr.constvars, consts)
|
||||
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
|
||||
for v in jaxpr.outvars]
|
||||
return [xla.device_put_p.bind(x, device=device) for x in outs]
|
||||
return [dispatch.device_put_p.bind(x, device=device) for x in outs]
|
||||
|
||||
|
||||
class XlaCompiledComputation:
|
||||
@ -688,7 +690,7 @@ class XlaCompiledComputation:
|
||||
num_partitions=1,
|
||||
device_assignment=(device.id,) if device else None)
|
||||
options.parameter_is_tupled_arguments = tuple_args
|
||||
compiled = xla.compile_or_get_cached(backend, xla_computation, options)
|
||||
compiled = dispatch.compile_or_get_cached(backend, xla_computation, options)
|
||||
buffer_counts = [aval_to_num_buffers(aval) for aval in avals_out]
|
||||
if nreps == 1:
|
||||
return XlaCompiledComputation(compiled, partial(
|
||||
@ -771,18 +773,18 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise UnexpectedTracerError("Encountered an unexpected tracer.")
|
||||
|
||||
jaxpr, kept_const_idx, kept_var_idx = xla._prune_unused_inputs(jaxpr)
|
||||
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
|
||||
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
||||
pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx)
|
||||
avals_in, arg_devices = util.unzip2(pruned_arg_specs)
|
||||
donated_invars = [
|
||||
x for i, x in enumerate(donated_invars) if i in kept_var_idx
|
||||
]
|
||||
map(xla.prefetch, itertools.chain(consts, xla.jaxpr_literals(jaxpr)))
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
map(dispatch.prefetch, itertools.chain(consts, dispatch.jaxpr_literals(jaxpr)))
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
nreps = xla.jaxpr_replicas(jaxpr)
|
||||
device = xla._xla_callable_device(nreps, backend, device, arg_devices)
|
||||
nreps = dispatch.jaxpr_replicas(jaxpr)
|
||||
device = dispatch._xla_callable_device(nreps, backend, device, arg_devices)
|
||||
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
|
||||
|
||||
# Computations that only produce constants and/or only rearrange their inputs,
|
||||
@ -792,7 +794,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars
|
||||
return XlaComputation(name, None, True, jaxpr, consts, device, avals_in,
|
||||
avals_out, kept_var_idx)
|
||||
|
||||
if not xla._on_exit:
|
||||
if not dispatch._on_exit:
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority, "Compiling %s (%s) for args %s.",
|
||||
fun.__name__, id(fun), avals_in)
|
||||
@ -810,7 +812,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars
|
||||
f"compiling computation that requires {nreps} replicas, but only "
|
||||
f"{xb.device_count(backend)} XLA devices are available")
|
||||
|
||||
if xb.process_count() > 1 and (nreps > 1 or xla.jaxpr_has_pmap(jaxpr)):
|
||||
if xb.process_count() > 1 and (nreps > 1 or dispatch.jaxpr_has_pmap(jaxpr)):
|
||||
raise NotImplementedError(
|
||||
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
|
||||
"extra data movement anyway, so maybe you don't want it after all).")
|
||||
@ -855,15 +857,15 @@ def _xla_call_impl_mlir(fun: lu.WrappedFun, *args, device, backend, name,
|
||||
donated_invars, inline):
|
||||
del inline # Only used at tracing time
|
||||
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
|
||||
*unsafe_map(xla.arg_spec, args))
|
||||
*unsafe_map(dispatch.arg_spec, args))
|
||||
return compiled_fun(*args)
|
||||
|
||||
|
||||
@util.cache()
|
||||
def _xla_primitive_callable(prim, *arg_specs: xla.ArgSpec, **params):
|
||||
def _xla_primitive_callable(prim, *arg_specs: dispatch.ArgSpec, **params):
|
||||
avals, arg_devices = util.unzip2(arg_specs)
|
||||
donated_invars = (False,) * len(arg_specs)
|
||||
device = xla._device_from_arg_devices(arg_devices)
|
||||
device = dispatch._device_from_arg_devices(arg_devices)
|
||||
def prim_fun(*args):
|
||||
out = prim.bind(*args, **params)
|
||||
if prim.multiple_results:
|
||||
@ -879,7 +881,7 @@ def _xla_primitive_callable(prim, *arg_specs: xla.ArgSpec, **params):
|
||||
|
||||
def apply_primitive(prim, *args, **params):
|
||||
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
|
||||
compiled_fun = _xla_primitive_callable(prim, *unsafe_map(xla.arg_spec, args),
|
||||
compiled_fun = _xla_primitive_callable(prim, *unsafe_map(dispatch.arg_spec, args),
|
||||
**params)
|
||||
return compiled_fun(*args)
|
||||
|
||||
|
@ -46,11 +46,13 @@ from .. import core
|
||||
from .. import linear_util as lu
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from ..core import ConcreteArray, ShapedArray
|
||||
from jax._src import device_array
|
||||
from .._src import source_info_util
|
||||
from .._src.util import (unzip3, prod, safe_map, safe_zip,
|
||||
extend_name_stack, wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, distributed_debug_log)
|
||||
from ..errors import JAXTypeError
|
||||
from jax._src import dispatch
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
@ -321,15 +323,15 @@ def _shard_device_array(x, devices, indices):
|
||||
_as_slice_indices(x, idx) for idx in indices)
|
||||
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
|
||||
return device_put(shards, devices)
|
||||
shard_arg_handlers[xla._DeviceArray] = _shard_device_array
|
||||
shard_arg_handlers[xla._CppDeviceArray] = _shard_device_array
|
||||
for t in device_array.device_array_types:
|
||||
shard_arg_handlers[t] = _shard_device_array
|
||||
|
||||
|
||||
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
|
||||
# from the input ShardingSpec, rather than the indices. However, this would
|
||||
# require duplicating the ordering logic of spec_to_indices, which is more
|
||||
# subtle and more likely to change than the index logic we have to support here.
|
||||
def _as_slice_indices(arr: xla.DeviceArrayProtocol, idx: Index) -> Tuple[
|
||||
def _as_slice_indices(arr: device_array.DeviceArrayProtocol, idx: Index) -> Tuple[
|
||||
Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]:
|
||||
"""Returns start_indices, limit_indices, removed_dims"""
|
||||
start_indices = [0] * arr.ndim
|
||||
@ -509,11 +511,11 @@ if _USE_CPP_SDA:
|
||||
ShardedDeviceArrayBase = pmap_lib.ShardedDeviceArrayBase # type: ignore
|
||||
# We want the C++ SDA to extend the DeviceArrayBase. We want this both to
|
||||
# benefit from its methods, and to have isinstance(x, DeviceArray) return true
|
||||
ShardedDeviceArrayBase.__bases__ = ((xla.DeviceArray,) + # type: ignore
|
||||
ShardedDeviceArrayBase.__bases__ = ((device_array.DeviceArray,) + # type: ignore
|
||||
ShardedDeviceArrayBase.__bases__)
|
||||
_SDA_BASE_CLASS = pmap_lib.ShardedDeviceArrayBase # type: ignore
|
||||
else:
|
||||
_SDA_BASE_CLASS: Type[xla.DeviceArray] = xla.DeviceArray # type: ignore
|
||||
_SDA_BASE_CLASS: Type[device_array.DeviceArray] = device_array.DeviceArray # type: ignore
|
||||
|
||||
|
||||
class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore
|
||||
@ -646,7 +648,7 @@ def _sda__getitem__(self, idx):
|
||||
if buf_idx is not None:
|
||||
buf = self.device_buffers[buf_idx]
|
||||
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
|
||||
return xla.make_device_array(aval, None, buf)
|
||||
return device_array.make_device_array(aval, None, buf)
|
||||
return super(self.__class__, self).__getitem__(idx)
|
||||
|
||||
|
||||
@ -726,7 +728,7 @@ def _register_handlers_for_sharded_device_array(sda):
|
||||
xla.register_constant_handler(sda, _sharded_device_array_constant_handler)
|
||||
|
||||
core.pytype_aval_mappings[sda] = ConcreteArray
|
||||
xla.device_put_handlers[sda] = xla._device_put_array
|
||||
dispatch.device_put_handlers[sda] = dispatch._device_put_array
|
||||
xla.pytype_aval_mappings[sda] = op.attrgetter("aval")
|
||||
xla.canonicalize_dtype_handlers[sda] = identity
|
||||
|
||||
@ -829,7 +831,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
out_axes = out_axes_thunk()
|
||||
assert len(out_sharded_avals) == len(out_axes), (len(out_sharded_avals), len(out_axes))
|
||||
@ -844,7 +846,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
check_multihost_collective_allowlist(jaxpr)
|
||||
|
||||
# TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
|
||||
jaxpr_replicas = xla.jaxpr_replicas(jaxpr)
|
||||
jaxpr_replicas = dispatch.jaxpr_replicas(jaxpr)
|
||||
num_local_replicas = axis_size * jaxpr_replicas
|
||||
num_global_replicas = global_axis_size * jaxpr_replicas
|
||||
|
||||
@ -1024,7 +1026,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
handle_outs)
|
||||
return WeakRefList([execute_fun, None])
|
||||
|
||||
compiled = xla.compile_or_get_cached(backend, built, compile_options)
|
||||
compiled = dispatch.compile_or_get_cached(backend, built, compile_options)
|
||||
handle_args = InputsHandler(compiled.local_devices(), input_sharding_specs,
|
||||
input_indices)
|
||||
execute_fun = partial(execute_replicated, compiled, backend, handle_args, handle_outs)
|
||||
@ -1314,9 +1316,9 @@ def partitioned_sharding_spec(num_partitions: int,
|
||||
def execute_replicated(compiled, backend, in_handler, out_handler, *args):
|
||||
input_bufs = in_handler(args)
|
||||
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
|
||||
if xla.needs_check_special():
|
||||
if dispatch.needs_check_special():
|
||||
for bufs in out_bufs:
|
||||
xla.check_special("parallel computation", bufs)
|
||||
dispatch.check_special("parallel computation", bufs)
|
||||
return out_handler(out_bufs)
|
||||
|
||||
|
||||
@ -1634,7 +1636,7 @@ def lower_mesh_computation(
|
||||
_sanitize_mesh_jaxpr(jaxpr)
|
||||
if local_mesh.shape != mesh.shape:
|
||||
check_multihost_collective_allowlist(jaxpr)
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
# 3. Build up the HLO
|
||||
c = xc.XlaBuilder(f"xmap_{fun.__name__}")
|
||||
@ -1766,7 +1768,7 @@ class MeshExecutable:
|
||||
input_indices, local_input_specs,
|
||||
handle_outs)
|
||||
else:
|
||||
compiled = xla.compile_or_get_cached(backend, computation, compile_options)
|
||||
compiled = dispatch.compile_or_get_cached(backend, computation, compile_options)
|
||||
handle_args = InputsHandler(compiled.local_devices(), local_input_specs,
|
||||
input_indices)
|
||||
self.unsafe_call = partial(execute_replicated, compiled, backend, handle_args, handle_outs)
|
||||
@ -1777,7 +1779,7 @@ class MeshExecutable:
|
||||
def call(self, *args):
|
||||
arg_avals = map(xla.abstractify, args)
|
||||
ref_avals = self._local_in_untiled_avals
|
||||
xla.check_arg_avals_for_call(ref_avals, arg_avals)
|
||||
dispatch.check_arg_avals_for_call(ref_avals, arg_avals)
|
||||
return self.unsafe_call(*args)
|
||||
|
||||
|
||||
@ -1908,6 +1910,6 @@ _thread_local_state = _ThreadLocalState()
|
||||
def device_put(x, devices: Sequence[xb.xla_client.Device], replicate: bool=False) -> List[xb.xla_client.Buffer]:
|
||||
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
|
||||
if replicate:
|
||||
return list(it.chain.from_iterable(xla.device_put(x, device) for device in devices))
|
||||
return list(it.chain.from_iterable(dispatch.device_put(x, device) for device in devices))
|
||||
else:
|
||||
return list(it.chain.from_iterable(xla.device_put(val, device) for val, device in safe_zip(x, devices)))
|
||||
return list(it.chain.from_iterable(dispatch.device_put(val, device) for val, device in safe_zip(x, devices)))
|
||||
|
@ -25,6 +25,7 @@ from . import partial_eval as pe
|
||||
from . import pxla
|
||||
from . import xla
|
||||
from .. import linear_util as lu
|
||||
from jax._src import dispatch
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from .._src.api_util import argnums_partial, flatten_axes, flatten_fun, _ensure_index_tuple
|
||||
@ -157,7 +158,7 @@ def _sharded_callable(
|
||||
device_assignment = np.reshape(device_assignment, (-1, nparts))
|
||||
# device_assignment = None # TODO(skye): replace with default device assignment?
|
||||
|
||||
compiled = xla.backend_compile(
|
||||
compiled = dispatch.backend_compile(
|
||||
xb.get_backend(), built,
|
||||
xb.get_compile_options(nrep, nparts, device_assignment))
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -19,7 +19,7 @@
|
||||
from . import fft as fft
|
||||
from . import linalg as linalg
|
||||
|
||||
from jax.interpreters.xla import DeviceArray as DeviceArray
|
||||
from jax._src.device_array import DeviceArray as DeviceArray
|
||||
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
ComplexWarning as ComplexWarning,
|
||||
|
@ -48,6 +48,7 @@ from jax.interpreters import ad
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters.sharded_jit import PartitionSpec as P
|
||||
from jax._src import device_array
|
||||
import jax._src.lib
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import test_util as jtu
|
||||
@ -225,7 +226,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
def test_jit_device(self):
|
||||
device = jax.devices()[-1]
|
||||
x = self.jit(lambda x: x, device=device)(3.)
|
||||
self.assertIsInstance(x, xla.DeviceArray)
|
||||
self.assertIsInstance(x, jnp.DeviceArray)
|
||||
self.assertEqual(x.device_buffer.device(), device)
|
||||
|
||||
def test_complex_support(self):
|
||||
@ -492,7 +493,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
jitted_f = self.jit(lambda a: a + 1)
|
||||
jitted_f(1)
|
||||
self.assertIsInstance(jitted_f(2), xla._CppDeviceArray)
|
||||
self.assertIsInstance(jitted_f(2), device_array.Buffer)
|
||||
|
||||
@jtu.skip_on_devices("cpu")
|
||||
def test_explicit_backend(self):
|
||||
@ -986,13 +987,13 @@ class APITest(jtu.JaxTestCase):
|
||||
"Transpose rule (for reverse-mode differentiation) for 'foo' not implemented")
|
||||
|
||||
def test_is_subclass(self):
|
||||
self.assertTrue(issubclass(xla.DeviceArray, jnp.ndarray))
|
||||
self.assertTrue(issubclass(xla._CppDeviceArray, jnp.ndarray))
|
||||
self.assertTrue(issubclass(device_array.DeviceArray, jnp.ndarray))
|
||||
self.assertTrue(issubclass(device_array.Buffer, jnp.ndarray))
|
||||
self.assertTrue(issubclass(pxla.ShardedDeviceArray, jnp.ndarray))
|
||||
self.assertTrue(issubclass(pxla._ShardedDeviceArray, jnp.ndarray))
|
||||
self.assertFalse(issubclass(np.ndarray, jnp.ndarray))
|
||||
self.assertFalse(issubclass(xla.DeviceArray, np.ndarray))
|
||||
self.assertFalse(issubclass(xla._CppDeviceArray, np.ndarray))
|
||||
self.assertFalse(issubclass(device_array.DeviceArray, np.ndarray))
|
||||
self.assertFalse(issubclass(device_array.Buffer, np.ndarray))
|
||||
self.assertFalse(issubclass(pxla.ShardedDeviceArray, np.ndarray))
|
||||
self.assertFalse(issubclass(pxla._ShardedDeviceArray, np.ndarray))
|
||||
|
||||
@ -1007,7 +1008,7 @@ class APITest(jtu.JaxTestCase):
|
||||
def test_device_put_and_get(self):
|
||||
x = np.arange(12.).reshape((3, 4)).astype("float32")
|
||||
dx = api.device_put(x)
|
||||
self.assertIsInstance(dx, xla.DeviceArray)
|
||||
self.assertIsInstance(dx, device_array.DeviceArray)
|
||||
self.assertIsInstance(dx, jnp.ndarray)
|
||||
self.assertNotIsInstance(dx, np.ndarray)
|
||||
x2 = api.device_get(dx)
|
||||
@ -1030,7 +1031,7 @@ class APITest(jtu.JaxTestCase):
|
||||
def test_device_get_scalar(self):
|
||||
x = np.arange(12.).reshape((3, 4)).astype("float32")
|
||||
x = api.device_put(x)
|
||||
self.assertIsInstance(x, xla.DeviceArray)
|
||||
self.assertIsInstance(x, device_array.DeviceArray)
|
||||
y = [x, 2]
|
||||
y2 = api.device_get(y)
|
||||
self.assertIsInstance(y2, list)
|
||||
@ -1465,11 +1466,11 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_devicearray_repr(self):
|
||||
x = device_put(jnp.zeros(3))
|
||||
self.assertIsInstance(x, xla.DeviceArray)
|
||||
self.assertIsInstance(x, device_array.DeviceArray)
|
||||
repr(x) # doesn't crash
|
||||
|
||||
x = device_put(jnp.ones(3) + 1j * jnp.ones(3))
|
||||
self.assertIsInstance(x, xla.DeviceArray)
|
||||
self.assertIsInstance(x, device_array.DeviceArray)
|
||||
repr(x) # doesn't crash
|
||||
|
||||
def test_devicearray_delete(self):
|
||||
@ -2304,7 +2305,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_device_array_hash(self):
|
||||
rep = jnp.ones((1,)) + 1.
|
||||
self.assertIsInstance(rep, jax.interpreters.xla.DeviceArray)
|
||||
self.assertIsInstance(rep, device_array.DeviceArray)
|
||||
self.assertNotIsInstance(rep, collections.abc.Hashable)
|
||||
with self.assertRaisesRegex(TypeError, 'unhashable type'):
|
||||
hash(rep)
|
||||
@ -2733,7 +2734,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_jit_returning_token(self):
|
||||
x = jax.jit(jax.lax.create_token)(1.0)
|
||||
self.assertIsInstance(x, jax.interpreters.xla.Token)
|
||||
self.assertIsInstance(x, jax.core.Token)
|
||||
|
||||
def test_leak_checker_catches_a_jit_leak(self):
|
||||
with jax.checking_leaks():
|
||||
|
@ -19,6 +19,8 @@ import numpy as np
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
from jax import core, jit, lax, make_jaxpr
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax.interpreters import xla
|
||||
from jax._src.lib import xla_bridge, xla_client
|
||||
xops = xla_client.ops
|
||||
@ -102,8 +104,8 @@ class ConcreteSparseArray(AbstractSparseArray):
|
||||
|
||||
def sparse_array_result_handler(device, aval):
|
||||
def build_sparse_array(data_buf, indices_buf):
|
||||
data = xla.make_device_array(aval.data_aval, device, data_buf)
|
||||
indices = xla.make_device_array(aval.indices_aval, device, indices_buf)
|
||||
data = device_array.make_device_array(aval.data_aval, device, data_buf)
|
||||
indices = device_array.make_device_array(aval.indices_aval, device, indices_buf)
|
||||
return SparseArray(aval, data, indices)
|
||||
return build_sparse_array
|
||||
|
||||
@ -129,8 +131,8 @@ core.pytype_aval_mappings[SparseArray] = lambda x: x.aval
|
||||
core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval
|
||||
xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval
|
||||
xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
|
||||
xla.device_put_handlers[SparseArray] = sparse_array_device_put_handler
|
||||
xla.xla_result_handlers[AbstractSparseArray] = sparse_array_result_handler
|
||||
dispatch.device_put_handlers[SparseArray] = sparse_array_device_put_handler
|
||||
dispatch.xla_result_handlers[AbstractSparseArray] = sparse_array_result_handler
|
||||
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
|
||||
xla.register_constant_handler(SparseArray, sparse_array_constant_handler)
|
||||
|
||||
@ -257,8 +259,8 @@ core.pytype_aval_mappings[Empty] = lambda x: ConcreteEmpty()
|
||||
core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval
|
||||
xla.pytype_aval_mappings[Empty] = lambda x: AbstractEmpty()
|
||||
xla.canonicalize_dtype_handlers[Empty] = lambda x: x
|
||||
xla.device_put_handlers[Empty] = lambda _, __: ()
|
||||
xla.xla_result_handlers[AbstractEmpty] = lambda _, __: lambda: Empty(AbstractEmpty())
|
||||
dispatch.device_put_handlers[Empty] = lambda _, __: ()
|
||||
dispatch.xla_result_handlers[AbstractEmpty] = lambda _, __: lambda: Empty(AbstractEmpty())
|
||||
xla.xla_shape_handlers[AbstractEmpty] = lambda _: ()
|
||||
|
||||
|
||||
|
@ -27,7 +27,6 @@ from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -166,7 +165,7 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
def testScalarInstantiation(self, scalar_type):
|
||||
a = scalar_type(1)
|
||||
self.assertEqual(a.dtype, jnp.dtype(scalar_type))
|
||||
self.assertIsInstance(a, xla.DeviceArray)
|
||||
self.assertIsInstance(a, jnp.DeviceArray)
|
||||
self.assertEqual(0, jnp.ndim(a))
|
||||
self.assertIsInstance(np.dtype(scalar_type).type(1), scalar_type)
|
||||
|
||||
|
@ -35,7 +35,6 @@ from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src.util import unzip2
|
||||
from jax.experimental import maps
|
||||
from jax.interpreters import xla
|
||||
import jax.numpy as jnp # scan tests use numpy
|
||||
import jax.scipy as jsp
|
||||
|
||||
@ -2759,7 +2758,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not DeviceArray
|
||||
_, vjp_fun = jax.vjp(cumprod, x)
|
||||
*_, ext_res = vjp_fun.args[0].args[0]
|
||||
self.assertIsInstance(ext_res, xla.DeviceArray)
|
||||
self.assertIsInstance(ext_res, jnp.DeviceArray)
|
||||
|
||||
def test_scan_vmap_collectives(self):
|
||||
def scan_f(state, x):
|
||||
|
@ -39,9 +39,9 @@ import jax.ops
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax import tree_util
|
||||
from jax.interpreters import xla
|
||||
from jax.test_util import check_grads
|
||||
from jax._src.util import prod, safe_zip
|
||||
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc
|
||||
@ -2483,7 +2483,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
expected_np_input_after_call = np.ones((1))
|
||||
expected_jnp_input_after_call = jnp.ones((1))
|
||||
|
||||
self.assertTrue(xla.type_is_device_array(jnp.concatenate([np_input])))
|
||||
self.assertTrue(device_array.type_is_device_array(jnp.concatenate([np_input])))
|
||||
|
||||
attempt_sideeffect(np_input)
|
||||
attempt_sideeffect(jnp_input)
|
||||
@ -3678,13 +3678,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
assert not np.isscalar(jnp.array(3))
|
||||
|
||||
def testArrayOutputsDeviceArrays(self):
|
||||
assert xla.type_is_device_array(jnp.array([]))
|
||||
assert xla.type_is_device_array(jnp.array(np.array([])))
|
||||
assert device_array.type_is_device_array(jnp.array([]))
|
||||
assert device_array.type_is_device_array(jnp.array(np.array([])))
|
||||
|
||||
class NDArrayLike:
|
||||
def __array__(self, dtype=None):
|
||||
return np.array([], dtype=dtype)
|
||||
assert xla.type_is_device_array(jnp.array(NDArrayLike()))
|
||||
assert device_array.type_is_device_array(jnp.array(NDArrayLike()))
|
||||
|
||||
# NOTE(mattjj): disabled b/c __array__ must produce ndarrays
|
||||
# class DeviceArrayLike:
|
||||
|
@ -36,7 +36,6 @@ from jax._src import lax_reference
|
||||
from jax.test_util import check_grads
|
||||
import jax.util
|
||||
from jax._src.util import prod
|
||||
from jax import xla
|
||||
|
||||
from jax._src.lax.lax import _device_put_raw
|
||||
|
||||
@ -2566,7 +2565,7 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
if jit:
|
||||
op = jax.jit(op)
|
||||
result = op(input_type(value))
|
||||
assert isinstance(result, xla.DeviceArray)
|
||||
assert isinstance(result, jnp.DeviceArray)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype_in={}_dtype_out={}".format(
|
||||
|
@ -23,7 +23,6 @@ import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -182,7 +181,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
|
||||
devices = self.get_devices()
|
||||
|
||||
def f(): return lax.add(3., 4.)
|
||||
self.assertIsInstance(f(), xla.DeviceArray)
|
||||
self.assertIsInstance(f(), jnp.DeviceArray)
|
||||
self.assert_uncommitted_to_device(f(), devices[0])
|
||||
self.assert_uncommitted_to_device(jax.jit(f)(), devices[0])
|
||||
self.assert_committed_to_device(jax.jit(f, device=devices[1])(),
|
||||
|
@ -40,6 +40,7 @@ from jax import random
|
||||
from jax.core import ShapedArray
|
||||
from jax import (pmap, soft_pmap, jit, vmap, jvp, grad, make_jaxpr,
|
||||
linearize, device_put)
|
||||
from jax._src import device_array
|
||||
import jax._src.lib
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.util import prod, safe_map
|
||||
@ -540,12 +541,12 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
y = f(x)
|
||||
self.assertIsInstance(y, jnp.ndarray)
|
||||
self.assertIsInstance(y, pxla.ShardedDeviceArray)
|
||||
self.assertIsInstance(y, jax.interpreters.xla.DeviceArray)
|
||||
self.assertIsInstance(y, device_array.DeviceArray)
|
||||
self.assertNotIsInstance(y, np.ndarray)
|
||||
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
||||
z = f(y)
|
||||
self.assertIsInstance(z, pxla.ShardedDeviceArray)
|
||||
self.assertIsInstance(z, jax.interpreters.xla.DeviceArray)
|
||||
self.assertIsInstance(z, device_array.DeviceArray)
|
||||
self.assertNotIsInstance(z, np.ndarray)
|
||||
self.assertAllClose(z, 2 * 2 * x, check_dtypes=False)
|
||||
|
||||
@ -2325,7 +2326,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
sharded_x = pmap(lambda x: x)(x)
|
||||
self.assertIsNone(sharded_x._npy_value)
|
||||
for i in range(8):
|
||||
self.assertIsInstance(sharded_x[i], jax.interpreters.xla.DeviceArray)
|
||||
self.assertIsInstance(sharded_x[i], device_array.DeviceArray)
|
||||
self.assertIsNone(sharded_x._npy_value)
|
||||
|
||||
def test_device_put_sharded_array(self):
|
||||
|
@ -16,7 +16,7 @@ from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.interpreters import xla
|
||||
from jax._src import dispatch
|
||||
|
||||
|
||||
class XlaInterpreterTest(jtu.JaxTestCase):
|
||||
@ -26,7 +26,7 @@ class XlaInterpreterTest(jtu.JaxTestCase):
|
||||
return args[0]
|
||||
|
||||
closed_jaxpr = jax.make_jaxpr(f)(*range(10))
|
||||
pruned_jaxpr, kept_const_idx, kept_var_idx = xla._prune_unused_inputs(
|
||||
pruned_jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(
|
||||
closed_jaxpr.jaxpr)
|
||||
assert len(pruned_jaxpr.invars) == 1
|
||||
assert kept_const_idx == set()
|
||||
|
Loading…
x
Reference in New Issue
Block a user