add IotaConstant (untested)

This commit is contained in:
Matthew Johnson 2018-12-13 11:12:11 -08:00
parent f971415218
commit dfc25a06d9

View File

@ -432,6 +432,12 @@ def full(shape, fill_value, dtype=None):
return full_p.bind(fill_value, shape=shape)
def iota(dtype, shape, dimension):
dtype = xla_bridge.canonicalize_dtype(dtype)
shape = tuple(map(int(shape)))
dimension = int(dimension)
return IotaConstant(dtype, shape, dimension)
### convenience wrappers around traceables
@ -2282,13 +2288,6 @@ class FilledConstant(xla.DeviceConstant):
filled_const.shape)
xla.register_device_constant(FilledConstant)
# TODO(mattjj): if we used isinstance rather than handlers here, these would all
# be covered as subclasses of DeviceArray. alternatively, just set up these in a
# loop after we've defined all the constant DeviceConstant subclasses.
batching.pytype_aval_mappings[FilledConstant] = make_shaped_array
ad_util.jaxval_adders[FilledConstant] = add
ad_util.jaxval_zeros_likers[FilledConstant] = zeros_like_array
def full_batch_rule(batched_args, batch_dims, shape):
fill_value, = batched_args
bdim, = batch_dims
@ -2305,6 +2304,40 @@ ad.deflinear(full_p, lambda t, shape: [_reduce_sum(t, tuple(range(len(shape))))]
batching.primitive_batchers[full_p] = full_batch_rule
class IotaConstant(xla.DeviceConstant):
__slots__ = ["axis"]
def __init__(self, dtype, shape, axis):
self.shape = shape
self.dtype = onp.dtype(dtype)
self.ndim = len(shape)
self.size = prod(shape)
self._npy_value = None
self.axis = axis
@property
def _value(self):
if self._npy_value is None:
iota = onp.arange(self.shape[self.axis], dtype=self.dtype)
iota = iota.reshape([self.shape[self.axis] if i == self.axis else 1
for i in range(self.ndim)])
self._npy_value = onp.broadcast_to(iota, self.shape)
return self._npy_value
@staticmethod
def constant_handler(c, iota_constant):
return c.BroadcastedIota(iota_constant.dtype, iota_constant.shape,
iota_constant.axis)
xla.register_device_constant(IotaConstant)
for t in [FilledConstant, IotaConstant]:
batching.pytype_aval_mappings[t] = make_shaped_array
ad_util.jaxval_adders[t] = add
ad_util.jaxval_zeros_likers[t] = zeros_like_array
### util
def _ndim(x):
@ -2437,19 +2470,15 @@ def _dynamic_slice_indices(operand, start_indices):
return rem(start_indices, onp.array(operand.shape, start_indices.dtype))
def _ndarray_full_like(x, fill_value, dtype=None, shape=None):
return onp.broadcast_to(onp.array(fill_value, dtype or _dtype(x)),
onp.shape(x) if shape is None else shape)
def _const(example, val):
return onp.array(val, _dtype(example))
_zeros = partial(_ndarray_full_like, fill_value=0)
_zero = partial(_ndarray_full_like, shape=(), fill_value=0)
_ones = partial(_ndarray_full_like, fill_value=1)
_one = partial(_ndarray_full_like, shape=(), fill_value=1)
_twos = partial(_ndarray_full_like, fill_value=2)
_two = partial(_ndarray_full_like, shape=(), fill_value=2)
_zeros = partial(full_like, fill_value=0)
_zero = partial(full_like, shape=(), fill_value=0)
_ones = partial(full_like, fill_value=1)
_one = partial(full_like, shape=(), fill_value=1)
_twos = partial(full_like, fill_value=2)
_two = partial(full_like, shape=(), fill_value=2)
_dtype = onp.result_type
_iscomplex = lambda x: onp.issubdtype(_dtype(x), onp.complexfloating)