Implements dynamic_update_slice when enable_xla=False and adds tests.

PiperOrigin-RevId: 380781017
This commit is contained in:
Marc van Zee 2021-06-22 04:42:52 -07:00 committed by jax authors
parent e50276019d
commit 1bba5d7f10
2 changed files with 58 additions and 25 deletions

View File

@ -2387,6 +2387,46 @@ def _dynamic_slice(operand, *start_indices, slice_sizes,
tf_impl_with_avals[lax.dynamic_slice_p] = _dynamic_slice
def _dynamic_update_slice(operand, update, *start_indices,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
start_indices = tf.stack(start_indices)
if _thread_local_state.enable_xla:
return tfxla.dynamic_update_slice(operand, update, start_indices)
# enable_xla==False.
op_shape = _eval_shape(_in_avals[0].shape)
op_size = tf.size(operand)
update_shape = _eval_shape(_in_avals[1].shape)
start_indices = _clip(op_shape, start_indices, update_shape)
end_indices = tf.add(start_indices, update_shape)
flatten = tf.keras.backend.flatten
# Get the cells to update in `operand` as an array of ids.
id_tensor = tf.reshape(tf.range(op_size), op_shape)
scattered_indices = tf.strided_slice(id_tensor, start_indices, end_indices)
# Create an array containing updates at scattered_indices and zeros otherwise.
flat_indices = tf.expand_dims(flatten(scattered_indices), -1)
flat_update = flatten(update)
update = tf.scatter_nd(flat_indices, flat_update, (op_size,))
update = tf.reshape(update, op_shape)
# Create a bool mask that is True only where `operand` should be updated.
update_mask = tf.ones_like(flat_update, dtype=tf.bool)
update_mask = tf.scatter_nd(flat_indices, update_mask, (op_size,))
update_mask = tf.reshape(update_mask, op_shape)
# Use the mask to only update `operand` with `update`.
return tf.where(update_mask, update, operand)
tf_impl_with_avals[lax.dynamic_update_slice_p] = _dynamic_update_slice
def _scatter_dimensions_proto(indices_shape, dimension_numbers):
proto = xla_data_pb2.ScatterDimensionNumbers()
proto.update_window_dims.extend(dimension_numbers.update_window_dims)
@ -2438,15 +2478,6 @@ tf_impl_with_avals[lax.scatter_mul_p] = _scatter
tf_impl_with_avals[lax.scatter_add_p] = _scatter
def _dynamic_update_slice(operand, update, *start_indices):
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("dynamic_update_slice")
return tfxla.dynamic_update_slice(operand, update, tf.stack(start_indices))
tf_impl[lax.dynamic_update_slice_p] = _dynamic_update_slice
def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr],
linear: Sequence[bool]) -> Sequence[TfVal]:
del linear

View File

@ -2004,22 +2004,24 @@ def _make_dynamic_update_slice_harness(name,
start_indices=(1,),
dtype=np.float32,
update_shape=(1,)):
define(
lax.dynamic_update_slice_p,
(
f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore
f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}"
f"_start_indices={start_indices}"),
lax.dynamic_update_slice,
[
RandArg(shape, dtype), # type: ignore
RandArg(update_shape, dtype), # type: ignore
np.array(start_indices)
], # type: ignore
dtype=dtype,
shape=shape, # type: ignore
start_indices=start_indices, # type: ignore
update_shape=update_shape) # type: ignore
for enable_xla in [False, True]:
define(
lax.dynamic_update_slice_p,
(
f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore
f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}"
f"_start_indices={start_indices}_enablexla={enable_xla}"),
lax.dynamic_update_slice,
[
RandArg(shape, dtype), # type: ignore
RandArg(update_shape, dtype), # type: ignore
np.array(start_indices)
], # type: ignore
dtype=dtype,
shape=shape, # type: ignore
start_indices=start_indices, # type: ignore
update_shape=update_shape, # type: ignore
enable_xla=enable_xla)
# Test first all dtypes