mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Reexpose reduce_window_shape_tuple since it has external users.
Fix accidental removal of rev() batching rule.
This commit is contained in:
parent
09201c72bc
commit
5517347cc9
13
jax/lax.py
13
jax/lax.py
@ -1721,8 +1721,15 @@ def _rev_shape_rule(operand, dimensions):
|
||||
raise TypeError(msg.format(dimensions, operand.ndim))
|
||||
return operand.shape
|
||||
|
||||
def _rev_batch_rule(batched_args, batch_dims, dimensions):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
new_dimensions = [i + 1 if i >= bdim else i for i in dimensions]
|
||||
return rev(operand, new_dimensions), bdim
|
||||
|
||||
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev')
|
||||
ad.deflinear(rev_p, lambda t, dimensions: [rev(t, dimensions)])
|
||||
batching.primitive_batchers[rev_p] = _rev_batch_rule
|
||||
|
||||
|
||||
def _transpose_shape_rule(operand, permutation):
|
||||
@ -2451,11 +2458,11 @@ def _common_reduce_window_shape_rule(operand, window_dimensions, window_strides,
|
||||
"window_dimensions: got window_strides {} and window_dimensions {}.")
|
||||
raise TypeError(msg.format(window_strides, window_dimensions))
|
||||
|
||||
return _reduce_window_shape_tuple(operand.shape, window_dimensions,
|
||||
return reduce_window_shape_tuple(operand.shape, window_dimensions,
|
||||
window_strides, padding)
|
||||
|
||||
def _reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
|
||||
padding):
|
||||
def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
|
||||
padding):
|
||||
pads = padtype_to_pads(operand_shape, window_dimensions, window_strides, padding)
|
||||
operand_padded = onp.add(operand_shape, onp.add(*zip(*pads)))
|
||||
t = onp.floor_divide(
|
||||
|
Loading…
x
Reference in New Issue
Block a user