Fix batching rule.

This commit is contained in:
Peter Hawkins 2019-08-15 12:24:38 -04:00
parent e4a7d30741
commit e57a5c42c5

View File

@ -2744,8 +2744,10 @@ def _batch_dynamic_slice_indices(indices, bdims):
if size < 0:
return concatenate([reshape(i, [1]) for i in indices], 0), None
indices = concatenate(
[broadcast_in_dim(x, (size, 1), broadcast_dimensions=(0,))
for x in indices], dimension=1)
[broadcast_in_dim(x, (size, 1),
broadcast_dimensions=((0,) if i is not None else ()))
for x, i in zip(indices, bdims)],
dimension=1)
return indices, 0
def _dynamic_slice_batching_rule(batched_args, batch_dims, slice_sizes,