mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Implements dynamic_update_slice when enable_xla=False and adds tests.
PiperOrigin-RevId: 380781017
This commit is contained in:
parent
e50276019d
commit
1bba5d7f10
@ -2387,6 +2387,46 @@ def _dynamic_slice(operand, *start_indices, slice_sizes,
|
||||
tf_impl_with_avals[lax.dynamic_slice_p] = _dynamic_slice
|
||||
|
||||
|
||||
def _dynamic_update_slice(operand, update, *start_indices,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray):
|
||||
start_indices = tf.stack(start_indices)
|
||||
|
||||
if _thread_local_state.enable_xla:
|
||||
return tfxla.dynamic_update_slice(operand, update, start_indices)
|
||||
|
||||
# enable_xla==False.
|
||||
|
||||
op_shape = _eval_shape(_in_avals[0].shape)
|
||||
op_size = tf.size(operand)
|
||||
update_shape = _eval_shape(_in_avals[1].shape)
|
||||
|
||||
start_indices = _clip(op_shape, start_indices, update_shape)
|
||||
end_indices = tf.add(start_indices, update_shape)
|
||||
flatten = tf.keras.backend.flatten
|
||||
|
||||
# Get the cells to update in `operand` as an array of ids.
|
||||
id_tensor = tf.reshape(tf.range(op_size), op_shape)
|
||||
scattered_indices = tf.strided_slice(id_tensor, start_indices, end_indices)
|
||||
|
||||
# Create an array containing updates at scattered_indices and zeros otherwise.
|
||||
flat_indices = tf.expand_dims(flatten(scattered_indices), -1)
|
||||
flat_update = flatten(update)
|
||||
update = tf.scatter_nd(flat_indices, flat_update, (op_size,))
|
||||
update = tf.reshape(update, op_shape)
|
||||
|
||||
# Create a bool mask that is True only where `operand` should be updated.
|
||||
update_mask = tf.ones_like(flat_update, dtype=tf.bool)
|
||||
update_mask = tf.scatter_nd(flat_indices, update_mask, (op_size,))
|
||||
update_mask = tf.reshape(update_mask, op_shape)
|
||||
|
||||
# Use the mask to only update `operand` with `update`.
|
||||
return tf.where(update_mask, update, operand)
|
||||
|
||||
|
||||
tf_impl_with_avals[lax.dynamic_update_slice_p] = _dynamic_update_slice
|
||||
|
||||
|
||||
def _scatter_dimensions_proto(indices_shape, dimension_numbers):
|
||||
proto = xla_data_pb2.ScatterDimensionNumbers()
|
||||
proto.update_window_dims.extend(dimension_numbers.update_window_dims)
|
||||
@ -2438,15 +2478,6 @@ tf_impl_with_avals[lax.scatter_mul_p] = _scatter
|
||||
tf_impl_with_avals[lax.scatter_add_p] = _scatter
|
||||
|
||||
|
||||
def _dynamic_update_slice(operand, update, *start_indices):
|
||||
if not _thread_local_state.enable_xla:
|
||||
raise _xla_disabled_error("dynamic_update_slice")
|
||||
return tfxla.dynamic_update_slice(operand, update, tf.stack(start_indices))
|
||||
|
||||
|
||||
tf_impl[lax.dynamic_update_slice_p] = _dynamic_update_slice
|
||||
|
||||
|
||||
def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr],
|
||||
linear: Sequence[bool]) -> Sequence[TfVal]:
|
||||
del linear
|
||||
|
@ -2004,22 +2004,24 @@ def _make_dynamic_update_slice_harness(name,
|
||||
start_indices=(1,),
|
||||
dtype=np.float32,
|
||||
update_shape=(1,)):
|
||||
define(
|
||||
lax.dynamic_update_slice_p,
|
||||
(
|
||||
f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore
|
||||
f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}"
|
||||
f"_start_indices={start_indices}"),
|
||||
lax.dynamic_update_slice,
|
||||
[
|
||||
RandArg(shape, dtype), # type: ignore
|
||||
RandArg(update_shape, dtype), # type: ignore
|
||||
np.array(start_indices)
|
||||
], # type: ignore
|
||||
dtype=dtype,
|
||||
shape=shape, # type: ignore
|
||||
start_indices=start_indices, # type: ignore
|
||||
update_shape=update_shape) # type: ignore
|
||||
for enable_xla in [False, True]:
|
||||
define(
|
||||
lax.dynamic_update_slice_p,
|
||||
(
|
||||
f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore
|
||||
f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}"
|
||||
f"_start_indices={start_indices}_enablexla={enable_xla}"),
|
||||
lax.dynamic_update_slice,
|
||||
[
|
||||
RandArg(shape, dtype), # type: ignore
|
||||
RandArg(update_shape, dtype), # type: ignore
|
||||
np.array(start_indices)
|
||||
], # type: ignore
|
||||
dtype=dtype,
|
||||
shape=shape, # type: ignore
|
||||
start_indices=start_indices, # type: ignore
|
||||
update_shape=update_shape, # type: ignore
|
||||
enable_xla=enable_xla)
|
||||
|
||||
|
||||
# Test first all dtypes
|
||||
|
Loading…
x
Reference in New Issue
Block a user