add tie_in and full primitives (constant creation)

This commit is contained in:
Matthew Johnson 2018-12-13 07:24:14 -08:00
parent 25cf9358d1
commit f971415218
5 changed files with 144 additions and 26 deletions

View File

@ -249,6 +249,8 @@ class Tracer(object):
def __hex__(self): return self.aval._hex(self)
def __oct__(self): return self.aval._oct(self)
def __setitem__(self, idx, val):
raise TypeError("JAX 'Tracer' objects do not support item assignment")
def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here

View File

@ -74,6 +74,8 @@ def execute_compiled_primitive(compiled, result_handler, *args):
def device_put(x):
if type(x) is DeviceArray:
return x.device_buffer
elif isinstance(x, DeviceConstant):
return instantiate_device_constant(x)
else:
return xb.device_put(x) # can round-trip elements of tuples here
@ -237,6 +239,7 @@ class DeviceArray(DeviceValue):
self.size = size
self._npy_value = None
# TODO make device_buffer a property, make the _npy_value writeable, invalidate
@property
def _value(self):
if self._npy_value is None:
@ -319,6 +322,31 @@ xb.register_constant_handler(DeviceArray,
lambda c, val: c.Constant(onp.asarray(val)))
class DeviceConstant(DeviceArray):
@staticmethod
def constant_handler(c, constant_instance):
assert False
# TODO(mattjj): tune cutoff
def instantiate_device_constant(const, cutoff=1000000):
# dispatch an XLA Computation to build the constant on the device if it's
# large, or alternatively build it on the host and transfer it if it's small
assert isinstance(const, DeviceConstant)
if const.size > cutoff:
c = xb.make_computation_builder("constant_instantiating_computation")
xla_const = const.constant_handler(c, const)
compiled = c.Build(xla_const).Compile((), xb.get_compile_options())
return compiled.Execute(())
else:
return xb.device_put(onp.asarray(const))
def register_device_constant(cls):
pytype_aval_mappings[cls] = pytype_aval_mappings[DeviceArray]
canonicalize_dtype_handlers[cls] = identity
core.pytype_aval_mappings[cls] = ConcreteArray
xb.register_constant_handler(cls, cls.constant_handler)
def xla_shape(x):
try:
return xb.Shape.array_shape(x.dtype, x.shape)

View File

@ -17,7 +17,7 @@ from __future__ import division
from __future__ import print_function
import collections
from .util import partial
from .util import partial, prod
import itertools
import operator
import six
@ -43,6 +43,9 @@ from .lib import xla_bridge
_max = builtins.max
_min = builtins.max
_reduce = six.moves.reduce
def identity(x): return x
### traceables
@ -411,6 +414,25 @@ class OpaqueParam(object):
opaque_param_ids = itertools.count()
def tie_in(x, y):
return tie_in_p.bind(x, y)
def full(shape, fill_value, dtype=None):
if onp.shape(fill_value):
msg = "full must be called with scalar fill_value, got fill_value.shape {}."
raise TypeError(msg.format(onp.shape(fill_value)))
dtype = dtype and xla_bridge.canonicalize_dtype(dtype)
if dtype is not None and _dtype(fill_value) != dtype:
# for Python scalars and raw ndarrays, we keep fill_value as a cpu ndarray
if onp.isscalar(fill_value) or type(fill_value) is onp.ndarray:
fill_value = onp.array(fill_value, dtype)
else:
fill_value = convert_element_type(fill_value, dtype)
return full_p.bind(fill_value, shape=shape)
### convenience wrappers around traceables
@ -439,7 +461,8 @@ def full_like(x, fill_value, dtype=None, shape=None):
`fill_value`, similar to the output of np.full.
"""
shape = onp.shape(x) if shape is None else shape
return broadcast(onp.array(fill_value, dtype or _dtype(x)), shape)
out = full(shape, fill_value, dtype or _dtype(x))
return tie_in(x, out)
def collapse(operand, start_dimension, stop_dimension):
@ -631,8 +654,7 @@ ShapedArray._iter = staticmethod(_iter)
# Add some ad handlers that use (or could use) lax primitives
def zeros_like_array(x):
dtype = xla_bridge.canonicalize_dtype(_dtype(x))
return onp.broadcast_to(onp.zeros((), dtype), onp.shape(x))
return full_like(x, 0)
for t in itertools.chain(array_types, [xla.DeviceArray]):
ad_util.jaxval_adders[t] = add
@ -648,8 +670,6 @@ _input_dtype = lambda *args, **_: xla_bridge.canonicalize_dtype(args[0].dtype)
_fixed_dtype = lambda dtype: lambda *args, **kwargs: xla_bridge.canonicalize_dtype(dtype)
_complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype
def identity(x): return x
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
prim = Primitive(name)
@ -1561,9 +1581,11 @@ def select_dtype_rule(pred, on_true, on_false):
return on_true.dtype
def select_transpose_rule(t, pred, on_true, on_false):
assert pred is not None
zeros = full_like(t, 0)
return [None,
select(pred, t, _zeros(on_false)) if on_true is None else None,
select(pred, _zeros(on_true), t) if on_false is None else None]
select(pred, t, zeros) if on_true is None else None,
select(pred, zeros, t) if on_false is None else None]
def select_batch_rule(batched_args, batch_dims, **unused_kwargs):
oprand, on_true, on_false, = batched_args
@ -2218,6 +2240,71 @@ while_p.def_abstract_eval(while_loop_abstract_eval)
xla.translations[while_p] = while_loop_translation_rule
### primitives for handling constants
def tie_in_transpose_rule(t):
return [ad_util.zero, t]
def tie_in_batch_rule(batched_args, batch_dims):
y = tie_in(*batched_args)
_, bdim_y = batch_dims
return y, bdim_y
tie_in_p = Primitive('tie_in')
tie_in_p.def_impl(lambda x, y: y)
tie_in_p.def_abstract_eval(lambda x, y: y)
xla.translations[tie_in_p] = lambda c, x, y: y
ad.deflinear(tie_in_p, tie_in_transpose_rule)
batching.primitive_batchers[tie_in_p] = tie_in_batch_rule
class FilledConstant(xla.DeviceConstant):
__slots__ = ["fill_value"]
def __init__(self, fill_value, shape):
assert type(fill_value) is onp.ndarray
self.shape = shape
self.dtype = _dtype(fill_value)
self.ndim = len(shape)
self.size = prod(shape)
self._npy_value = None
self.fill_value = fill_value
@property
def _value(self):
return onp.full(self.shape, self.fill_value)
@staticmethod
def constant_handler(c, filled_const):
return c.Broadcast(c.NumpyArrayConstant(filled_const.fill_value),
filled_const.shape)
xla.register_device_constant(FilledConstant)
# TODO(mattjj): if we used isinstance rather than handlers here, these would all
# be covered as subclasses of DeviceArray. alternatively, just set up these in a
# loop after we've defined all the constant DeviceConstant subclasses.
batching.pytype_aval_mappings[FilledConstant] = make_shaped_array
ad_util.jaxval_adders[FilledConstant] = add
ad_util.jaxval_zeros_likers[FilledConstant] = zeros_like_array
def full_batch_rule(batched_args, batch_dims, shape):
fill_value, = batched_args
bdim, = batch_dims
assert bdim == 0
return broadcast_in_dim(fill_value, fill_value.shape + shape, [bdim])
full_p = Primitive('full_p')
full_p.def_impl(FilledConstant)
full_p.def_abstract_eval(
lambda fill_value, shape: ShapedArray(shape, _dtype(fill_value)))
xla.translations[full_p] = \
lambda c, fill_value, shape: c.Broadcast(fill_value, shape)
ad.deflinear(full_p, lambda t, shape: [_reduce_sum(t, tuple(range(len(shape))))])
batching.primitive_batchers[full_p] = full_batch_rule
### util
def _ndim(x):
@ -2350,13 +2437,19 @@ def _dynamic_slice_indices(operand, start_indices):
return rem(start_indices, onp.array(operand.shape, start_indices.dtype))
_const = lambda example, val: onp.array(val, _dtype(example))
_zeros = partial(full_like, fill_value=0)
_zero = partial(full_like, shape=(), fill_value=0)
_ones = partial(full_like, fill_value=1)
_one = partial(full_like, shape=(), fill_value=1)
_twos = partial(full_like, fill_value=2)
_two = partial(full_like, shape=(), fill_value=2)
def _ndarray_full_like(x, fill_value, dtype=None, shape=None):
return onp.broadcast_to(onp.array(fill_value, dtype or _dtype(x)),
onp.shape(x) if shape is None else shape)
def _const(example, val):
return onp.array(val, _dtype(example))
_zeros = partial(_ndarray_full_like, fill_value=0)
_zero = partial(_ndarray_full_like, shape=(), fill_value=0)
_ones = partial(_ndarray_full_like, fill_value=1)
_one = partial(_ndarray_full_like, shape=(), fill_value=1)
_twos = partial(_ndarray_full_like, fill_value=2)
_two = partial(_ndarray_full_like, shape=(), fill_value=2)
_dtype = onp.result_type
_iscomplex = lambda x: onp.issubdtype(_dtype(x), onp.complexfloating)

View File

@ -355,6 +355,7 @@ def _ndarray_constant_handler(c, val):
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
staged into the XLA Computation.
"""
# TODO(mattjj): revise this to use c.BroadcastInDim rather than Transpose
if onp.any(onp.equal(0, val.strides)) and val.size > 0:
zero_stride_axes, = onp.where(onp.equal(0, val.strides))
other_axes, = onp.where(onp.not_equal(0, val.strides))

View File

@ -794,33 +794,27 @@ asarray = array
@_wraps(onp.zeros_like)
def zeros_like(x, dtype=None):
return zeros(_shape(x), dtype or _dtype(x))
return lax.full_like(x, 0, dtype)
@_wraps(onp.ones_like)
def ones_like(x, dtype=None):
return ones(_shape(x), dtype or _dtype(x))
return lax.full_like(x, 1, dtype)
@_wraps(onp.full)
def full(shape, fill_value, dtype=None):
if dtype:
fill_value = lax.convert_element_type(fill_value, dtype)
return lax.broadcast(fill_value, tuple(shape))
full = _wraps(onp.full)(lax.full)
@_wraps(onp.zeros)
def zeros(shape, dtype=onp.dtype("float64")):
shape = (shape,) if onp.isscalar(shape) else shape
dtype = xla_bridge.canonicalize_dtype(dtype)
return onp.broadcast_to(onp.zeros((), dtype), tuple(shape))
return lax.full(shape, 0, dtype)
@_wraps(onp.ones)
def ones(shape, dtype=onp.dtype("float64")):
shape = (shape,) if onp.isscalar(shape) else shape
dtype = xla_bridge.canonicalize_dtype(dtype)
return onp.broadcast_to(onp.ones((), dtype), tuple(shape))
return lax.full(shape, 1, dtype)
@_wraps(onp.repeat)