mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add tie_in and full primitives (constant creation)
This commit is contained in:
parent
25cf9358d1
commit
f971415218
@ -249,6 +249,8 @@ class Tracer(object):
|
||||
def __hex__(self): return self.aval._hex(self)
|
||||
def __oct__(self): return self.aval._oct(self)
|
||||
|
||||
def __setitem__(self, idx, val):
|
||||
raise TypeError("JAX 'Tracer' objects do not support item assignment")
|
||||
|
||||
def __getattr__(self, name):
|
||||
# if the aval property raises an AttributeError, gets caught here
|
||||
|
@ -74,6 +74,8 @@ def execute_compiled_primitive(compiled, result_handler, *args):
|
||||
def device_put(x):
|
||||
if type(x) is DeviceArray:
|
||||
return x.device_buffer
|
||||
elif isinstance(x, DeviceConstant):
|
||||
return instantiate_device_constant(x)
|
||||
else:
|
||||
return xb.device_put(x) # can round-trip elements of tuples here
|
||||
|
||||
@ -237,6 +239,7 @@ class DeviceArray(DeviceValue):
|
||||
self.size = size
|
||||
self._npy_value = None
|
||||
|
||||
# TODO make device_buffer a property, make the _npy_value writeable, invalidate
|
||||
@property
|
||||
def _value(self):
|
||||
if self._npy_value is None:
|
||||
@ -319,6 +322,31 @@ xb.register_constant_handler(DeviceArray,
|
||||
lambda c, val: c.Constant(onp.asarray(val)))
|
||||
|
||||
|
||||
class DeviceConstant(DeviceArray):
|
||||
@staticmethod
|
||||
def constant_handler(c, constant_instance):
|
||||
assert False
|
||||
|
||||
# TODO(mattjj): tune cutoff
|
||||
def instantiate_device_constant(const, cutoff=1000000):
|
||||
# dispatch an XLA Computation to build the constant on the device if it's
|
||||
# large, or alternatively build it on the host and transfer it if it's small
|
||||
assert isinstance(const, DeviceConstant)
|
||||
if const.size > cutoff:
|
||||
c = xb.make_computation_builder("constant_instantiating_computation")
|
||||
xla_const = const.constant_handler(c, const)
|
||||
compiled = c.Build(xla_const).Compile((), xb.get_compile_options())
|
||||
return compiled.Execute(())
|
||||
else:
|
||||
return xb.device_put(onp.asarray(const))
|
||||
|
||||
def register_device_constant(cls):
|
||||
pytype_aval_mappings[cls] = pytype_aval_mappings[DeviceArray]
|
||||
canonicalize_dtype_handlers[cls] = identity
|
||||
core.pytype_aval_mappings[cls] = ConcreteArray
|
||||
xb.register_constant_handler(cls, cls.constant_handler)
|
||||
|
||||
|
||||
def xla_shape(x):
|
||||
try:
|
||||
return xb.Shape.array_shape(x.dtype, x.shape)
|
||||
|
123
jax/lax.py
123
jax/lax.py
@ -17,7 +17,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
from .util import partial
|
||||
from .util import partial, prod
|
||||
import itertools
|
||||
import operator
|
||||
import six
|
||||
@ -43,6 +43,9 @@ from .lib import xla_bridge
|
||||
|
||||
_max = builtins.max
|
||||
_min = builtins.max
|
||||
_reduce = six.moves.reduce
|
||||
|
||||
def identity(x): return x
|
||||
|
||||
### traceables
|
||||
|
||||
@ -411,6 +414,25 @@ class OpaqueParam(object):
|
||||
opaque_param_ids = itertools.count()
|
||||
|
||||
|
||||
def tie_in(x, y):
|
||||
return tie_in_p.bind(x, y)
|
||||
|
||||
def full(shape, fill_value, dtype=None):
|
||||
if onp.shape(fill_value):
|
||||
msg = "full must be called with scalar fill_value, got fill_value.shape {}."
|
||||
raise TypeError(msg.format(onp.shape(fill_value)))
|
||||
|
||||
dtype = dtype and xla_bridge.canonicalize_dtype(dtype)
|
||||
if dtype is not None and _dtype(fill_value) != dtype:
|
||||
# for Python scalars and raw ndarrays, we keep fill_value as a cpu ndarray
|
||||
if onp.isscalar(fill_value) or type(fill_value) is onp.ndarray:
|
||||
fill_value = onp.array(fill_value, dtype)
|
||||
else:
|
||||
fill_value = convert_element_type(fill_value, dtype)
|
||||
|
||||
return full_p.bind(fill_value, shape=shape)
|
||||
|
||||
|
||||
### convenience wrappers around traceables
|
||||
|
||||
|
||||
@ -439,7 +461,8 @@ def full_like(x, fill_value, dtype=None, shape=None):
|
||||
`fill_value`, similar to the output of np.full.
|
||||
"""
|
||||
shape = onp.shape(x) if shape is None else shape
|
||||
return broadcast(onp.array(fill_value, dtype or _dtype(x)), shape)
|
||||
out = full(shape, fill_value, dtype or _dtype(x))
|
||||
return tie_in(x, out)
|
||||
|
||||
|
||||
def collapse(operand, start_dimension, stop_dimension):
|
||||
@ -631,8 +654,7 @@ ShapedArray._iter = staticmethod(_iter)
|
||||
# Add some ad handlers that use (or could use) lax primitives
|
||||
|
||||
def zeros_like_array(x):
|
||||
dtype = xla_bridge.canonicalize_dtype(_dtype(x))
|
||||
return onp.broadcast_to(onp.zeros((), dtype), onp.shape(x))
|
||||
return full_like(x, 0)
|
||||
|
||||
for t in itertools.chain(array_types, [xla.DeviceArray]):
|
||||
ad_util.jaxval_adders[t] = add
|
||||
@ -648,8 +670,6 @@ _input_dtype = lambda *args, **_: xla_bridge.canonicalize_dtype(args[0].dtype)
|
||||
_fixed_dtype = lambda dtype: lambda *args, **kwargs: xla_bridge.canonicalize_dtype(dtype)
|
||||
_complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype
|
||||
|
||||
def identity(x): return x
|
||||
|
||||
|
||||
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
||||
prim = Primitive(name)
|
||||
@ -1561,9 +1581,11 @@ def select_dtype_rule(pred, on_true, on_false):
|
||||
return on_true.dtype
|
||||
|
||||
def select_transpose_rule(t, pred, on_true, on_false):
|
||||
assert pred is not None
|
||||
zeros = full_like(t, 0)
|
||||
return [None,
|
||||
select(pred, t, _zeros(on_false)) if on_true is None else None,
|
||||
select(pred, _zeros(on_true), t) if on_false is None else None]
|
||||
select(pred, t, zeros) if on_true is None else None,
|
||||
select(pred, zeros, t) if on_false is None else None]
|
||||
|
||||
def select_batch_rule(batched_args, batch_dims, **unused_kwargs):
|
||||
oprand, on_true, on_false, = batched_args
|
||||
@ -2218,6 +2240,71 @@ while_p.def_abstract_eval(while_loop_abstract_eval)
|
||||
xla.translations[while_p] = while_loop_translation_rule
|
||||
|
||||
|
||||
### primitives for handling constants
|
||||
|
||||
|
||||
def tie_in_transpose_rule(t):
|
||||
return [ad_util.zero, t]
|
||||
|
||||
def tie_in_batch_rule(batched_args, batch_dims):
|
||||
y = tie_in(*batched_args)
|
||||
_, bdim_y = batch_dims
|
||||
return y, bdim_y
|
||||
|
||||
tie_in_p = Primitive('tie_in')
|
||||
tie_in_p.def_impl(lambda x, y: y)
|
||||
tie_in_p.def_abstract_eval(lambda x, y: y)
|
||||
xla.translations[tie_in_p] = lambda c, x, y: y
|
||||
ad.deflinear(tie_in_p, tie_in_transpose_rule)
|
||||
batching.primitive_batchers[tie_in_p] = tie_in_batch_rule
|
||||
|
||||
|
||||
class FilledConstant(xla.DeviceConstant):
|
||||
__slots__ = ["fill_value"]
|
||||
|
||||
def __init__(self, fill_value, shape):
|
||||
assert type(fill_value) is onp.ndarray
|
||||
self.shape = shape
|
||||
self.dtype = _dtype(fill_value)
|
||||
self.ndim = len(shape)
|
||||
self.size = prod(shape)
|
||||
self._npy_value = None
|
||||
|
||||
self.fill_value = fill_value
|
||||
|
||||
@property
|
||||
def _value(self):
|
||||
return onp.full(self.shape, self.fill_value)
|
||||
|
||||
@staticmethod
|
||||
def constant_handler(c, filled_const):
|
||||
return c.Broadcast(c.NumpyArrayConstant(filled_const.fill_value),
|
||||
filled_const.shape)
|
||||
xla.register_device_constant(FilledConstant)
|
||||
|
||||
# TODO(mattjj): if we used isinstance rather than handlers here, these would all
|
||||
# be covered as subclasses of DeviceArray. alternatively, just set up these in a
|
||||
# loop after we've defined all the constant DeviceConstant subclasses.
|
||||
batching.pytype_aval_mappings[FilledConstant] = make_shaped_array
|
||||
ad_util.jaxval_adders[FilledConstant] = add
|
||||
ad_util.jaxval_zeros_likers[FilledConstant] = zeros_like_array
|
||||
|
||||
def full_batch_rule(batched_args, batch_dims, shape):
|
||||
fill_value, = batched_args
|
||||
bdim, = batch_dims
|
||||
assert bdim == 0
|
||||
return broadcast_in_dim(fill_value, fill_value.shape + shape, [bdim])
|
||||
|
||||
full_p = Primitive('full_p')
|
||||
full_p.def_impl(FilledConstant)
|
||||
full_p.def_abstract_eval(
|
||||
lambda fill_value, shape: ShapedArray(shape, _dtype(fill_value)))
|
||||
xla.translations[full_p] = \
|
||||
lambda c, fill_value, shape: c.Broadcast(fill_value, shape)
|
||||
ad.deflinear(full_p, lambda t, shape: [_reduce_sum(t, tuple(range(len(shape))))])
|
||||
batching.primitive_batchers[full_p] = full_batch_rule
|
||||
|
||||
|
||||
### util
|
||||
|
||||
def _ndim(x):
|
||||
@ -2350,13 +2437,19 @@ def _dynamic_slice_indices(operand, start_indices):
|
||||
return rem(start_indices, onp.array(operand.shape, start_indices.dtype))
|
||||
|
||||
|
||||
_const = lambda example, val: onp.array(val, _dtype(example))
|
||||
_zeros = partial(full_like, fill_value=0)
|
||||
_zero = partial(full_like, shape=(), fill_value=0)
|
||||
_ones = partial(full_like, fill_value=1)
|
||||
_one = partial(full_like, shape=(), fill_value=1)
|
||||
_twos = partial(full_like, fill_value=2)
|
||||
_two = partial(full_like, shape=(), fill_value=2)
|
||||
def _ndarray_full_like(x, fill_value, dtype=None, shape=None):
|
||||
return onp.broadcast_to(onp.array(fill_value, dtype or _dtype(x)),
|
||||
onp.shape(x) if shape is None else shape)
|
||||
|
||||
def _const(example, val):
|
||||
return onp.array(val, _dtype(example))
|
||||
|
||||
_zeros = partial(_ndarray_full_like, fill_value=0)
|
||||
_zero = partial(_ndarray_full_like, shape=(), fill_value=0)
|
||||
_ones = partial(_ndarray_full_like, fill_value=1)
|
||||
_one = partial(_ndarray_full_like, shape=(), fill_value=1)
|
||||
_twos = partial(_ndarray_full_like, fill_value=2)
|
||||
_two = partial(_ndarray_full_like, shape=(), fill_value=2)
|
||||
|
||||
_dtype = onp.result_type
|
||||
_iscomplex = lambda x: onp.issubdtype(_dtype(x), onp.complexfloating)
|
||||
|
@ -355,6 +355,7 @@ def _ndarray_constant_handler(c, val):
|
||||
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
|
||||
staged into the XLA Computation.
|
||||
"""
|
||||
# TODO(mattjj): revise this to use c.BroadcastInDim rather than Transpose
|
||||
if onp.any(onp.equal(0, val.strides)) and val.size > 0:
|
||||
zero_stride_axes, = onp.where(onp.equal(0, val.strides))
|
||||
other_axes, = onp.where(onp.not_equal(0, val.strides))
|
||||
|
@ -794,33 +794,27 @@ asarray = array
|
||||
|
||||
@_wraps(onp.zeros_like)
|
||||
def zeros_like(x, dtype=None):
|
||||
return zeros(_shape(x), dtype or _dtype(x))
|
||||
return lax.full_like(x, 0, dtype)
|
||||
|
||||
|
||||
@_wraps(onp.ones_like)
|
||||
def ones_like(x, dtype=None):
|
||||
return ones(_shape(x), dtype or _dtype(x))
|
||||
return lax.full_like(x, 1, dtype)
|
||||
|
||||
|
||||
@_wraps(onp.full)
|
||||
def full(shape, fill_value, dtype=None):
|
||||
if dtype:
|
||||
fill_value = lax.convert_element_type(fill_value, dtype)
|
||||
return lax.broadcast(fill_value, tuple(shape))
|
||||
full = _wraps(onp.full)(lax.full)
|
||||
|
||||
|
||||
@_wraps(onp.zeros)
|
||||
def zeros(shape, dtype=onp.dtype("float64")):
|
||||
shape = (shape,) if onp.isscalar(shape) else shape
|
||||
dtype = xla_bridge.canonicalize_dtype(dtype)
|
||||
return onp.broadcast_to(onp.zeros((), dtype), tuple(shape))
|
||||
return lax.full(shape, 0, dtype)
|
||||
|
||||
|
||||
@_wraps(onp.ones)
|
||||
def ones(shape, dtype=onp.dtype("float64")):
|
||||
shape = (shape,) if onp.isscalar(shape) else shape
|
||||
dtype = xla_bridge.canonicalize_dtype(dtype)
|
||||
return onp.broadcast_to(onp.ones((), dtype), tuple(shape))
|
||||
return lax.full(shape, 1, dtype)
|
||||
|
||||
|
||||
@_wraps(onp.repeat)
|
||||
|
Loading…
x
Reference in New Issue
Block a user