Merge pull request #963 from hawkinsp/deviceput

Implement device_put as a primitive.
This commit is contained in:
Peter Hawkins 2019-07-02 14:40:59 -04:00 committed by GitHub
commit c8266cdbab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 23 deletions

View File

@ -989,24 +989,8 @@ def make_jaxpr(fun):
tree_to_pval_tuples = partial(process_pytree, pe.pack_pvals)
_traceable_device_put = jit(lambda x: x)
def device_put(x, device_num=0):
def _device_put(x):
if isinstance(x, core.Tracer):
return _traceable_device_put(x)
try:
a = xla.abstractify(x)
except TypeError:
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
.format(x, type(x)))
result_shape = xla.xla_shape_to_result_shape(xla.xla_shape(a))
handler = xla.device_persistent_result_handler(result_shape)
return handler(xla.device_put(x, device_num))
return tree_map(_device_put, x)
return tree_map(lambda y: xla.device_put_p.bind(y, device_num=device_num), x)
device_get = _jit(lambda x: x, (), device_values=False)

View File

@ -238,7 +238,7 @@ def defvectorized(prim):
primitive_batchers[prim] = partial(vectorized_batcher, prim)
def vectorized_batcher(prim, batched_args, batch_dims, **params):
assert all(batch_dims[0] == bd for bd in batch_dims[1:])
assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
return prim.bind(*batched_args, **params), batch_dims[0]
def defbroadcasting(prim):
@ -262,7 +262,7 @@ def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
params = dict(params, input_shape=operand.shape)
return prim.bind(operand, axes=axes, **params), bdim_out
# set up primitive batches for ad_util primitives
# sets up primitive batchers for ad_util and xla primitives
def add_batched(batched_args, batch_dims):
bdx, bdy = batch_dims
@ -284,6 +284,7 @@ def zeros_like_batched(batched_args, batch_dims):
return zeros_like_jaxval(val), bdim
primitive_batchers[zeros_like_p] = zeros_like_batched
defvectorized(xla.device_put_p)
### util

View File

@ -179,11 +179,11 @@ class _ResultArray(tuple): pass
def result_handler(result_shape):
if FLAGS.jax_device_values:
return device_persistent_result_handler(result_shape)
return _device_persistent_result_handler(result_shape)
else:
return _pyval_result_handler(result_shape)
def device_persistent_result_handler(result_shape):
def _device_persistent_result_handler(result_shape):
t = type(result_shape)
if t is _ResultArray:
return partial(DeviceArray, result_shape)
@ -416,7 +416,7 @@ class DeviceTuple(DeviceValue):
def __iter__(self):
bufs = self.device_buffer.destructure()
handlers = map(device_persistent_result_handler, self.result_shapes)
handlers = map(_device_persistent_result_handler, self.result_shapes)
elts = [handler(buf) for handler, buf in zip(handlers, bufs)]
return iter(elts)
@ -632,7 +632,7 @@ def _xla_callable(fun, device_values, *abstract_args):
compiled, result_shape = _compile_jaxpr(jaxpr, consts, *abstract_args)
del master, consts, jaxpr, env
if device_values:
handle_result = device_persistent_result_handler(result_shape)
handle_result = _device_persistent_result_handler(result_shape)
else:
handle_result = _pyval_result_handler(result_shape)
return partial(_execute_compiled, compiled, pval, handle_result)
@ -655,3 +655,21 @@ xla_call_p.def_impl(_xla_call_impl)
translations[xla_call_p] = xla_call_translation_rule
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
def _device_put_impl(x, device_num=0):
try:
a = abstractify(x)
except TypeError:
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
.format(x, type(x)))
result_shape = xla_shape_to_result_shape(xla_shape(a))
handler = _device_persistent_result_handler(result_shape)
return handler(device_put(x, device_num))
device_put_p = core.Primitive('device_put')
device_put_p.def_impl(_device_put_impl)
device_put_p.def_abstract_eval(lambda x, **kwargs: x)
translations[device_put_p] = lambda c, x, **kwargs: x
ad.deflinear(device_put_p, lambda cotangent, **kwargs: [cotangent])