mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make op-by-op work with all jit-returned devicearrays.
This commit is contained in:
parent
7529815614
commit
f01fc35ce5
@ -105,6 +105,8 @@ def jit(fun, static_argnums=(), device_assignment=None, backend=None):
|
||||
change. Optional, an int specifying the device ordinal for which to compile the
|
||||
function. The default is inherited from XLA's DeviceAssignment logic and is
|
||||
usually to use device 0.
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.
|
||||
|
||||
Returns:
|
||||
A wrapped version of `fun`, set up for just-in-time compilation.
|
||||
@ -210,6 +212,8 @@ def xla_computation(fun, static_argnums=(), axis_env=None, backend=None):
|
||||
functions that involve parallel communication collectives, and it
|
||||
specifies the axis name/size environment that would be set up by
|
||||
applications of ``jax.pmap``. See the examples below.
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.
|
||||
|
||||
Returns:
|
||||
A wrapped version of ``fun`` that when applied to example arguments returns a
|
||||
@ -649,6 +653,8 @@ def pmap(fun, axis_name=None, backend=None):
|
||||
fun: Function to be mapped over argument axes.
|
||||
axis_name: Optional, a hashable Python object used to identify the mapped
|
||||
axis so that parallel collectives can be applied.
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.
|
||||
|
||||
Returns:
|
||||
A parallelized version of ``fun`` with arguments that correspond to those of
|
||||
|
@ -598,8 +598,8 @@ xla_pmap = partial(core.call_bind, xla_pmap_p)
|
||||
xla_pmap_p.def_custom_bind(xla_pmap)
|
||||
xla_pmap_p.def_impl(xla_pmap_impl)
|
||||
|
||||
def _xla_pmap_translation_rule(c, jaxpr, backend, axis_env, env_nodes, in_nodes,
|
||||
axis_name, axis_size):
|
||||
def _xla_pmap_translation_rule(c, jaxpr, axis_env, env_nodes, in_nodes,
|
||||
axis_name, axis_size, backend=None):
|
||||
new_env = xla.extend_axis_env(axis_env, axis_name, axis_size)
|
||||
in_nodes_sharded = list(map(partial(xla_shard, c, new_env.sizes), in_nodes))
|
||||
subc = xla._jaxpr_computation(jaxpr, backend, new_env, (),
|
||||
|
@ -145,6 +145,7 @@ def device_put(x, device_num=0, backend=None):
|
||||
DeviceTuple, and arraylike includes DeviceArray, DeviceConstant, and
|
||||
anything that has an '__array__' attr.
|
||||
device_num: an int representing the target physical device number.
|
||||
backend: a string representing the xla backend. ('cpu','gpu','tpu')
|
||||
|
||||
Returns:
|
||||
A buffer representing the input `x` placed on the appropriate device.
|
||||
@ -152,10 +153,20 @@ def device_put(x, device_num=0, backend=None):
|
||||
x = _canonicalize_pyval_dtype(x)
|
||||
t = type(x)
|
||||
if t is DeviceArray or t is DeviceTuple:
|
||||
if x.device_buffer.device() == device_num:
|
||||
return x.device_buffer
|
||||
# TODO(levskaya) remove if-condition after increasing minimum Jaxlib version to
|
||||
# 0.1.24.
|
||||
if hasattr(x.device_buffer, 'platform'):
|
||||
backend_match = x.device_buffer.platform() == backend
|
||||
else:
|
||||
return x.device_buffer.copy_to_device(device_num)
|
||||
backend_match = True
|
||||
if backend_match:
|
||||
if x.device_buffer.device() == device_num:
|
||||
return x.device_buffer
|
||||
else:
|
||||
return x.device_buffer.copy_to_device(device_num)
|
||||
else:
|
||||
# Buffers from different XLA backends are passed through the host.
|
||||
return xla_client.Buffer.from_pyval(x, device_num, backend=xb.get_backend(backend))
|
||||
elif isinstance(x, DeviceConstant):
|
||||
return _instantiate_device_constant(x, device_num=device_num, backend=backend)
|
||||
elif isinstance(x, (DeviceArray, onp.ndarray)):
|
||||
@ -281,7 +292,7 @@ def _jaxpr_computation(jaxpr, backend, axis_env, const_vals, freevar_shapes, *ar
|
||||
(subjaxpr, const_bindings, freevar_bindings), = eqn.bound_subjaxprs
|
||||
env_nodes = list(map(read, const_bindings + freevar_bindings))
|
||||
rule = call_translations[eqn.primitive]
|
||||
ans = rule(c, subjaxpr, backend, axis_env, env_nodes, in_nodes, **eqn.params)
|
||||
ans = rule(c, subjaxpr, axis_env, env_nodes, in_nodes, backend=backend, **eqn.params)
|
||||
else:
|
||||
msg = "XLA translation rule for primitive '{}' not found"
|
||||
raise NotImplementedError(msg.format(eqn.primitive.name))
|
||||
@ -718,8 +729,8 @@ xla_call = partial(core.call_bind, xla_call_p)
|
||||
xla_call_p.def_custom_bind(xla_call)
|
||||
xla_call_p.def_impl(_xla_call_impl)
|
||||
|
||||
def _xla_call_translation_rule(c, jaxpr, backend, axis_env, env_nodes, in_nodes,
|
||||
device_assignment):
|
||||
def _xla_call_translation_rule(c, jaxpr, axis_env, env_nodes, in_nodes,
|
||||
device_assignment=None, backend=None):
|
||||
del device_assignment # Unused.
|
||||
subc = _jaxpr_computation(jaxpr, backend, axis_env, (), _map(c.GetShape, env_nodes),
|
||||
*map(c.GetShape, in_nodes))
|
||||
|
@ -1445,16 +1445,22 @@ _input_dtype = lambda *args, **_: xla_bridge.canonicalize_dtype(args[0].dtype)
|
||||
_fixed_dtype = lambda dtype: lambda *args, **kwargs: xla_bridge.canonicalize_dtype(dtype)
|
||||
_complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype
|
||||
|
||||
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None, reduction=False):
|
||||
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
||||
prim = Primitive(name)
|
||||
prim.def_impl(partial(xla.apply_primitive, prim))
|
||||
prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
|
||||
if reduction:
|
||||
xla.reduction_translations[prim] = translation_rule or partial(standard_translate, name)
|
||||
else:
|
||||
xla.translations[prim] = translation_rule or partial(standard_translate, name)
|
||||
xla.translations[prim] = translation_rule or partial(standard_translate, name)
|
||||
return prim
|
||||
|
||||
|
||||
def standard_reduction_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
||||
prim = Primitive(name)
|
||||
prim.def_impl(partial(xla.apply_primitive, prim))
|
||||
prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
|
||||
xla.reduction_translations[prim] = translation_rule or partial(standard_translate, name)
|
||||
return prim
|
||||
|
||||
|
||||
def standard_abstract_eval(shape_rule, dtype_rule, *args, **kwargs):
|
||||
assert all(isinstance(arg, UnshapedArray) for arg in args), args
|
||||
least_specialized = _max(
|
||||
@ -3144,25 +3150,25 @@ def _scatter_batching_rule(
|
||||
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
|
||||
return scatter_op(operand, scatter_indices, updates, dnums), 0
|
||||
|
||||
scatter_add_p = standard_primitive(
|
||||
scatter_add_p = standard_reduction_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
|
||||
_scatter_translation_rule, reduction=True)
|
||||
_scatter_translation_rule)
|
||||
ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp
|
||||
ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule
|
||||
batching.primitive_batchers[scatter_add_p] = (
|
||||
partial(_scatter_batching_rule, scatter_add))
|
||||
|
||||
# TODO(jlebar): Add derivatives.
|
||||
scatter_min_p = standard_primitive(
|
||||
scatter_min_p = standard_reduction_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
|
||||
_scatter_translation_rule, reduction=True)
|
||||
_scatter_translation_rule)
|
||||
batching.primitive_batchers[scatter_min_p] = (
|
||||
partial(_scatter_batching_rule, scatter_min))
|
||||
|
||||
# TODO(jlebar): Add derivatives.
|
||||
scatter_max_p = standard_primitive(
|
||||
scatter_max_p = standard_reduction_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-max',
|
||||
_scatter_translation_rule, reduction=True)
|
||||
_scatter_translation_rule)
|
||||
batching.primitive_batchers[scatter_max_p] = (
|
||||
partial(_scatter_batching_rule, scatter_max))
|
||||
|
||||
@ -3260,9 +3266,9 @@ def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,
|
||||
return val_out, tangent_out
|
||||
|
||||
|
||||
scatter_p = standard_primitive(
|
||||
scatter_p = standard_reduction_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter',
|
||||
_scatter_translation_rule, reduction=True)
|
||||
_scatter_translation_rule)
|
||||
ad.primitive_jvps[scatter_p] = _scatter_jvp
|
||||
batching.primitive_batchers[scatter_p] = (
|
||||
partial(_scatter_batching_rule, scatter))
|
||||
@ -3291,8 +3297,9 @@ def _reduction_computation(c, jaxpr, backend, consts, init_value):
|
||||
shape = c.GetShape(init_value)
|
||||
return xla.jaxpr_computation(jaxpr, backend, consts, (), shape, shape)
|
||||
|
||||
reduce_p = standard_primitive(_reduce_shape_rule, _input_dtype, 'reduce',
|
||||
_reduce_translation_rule, reduction=True)
|
||||
reduce_p = standard_reduction_primitive(
|
||||
_reduce_shape_rule, _input_dtype, 'reduce',
|
||||
_reduce_translation_rule)
|
||||
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
||||
|
||||
|
||||
@ -3458,9 +3465,9 @@ def _generic_reduce_window_batch_rule(
|
||||
window_dimensions, window_strides, padding)
|
||||
|
||||
|
||||
reduce_window_p = standard_primitive(
|
||||
reduce_window_p = standard_reduction_primitive(
|
||||
_reduce_window_shape_rule, _input_dtype, 'reduce_window',
|
||||
_reduce_window_translation_rule, reduction=True)
|
||||
_reduce_window_translation_rule)
|
||||
batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
|
||||
|
||||
|
||||
@ -3597,9 +3604,9 @@ def _select_and_scatter_translation(
|
||||
return c.SelectAndScatter(operand, select, window_dimensions, window_strides,
|
||||
padding, source, init_value, scatter)
|
||||
|
||||
select_and_scatter_p = standard_primitive(
|
||||
select_and_scatter_p = standard_reduction_primitive(
|
||||
_select_and_scatter_shape_rule, _input_dtype, 'select_and_scatter',
|
||||
_select_and_scatter_translation, reduction=True)
|
||||
_select_and_scatter_translation)
|
||||
|
||||
|
||||
def _select_and_scatter_add_shape_rule(
|
||||
|
Loading…
x
Reference in New Issue
Block a user