mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Keep ShapedArray avals on xla.DeviceArray values
Makes abstractification of DeviceArray values cheaper, which is on the critical path for executing a compiled function.
This commit is contained in:
parent
4474b11181
commit
3e78a0e290
@ -31,7 +31,7 @@ from .. import linear_util as lu
|
||||
from ..abstract_arrays import ConcreteArray, ShapedArray
|
||||
from ..util import partial, unzip2, concatenate, prod
|
||||
from ..lib import xla_bridge as xb
|
||||
from .xla import xla_shape, xla_destructure, xla_shape_to_result_shape
|
||||
from .xla import xla_shape, xla_destructure, aval_from_xla_shape
|
||||
from .partial_eval import trace_to_subjaxpr, merge_pvals, JaxprTrace, PartialVal
|
||||
from .batching import dimsize, broadcast
|
||||
from . import batching
|
||||
@ -240,7 +240,7 @@ def compile_replicated(jaxpr, axis_name, axis_size, consts, *abstract_args):
|
||||
axis_env = xla.AxisEnv(num_replicas, [axis_name], [axis_size])
|
||||
arg_shapes = list(map(xla_shape, abstract_args))
|
||||
built_c = xla._jaxpr_computation(jaxpr, axis_env, consts, (), *arg_shapes)
|
||||
result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape())
|
||||
result_shape = aval_from_xla_shape(built_c.GetReturnValueShape())
|
||||
compiled = built_c.Compile(arg_shapes, xb.get_compile_options(num_replicas),
|
||||
backend=xb.get_backend())
|
||||
return compiled, num_replicas, result_shape
|
||||
@ -435,7 +435,7 @@ class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):
|
||||
represent distinct logical shards. The correspondence can be computed with
|
||||
the assign_shards_to_replicas function.
|
||||
"""
|
||||
__slots__ = ["device_buffers", "axis_size", "aval"]
|
||||
__slots__ = ["device_buffers", "axis_size"]
|
||||
_collect = staticmethod(onp.stack)
|
||||
|
||||
def __init__(self, aval, device_buffers):
|
||||
@ -443,8 +443,6 @@ class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):
|
||||
# return it unmodified.
|
||||
self.aval = aval
|
||||
self.device_buffers = device_buffers
|
||||
self.shape = aval.shape
|
||||
self.dtype = aval.dtype
|
||||
self.axis_size = aval.shape[0]
|
||||
self._npy_value = None
|
||||
|
||||
@ -477,7 +475,7 @@ class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):
|
||||
if self._npy_value is None and type(idx) is int:
|
||||
ids = self._ids()
|
||||
device_buffer = self.device_buffers[ids[idx]]
|
||||
result_shape = xla_shape_to_result_shape(device_buffer.shape())
|
||||
result_shape = aval_from_xla_shape(device_buffer.shape())
|
||||
handler = xla.result_handler(result_shape)
|
||||
return handler(device_buffer)
|
||||
else:
|
||||
|
@ -55,7 +55,7 @@ def apply_primitive(prim, *args, **params):
|
||||
def _xla_primitive_callable(prim, *abstract_args, **params):
|
||||
shapes = tuple(map(xla_shape, abstract_args))
|
||||
built_c = primitive_computation(prim, *shapes, **params)
|
||||
result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape())
|
||||
result_shape = aval_from_xla_shape(built_c.GetReturnValueShape())
|
||||
handle_result = result_handler(result_shape)
|
||||
compiled = built_c.Compile(shapes, xb.get_compile_options(),
|
||||
backend=xb.get_backend())
|
||||
@ -91,12 +91,12 @@ def primitive_computation(prim, *shapes, **params):
|
||||
return c.Build()
|
||||
except RuntimeError as e:
|
||||
# try for a better error message by using the abstract_eval checks
|
||||
prim.abstract_eval(*map(_aval_from_xla_shape, shapes), **params)
|
||||
prim.abstract_eval(*map(aval_from_xla_shape, shapes), **params)
|
||||
raise e
|
||||
|
||||
def _aval_from_xla_shape(shape):
|
||||
def aval_from_xla_shape(shape):
|
||||
if shape.is_tuple():
|
||||
return AbstractTuple(map(_aval_from_xla_shape, shape.tuple_shapes()))
|
||||
return AbstractTuple(map(aval_from_xla_shape, shape.tuple_shapes()))
|
||||
else:
|
||||
return ShapedArray(shape.dimensions(), shape.element_type())
|
||||
|
||||
@ -179,26 +179,12 @@ def device_put(x, device_num=0):
|
||||
# JaxType, i.e. that the mapping is bijective. That assumption could be relaxed,
|
||||
# but it would mean we need to do a bit more bookkeping on the Python side to
|
||||
# track abstract values of outputs.
|
||||
def xla_shape_to_result_shape(xla_shape):
|
||||
if xla_shape.is_tuple():
|
||||
aval = _aval_from_xla_shape(xla_shape)
|
||||
result_shapes = tuple(map(xla_shape_to_result_shape, xla_shape.tuple_shapes()))
|
||||
return _ResultTuple((aval, result_shapes))
|
||||
else:
|
||||
shape, dtype = xla_shape.dimensions(), xla_shape.element_type()
|
||||
return _ResultArray((shape, dtype))
|
||||
|
||||
class _ResultTuple(tuple): pass
|
||||
class _ResultArray(tuple): pass
|
||||
|
||||
def result_handler(result_shape):
|
||||
t = type(result_shape)
|
||||
if t is _ResultArray:
|
||||
return partial(DeviceArray, result_shape)
|
||||
elif t is _ResultTuple:
|
||||
return partial(DeviceTuple, result_shape)
|
||||
def result_handler(aval):
|
||||
if isinstance(aval, core.AbstractTuple):
|
||||
return partial(DeviceTuple, aval)
|
||||
else:
|
||||
raise TypeError(t)
|
||||
return partial(DeviceArray, aval)
|
||||
|
||||
|
||||
def _compile_jaxpr(jaxpr, device_assignment, axis_env, const_vals, *abstract_args):
|
||||
@ -208,7 +194,7 @@ def _compile_jaxpr(jaxpr, device_assignment, axis_env, const_vals, *abstract_arg
|
||||
raise ValueError(msg.format(axis_env.nreps, xb.device_count()))
|
||||
arg_shapes = list(map(xla_shape, abstract_args))
|
||||
built_c = _jaxpr_computation(jaxpr, axis_env, const_vals, (), *arg_shapes)
|
||||
result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape())
|
||||
result_shape = aval_from_xla_shape(built_c.GetReturnValueShape())
|
||||
compile_opts = xb.get_compile_options(num_replicas=axis_env.nreps,
|
||||
device_assignment=device_assignment)
|
||||
compiled_c = built_c.Compile(arg_shapes, compile_opts, backend=xb.get_backend())
|
||||
@ -364,7 +350,7 @@ def lower_fun(fun, instantiate=False, initial_style=False):
|
||||
else:
|
||||
axis_env, xla_args = AxisEnv(1, [], []), args
|
||||
xla_shapes = tuple(map(c.GetShape, xla_args))
|
||||
avals = map(_aval_from_xla_shape, xla_shapes)
|
||||
avals = map(aval_from_xla_shape, xla_shapes)
|
||||
pvals = [pe.PartialVal((a, core.unit)) for a in avals]
|
||||
jaxpr, _, consts = pe.trace_unwrapped_to_jaxpr(fun, pvals, instantiate,
|
||||
**params)
|
||||
@ -448,8 +434,10 @@ for _t in array_types:
|
||||
|
||||
class DeviceValue(object):
|
||||
"""A DeviceValue represents a value backed by device memory."""
|
||||
__slots__ = ["device_buffer"]
|
||||
def __init__(self, device_buffer):
|
||||
__slots__ = ["aval", "device_buffer"]
|
||||
|
||||
def __init__(self, aval, device_buffer):
|
||||
self.aval = aval
|
||||
self.device_buffer = device_buffer
|
||||
|
||||
def _check_if_deleted(self):
|
||||
@ -469,15 +457,15 @@ class DeviceValue(object):
|
||||
|
||||
class DeviceTuple(DeviceValue):
|
||||
"""A DeviceTuple is a JaxTuple backed by a single device memory buffer."""
|
||||
__slots__ = ["aval", "result_shapes"]
|
||||
__slots__ = []
|
||||
|
||||
def __init__(self, result_shape, device_buffer):
|
||||
def __init__(self, aval, device_buffer):
|
||||
self.aval = aval
|
||||
self.device_buffer = device_buffer
|
||||
self.aval, self.result_shapes = result_shape
|
||||
|
||||
def __iter__(self):
|
||||
bufs = self.device_buffer.destructure()
|
||||
handlers = map(result_handler, self.result_shapes)
|
||||
handlers = map(result_handler, self.aval)
|
||||
elts = [handler(buf) for handler, buf in zip(handlers, bufs)]
|
||||
return iter(elts)
|
||||
|
||||
@ -514,12 +502,12 @@ class DeviceArray(DeviceValue):
|
||||
"""A DeviceArray is an ndarray backed by a single device memory buffer."""
|
||||
# We don't subclass ndarray because that would open up a host of issues,
|
||||
# but lax_numpy.py overrides isinstance behavior and attaches ndarray methods.
|
||||
__slots__ = ["shape", "dtype", "_npy_value"]
|
||||
__array_priority__ = 100.
|
||||
__slots__ = ["_npy_value"]
|
||||
__array_priority__ = 100
|
||||
|
||||
def __init__(self, result_shape, device_buffer):
|
||||
def __init__(self, aval, device_buffer):
|
||||
self.aval = aval
|
||||
self.device_buffer = device_buffer
|
||||
self.shape, self.dtype = result_shape
|
||||
self._npy_value = None
|
||||
|
||||
@property
|
||||
@ -530,13 +518,21 @@ class DeviceArray(DeviceValue):
|
||||
self._npy_value.flags.writeable = False
|
||||
return self._npy_value
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.aval.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.aval.dtype
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return prod(self.shape)
|
||||
return prod(self.aval.shape)
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return len(self.shape)
|
||||
return len(self.aval.shape)
|
||||
|
||||
def copy(self):
|
||||
"""Returns an ndarray (backed by host memory, not device memory)."""
|
||||
@ -580,7 +576,7 @@ class DeviceArray(DeviceValue):
|
||||
|
||||
def __len__(self):
|
||||
try:
|
||||
return self.shape[0]
|
||||
return self.aval.shape[0]
|
||||
except IndexError:
|
||||
raise TypeError("len() of unsized object") # same as numpy error
|
||||
|
||||
@ -632,7 +628,7 @@ core.literalable_types.add(DeviceArray)
|
||||
# DeviceValues don't need to be canonicalized because we assume values on the
|
||||
# device have already been canonicalized.
|
||||
core.pytype_aval_mappings[DeviceArray] = ConcreteArray
|
||||
pytype_aval_mappings[DeviceArray] = make_shaped_array
|
||||
pytype_aval_mappings[DeviceArray] = lambda x: x.aval
|
||||
canonicalize_dtype_handlers[DeviceArray] = _identity
|
||||
|
||||
def _device_array_constant_handler(c, val, canonicalize_types=True):
|
||||
@ -728,8 +724,7 @@ def _device_put_impl(x, device_num=0):
|
||||
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
|
||||
.format(x, type(x)))
|
||||
|
||||
result_shape = xla_shape_to_result_shape(xla_shape(a))
|
||||
handler = result_handler(result_shape)
|
||||
handler = result_handler(a)
|
||||
return handler(device_put(x, device_num))
|
||||
|
||||
device_put_p = core.Primitive('device_put')
|
||||
|
@ -3924,8 +3924,7 @@ class _FilledConstant(xla.DeviceConstant):
|
||||
|
||||
def __init__(self, fill_value, shape):
|
||||
assert type(fill_value) is onp.ndarray
|
||||
self.shape = shape
|
||||
self.dtype = _dtype(fill_value)
|
||||
self.aval = ShapedArray(shape, _dtype(fill_value))
|
||||
self._npy_value = None
|
||||
|
||||
self.fill_value = fill_value
|
||||
@ -3945,8 +3944,7 @@ class _IotaConstant(xla.DeviceConstant):
|
||||
__slots__ = ["axis"]
|
||||
|
||||
def __init__(self, dtype, shape, axis):
|
||||
self.shape = shape
|
||||
self.dtype = onp.dtype(dtype)
|
||||
self.aval = ShapedArray(shape, onp.dtype(dtype))
|
||||
self._npy_value = None
|
||||
|
||||
self.axis = axis
|
||||
@ -3972,8 +3970,7 @@ class _EyeConstant(xla.DeviceConstant):
|
||||
__slots__ = ["axes"]
|
||||
|
||||
def __init__(self, shape, axes, dtype):
|
||||
self.shape = shape
|
||||
self.dtype = onp.dtype(dtype)
|
||||
self.aval = ShapedArray(shape, onp.dtype(dtype))
|
||||
self._npy_value = None
|
||||
|
||||
self.axes = axes
|
||||
|
Loading…
x
Reference in New Issue
Block a user