Keep ShapedArray avals on xla.DeviceArray values

Makes abstractification of DeviceArray values cheaper, which is on the critical path for executing a compiled function.
This commit is contained in:
Peter Hawkins 2019-08-10 16:04:43 -04:00
parent 4474b11181
commit 3e78a0e290
3 changed files with 42 additions and 52 deletions

View File

@ -31,7 +31,7 @@ from .. import linear_util as lu
from ..abstract_arrays import ConcreteArray, ShapedArray
from ..util import partial, unzip2, concatenate, prod
from ..lib import xla_bridge as xb
from .xla import xla_shape, xla_destructure, xla_shape_to_result_shape
from .xla import xla_shape, xla_destructure, aval_from_xla_shape
from .partial_eval import trace_to_subjaxpr, merge_pvals, JaxprTrace, PartialVal
from .batching import dimsize, broadcast
from . import batching
@ -240,7 +240,7 @@ def compile_replicated(jaxpr, axis_name, axis_size, consts, *abstract_args):
axis_env = xla.AxisEnv(num_replicas, [axis_name], [axis_size])
arg_shapes = list(map(xla_shape, abstract_args))
built_c = xla._jaxpr_computation(jaxpr, axis_env, consts, (), *arg_shapes)
result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape())
result_shape = aval_from_xla_shape(built_c.GetReturnValueShape())
compiled = built_c.Compile(arg_shapes, xb.get_compile_options(num_replicas),
backend=xb.get_backend())
return compiled, num_replicas, result_shape
@ -435,7 +435,7 @@ class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):
represent distinct logical shards. The correspondence can be computed with
the assign_shards_to_replicas function.
"""
__slots__ = ["device_buffers", "axis_size", "aval"]
__slots__ = ["device_buffers", "axis_size"]
_collect = staticmethod(onp.stack)
def __init__(self, aval, device_buffers):
@ -443,8 +443,6 @@ class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):
# return it unmodified.
self.aval = aval
self.device_buffers = device_buffers
self.shape = aval.shape
self.dtype = aval.dtype
self.axis_size = aval.shape[0]
self._npy_value = None
@ -477,7 +475,7 @@ class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):
if self._npy_value is None and type(idx) is int:
ids = self._ids()
device_buffer = self.device_buffers[ids[idx]]
result_shape = xla_shape_to_result_shape(device_buffer.shape())
result_shape = aval_from_xla_shape(device_buffer.shape())
handler = xla.result_handler(result_shape)
return handler(device_buffer)
else:

View File

@ -55,7 +55,7 @@ def apply_primitive(prim, *args, **params):
def _xla_primitive_callable(prim, *abstract_args, **params):
shapes = tuple(map(xla_shape, abstract_args))
built_c = primitive_computation(prim, *shapes, **params)
result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape())
result_shape = aval_from_xla_shape(built_c.GetReturnValueShape())
handle_result = result_handler(result_shape)
compiled = built_c.Compile(shapes, xb.get_compile_options(),
backend=xb.get_backend())
@ -91,12 +91,12 @@ def primitive_computation(prim, *shapes, **params):
return c.Build()
except RuntimeError as e:
# try for a better error message by using the abstract_eval checks
prim.abstract_eval(*map(_aval_from_xla_shape, shapes), **params)
prim.abstract_eval(*map(aval_from_xla_shape, shapes), **params)
raise e
def _aval_from_xla_shape(shape):
def aval_from_xla_shape(shape):
if shape.is_tuple():
return AbstractTuple(map(_aval_from_xla_shape, shape.tuple_shapes()))
return AbstractTuple(map(aval_from_xla_shape, shape.tuple_shapes()))
else:
return ShapedArray(shape.dimensions(), shape.element_type())
@ -179,26 +179,12 @@ def device_put(x, device_num=0):
# JaxType, i.e. that the mapping is bijective. That assumption could be relaxed,
# but it would mean we need to do a bit more bookkeping on the Python side to
# track abstract values of outputs.
def xla_shape_to_result_shape(xla_shape):
if xla_shape.is_tuple():
aval = _aval_from_xla_shape(xla_shape)
result_shapes = tuple(map(xla_shape_to_result_shape, xla_shape.tuple_shapes()))
return _ResultTuple((aval, result_shapes))
else:
shape, dtype = xla_shape.dimensions(), xla_shape.element_type()
return _ResultArray((shape, dtype))
class _ResultTuple(tuple): pass
class _ResultArray(tuple): pass
def result_handler(result_shape):
t = type(result_shape)
if t is _ResultArray:
return partial(DeviceArray, result_shape)
elif t is _ResultTuple:
return partial(DeviceTuple, result_shape)
def result_handler(aval):
if isinstance(aval, core.AbstractTuple):
return partial(DeviceTuple, aval)
else:
raise TypeError(t)
return partial(DeviceArray, aval)
def _compile_jaxpr(jaxpr, device_assignment, axis_env, const_vals, *abstract_args):
@ -208,7 +194,7 @@ def _compile_jaxpr(jaxpr, device_assignment, axis_env, const_vals, *abstract_arg
raise ValueError(msg.format(axis_env.nreps, xb.device_count()))
arg_shapes = list(map(xla_shape, abstract_args))
built_c = _jaxpr_computation(jaxpr, axis_env, const_vals, (), *arg_shapes)
result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape())
result_shape = aval_from_xla_shape(built_c.GetReturnValueShape())
compile_opts = xb.get_compile_options(num_replicas=axis_env.nreps,
device_assignment=device_assignment)
compiled_c = built_c.Compile(arg_shapes, compile_opts, backend=xb.get_backend())
@ -364,7 +350,7 @@ def lower_fun(fun, instantiate=False, initial_style=False):
else:
axis_env, xla_args = AxisEnv(1, [], []), args
xla_shapes = tuple(map(c.GetShape, xla_args))
avals = map(_aval_from_xla_shape, xla_shapes)
avals = map(aval_from_xla_shape, xla_shapes)
pvals = [pe.PartialVal((a, core.unit)) for a in avals]
jaxpr, _, consts = pe.trace_unwrapped_to_jaxpr(fun, pvals, instantiate,
**params)
@ -448,8 +434,10 @@ for _t in array_types:
class DeviceValue(object):
"""A DeviceValue represents a value backed by device memory."""
__slots__ = ["device_buffer"]
def __init__(self, device_buffer):
__slots__ = ["aval", "device_buffer"]
def __init__(self, aval, device_buffer):
self.aval = aval
self.device_buffer = device_buffer
def _check_if_deleted(self):
@ -469,15 +457,15 @@ class DeviceValue(object):
class DeviceTuple(DeviceValue):
"""A DeviceTuple is a JaxTuple backed by a single device memory buffer."""
__slots__ = ["aval", "result_shapes"]
__slots__ = []
def __init__(self, result_shape, device_buffer):
def __init__(self, aval, device_buffer):
self.aval = aval
self.device_buffer = device_buffer
self.aval, self.result_shapes = result_shape
def __iter__(self):
bufs = self.device_buffer.destructure()
handlers = map(result_handler, self.result_shapes)
handlers = map(result_handler, self.aval)
elts = [handler(buf) for handler, buf in zip(handlers, bufs)]
return iter(elts)
@ -514,12 +502,12 @@ class DeviceArray(DeviceValue):
"""A DeviceArray is an ndarray backed by a single device memory buffer."""
# We don't subclass ndarray because that would open up a host of issues,
# but lax_numpy.py overrides isinstance behavior and attaches ndarray methods.
__slots__ = ["shape", "dtype", "_npy_value"]
__array_priority__ = 100.
__slots__ = ["_npy_value"]
__array_priority__ = 100
def __init__(self, result_shape, device_buffer):
def __init__(self, aval, device_buffer):
self.aval = aval
self.device_buffer = device_buffer
self.shape, self.dtype = result_shape
self._npy_value = None
@property
@ -530,13 +518,21 @@ class DeviceArray(DeviceValue):
self._npy_value.flags.writeable = False
return self._npy_value
@property
def shape(self):
return self.aval.shape
@property
def dtype(self):
return self.aval.dtype
@property
def size(self):
return prod(self.shape)
return prod(self.aval.shape)
@property
def ndim(self):
return len(self.shape)
return len(self.aval.shape)
def copy(self):
"""Returns an ndarray (backed by host memory, not device memory)."""
@ -580,7 +576,7 @@ class DeviceArray(DeviceValue):
def __len__(self):
try:
return self.shape[0]
return self.aval.shape[0]
except IndexError:
raise TypeError("len() of unsized object") # same as numpy error
@ -632,7 +628,7 @@ core.literalable_types.add(DeviceArray)
# DeviceValues don't need to be canonicalized because we assume values on the
# device have already been canonicalized.
core.pytype_aval_mappings[DeviceArray] = ConcreteArray
pytype_aval_mappings[DeviceArray] = make_shaped_array
pytype_aval_mappings[DeviceArray] = lambda x: x.aval
canonicalize_dtype_handlers[DeviceArray] = _identity
def _device_array_constant_handler(c, val, canonicalize_types=True):
@ -728,8 +724,7 @@ def _device_put_impl(x, device_num=0):
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
.format(x, type(x)))
result_shape = xla_shape_to_result_shape(xla_shape(a))
handler = result_handler(result_shape)
handler = result_handler(a)
return handler(device_put(x, device_num))
device_put_p = core.Primitive('device_put')

View File

@ -3924,8 +3924,7 @@ class _FilledConstant(xla.DeviceConstant):
def __init__(self, fill_value, shape):
assert type(fill_value) is onp.ndarray
self.shape = shape
self.dtype = _dtype(fill_value)
self.aval = ShapedArray(shape, _dtype(fill_value))
self._npy_value = None
self.fill_value = fill_value
@ -3945,8 +3944,7 @@ class _IotaConstant(xla.DeviceConstant):
__slots__ = ["axis"]
def __init__(self, dtype, shape, axis):
self.shape = shape
self.dtype = onp.dtype(dtype)
self.aval = ShapedArray(shape, onp.dtype(dtype))
self._npy_value = None
self.axis = axis
@ -3972,8 +3970,7 @@ class _EyeConstant(xla.DeviceConstant):
__slots__ = ["axes"]
def __init__(self, shape, axes, dtype):
self.shape = shape
self.dtype = onp.dtype(dtype)
self.aval = ShapedArray(shape, onp.dtype(dtype))
self._npy_value = None
self.axes = axes