mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 01:16:05 +00:00
Merge branch 'master' into use-raise-from
This commit is contained in:
commit
d707ae17e5
10
docs/faq.rst
10
docs/faq.rst
@ -67,7 +67,10 @@ and 2) whether it is **committed** to the device or not (the data is sometimes
|
||||
referred to as being *sticky* to the device).
|
||||
|
||||
By default, JAX arrays are placed uncommitted on the default device
|
||||
(``jax.devices()[0]``).
|
||||
(``jax.devices()[0]``), which is the first GPU by default. If no GPU is
|
||||
present, ``jax.devices()[0]`` is the first CPU. The default device can
|
||||
be set to "cpu" or "gpu" manually by setting the environment variable
|
||||
``JAX_PLATFORM_NAME`` or the absl flag ``--jax_platform_name``.
|
||||
|
||||
>>> from jax import numpy as jnp
|
||||
>>> print(jnp.ones(3).device_buffer.device()) # doctest: +SKIP
|
||||
@ -97,6 +100,11 @@ device.
|
||||
Jitted functions behave like any other primitive operations—they will follow the
|
||||
data and will show errors if invoked on data committed on more than one device.
|
||||
|
||||
``jnp.device_put(jnp.zeros(...), jax.devices()[1])`` or similar will actually create the
|
||||
array of zeros on ``jax.devices()[1]``, instead of creating the array on the default
|
||||
device then moving it. This is thanks to some laziness in array creation, which holds
|
||||
for all the constant creation operations (``ones``, ``full``, ``eye``, etc).
|
||||
|
||||
(As of April 2020, :func:`jax.jit` has a `device` parameter that affects the device
|
||||
placement. That parameter is experimental, is likely to be removed or changed,
|
||||
and its use is not recommended.)
|
||||
|
@ -2093,7 +2093,7 @@ def device_put_sharded(x: Sequence[Any], devices: Sequence[xc.Device]):
|
||||
f"abstract values not compatible: {avals}"
|
||||
x_aval = core.raise_to_shaped(avals[0])
|
||||
aval = ShapedArray((len(devices),) + x_aval.shape, x_aval.dtype)
|
||||
buffers = [xla.device_put(x, d) for x, d in zip(xs, devices)]
|
||||
buffers = list(it.chain.from_iterable(xla.device_put(x, d) for x, d in zip(xs, devices)))
|
||||
return pxla.ShardedDeviceArray(aval, buffers)
|
||||
return tree_multimap(_device_put_sharded, *x)
|
||||
|
||||
|
@ -747,6 +747,7 @@ def find_top_trace(xs) -> Trace:
|
||||
|
||||
class AbstractValue:
|
||||
__slots__: List[str] = []
|
||||
_num_buffers: int = 1 # number of buffers used to represent the value.
|
||||
|
||||
def at_least_vspace(self):
|
||||
assert False
|
||||
@ -769,6 +770,8 @@ class Bot(AbstractValue): pass
|
||||
bot = Bot()
|
||||
|
||||
class AbstractUnit(AbstractValue):
|
||||
# TODO(jakevdp): make it possible to set zero buffers
|
||||
# _num_buffers = 0
|
||||
def join(self, other):
|
||||
if not skip_checks:
|
||||
assert other is abstract_unit, other
|
||||
|
@ -66,6 +66,15 @@ unsafe_map, map = map, safe_map
|
||||
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||
|
||||
|
||||
def device_put(x, devices: Sequence[xb.xla_client.Device], replicate: bool=False) -> List[xb.xla_client._xla.PyLocalBuffer]:
|
||||
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
|
||||
if replicate:
|
||||
return list(it.chain.from_iterable(xla.device_put(x, device) for device in devices))
|
||||
else:
|
||||
return list(it.chain.from_iterable(xla.device_put(val, device) for val, device in safe_zip(x, devices)))
|
||||
|
||||
|
||||
|
||||
# TODO(skye): make this a namedtuple. This may allow us to use ShardingSpecs in
|
||||
# performance-sensitive code, e.g. shard_args.
|
||||
class ShardingSpec:
|
||||
@ -237,9 +246,9 @@ def shard_args(devices: Sequence[xb.xla_client.Device],
|
||||
|
||||
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {}
|
||||
shard_arg_handlers[core.Unit] = \
|
||||
lambda x, devices, _: [xla.device_put(core.unit, d) for d in devices]
|
||||
lambda x, devices, _: device_put(core.unit, devices, replicate=True)
|
||||
def _shard_array(x, devices, indices):
|
||||
return [xla.device_put(x[i], d) for (i, d) in zip(indices, devices)]
|
||||
return device_put([x[i] for i in indices], devices)
|
||||
for _t in array_types:
|
||||
shard_arg_handlers[_t] = _shard_array
|
||||
|
||||
@ -247,7 +256,7 @@ def _shard_device_array(x, devices, indices):
|
||||
start_indices, limit_indices, removed_dims = map(tuple, unzip3(
|
||||
_as_slice_indices(x, idx) for idx in indices))
|
||||
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
|
||||
return [xla.device_put(s, d) for s, d in zip(shards, devices)]
|
||||
return device_put(shards, devices)
|
||||
shard_arg_handlers[xla.DeviceArray] = _shard_device_array
|
||||
|
||||
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
|
||||
@ -862,7 +871,7 @@ def replicate(val, axis_size, nrep, devices=None, backend=None):
|
||||
replicated_aval = ShapedArray((axis_size,) + aval.shape, aval.dtype)
|
||||
# TODO(skye): figure out how partitioning should work here
|
||||
sharding_spec = _pmap_sharding_spec(nrep, axis_size, 1, None, aval, True)
|
||||
device_buffers = [xla.device_put(val, d) for d in devices]
|
||||
device_buffers = device_put(val, devices, replicate=True)
|
||||
return ShardedDeviceArray(replicated_aval, sharding_spec, device_buffers)
|
||||
|
||||
def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, mapped):
|
||||
|
@ -180,7 +180,7 @@ def _xla_sharded_args(c, avals, in_parts):
|
||||
xla_args = []
|
||||
for i, (sharding, aval) in enumerate(safe_zip(in_parts, avals)):
|
||||
param = xb.with_sharding(c, sharding, xb.parameter, c, i,
|
||||
xla.aval_to_xla_shape(aval))
|
||||
*xla.aval_to_xla_shapes(aval))
|
||||
xla_args.append(param)
|
||||
return xla_args
|
||||
|
||||
|
@ -75,19 +75,19 @@ _scalar_types = dtypes.python_scalar_dtypes.keys()
|
||||
|
||||
# unit representation
|
||||
def _make_unit(c): return xb.constant(c, np.zeros((), dtype=np.dtype('bool')))
|
||||
def _make_abstract_unit(_): return xc.Shape.array_shape(np.dtype('bool'), ())
|
||||
def _make_abstract_unit(_): return (xc.Shape.array_shape(np.dtype('bool'), ()),)
|
||||
def _device_put_unit(_, device):
|
||||
backend = xb.get_device_backend(device)
|
||||
return backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')),
|
||||
device)
|
||||
return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')),
|
||||
device),)
|
||||
def _make_array_shape(a):
|
||||
return xc.Shape.array_shape(a.dtype, a.shape)
|
||||
return (xc.Shape.array_shape(a.dtype, a.shape),)
|
||||
|
||||
### handlers
|
||||
|
||||
xb.register_constant_handler(core.Unit, lambda c, *_: _make_unit(c))
|
||||
|
||||
def aval_to_xla_shape(aval):
|
||||
def aval_to_xla_shapes(aval):
|
||||
try:
|
||||
return xla_shape_handlers[type(aval)](aval)
|
||||
except KeyError as err:
|
||||
@ -99,7 +99,7 @@ xla_shape_handlers: Dict[Type[core.AbstractValue], Callable] = {
|
||||
ConcreteArray: _make_array_shape,
|
||||
}
|
||||
|
||||
def aval_to_result_handler(device: Optional[Device], aval: core.ShapedArray):
|
||||
def aval_to_result_handler(device: Optional[Device], aval: core.AbstractValue) -> Callable:
|
||||
try:
|
||||
return xla_result_handlers[type(aval)](device, aval)
|
||||
except KeyError as err:
|
||||
@ -114,7 +114,7 @@ xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Callable]] = {
|
||||
ConcreteArray: array_result_handler,
|
||||
}
|
||||
|
||||
def device_put(x, device: Optional[Device] = None):
|
||||
def device_put(x, device: Optional[Device] = None) -> Tuple[Any]:
|
||||
x = canonicalize_dtype(x)
|
||||
try:
|
||||
return device_put_handlers[type(x)](x, device)
|
||||
@ -123,12 +123,14 @@ def device_put(x, device: Optional[Device] = None):
|
||||
|
||||
def _device_put_array(x, device: Optional[Device]):
|
||||
backend = xb.get_device_backend(device)
|
||||
return backend.buffer_from_pyval(x, device)
|
||||
return (backend.buffer_from_pyval(x, device),)
|
||||
|
||||
def _device_put_scalar(x, device):
|
||||
return _device_put_array(dtypes.coerce_to_array(x), device)
|
||||
|
||||
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Any]] = {core.Unit: _device_put_unit}
|
||||
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = {
|
||||
core.Unit: _device_put_unit
|
||||
}
|
||||
device_put_handlers.update((t, _device_put_array) for t in array_types)
|
||||
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
|
||||
|
||||
@ -224,6 +226,15 @@ def apply_primitive(prim, *args, **params):
|
||||
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
|
||||
return compiled_fun(*args)
|
||||
|
||||
|
||||
def _partition_outputs(avals, outs):
|
||||
nouts = [aval._num_buffers for aval in avals]
|
||||
if not core.skip_checks:
|
||||
assert sum(nouts) == len(outs), f"Internal error: sum(nouts)={sum(nouts)} should equal len(outs)={len(outs)}."
|
||||
outs = iter(outs)
|
||||
return [[next(outs) for _ in range(nout)] for nout in nouts]
|
||||
|
||||
|
||||
@cache()
|
||||
def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue,
|
||||
Optional[Device]], **params):
|
||||
@ -242,7 +253,8 @@ def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue,
|
||||
handle_result = aval_to_result_handler(device, aval_out)
|
||||
else:
|
||||
handlers = map(partial(aval_to_result_handler, device), aval_out)
|
||||
handle_result = lambda xs: tuple(h(x) for h, x in zip(handlers, xs))
|
||||
handle_result = lambda *bufs:\
|
||||
tuple(handler(*bs) for handler, bs in zip(handlers, _partition_outputs(aval_out, bufs)))
|
||||
tuple_args = len(avals) > 100
|
||||
if prim in initial_style_translations:
|
||||
nreps = initial_style_primitive_replicas(params)
|
||||
@ -326,20 +338,18 @@ def backend_compile(backend, built_c, options):
|
||||
|
||||
def _execute_compiled_primitive(prim, compiled, result_handler, *args):
|
||||
device, = compiled.local_devices()
|
||||
input_bufs = [device_put(x, device) for x in args if x is not token]
|
||||
input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
|
||||
out_bufs = compiled.execute(input_bufs)
|
||||
if FLAGS.jax_debug_nans:
|
||||
check_nans(prim, out_bufs)
|
||||
return result_handler(out_bufs if prim.multiple_results else out_bufs[0])
|
||||
if FLAGS.jax_debug_nans: check_nans(prim, out_bufs)
|
||||
return result_handler(*out_bufs)
|
||||
|
||||
def _execute_replicated_primitive(prim, compiled, result_handler, *args):
|
||||
input_bufs = [
|
||||
[device_put(x, device) for x in args if x is not token]
|
||||
list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
|
||||
for device in compiled.local_devices()]
|
||||
out_buf = compiled.execute_on_local_devices(input_bufs)[0]
|
||||
if not prim.multiple_results:
|
||||
out_buf, = out_buf
|
||||
return result_handler(out_buf)
|
||||
out_bufs = compiled.execute_on_local_devices(input_bufs)[0]
|
||||
return result_handler(*out_bufs)
|
||||
|
||||
|
||||
def check_nans(prim, bufs):
|
||||
for buf in bufs:
|
||||
@ -368,6 +378,12 @@ def jaxpr_literals(jaxpr):
|
||||
yield from jaxpr_literals(subjaxpr)
|
||||
|
||||
|
||||
def _flatmap(func: Callable, vars: Sequence):
|
||||
return list(it.chain.from_iterable(map(func, vars)))
|
||||
|
||||
def _partitionmap(func: Callable, vars: Sequence, nodes: Sequence):
|
||||
return map(func, vars, _partition_outputs([v.aval for v in vars], nodes))
|
||||
|
||||
def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
|
||||
if backend not in ('cpu', 'gpu', 'tpu'):
|
||||
platform = xb.get_backend(backend).platform # canonicalize
|
||||
@ -376,7 +392,7 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
|
||||
|
||||
def read(v):
|
||||
if type(v) is Literal:
|
||||
return xb.constant(c, canonicalize_dtype(v.val))
|
||||
return [xb.constant(c, canonicalize_dtype(v.val))]
|
||||
else:
|
||||
return env[v]
|
||||
|
||||
@ -391,9 +407,9 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
|
||||
env[v] = node
|
||||
|
||||
env = {}
|
||||
write(core.unitvar, _make_unit(c))
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
_partitionmap(write, [core.unitvar], [_make_unit(c)])
|
||||
_partitionmap(write, jaxpr.constvars, consts)
|
||||
_partitionmap(write, jaxpr.invars, args)
|
||||
for eqn in jaxpr.eqns:
|
||||
frame = source_info_util.user_frame(eqn.source_info)
|
||||
c.set_op_metadata(xc.OpMetadata(
|
||||
@ -402,7 +418,7 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
|
||||
eqn.primitive.name, eqn.params)),
|
||||
source_file=frame.file_name if frame else None,
|
||||
source_line=frame.line_num if frame else None))
|
||||
in_nodes = map(read, eqn.invars)
|
||||
in_nodes = _flatmap(read, eqn.invars)
|
||||
if eqn.primitive in backend_specific_translations[platform]:
|
||||
rule = backend_specific_translations[platform][eqn.primitive]
|
||||
ans = rule(c, *in_nodes, **eqn.params)
|
||||
@ -427,10 +443,14 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
|
||||
|
||||
assert isinstance(ans, xe.XlaOp)
|
||||
c.get_shape(ans) # force xla to do shape error checking
|
||||
out_nodes = xla_destructure(c, ans) if eqn.primitive.multiple_results else [ans]
|
||||
if eqn.primitive.multiple_results or any(v.aval._num_buffers > 1 for v in eqn.outvars):
|
||||
out_nodes = xla_destructure(c, ans)
|
||||
else:
|
||||
out_nodes = [ans]
|
||||
c.clear_op_metadata()
|
||||
map(write, eqn.outvars, out_nodes)
|
||||
return map(read, jaxpr.outvars)
|
||||
_partitionmap(write, eqn.outvars, out_nodes)
|
||||
return _flatmap(read, jaxpr.outvars)
|
||||
|
||||
|
||||
def xla_destructure(c, ans):
|
||||
num_elements = len(c.get_shape(ans).tuple_shapes())
|
||||
@ -606,15 +626,16 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
device = _xla_callable_device(nreps, backend, device, arg_devices)
|
||||
backend = device.platform if device else backend
|
||||
if config.omnistaging_enabled:
|
||||
result_handlers = tuple(aval_to_result_handler(device, a) for a in out_avals)
|
||||
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
||||
else:
|
||||
result_handlers = tuple(map(partial(_pval_to_result_handler, device), pvals)) # type: ignore
|
||||
out_avals = [pval.get_aval() for pval in pvals]
|
||||
result_handlers = map(partial(_pval_to_result_handler, device), pvals) # type: ignore
|
||||
|
||||
# Computations that only produce constants and/or only rearrange their inputs,
|
||||
# which are often produced from partial evaluation, don't need compilation,
|
||||
# and don't need to force their (potentially lazy) arguments.
|
||||
if not jaxpr.eqns:
|
||||
return partial(_execute_trivial, jaxpr, device, consts, result_handlers)
|
||||
return partial(_execute_trivial, jaxpr, device, consts, out_avals, result_handlers)
|
||||
|
||||
if not _on_exit:
|
||||
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
|
||||
@ -664,9 +685,9 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
options.parameter_is_tupled_arguments = tuple_args
|
||||
compiled = backend_compile(backend, built, options)
|
||||
if nreps == 1:
|
||||
return partial(_execute_compiled, compiled, result_handlers)
|
||||
return partial(_execute_compiled, compiled, out_avals, result_handlers)
|
||||
else:
|
||||
return partial(_execute_replicated, compiled, result_handlers)
|
||||
return partial(_execute_replicated, compiled, out_avals, result_handlers)
|
||||
|
||||
def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):
|
||||
"""Configures input/output "must" aliasing based on `donated_args`."""
|
||||
@ -728,17 +749,18 @@ def _xla_callable_args(
|
||||
else:
|
||||
parts = [_replicated_param if part is None else part
|
||||
for part in partitions]
|
||||
return [_xla_param(c, i, aval_to_xla_shape(a), r, p)
|
||||
counts = it.count()
|
||||
return [_xla_param(c, next(counts), xla_shape, r, p)
|
||||
if a is not abstract_token else xops.CreateToken(c)
|
||||
for i, (a, r, p)
|
||||
in enumerate(safe_zip(avals, replicated, parts))]
|
||||
for (a, r, p) in safe_zip(avals, replicated, parts)
|
||||
for xla_shape in aval_to_xla_shapes(a)]
|
||||
else:
|
||||
if replicated is not None:
|
||||
replicated = [r for a, r in zip(avals, replicated)
|
||||
if a is not abstract_token]
|
||||
tuple_parts = tuple(partitions) if partitions is not None else None
|
||||
tuple_shape = xc.Shape.tuple_shape(
|
||||
[aval_to_xla_shape(a) for a in avals if a is not abstract_token])
|
||||
[shape for a in avals for shape in aval_to_xla_shapes(a) if a is not abstract_token])
|
||||
tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts)
|
||||
xla_inputs = iter(xla_destructure(c, tuple_param))
|
||||
xla_args = [next(xla_inputs) if a is not abstract_token else
|
||||
@ -756,29 +778,29 @@ def _xla_param(builder, param_num, xla_shape, replicated, partitions):
|
||||
else:
|
||||
return xb.with_sharding(builder, partitions, make_param)
|
||||
|
||||
def _execute_compiled(compiled: XlaExecutable, handlers, *args):
|
||||
def _execute_compiled(compiled: XlaExecutable, avals, handlers, *args):
|
||||
device, = compiled.local_devices()
|
||||
input_bufs = [device_put(x, device) for x in args if x is not token]
|
||||
input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
|
||||
out_bufs = compiled.execute(input_bufs)
|
||||
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
|
||||
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
|
||||
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]
|
||||
|
||||
def _execute_replicated(compiled: XlaExecutable, handlers, *args):
|
||||
def _execute_replicated(compiled: XlaExecutable, avals, handlers, *args):
|
||||
input_bufs = [
|
||||
[device_put(x, device) for x in args if x is not token]
|
||||
list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
|
||||
for device in compiled.local_devices()]
|
||||
out_bufs = compiled.execute_on_local_devices(input_bufs)[0]
|
||||
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
|
||||
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
|
||||
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]
|
||||
|
||||
def _execute_trivial(jaxpr, device: Optional[Device], consts, handlers, *args):
|
||||
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, *args):
|
||||
env = {core.unitvar: core.unit}
|
||||
map(env.setdefault, jaxpr.invars, args)
|
||||
map(env.setdefault, jaxpr.constvars, consts)
|
||||
outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v]
|
||||
for v in jaxpr.outvars]
|
||||
return [_copy_device_array_to_device(x, device) if type(x) is DeviceArray
|
||||
else h(device_put(x, device)) for h, x in zip(handlers, outs)]
|
||||
else h(*device_put(x, device)) for h, x in zip(handlers, outs)]
|
||||
|
||||
xla_call_p = core.CallPrimitive('xla_call')
|
||||
xla_call = xla_call_p.bind
|
||||
@ -924,7 +946,7 @@ token = Token()
|
||||
|
||||
pytype_aval_mappings[Token] = lambda _: abstract_token
|
||||
core.pytype_aval_mappings[Token] = lambda _: abstract_token
|
||||
xla_shape_handlers[AbstractToken] = lambda _: xc.Shape.token_shape()
|
||||
xla_shape_handlers[AbstractToken] = lambda _: (xc.Shape.token_shape(),)
|
||||
xla_result_handlers[AbstractToken] = lambda _, __: lambda _: token
|
||||
canonicalize_dtype_handlers[Token] = identity
|
||||
|
||||
@ -949,7 +971,8 @@ class DeviceArray:
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
|
||||
def __init__(self, aval: core.ShapedArray, device: Optional[Device],
|
||||
lazy_expr: lazy.LazyExpr, device_buffer: PyLocalBuffer):
|
||||
lazy_expr: lazy.LazyExpr,
|
||||
device_buffer: PyLocalBuffer):
|
||||
self.aval = aval
|
||||
self.device_buffer = device_buffer
|
||||
self._device = device
|
||||
@ -1137,7 +1160,7 @@ xb.register_constant_handler(DeviceArray, _device_array_constant_handler)
|
||||
|
||||
def _device_put_device_array(x: DeviceArray, device: Optional[Device]):
|
||||
x = _copy_device_array_to_device(x, device)
|
||||
return _force(x).device_buffer
|
||||
return (_force(x).device_buffer,)
|
||||
device_put_handlers[DeviceArray] = _device_put_device_array
|
||||
|
||||
def _copy_device_array_to_device(x: DeviceArray, device: Optional[xc.Device]) -> DeviceArray:
|
||||
@ -1219,8 +1242,7 @@ def _device_put_impl(x, device: Optional[Device] = None):
|
||||
except TypeError as err:
|
||||
raise TypeError(
|
||||
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
|
||||
handler = aval_to_result_handler(device, a) # type: ignore[arg-type]
|
||||
return handler(device_put(x, device))
|
||||
return aval_to_result_handler(device, a)(*device_put(x, device))
|
||||
|
||||
device_put_p = core.Primitive('device_put')
|
||||
device_put_p.def_impl(_device_put_impl)
|
||||
|
@ -1402,7 +1402,7 @@ def _device_put_raw(x):
|
||||
return x
|
||||
else:
|
||||
aval = raise_to_shaped(core.get_aval(x))
|
||||
return xla.array_result_handler(None, aval)(xla.device_put(x))
|
||||
return xla.array_result_handler(None, aval)(*xla.device_put(x))
|
||||
|
||||
def iota(dtype: DType, size: int) -> Array:
|
||||
"""Wraps XLA's `Iota
|
||||
@ -5745,8 +5745,8 @@ def _infeed_abstract_eval(token, *, shapes, partitions):
|
||||
|
||||
|
||||
def _infeed_translation_rule(c, token, *, shapes, partitions):
|
||||
shape = tuple(xla.aval_to_xla_shape(x).with_major_to_minor_layout_if_absent()
|
||||
for x in shapes)
|
||||
shape = tuple(shape.with_major_to_minor_layout_if_absent()
|
||||
for x in shapes for shape in xla.aval_to_xla_shapes(x))
|
||||
build_infeed = partial(xops.InfeedWithToken, token,
|
||||
xla_client.Shape.tuple_shape(shape))
|
||||
if partitions:
|
||||
|
@ -309,6 +309,12 @@ class CoreTest(jtu.JaxTestCase):
|
||||
syms = {c: d, a: b}
|
||||
assert 'bd' == ''.join(map(str, tree_leaves(syms)))
|
||||
|
||||
def test_device_put_unit(self):
|
||||
def f(x, y):
|
||||
return x, 2 * y
|
||||
args_maker = lambda: (core.unit, 1)
|
||||
self._CompileAndCheck(f, args_maker)
|
||||
|
||||
|
||||
class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
|
||||
|
282
tests/custom_object_test.py
Normal file
282
tests/custom_object_test.py
Normal file
@ -0,0 +1,282 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
from jax import core, jit, lax, lazy, make_jaxpr
|
||||
from jax.interpreters import xla
|
||||
from jax.lib import xla_client
|
||||
xops = xla_client.ops
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
# TODO(jakevdp): use a setup/teardown method to populate and unpopulate all the
|
||||
# dictionaries associated with the following objects.
|
||||
|
||||
# Define a sparse array data structure. The important feature here is that
|
||||
# it is a jaxpr object that is backed by two device buffers.
|
||||
class SparseArray:
|
||||
"""Simple sparse COO array data structure."""
|
||||
def __init__(self, aval, data, indices):
|
||||
self.aval = aval
|
||||
self.shape = aval.shape
|
||||
self.data = data
|
||||
self.indices = indices
|
||||
|
||||
@property
|
||||
def index_dtype(self):
|
||||
return self.indices.dtype
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.data.dtype
|
||||
|
||||
@property
|
||||
def nnz(self):
|
||||
return self.data.shape[0]
|
||||
|
||||
def __repr__(self):
|
||||
return repr(list((tuple(ind), d) for ind, d in zip(self.indices, self.data)))
|
||||
|
||||
|
||||
class AbstractSparseArray(core.ShapedArray):
|
||||
__slots__ = ['index_dtype', 'nnz', 'data_aval', 'indices_aval']
|
||||
_num_buffers = 2
|
||||
|
||||
def __init__(self, shape, dtype, index_dtype, nnz):
|
||||
super(AbstractSparseArray, self).__init__(shape, dtype)
|
||||
self.index_dtype = index_dtype
|
||||
self.nnz = nnz
|
||||
self.data_aval = core.ShapedArray((nnz,), dtype)
|
||||
self.indices_aval = core.ShapedArray((nnz, len(shape)), index_dtype)
|
||||
|
||||
@core.aval_property
|
||||
def data(self):
|
||||
return sp_data_p.bind(self)
|
||||
|
||||
@core.aval_property
|
||||
def indices(self):
|
||||
return sp_indices_p.bind(self)
|
||||
|
||||
def abstract_sparse_array(arr):
|
||||
return AbstractSparseArray(arr.shape, arr.dtype, arr.index_dtype, arr.nnz)
|
||||
|
||||
def sparse_array_result_handler(device, aval):
|
||||
def build_sparse_array(data_buf, indices_buf):
|
||||
data = xla.DeviceArray(aval.data_aval, device, lazy.array(aval.data_aval.shape), data_buf)
|
||||
indices = xla.DeviceArray(aval.indices_aval, device, lazy.array(aval.indices_aval.shape), indices_buf)
|
||||
return SparseArray(aval, data, indices)
|
||||
return build_sparse_array
|
||||
|
||||
def sparse_array_shape_handler(a):
|
||||
return (
|
||||
xla.xc.Shape.array_shape(a.data_aval.dtype, a.data_aval.shape),
|
||||
xla.xc.Shape.array_shape(a.indices_aval.dtype, a.indices_aval.shape),
|
||||
)
|
||||
|
||||
def sparse_array_device_put_handler(a, device):
|
||||
return (
|
||||
xla.xb.get_device_backend(device).buffer_from_pyval(a.data, device),
|
||||
xla.xb.get_device_backend(device).buffer_from_pyval(a.indices, device)
|
||||
)
|
||||
|
||||
core.pytype_aval_mappings[SparseArray] = abstract_sparse_array
|
||||
core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval
|
||||
xla.pytype_aval_mappings[SparseArray] = abstract_sparse_array
|
||||
xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
|
||||
xla.device_put_handlers[SparseArray] = sparse_array_device_put_handler
|
||||
xla.xla_result_handlers[AbstractSparseArray] = sparse_array_result_handler
|
||||
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
|
||||
|
||||
|
||||
sp_indices_p = core.Primitive('sp_indices')
|
||||
|
||||
@sp_indices_p.def_impl
|
||||
def _sp_indices_impl(mat):
|
||||
return mat.indices
|
||||
|
||||
@sp_indices_p.def_abstract_eval
|
||||
def _sp_indices_abstract_eval(mat):
|
||||
return mat.indices_aval
|
||||
|
||||
def _sp_indices_translation_rule(c, data, indices):
|
||||
return indices
|
||||
|
||||
xla.translations[sp_indices_p] = _sp_indices_translation_rule
|
||||
|
||||
sp_data_p = core.Primitive('sp_data')
|
||||
|
||||
@sp_data_p.def_impl
|
||||
def _sp_data_impl(mat):
|
||||
return mat.data
|
||||
|
||||
@sp_data_p.def_abstract_eval
|
||||
def _sp_data_abstract_eval(mat):
|
||||
return mat.data_aval
|
||||
|
||||
def _sp_data_translation_rule(c, data, indices):
|
||||
return data
|
||||
|
||||
xla.translations[sp_data_p] = _sp_data_translation_rule
|
||||
|
||||
def identity(x):
|
||||
return identity_p.bind(x)
|
||||
|
||||
identity_p = core.Primitive('identity')
|
||||
|
||||
@identity_p.def_impl
|
||||
def _identity_impl(mat):
|
||||
return SparseArray(mat.aval, mat.data, mat.indices)
|
||||
|
||||
@identity_p.def_abstract_eval
|
||||
def _identity_abstract_eval(mat):
|
||||
return mat
|
||||
|
||||
def _identity_translation_rule(c, data, indices):
|
||||
return xops.Tuple(c, (data, indices))
|
||||
|
||||
xla.translations[identity_p] = _identity_translation_rule
|
||||
|
||||
def make_sparse_array(rng, shape, dtype, nnz=0.2):
|
||||
mat = rng(shape, dtype)
|
||||
size = int(np.prod(shape))
|
||||
if 0 < nnz < 1:
|
||||
nnz = nnz * size
|
||||
nnz = int(nnz)
|
||||
if nnz == 0:
|
||||
mat = np.zeros_like(mat)
|
||||
elif nnz < size:
|
||||
# TODO(jakevdp): do we care about duplicates?
|
||||
cutoff = np.sort(mat.ravel())[nnz]
|
||||
mat[mat >= cutoff] = 0
|
||||
nz = (mat != 0)
|
||||
data = jnp.array(mat[nz])
|
||||
indices = jnp.array(np.where(nz)).T
|
||||
aval = AbstractSparseArray(shape, data.dtype, indices.dtype, len(indices))
|
||||
return SparseArray(aval, data, indices)
|
||||
|
||||
def matvec(mat, v):
|
||||
v = jnp.asarray(v)
|
||||
assert v.ndim == 1
|
||||
assert len(mat.shape) == 2
|
||||
assert v.shape[0] == mat.shape[1]
|
||||
rows = mat.indices[:, 0]
|
||||
cols = mat.indices[:, 1]
|
||||
dv = mat.data * v[cols]
|
||||
return jnp.zeros(mat.shape[0], dtype=dv.dtype).at[rows].add(dv)
|
||||
|
||||
|
||||
class Empty:
|
||||
def __init__(self, aval):
|
||||
self.aval = aval
|
||||
|
||||
class AbstractEmpty(core.AbstractValue):
|
||||
_num_buffers = 0
|
||||
|
||||
def join(self, other):
|
||||
assert isinstance(other, self.__class__), other
|
||||
return self
|
||||
|
||||
def __hash__(self):
|
||||
return hash(())
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, AbstractEmpty)
|
||||
|
||||
|
||||
def abstract_empty(e):
|
||||
return AbstractEmpty()
|
||||
|
||||
core.pytype_aval_mappings[Empty] = abstract_empty
|
||||
core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval
|
||||
xla.pytype_aval_mappings[Empty] = abstract_empty
|
||||
xla.canonicalize_dtype_handlers[Empty] = lambda x: x
|
||||
xla.device_put_handlers[Empty] = lambda _, __: ()
|
||||
xla.xla_result_handlers[AbstractEmpty] = lambda _, __: lambda: Empty(AbstractEmpty())
|
||||
xla.xla_shape_handlers[AbstractEmpty] = lambda _: ()
|
||||
|
||||
|
||||
class CustomObjectTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_compile={}_primitive={}".format(compile, primitive),
|
||||
"compile": compile, "primitive": primitive}
|
||||
for primitive in [True, False]
|
||||
for compile in [True, False]))
|
||||
def testSparseIdentity(self, compile, primitive):
|
||||
f = identity if primitive else (lambda x: x)
|
||||
f = jit(f) if compile else f
|
||||
rng = jtu.rand_default(self.rng())
|
||||
M = make_sparse_array(rng, (10,), jnp.float32)
|
||||
M2 = f(M)
|
||||
|
||||
jaxpr = make_jaxpr(f)(M).jaxpr
|
||||
core.check_jaxpr(jaxpr)
|
||||
|
||||
self.assertEqual(M.dtype, M2.dtype)
|
||||
self.assertEqual(M.index_dtype, M2.index_dtype)
|
||||
self.assertAllClose(M.data, M2.data)
|
||||
self.assertAllClose(M.indices, M2.indices)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_compile={}_primitive={}".format(compile, primitive),
|
||||
"compile": compile, "primitive": primitive}
|
||||
for primitive in [True, False]
|
||||
for compile in [True, False]))
|
||||
def testSparseLaxLoop(self, compile, primitive):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
f = identity if primitive else (lambda x: x)
|
||||
f = jit(f) if compile else f
|
||||
body_fun = lambda _, A: f(A)
|
||||
M = make_sparse_array(rng, (10,), jnp.float32)
|
||||
lax.fori_loop(0, 10, body_fun, M)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_attr={}".format(attr), "attr": attr}
|
||||
for attr in ["data", "indices"]))
|
||||
def testSparseAttrAccess(self, attr):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [make_sparse_array(rng, (10,), jnp.float32)]
|
||||
f = lambda x: getattr(x, attr)
|
||||
self._CompileAndCheck(f, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(3, 3), (2, 6), (6, 2)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testSparseMatvec(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [make_sparse_array(rng, shape, dtype), rng(shape[-1:], dtype)]
|
||||
self._CompileAndCheck(matvec, args_maker)
|
||||
|
||||
def testLowerToNothing(self):
|
||||
empty = Empty(AbstractEmpty())
|
||||
jaxpr = make_jaxpr(jit(lambda e: e))(empty).jaxpr
|
||||
core.check_jaxpr(jaxpr)
|
||||
|
||||
# cannot return a unit, because CompileAndCheck assumes array output.
|
||||
testfunc = lambda e: None
|
||||
args_maker = lambda: [empty]
|
||||
self._CompileAndCheck(testfunc, args_maker)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -1156,7 +1156,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
# subsequent pmap
|
||||
shard_shape = (3,2)
|
||||
shard = jnp.arange(prod(shard_shape)).reshape(shard_shape)
|
||||
bufs = [xla.device_put(shard, d) for d in xla_bridge.devices()[:4]]
|
||||
bufs = pxla.device_put(shard, xla_bridge.devices()[:4], replicate=True)
|
||||
aval = ShapedArray((6,4), shard.dtype)
|
||||
sharding_spec = pxla.ShardingSpec(
|
||||
shards_per_axis=(2, 2),
|
||||
|
Loading…
x
Reference in New Issue
Block a user