mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fixed bug due to input_shape kwarg not being modified in batching rule for reducers. Fixes b/120595235
This commit is contained in:
parent
61aa47a1a8
commit
ae7df43e9b
@ -224,6 +224,8 @@ def reducer_batcher(prim, batched_args, batch_dims, axes, **kwargs):
|
||||
bdim, = batch_dims
|
||||
axes = tuple(onp.where(onp.less(axes, bdim), axes, onp.add(axes, 1)))
|
||||
bdim_out = list(onp.delete(onp.arange(operand.ndim), axes)).index(bdim)
|
||||
if 'input_shape' in kwargs:
|
||||
kwargs['input_shape'] = operand.shape
|
||||
return prim.bind(operand, axes=axes, **kwargs), bdim_out
|
||||
|
||||
def add_batched(batched_args, batch_dims):
|
||||
|
@ -1790,6 +1790,8 @@ batching.defreducer(reduce_p)
|
||||
|
||||
|
||||
def reduce_sum_shape_rule(operand, axes, input_shape):
|
||||
assert operand.shape == input_shape, ('{} != {}'
|
||||
.format(operand.shape, input_shape))
|
||||
return tuple(onp.delete(operand.shape, axes))
|
||||
|
||||
def reduce_sum_translation_rule(c, operand, axes, input_shape):
|
||||
|
Loading…
x
Reference in New Issue
Block a user