mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add IotaConstant (untested)
This commit is contained in:
parent
f971415218
commit
dfc25a06d9
63
jax/lax.py
63
jax/lax.py
@ -432,6 +432,12 @@ def full(shape, fill_value, dtype=None):
|
||||
|
||||
return full_p.bind(fill_value, shape=shape)
|
||||
|
||||
def iota(dtype, shape, dimension):
|
||||
dtype = xla_bridge.canonicalize_dtype(dtype)
|
||||
shape = tuple(map(int(shape)))
|
||||
dimension = int(dimension)
|
||||
return IotaConstant(dtype, shape, dimension)
|
||||
|
||||
|
||||
### convenience wrappers around traceables
|
||||
|
||||
@ -2282,13 +2288,6 @@ class FilledConstant(xla.DeviceConstant):
|
||||
filled_const.shape)
|
||||
xla.register_device_constant(FilledConstant)
|
||||
|
||||
# TODO(mattjj): if we used isinstance rather than handlers here, these would all
|
||||
# be covered as subclasses of DeviceArray. alternatively, just set up these in a
|
||||
# loop after we've defined all the constant DeviceConstant subclasses.
|
||||
batching.pytype_aval_mappings[FilledConstant] = make_shaped_array
|
||||
ad_util.jaxval_adders[FilledConstant] = add
|
||||
ad_util.jaxval_zeros_likers[FilledConstant] = zeros_like_array
|
||||
|
||||
def full_batch_rule(batched_args, batch_dims, shape):
|
||||
fill_value, = batched_args
|
||||
bdim, = batch_dims
|
||||
@ -2305,6 +2304,40 @@ ad.deflinear(full_p, lambda t, shape: [_reduce_sum(t, tuple(range(len(shape))))]
|
||||
batching.primitive_batchers[full_p] = full_batch_rule
|
||||
|
||||
|
||||
class IotaConstant(xla.DeviceConstant):
|
||||
__slots__ = ["axis"]
|
||||
|
||||
def __init__(self, dtype, shape, axis):
|
||||
self.shape = shape
|
||||
self.dtype = onp.dtype(dtype)
|
||||
self.ndim = len(shape)
|
||||
self.size = prod(shape)
|
||||
self._npy_value = None
|
||||
|
||||
self.axis = axis
|
||||
|
||||
@property
|
||||
def _value(self):
|
||||
if self._npy_value is None:
|
||||
iota = onp.arange(self.shape[self.axis], dtype=self.dtype)
|
||||
iota = iota.reshape([self.shape[self.axis] if i == self.axis else 1
|
||||
for i in range(self.ndim)])
|
||||
self._npy_value = onp.broadcast_to(iota, self.shape)
|
||||
return self._npy_value
|
||||
|
||||
@staticmethod
|
||||
def constant_handler(c, iota_constant):
|
||||
return c.BroadcastedIota(iota_constant.dtype, iota_constant.shape,
|
||||
iota_constant.axis)
|
||||
xla.register_device_constant(IotaConstant)
|
||||
|
||||
|
||||
for t in [FilledConstant, IotaConstant]:
|
||||
batching.pytype_aval_mappings[t] = make_shaped_array
|
||||
ad_util.jaxval_adders[t] = add
|
||||
ad_util.jaxval_zeros_likers[t] = zeros_like_array
|
||||
|
||||
|
||||
### util
|
||||
|
||||
def _ndim(x):
|
||||
@ -2437,19 +2470,15 @@ def _dynamic_slice_indices(operand, start_indices):
|
||||
return rem(start_indices, onp.array(operand.shape, start_indices.dtype))
|
||||
|
||||
|
||||
def _ndarray_full_like(x, fill_value, dtype=None, shape=None):
|
||||
return onp.broadcast_to(onp.array(fill_value, dtype or _dtype(x)),
|
||||
onp.shape(x) if shape is None else shape)
|
||||
|
||||
def _const(example, val):
|
||||
return onp.array(val, _dtype(example))
|
||||
|
||||
_zeros = partial(_ndarray_full_like, fill_value=0)
|
||||
_zero = partial(_ndarray_full_like, shape=(), fill_value=0)
|
||||
_ones = partial(_ndarray_full_like, fill_value=1)
|
||||
_one = partial(_ndarray_full_like, shape=(), fill_value=1)
|
||||
_twos = partial(_ndarray_full_like, fill_value=2)
|
||||
_two = partial(_ndarray_full_like, shape=(), fill_value=2)
|
||||
_zeros = partial(full_like, fill_value=0)
|
||||
_zero = partial(full_like, shape=(), fill_value=0)
|
||||
_ones = partial(full_like, fill_value=1)
|
||||
_one = partial(full_like, shape=(), fill_value=1)
|
||||
_twos = partial(full_like, fill_value=2)
|
||||
_two = partial(full_like, shape=(), fill_value=2)
|
||||
|
||||
_dtype = onp.result_type
|
||||
_iscomplex = lambda x: onp.issubdtype(_dtype(x), onp.complexfloating)
|
||||
|
Loading…
x
Reference in New Issue
Block a user