Make op-by-op work with all jit-returned devicearrays.

This commit is contained in:
Anselm Levskaya 2019-08-21 00:22:53 -07:00
parent 7529815614
commit f01fc35ce5
4 changed files with 51 additions and 27 deletions

View File

@ -105,6 +105,8 @@ def jit(fun, static_argnums=(), device_assignment=None, backend=None):
change. Optional, an int specifying the device ordinal for which to compile the
function. The default is inherited from XLA's DeviceAssignment logic and is
usually to use device 0.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.
Returns:
A wrapped version of `fun`, set up for just-in-time compilation.
@ -210,6 +212,8 @@ def xla_computation(fun, static_argnums=(), axis_env=None, backend=None):
functions that involve parallel communication collectives, and it
specifies the axis name/size environment that would be set up by
applications of ``jax.pmap``. See the examples below.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns a
@ -649,6 +653,8 @@ def pmap(fun, axis_name=None, backend=None):
fun: Function to be mapped over argument axes.
axis_name: Optional, a hashable Python object used to identify the mapped
axis so that parallel collectives can be applied.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.
Returns:
A parallelized version of ``fun`` with arguments that correspond to those of

View File

@ -598,8 +598,8 @@ xla_pmap = partial(core.call_bind, xla_pmap_p)
xla_pmap_p.def_custom_bind(xla_pmap)
xla_pmap_p.def_impl(xla_pmap_impl)
def _xla_pmap_translation_rule(c, jaxpr, backend, axis_env, env_nodes, in_nodes,
axis_name, axis_size):
def _xla_pmap_translation_rule(c, jaxpr, axis_env, env_nodes, in_nodes,
axis_name, axis_size, backend=None):
new_env = xla.extend_axis_env(axis_env, axis_name, axis_size)
in_nodes_sharded = list(map(partial(xla_shard, c, new_env.sizes), in_nodes))
subc = xla._jaxpr_computation(jaxpr, backend, new_env, (),

View File

@ -145,6 +145,7 @@ def device_put(x, device_num=0, backend=None):
DeviceTuple, and arraylike includes DeviceArray, DeviceConstant, and
anything that has an '__array__' attr.
device_num: an int representing the target physical device number.
backend: a string representing the xla backend. ('cpu','gpu','tpu')
Returns:
A buffer representing the input `x` placed on the appropriate device.
@ -152,10 +153,20 @@ def device_put(x, device_num=0, backend=None):
x = _canonicalize_pyval_dtype(x)
t = type(x)
if t is DeviceArray or t is DeviceTuple:
if x.device_buffer.device() == device_num:
return x.device_buffer
# TODO(levskaya) remove if-condition after increasing minimum Jaxlib version to
# 0.1.24.
if hasattr(x.device_buffer, 'platform'):
backend_match = x.device_buffer.platform() == backend
else:
return x.device_buffer.copy_to_device(device_num)
backend_match = True
if backend_match:
if x.device_buffer.device() == device_num:
return x.device_buffer
else:
return x.device_buffer.copy_to_device(device_num)
else:
# Buffers from different XLA backends are passed through the host.
return xla_client.Buffer.from_pyval(x, device_num, backend=xb.get_backend(backend))
elif isinstance(x, DeviceConstant):
return _instantiate_device_constant(x, device_num=device_num, backend=backend)
elif isinstance(x, (DeviceArray, onp.ndarray)):
@ -281,7 +292,7 @@ def _jaxpr_computation(jaxpr, backend, axis_env, const_vals, freevar_shapes, *ar
(subjaxpr, const_bindings, freevar_bindings), = eqn.bound_subjaxprs
env_nodes = list(map(read, const_bindings + freevar_bindings))
rule = call_translations[eqn.primitive]
ans = rule(c, subjaxpr, backend, axis_env, env_nodes, in_nodes, **eqn.params)
ans = rule(c, subjaxpr, axis_env, env_nodes, in_nodes, backend=backend, **eqn.params)
else:
msg = "XLA translation rule for primitive '{}' not found"
raise NotImplementedError(msg.format(eqn.primitive.name))
@ -718,8 +729,8 @@ xla_call = partial(core.call_bind, xla_call_p)
xla_call_p.def_custom_bind(xla_call)
xla_call_p.def_impl(_xla_call_impl)
def _xla_call_translation_rule(c, jaxpr, backend, axis_env, env_nodes, in_nodes,
device_assignment):
def _xla_call_translation_rule(c, jaxpr, axis_env, env_nodes, in_nodes,
device_assignment=None, backend=None):
del device_assignment # Unused.
subc = _jaxpr_computation(jaxpr, backend, axis_env, (), _map(c.GetShape, env_nodes),
*map(c.GetShape, in_nodes))

View File

@ -1445,16 +1445,22 @@ _input_dtype = lambda *args, **_: xla_bridge.canonicalize_dtype(args[0].dtype)
_fixed_dtype = lambda dtype: lambda *args, **kwargs: xla_bridge.canonicalize_dtype(dtype)
_complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None, reduction=False):
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
prim = Primitive(name)
prim.def_impl(partial(xla.apply_primitive, prim))
prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
if reduction:
xla.reduction_translations[prim] = translation_rule or partial(standard_translate, name)
else:
xla.translations[prim] = translation_rule or partial(standard_translate, name)
xla.translations[prim] = translation_rule or partial(standard_translate, name)
return prim
def standard_reduction_primitive(shape_rule, dtype_rule, name, translation_rule=None):
prim = Primitive(name)
prim.def_impl(partial(xla.apply_primitive, prim))
prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
xla.reduction_translations[prim] = translation_rule or partial(standard_translate, name)
return prim
def standard_abstract_eval(shape_rule, dtype_rule, *args, **kwargs):
assert all(isinstance(arg, UnshapedArray) for arg in args), args
least_specialized = _max(
@ -3144,25 +3150,25 @@ def _scatter_batching_rule(
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
return scatter_op(operand, scatter_indices, updates, dnums), 0
scatter_add_p = standard_primitive(
scatter_add_p = standard_reduction_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
_scatter_translation_rule, reduction=True)
_scatter_translation_rule)
ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp
ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule
batching.primitive_batchers[scatter_add_p] = (
partial(_scatter_batching_rule, scatter_add))
# TODO(jlebar): Add derivatives.
scatter_min_p = standard_primitive(
scatter_min_p = standard_reduction_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
_scatter_translation_rule, reduction=True)
_scatter_translation_rule)
batching.primitive_batchers[scatter_min_p] = (
partial(_scatter_batching_rule, scatter_min))
# TODO(jlebar): Add derivatives.
scatter_max_p = standard_primitive(
scatter_max_p = standard_reduction_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-max',
_scatter_translation_rule, reduction=True)
_scatter_translation_rule)
batching.primitive_batchers[scatter_max_p] = (
partial(_scatter_batching_rule, scatter_max))
@ -3260,9 +3266,9 @@ def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,
return val_out, tangent_out
scatter_p = standard_primitive(
scatter_p = standard_reduction_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter',
_scatter_translation_rule, reduction=True)
_scatter_translation_rule)
ad.primitive_jvps[scatter_p] = _scatter_jvp
batching.primitive_batchers[scatter_p] = (
partial(_scatter_batching_rule, scatter))
@ -3291,8 +3297,9 @@ def _reduction_computation(c, jaxpr, backend, consts, init_value):
shape = c.GetShape(init_value)
return xla.jaxpr_computation(jaxpr, backend, consts, (), shape, shape)
reduce_p = standard_primitive(_reduce_shape_rule, _input_dtype, 'reduce',
_reduce_translation_rule, reduction=True)
reduce_p = standard_reduction_primitive(
_reduce_shape_rule, _input_dtype, 'reduce',
_reduce_translation_rule)
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
@ -3458,9 +3465,9 @@ def _generic_reduce_window_batch_rule(
window_dimensions, window_strides, padding)
reduce_window_p = standard_primitive(
reduce_window_p = standard_reduction_primitive(
_reduce_window_shape_rule, _input_dtype, 'reduce_window',
_reduce_window_translation_rule, reduction=True)
_reduce_window_translation_rule)
batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
@ -3597,9 +3604,9 @@ def _select_and_scatter_translation(
return c.SelectAndScatter(operand, select, window_dimensions, window_strides,
padding, source, init_value, scatter)
select_and_scatter_p = standard_primitive(
select_and_scatter_p = standard_reduction_primitive(
_select_and_scatter_shape_rule, _input_dtype, 'select_and_scatter',
_select_and_scatter_translation, reduction=True)
_select_and_scatter_translation)
def _select_and_scatter_add_shape_rule(