mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix batching rule.
This commit is contained in:
parent
e4a7d30741
commit
e57a5c42c5
@ -2744,8 +2744,10 @@ def _batch_dynamic_slice_indices(indices, bdims):
|
||||
if size < 0:
|
||||
return concatenate([reshape(i, [1]) for i in indices], 0), None
|
||||
indices = concatenate(
|
||||
[broadcast_in_dim(x, (size, 1), broadcast_dimensions=(0,))
|
||||
for x in indices], dimension=1)
|
||||
[broadcast_in_dim(x, (size, 1),
|
||||
broadcast_dimensions=((0,) if i is not None else ()))
|
||||
for x, i in zip(indices, bdims)],
|
||||
dimension=1)
|
||||
return indices, 0
|
||||
|
||||
def _dynamic_slice_batching_rule(batched_args, batch_dims, slice_sizes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user