Merge branch 'master' into use-raise-from

This commit is contained in:
Akihiro Nitta 2020-10-01 00:27:03 +09:00
commit d707ae17e5
No known key found for this signature in database
GPG Key ID: 0870EC9AD2D54289
10 changed files with 390 additions and 60 deletions

View File

@ -67,7 +67,10 @@ and 2) whether it is **committed** to the device or not (the data is sometimes
referred to as being *sticky* to the device).
By default, JAX arrays are placed uncommitted on the default device
(``jax.devices()[0]``).
(``jax.devices()[0]``), which is the first GPU by default. If no GPU is
present, ``jax.devices()[0]`` is the first CPU. The default device can
be set to "cpu" or "gpu" manually by setting the environment variable
``JAX_PLATFORM_NAME`` or the absl flag ``--jax_platform_name``.
>>> from jax import numpy as jnp
>>> print(jnp.ones(3).device_buffer.device()) # doctest: +SKIP
@ -97,6 +100,11 @@ device.
Jitted functions behave like any other primitive operations—they will follow the
data and will show errors if invoked on data committed on more than one device.
``jnp.device_put(jnp.zeros(...), jax.devices()[1])`` or similar will actually create the
array of zeros on ``jax.devices()[1]``, instead of creating the array on the default
device then moving it. This is thanks to some laziness in array creation, which holds
for all the constant creation operations (``ones``, ``full``, ``eye``, etc).
(As of April 2020, :func:`jax.jit` has a `device` parameter that affects the device
placement. That parameter is experimental, is likely to be removed or changed,
and its use is not recommended.)

View File

@ -2093,7 +2093,7 @@ def device_put_sharded(x: Sequence[Any], devices: Sequence[xc.Device]):
f"abstract values not compatible: {avals}"
x_aval = core.raise_to_shaped(avals[0])
aval = ShapedArray((len(devices),) + x_aval.shape, x_aval.dtype)
buffers = [xla.device_put(x, d) for x, d in zip(xs, devices)]
buffers = list(it.chain.from_iterable(xla.device_put(x, d) for x, d in zip(xs, devices)))
return pxla.ShardedDeviceArray(aval, buffers)
return tree_multimap(_device_put_sharded, *x)

View File

@ -747,6 +747,7 @@ def find_top_trace(xs) -> Trace:
class AbstractValue:
__slots__: List[str] = []
_num_buffers: int = 1 # number of buffers used to represent the value.
def at_least_vspace(self):
assert False
@ -769,6 +770,8 @@ class Bot(AbstractValue): pass
bot = Bot()
class AbstractUnit(AbstractValue):
# TODO(jakevdp): make it possible to set zero buffers
# _num_buffers = 0
def join(self, other):
if not skip_checks:
assert other is abstract_unit, other

View File

@ -66,6 +66,15 @@ unsafe_map, map = map, safe_map
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
def device_put(x, devices: Sequence[xb.xla_client.Device], replicate: bool=False) -> List[xb.xla_client._xla.PyLocalBuffer]:
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
if replicate:
return list(it.chain.from_iterable(xla.device_put(x, device) for device in devices))
else:
return list(it.chain.from_iterable(xla.device_put(val, device) for val, device in safe_zip(x, devices)))
# TODO(skye): make this a namedtuple. This may allow us to use ShardingSpecs in
# performance-sensitive code, e.g. shard_args.
class ShardingSpec:
@ -237,9 +246,9 @@ def shard_args(devices: Sequence[xb.xla_client.Device],
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {}
shard_arg_handlers[core.Unit] = \
lambda x, devices, _: [xla.device_put(core.unit, d) for d in devices]
lambda x, devices, _: device_put(core.unit, devices, replicate=True)
def _shard_array(x, devices, indices):
return [xla.device_put(x[i], d) for (i, d) in zip(indices, devices)]
return device_put([x[i] for i in indices], devices)
for _t in array_types:
shard_arg_handlers[_t] = _shard_array
@ -247,7 +256,7 @@ def _shard_device_array(x, devices, indices):
start_indices, limit_indices, removed_dims = map(tuple, unzip3(
_as_slice_indices(x, idx) for idx in indices))
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
return [xla.device_put(s, d) for s, d in zip(shards, devices)]
return device_put(shards, devices)
shard_arg_handlers[xla.DeviceArray] = _shard_device_array
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
@ -862,7 +871,7 @@ def replicate(val, axis_size, nrep, devices=None, backend=None):
replicated_aval = ShapedArray((axis_size,) + aval.shape, aval.dtype)
# TODO(skye): figure out how partitioning should work here
sharding_spec = _pmap_sharding_spec(nrep, axis_size, 1, None, aval, True)
device_buffers = [xla.device_put(val, d) for d in devices]
device_buffers = device_put(val, devices, replicate=True)
return ShardedDeviceArray(replicated_aval, sharding_spec, device_buffers)
def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, mapped):

View File

@ -180,7 +180,7 @@ def _xla_sharded_args(c, avals, in_parts):
xla_args = []
for i, (sharding, aval) in enumerate(safe_zip(in_parts, avals)):
param = xb.with_sharding(c, sharding, xb.parameter, c, i,
xla.aval_to_xla_shape(aval))
*xla.aval_to_xla_shapes(aval))
xla_args.append(param)
return xla_args

View File

@ -75,19 +75,19 @@ _scalar_types = dtypes.python_scalar_dtypes.keys()
# unit representation
def _make_unit(c): return xb.constant(c, np.zeros((), dtype=np.dtype('bool')))
def _make_abstract_unit(_): return xc.Shape.array_shape(np.dtype('bool'), ())
def _make_abstract_unit(_): return (xc.Shape.array_shape(np.dtype('bool'), ()),)
def _device_put_unit(_, device):
backend = xb.get_device_backend(device)
return backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')),
device)
return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')),
device),)
def _make_array_shape(a):
return xc.Shape.array_shape(a.dtype, a.shape)
return (xc.Shape.array_shape(a.dtype, a.shape),)
### handlers
xb.register_constant_handler(core.Unit, lambda c, *_: _make_unit(c))
def aval_to_xla_shape(aval):
def aval_to_xla_shapes(aval):
try:
return xla_shape_handlers[type(aval)](aval)
except KeyError as err:
@ -99,7 +99,7 @@ xla_shape_handlers: Dict[Type[core.AbstractValue], Callable] = {
ConcreteArray: _make_array_shape,
}
def aval_to_result_handler(device: Optional[Device], aval: core.ShapedArray):
def aval_to_result_handler(device: Optional[Device], aval: core.AbstractValue) -> Callable:
try:
return xla_result_handlers[type(aval)](device, aval)
except KeyError as err:
@ -114,7 +114,7 @@ xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Callable]] = {
ConcreteArray: array_result_handler,
}
def device_put(x, device: Optional[Device] = None):
def device_put(x, device: Optional[Device] = None) -> Tuple[Any]:
x = canonicalize_dtype(x)
try:
return device_put_handlers[type(x)](x, device)
@ -123,12 +123,14 @@ def device_put(x, device: Optional[Device] = None):
def _device_put_array(x, device: Optional[Device]):
backend = xb.get_device_backend(device)
return backend.buffer_from_pyval(x, device)
return (backend.buffer_from_pyval(x, device),)
def _device_put_scalar(x, device):
return _device_put_array(dtypes.coerce_to_array(x), device)
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Any]] = {core.Unit: _device_put_unit}
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = {
core.Unit: _device_put_unit
}
device_put_handlers.update((t, _device_put_array) for t in array_types)
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
@ -224,6 +226,15 @@ def apply_primitive(prim, *args, **params):
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
return compiled_fun(*args)
def _partition_outputs(avals, outs):
nouts = [aval._num_buffers for aval in avals]
if not core.skip_checks:
assert sum(nouts) == len(outs), f"Internal error: sum(nouts)={sum(nouts)} should equal len(outs)={len(outs)}."
outs = iter(outs)
return [[next(outs) for _ in range(nout)] for nout in nouts]
@cache()
def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue,
Optional[Device]], **params):
@ -242,7 +253,8 @@ def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue,
handle_result = aval_to_result_handler(device, aval_out)
else:
handlers = map(partial(aval_to_result_handler, device), aval_out)
handle_result = lambda xs: tuple(h(x) for h, x in zip(handlers, xs))
handle_result = lambda *bufs:\
tuple(handler(*bs) for handler, bs in zip(handlers, _partition_outputs(aval_out, bufs)))
tuple_args = len(avals) > 100
if prim in initial_style_translations:
nreps = initial_style_primitive_replicas(params)
@ -326,20 +338,18 @@ def backend_compile(backend, built_c, options):
def _execute_compiled_primitive(prim, compiled, result_handler, *args):
device, = compiled.local_devices()
input_bufs = [device_put(x, device) for x in args if x is not token]
input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
out_bufs = compiled.execute(input_bufs)
if FLAGS.jax_debug_nans:
check_nans(prim, out_bufs)
return result_handler(out_bufs if prim.multiple_results else out_bufs[0])
if FLAGS.jax_debug_nans: check_nans(prim, out_bufs)
return result_handler(*out_bufs)
def _execute_replicated_primitive(prim, compiled, result_handler, *args):
input_bufs = [
[device_put(x, device) for x in args if x is not token]
list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
for device in compiled.local_devices()]
out_buf = compiled.execute_on_local_devices(input_bufs)[0]
if not prim.multiple_results:
out_buf, = out_buf
return result_handler(out_buf)
out_bufs = compiled.execute_on_local_devices(input_bufs)[0]
return result_handler(*out_bufs)
def check_nans(prim, bufs):
for buf in bufs:
@ -368,6 +378,12 @@ def jaxpr_literals(jaxpr):
yield from jaxpr_literals(subjaxpr)
def _flatmap(func: Callable, vars: Sequence):
return list(it.chain.from_iterable(map(func, vars)))
def _partitionmap(func: Callable, vars: Sequence, nodes: Sequence):
return map(func, vars, _partition_outputs([v.aval for v in vars], nodes))
def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
if backend not in ('cpu', 'gpu', 'tpu'):
platform = xb.get_backend(backend).platform # canonicalize
@ -376,7 +392,7 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
def read(v):
if type(v) is Literal:
return xb.constant(c, canonicalize_dtype(v.val))
return [xb.constant(c, canonicalize_dtype(v.val))]
else:
return env[v]
@ -391,9 +407,9 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
env[v] = node
env = {}
write(core.unitvar, _make_unit(c))
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
_partitionmap(write, [core.unitvar], [_make_unit(c)])
_partitionmap(write, jaxpr.constvars, consts)
_partitionmap(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
frame = source_info_util.user_frame(eqn.source_info)
c.set_op_metadata(xc.OpMetadata(
@ -402,7 +418,7 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
eqn.primitive.name, eqn.params)),
source_file=frame.file_name if frame else None,
source_line=frame.line_num if frame else None))
in_nodes = map(read, eqn.invars)
in_nodes = _flatmap(read, eqn.invars)
if eqn.primitive in backend_specific_translations[platform]:
rule = backend_specific_translations[platform][eqn.primitive]
ans = rule(c, *in_nodes, **eqn.params)
@ -427,10 +443,14 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
assert isinstance(ans, xe.XlaOp)
c.get_shape(ans) # force xla to do shape error checking
out_nodes = xla_destructure(c, ans) if eqn.primitive.multiple_results else [ans]
if eqn.primitive.multiple_results or any(v.aval._num_buffers > 1 for v in eqn.outvars):
out_nodes = xla_destructure(c, ans)
else:
out_nodes = [ans]
c.clear_op_metadata()
map(write, eqn.outvars, out_nodes)
return map(read, jaxpr.outvars)
_partitionmap(write, eqn.outvars, out_nodes)
return _flatmap(read, jaxpr.outvars)
def xla_destructure(c, ans):
num_elements = len(c.get_shape(ans).tuple_shapes())
@ -606,15 +626,16 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
device = _xla_callable_device(nreps, backend, device, arg_devices)
backend = device.platform if device else backend
if config.omnistaging_enabled:
result_handlers = tuple(aval_to_result_handler(device, a) for a in out_avals)
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
else:
result_handlers = tuple(map(partial(_pval_to_result_handler, device), pvals)) # type: ignore
out_avals = [pval.get_aval() for pval in pvals]
result_handlers = map(partial(_pval_to_result_handler, device), pvals) # type: ignore
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to force their (potentially lazy) arguments.
if not jaxpr.eqns:
return partial(_execute_trivial, jaxpr, device, consts, result_handlers)
return partial(_execute_trivial, jaxpr, device, consts, out_avals, result_handlers)
if not _on_exit:
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
@ -664,9 +685,9 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
options.parameter_is_tupled_arguments = tuple_args
compiled = backend_compile(backend, built, options)
if nreps == 1:
return partial(_execute_compiled, compiled, result_handlers)
return partial(_execute_compiled, compiled, out_avals, result_handlers)
else:
return partial(_execute_replicated, compiled, result_handlers)
return partial(_execute_replicated, compiled, out_avals, result_handlers)
def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):
"""Configures input/output "must" aliasing based on `donated_args`."""
@ -728,17 +749,18 @@ def _xla_callable_args(
else:
parts = [_replicated_param if part is None else part
for part in partitions]
return [_xla_param(c, i, aval_to_xla_shape(a), r, p)
counts = it.count()
return [_xla_param(c, next(counts), xla_shape, r, p)
if a is not abstract_token else xops.CreateToken(c)
for i, (a, r, p)
in enumerate(safe_zip(avals, replicated, parts))]
for (a, r, p) in safe_zip(avals, replicated, parts)
for xla_shape in aval_to_xla_shapes(a)]
else:
if replicated is not None:
replicated = [r for a, r in zip(avals, replicated)
if a is not abstract_token]
tuple_parts = tuple(partitions) if partitions is not None else None
tuple_shape = xc.Shape.tuple_shape(
[aval_to_xla_shape(a) for a in avals if a is not abstract_token])
[shape for a in avals for shape in aval_to_xla_shapes(a) if a is not abstract_token])
tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts)
xla_inputs = iter(xla_destructure(c, tuple_param))
xla_args = [next(xla_inputs) if a is not abstract_token else
@ -756,29 +778,29 @@ def _xla_param(builder, param_num, xla_shape, replicated, partitions):
else:
return xb.with_sharding(builder, partitions, make_param)
def _execute_compiled(compiled: XlaExecutable, handlers, *args):
def _execute_compiled(compiled: XlaExecutable, avals, handlers, *args):
device, = compiled.local_devices()
input_bufs = [device_put(x, device) for x in args if x is not token]
input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
out_bufs = compiled.execute(input_bufs)
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]
def _execute_replicated(compiled: XlaExecutable, handlers, *args):
def _execute_replicated(compiled: XlaExecutable, avals, handlers, *args):
input_bufs = [
[device_put(x, device) for x in args if x is not token]
list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
for device in compiled.local_devices()]
out_bufs = compiled.execute_on_local_devices(input_bufs)[0]
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]
def _execute_trivial(jaxpr, device: Optional[Device], consts, handlers, *args):
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, *args):
env = {core.unitvar: core.unit}
map(env.setdefault, jaxpr.invars, args)
map(env.setdefault, jaxpr.constvars, consts)
outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v]
for v in jaxpr.outvars]
return [_copy_device_array_to_device(x, device) if type(x) is DeviceArray
else h(device_put(x, device)) for h, x in zip(handlers, outs)]
else h(*device_put(x, device)) for h, x in zip(handlers, outs)]
xla_call_p = core.CallPrimitive('xla_call')
xla_call = xla_call_p.bind
@ -924,7 +946,7 @@ token = Token()
pytype_aval_mappings[Token] = lambda _: abstract_token
core.pytype_aval_mappings[Token] = lambda _: abstract_token
xla_shape_handlers[AbstractToken] = lambda _: xc.Shape.token_shape()
xla_shape_handlers[AbstractToken] = lambda _: (xc.Shape.token_shape(),)
xla_result_handlers[AbstractToken] = lambda _, __: lambda _: token
canonicalize_dtype_handlers[Token] = identity
@ -949,7 +971,8 @@ class DeviceArray:
_HAS_DYNAMIC_ATTRIBUTES = True
def __init__(self, aval: core.ShapedArray, device: Optional[Device],
lazy_expr: lazy.LazyExpr, device_buffer: PyLocalBuffer):
lazy_expr: lazy.LazyExpr,
device_buffer: PyLocalBuffer):
self.aval = aval
self.device_buffer = device_buffer
self._device = device
@ -1137,7 +1160,7 @@ xb.register_constant_handler(DeviceArray, _device_array_constant_handler)
def _device_put_device_array(x: DeviceArray, device: Optional[Device]):
x = _copy_device_array_to_device(x, device)
return _force(x).device_buffer
return (_force(x).device_buffer,)
device_put_handlers[DeviceArray] = _device_put_device_array
def _copy_device_array_to_device(x: DeviceArray, device: Optional[xc.Device]) -> DeviceArray:
@ -1219,8 +1242,7 @@ def _device_put_impl(x, device: Optional[Device] = None):
except TypeError as err:
raise TypeError(
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
handler = aval_to_result_handler(device, a) # type: ignore[arg-type]
return handler(device_put(x, device))
return aval_to_result_handler(device, a)(*device_put(x, device))
device_put_p = core.Primitive('device_put')
device_put_p.def_impl(_device_put_impl)

View File

@ -1402,7 +1402,7 @@ def _device_put_raw(x):
return x
else:
aval = raise_to_shaped(core.get_aval(x))
return xla.array_result_handler(None, aval)(xla.device_put(x))
return xla.array_result_handler(None, aval)(*xla.device_put(x))
def iota(dtype: DType, size: int) -> Array:
"""Wraps XLA's `Iota
@ -5745,8 +5745,8 @@ def _infeed_abstract_eval(token, *, shapes, partitions):
def _infeed_translation_rule(c, token, *, shapes, partitions):
shape = tuple(xla.aval_to_xla_shape(x).with_major_to_minor_layout_if_absent()
for x in shapes)
shape = tuple(shape.with_major_to_minor_layout_if_absent()
for x in shapes for shape in xla.aval_to_xla_shapes(x))
build_infeed = partial(xops.InfeedWithToken, token,
xla_client.Shape.tuple_shape(shape))
if partitions:

View File

@ -309,6 +309,12 @@ class CoreTest(jtu.JaxTestCase):
syms = {c: d, a: b}
assert 'bd' == ''.join(map(str, tree_leaves(syms)))
def test_device_put_unit(self):
def f(x, y):
return x, 2 * y
args_maker = lambda: (core.unit, 1)
self._CompileAndCheck(f, args_maker)
class JaxprTypeChecks(jtu.JaxTestCase):

282
tests/custom_object_test.py Normal file
View File

@ -0,0 +1,282 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest, parameterized
import numpy as np
from jax import test_util as jtu
import jax.numpy as jnp
from jax import core, jit, lax, lazy, make_jaxpr
from jax.interpreters import xla
from jax.lib import xla_client
xops = xla_client.ops
from jax.config import config
config.parse_flags_with_absl()
# TODO(jakevdp): use a setup/teardown method to populate and unpopulate all the
# dictionaries associated with the following objects.
# Define a sparse array data structure. The important feature here is that
# it is a jaxpr object that is backed by two device buffers.
class SparseArray:
"""Simple sparse COO array data structure."""
def __init__(self, aval, data, indices):
self.aval = aval
self.shape = aval.shape
self.data = data
self.indices = indices
@property
def index_dtype(self):
return self.indices.dtype
@property
def dtype(self):
return self.data.dtype
@property
def nnz(self):
return self.data.shape[0]
def __repr__(self):
return repr(list((tuple(ind), d) for ind, d in zip(self.indices, self.data)))
class AbstractSparseArray(core.ShapedArray):
__slots__ = ['index_dtype', 'nnz', 'data_aval', 'indices_aval']
_num_buffers = 2
def __init__(self, shape, dtype, index_dtype, nnz):
super(AbstractSparseArray, self).__init__(shape, dtype)
self.index_dtype = index_dtype
self.nnz = nnz
self.data_aval = core.ShapedArray((nnz,), dtype)
self.indices_aval = core.ShapedArray((nnz, len(shape)), index_dtype)
@core.aval_property
def data(self):
return sp_data_p.bind(self)
@core.aval_property
def indices(self):
return sp_indices_p.bind(self)
def abstract_sparse_array(arr):
return AbstractSparseArray(arr.shape, arr.dtype, arr.index_dtype, arr.nnz)
def sparse_array_result_handler(device, aval):
def build_sparse_array(data_buf, indices_buf):
data = xla.DeviceArray(aval.data_aval, device, lazy.array(aval.data_aval.shape), data_buf)
indices = xla.DeviceArray(aval.indices_aval, device, lazy.array(aval.indices_aval.shape), indices_buf)
return SparseArray(aval, data, indices)
return build_sparse_array
def sparse_array_shape_handler(a):
return (
xla.xc.Shape.array_shape(a.data_aval.dtype, a.data_aval.shape),
xla.xc.Shape.array_shape(a.indices_aval.dtype, a.indices_aval.shape),
)
def sparse_array_device_put_handler(a, device):
return (
xla.xb.get_device_backend(device).buffer_from_pyval(a.data, device),
xla.xb.get_device_backend(device).buffer_from_pyval(a.indices, device)
)
core.pytype_aval_mappings[SparseArray] = abstract_sparse_array
core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval
xla.pytype_aval_mappings[SparseArray] = abstract_sparse_array
xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
xla.device_put_handlers[SparseArray] = sparse_array_device_put_handler
xla.xla_result_handlers[AbstractSparseArray] = sparse_array_result_handler
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
sp_indices_p = core.Primitive('sp_indices')
@sp_indices_p.def_impl
def _sp_indices_impl(mat):
return mat.indices
@sp_indices_p.def_abstract_eval
def _sp_indices_abstract_eval(mat):
return mat.indices_aval
def _sp_indices_translation_rule(c, data, indices):
return indices
xla.translations[sp_indices_p] = _sp_indices_translation_rule
sp_data_p = core.Primitive('sp_data')
@sp_data_p.def_impl
def _sp_data_impl(mat):
return mat.data
@sp_data_p.def_abstract_eval
def _sp_data_abstract_eval(mat):
return mat.data_aval
def _sp_data_translation_rule(c, data, indices):
return data
xla.translations[sp_data_p] = _sp_data_translation_rule
def identity(x):
return identity_p.bind(x)
identity_p = core.Primitive('identity')
@identity_p.def_impl
def _identity_impl(mat):
return SparseArray(mat.aval, mat.data, mat.indices)
@identity_p.def_abstract_eval
def _identity_abstract_eval(mat):
return mat
def _identity_translation_rule(c, data, indices):
return xops.Tuple(c, (data, indices))
xla.translations[identity_p] = _identity_translation_rule
def make_sparse_array(rng, shape, dtype, nnz=0.2):
mat = rng(shape, dtype)
size = int(np.prod(shape))
if 0 < nnz < 1:
nnz = nnz * size
nnz = int(nnz)
if nnz == 0:
mat = np.zeros_like(mat)
elif nnz < size:
# TODO(jakevdp): do we care about duplicates?
cutoff = np.sort(mat.ravel())[nnz]
mat[mat >= cutoff] = 0
nz = (mat != 0)
data = jnp.array(mat[nz])
indices = jnp.array(np.where(nz)).T
aval = AbstractSparseArray(shape, data.dtype, indices.dtype, len(indices))
return SparseArray(aval, data, indices)
def matvec(mat, v):
v = jnp.asarray(v)
assert v.ndim == 1
assert len(mat.shape) == 2
assert v.shape[0] == mat.shape[1]
rows = mat.indices[:, 0]
cols = mat.indices[:, 1]
dv = mat.data * v[cols]
return jnp.zeros(mat.shape[0], dtype=dv.dtype).at[rows].add(dv)
class Empty:
def __init__(self, aval):
self.aval = aval
class AbstractEmpty(core.AbstractValue):
_num_buffers = 0
def join(self, other):
assert isinstance(other, self.__class__), other
return self
def __hash__(self):
return hash(())
def __eq__(self, other):
return isinstance(other, AbstractEmpty)
def abstract_empty(e):
return AbstractEmpty()
core.pytype_aval_mappings[Empty] = abstract_empty
core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval
xla.pytype_aval_mappings[Empty] = abstract_empty
xla.canonicalize_dtype_handlers[Empty] = lambda x: x
xla.device_put_handlers[Empty] = lambda _, __: ()
xla.xla_result_handlers[AbstractEmpty] = lambda _, __: lambda: Empty(AbstractEmpty())
xla.xla_shape_handlers[AbstractEmpty] = lambda _: ()
class CustomObjectTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_compile={}_primitive={}".format(compile, primitive),
"compile": compile, "primitive": primitive}
for primitive in [True, False]
for compile in [True, False]))
def testSparseIdentity(self, compile, primitive):
f = identity if primitive else (lambda x: x)
f = jit(f) if compile else f
rng = jtu.rand_default(self.rng())
M = make_sparse_array(rng, (10,), jnp.float32)
M2 = f(M)
jaxpr = make_jaxpr(f)(M).jaxpr
core.check_jaxpr(jaxpr)
self.assertEqual(M.dtype, M2.dtype)
self.assertEqual(M.index_dtype, M2.index_dtype)
self.assertAllClose(M.data, M2.data)
self.assertAllClose(M.indices, M2.indices)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_compile={}_primitive={}".format(compile, primitive),
"compile": compile, "primitive": primitive}
for primitive in [True, False]
for compile in [True, False]))
def testSparseLaxLoop(self, compile, primitive):
rng = jtu.rand_default(self.rng())
f = identity if primitive else (lambda x: x)
f = jit(f) if compile else f
body_fun = lambda _, A: f(A)
M = make_sparse_array(rng, (10,), jnp.float32)
lax.fori_loop(0, 10, body_fun, M)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_attr={}".format(attr), "attr": attr}
for attr in ["data", "indices"]))
def testSparseAttrAccess(self, attr):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [make_sparse_array(rng, (10,), jnp.float32)]
f = lambda x: getattr(x, attr)
self._CompileAndCheck(f, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(3, 3), (2, 6), (6, 2)]
for dtype in jtu.dtypes.floating))
def testSparseMatvec(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [make_sparse_array(rng, shape, dtype), rng(shape[-1:], dtype)]
self._CompileAndCheck(matvec, args_maker)
def testLowerToNothing(self):
empty = Empty(AbstractEmpty())
jaxpr = make_jaxpr(jit(lambda e: e))(empty).jaxpr
core.check_jaxpr(jaxpr)
# cannot return a unit, because CompileAndCheck assumes array output.
testfunc = lambda e: None
args_maker = lambda: [empty]
self._CompileAndCheck(testfunc, args_maker)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1156,7 +1156,7 @@ class PmapTest(jtu.JaxTestCase):
# subsequent pmap
shard_shape = (3,2)
shard = jnp.arange(prod(shard_shape)).reshape(shard_shape)
bufs = [xla.device_put(shard, d) for d in xla_bridge.devices()[:4]]
bufs = pxla.device_put(shard, xla_bridge.devices()[:4], replicate=True)
aval = ShapedArray((6,4), shard.dtype)
sharding_spec = pxla.ShardingSpec(
shards_per_axis=(2, 2),