Fix batching rule.

This commit is contained in:
Peter Hawkins 2019-08-15 11:42:08 -04:00
parent d09924f71c
commit e4a7d30741

View File

@ -2740,12 +2740,12 @@ def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes=None,
[ad_util.zero] * len(start_indices))
def _batch_dynamic_slice_indices(indices, bdims):
size = next((x.shape[i] for x, i in zip(indices, bdims)), -1)
if size >= 0:
return concatenate([reshape(i, [1]) for i in start_indices], 0), None
size = next((x.shape[i] for x, i in zip(indices, bdims) if i is not None), -1)
if size < 0:
return concatenate([reshape(i, [1]) for i in indices], 0), None
indices = concatenate(
[broadcast_to(x, (size, 1), broadcast_dims=(0,)) for x in indices],
dimension=1)
[broadcast_in_dim(x, (size, 1), broadcast_dimensions=(0,))
for x in indices], dimension=1)
return indices, 0
def _dynamic_slice_batching_rule(batched_args, batch_dims, slice_sizes,