Reexpose reduce_window_shape_tuple since it has external users.

Fix accidental removal of rev() batching rule.
This commit is contained in:
Peter Hawkins 2019-02-01 16:29:53 -05:00
parent 09201c72bc
commit 5517347cc9

View File

@ -1721,8 +1721,15 @@ def _rev_shape_rule(operand, dimensions):
raise TypeError(msg.format(dimensions, operand.ndim))
return operand.shape
def _rev_batch_rule(batched_args, batch_dims, dimensions):
operand, = batched_args
bdim, = batch_dims
new_dimensions = [i + 1 if i >= bdim else i for i in dimensions]
return rev(operand, new_dimensions), bdim
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev')
ad.deflinear(rev_p, lambda t, dimensions: [rev(t, dimensions)])
batching.primitive_batchers[rev_p] = _rev_batch_rule
def _transpose_shape_rule(operand, permutation):
@ -2451,11 +2458,11 @@ def _common_reduce_window_shape_rule(operand, window_dimensions, window_strides,
"window_dimensions: got window_strides {} and window_dimensions {}.")
raise TypeError(msg.format(window_strides, window_dimensions))
return _reduce_window_shape_tuple(operand.shape, window_dimensions,
return reduce_window_shape_tuple(operand.shape, window_dimensions,
window_strides, padding)
def _reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
padding):
def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
padding):
pads = padtype_to_pads(operand_shape, window_dimensions, window_strides, padding)
operand_padded = onp.add(operand_shape, onp.add(*zip(*pads)))
t = onp.floor_divide(