Fixed bug due to input_shape kwarg not being modified in batching rule for reducers. Fixes b/120595235

This commit is contained in:
Dougal Maclaurin 2018-12-06 22:45:49 -05:00
parent 61aa47a1a8
commit ae7df43e9b
2 changed files with 4 additions and 0 deletions

View File

@ -224,6 +224,8 @@ def reducer_batcher(prim, batched_args, batch_dims, axes, **kwargs):
bdim, = batch_dims
axes = tuple(onp.where(onp.less(axes, bdim), axes, onp.add(axes, 1)))
bdim_out = list(onp.delete(onp.arange(operand.ndim), axes)).index(bdim)
if 'input_shape' in kwargs:
kwargs['input_shape'] = operand.shape
return prim.bind(operand, axes=axes, **kwargs), bdim_out
def add_batched(batched_args, batch_dims):

View File

@ -1790,6 +1790,8 @@ batching.defreducer(reduce_p)
def reduce_sum_shape_rule(operand, axes, input_shape):
assert operand.shape == input_shape, ('{} != {}'
.format(operand.shape, input_shape))
return tuple(onp.delete(operand.shape, axes))
def reduce_sum_translation_rule(c, operand, axes, input_shape):