mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #963 from hawkinsp/deviceput
Implement device_put as a primitive.
This commit is contained in:
commit
c8266cdbab
18
jax/api.py
18
jax/api.py
@ -989,24 +989,8 @@ def make_jaxpr(fun):
|
||||
tree_to_pval_tuples = partial(process_pytree, pe.pack_pvals)
|
||||
|
||||
|
||||
|
||||
_traceable_device_put = jit(lambda x: x)
|
||||
|
||||
def device_put(x, device_num=0):
|
||||
def _device_put(x):
|
||||
if isinstance(x, core.Tracer):
|
||||
return _traceable_device_put(x)
|
||||
|
||||
try:
|
||||
a = xla.abstractify(x)
|
||||
except TypeError:
|
||||
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
|
||||
.format(x, type(x)))
|
||||
|
||||
result_shape = xla.xla_shape_to_result_shape(xla.xla_shape(a))
|
||||
handler = xla.device_persistent_result_handler(result_shape)
|
||||
return handler(xla.device_put(x, device_num))
|
||||
return tree_map(_device_put, x)
|
||||
return tree_map(lambda y: xla.device_put_p.bind(y, device_num=device_num), x)
|
||||
|
||||
|
||||
device_get = _jit(lambda x: x, (), device_values=False)
|
||||
|
@ -238,7 +238,7 @@ def defvectorized(prim):
|
||||
primitive_batchers[prim] = partial(vectorized_batcher, prim)
|
||||
|
||||
def vectorized_batcher(prim, batched_args, batch_dims, **params):
|
||||
assert all(batch_dims[0] == bd for bd in batch_dims[1:])
|
||||
assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
|
||||
return prim.bind(*batched_args, **params), batch_dims[0]
|
||||
|
||||
def defbroadcasting(prim):
|
||||
@ -262,7 +262,7 @@ def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
|
||||
params = dict(params, input_shape=operand.shape)
|
||||
return prim.bind(operand, axes=axes, **params), bdim_out
|
||||
|
||||
# set up primitive batches for ad_util primitives
|
||||
# sets up primitive batchers for ad_util and xla primitives
|
||||
|
||||
def add_batched(batched_args, batch_dims):
|
||||
bdx, bdy = batch_dims
|
||||
@ -284,6 +284,7 @@ def zeros_like_batched(batched_args, batch_dims):
|
||||
return zeros_like_jaxval(val), bdim
|
||||
primitive_batchers[zeros_like_p] = zeros_like_batched
|
||||
|
||||
defvectorized(xla.device_put_p)
|
||||
|
||||
### util
|
||||
|
||||
|
@ -179,11 +179,11 @@ class _ResultArray(tuple): pass
|
||||
|
||||
def result_handler(result_shape):
|
||||
if FLAGS.jax_device_values:
|
||||
return device_persistent_result_handler(result_shape)
|
||||
return _device_persistent_result_handler(result_shape)
|
||||
else:
|
||||
return _pyval_result_handler(result_shape)
|
||||
|
||||
def device_persistent_result_handler(result_shape):
|
||||
def _device_persistent_result_handler(result_shape):
|
||||
t = type(result_shape)
|
||||
if t is _ResultArray:
|
||||
return partial(DeviceArray, result_shape)
|
||||
@ -416,7 +416,7 @@ class DeviceTuple(DeviceValue):
|
||||
|
||||
def __iter__(self):
|
||||
bufs = self.device_buffer.destructure()
|
||||
handlers = map(device_persistent_result_handler, self.result_shapes)
|
||||
handlers = map(_device_persistent_result_handler, self.result_shapes)
|
||||
elts = [handler(buf) for handler, buf in zip(handlers, bufs)]
|
||||
return iter(elts)
|
||||
|
||||
@ -632,7 +632,7 @@ def _xla_callable(fun, device_values, *abstract_args):
|
||||
compiled, result_shape = _compile_jaxpr(jaxpr, consts, *abstract_args)
|
||||
del master, consts, jaxpr, env
|
||||
if device_values:
|
||||
handle_result = device_persistent_result_handler(result_shape)
|
||||
handle_result = _device_persistent_result_handler(result_shape)
|
||||
else:
|
||||
handle_result = _pyval_result_handler(result_shape)
|
||||
return partial(_execute_compiled, compiled, pval, handle_result)
|
||||
@ -655,3 +655,21 @@ xla_call_p.def_impl(_xla_call_impl)
|
||||
|
||||
translations[xla_call_p] = xla_call_translation_rule
|
||||
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
|
||||
|
||||
|
||||
def _device_put_impl(x, device_num=0):
|
||||
try:
|
||||
a = abstractify(x)
|
||||
except TypeError:
|
||||
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
|
||||
.format(x, type(x)))
|
||||
|
||||
result_shape = xla_shape_to_result_shape(xla_shape(a))
|
||||
handler = _device_persistent_result_handler(result_shape)
|
||||
return handler(device_put(x, device_num))
|
||||
|
||||
device_put_p = core.Primitive('device_put')
|
||||
device_put_p.def_impl(_device_put_impl)
|
||||
device_put_p.def_abstract_eval(lambda x, **kwargs: x)
|
||||
translations[device_put_p] = lambda c, x, **kwargs: x
|
||||
ad.deflinear(device_put_p, lambda cotangent, **kwargs: [cotangent])
|
||||
|
Loading…
x
Reference in New Issue
Block a user