mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add support for base dilation and window dilation to reduce window op… (#3803)
This commit is contained in:
parent
ce14409025
commit
a6e2d20b31
@ -781,11 +781,14 @@ tf_impl[lax.cumprod_p] = tf.math.cumprod
|
||||
|
||||
|
||||
def _reduce_window_shape(jax_f, operand, window_dimensions,
|
||||
window_strides, padding, input_shape=None):
|
||||
window_strides, padding, base_dilation,
|
||||
window_dilation, input_shape=None):
|
||||
"""Shape inference function for reduce_window_{sum,min,max}."""
|
||||
params = dict(window_dimensions=window_dimensions,
|
||||
window_strides=window_strides,
|
||||
padding=padding,
|
||||
base_dilation=base_dilation,
|
||||
window_dilation=window_dilation,
|
||||
input_shape=input_shape)
|
||||
try:
|
||||
out, = _infer_shape_jax(jax_f, operand, **params)
|
||||
@ -802,17 +805,20 @@ def _get_shape_from_tensor_or_array(x):
|
||||
|
||||
|
||||
def _reduce_window(jax_f, reducer, init_val, operand, window_dimensions,
|
||||
window_strides, padding, input_shape=None):
|
||||
window_strides, padding, base_dilation, window_dilation,
|
||||
input_shape=None):
|
||||
"""TensorFlow implementation of reduce_window_{sum,min,max}."""
|
||||
del input_shape
|
||||
# TODO(tomhennigan): tf2xla should have a shape inference function.
|
||||
out_shape = _reduce_window_shape(jax_f, operand, window_dimensions,
|
||||
window_strides, padding)
|
||||
window_strides, padding, base_dilation,
|
||||
window_dilation)
|
||||
a = tf.constant(0, operand.dtype)
|
||||
reducer_fn = reducer.get_concrete_function(a, a)
|
||||
out = tfxla.reduce_window(operand, tf.constant(init_val, operand.dtype),
|
||||
reducer_fn, window_dimensions,
|
||||
window_strides, padding=padding)
|
||||
window_strides, base_dilations=base_dilation,
|
||||
window_dilations=window_dilation, padding=padding)
|
||||
out.set_shape(out_shape)
|
||||
return out
|
||||
# pylint: disable=protected-access
|
||||
|
@ -176,8 +176,9 @@ def _pooling_layer(reducer, init_val, rescaler=None):
|
||||
def init_fun(rng, input_shape):
|
||||
padding_vals = lax.padtype_to_pads(input_shape, window_shape,
|
||||
strides, padding)
|
||||
out_shape = lax.reduce_window_shape_tuple(input_shape, window_shape,
|
||||
strides, padding_vals)
|
||||
ones = (1,) * len(window_shape)
|
||||
out_shape = lax.reduce_window_shape_tuple(
|
||||
input_shape, window_shape, strides, padding_vals, ones, ones)
|
||||
return out_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
out = lax.reduce_window(inputs, init_val, reducer, window_shape,
|
||||
|
221
jax/lax/lax.py
221
jax/lax/lax.py
@ -1089,7 +1089,9 @@ def _reduce_and(operand: Array, axes: Sequence[int]) -> Array:
|
||||
|
||||
def reduce_window(operand: Array, init_value: Array, computation: Callable,
|
||||
window_dimensions: Shape, window_strides: Sequence[int],
|
||||
padding: Union[str, Sequence[Tuple[int, int]]]) -> Array:
|
||||
padding: Union[str, Sequence[Tuple[int, int]]],
|
||||
base_dilation: Optional[Sequence[int]] = None,
|
||||
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
||||
"""Wraps XLA's `ReduceWindowWithGeneralPadding
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
|
||||
operator.
|
||||
@ -1099,15 +1101,22 @@ def reduce_window(operand: Array, init_value: Array, computation: Callable,
|
||||
window_strides, padding))
|
||||
else:
|
||||
padding = tuple(padding)
|
||||
if base_dilation is None:
|
||||
base_dilation = (1,) * len(window_dimensions)
|
||||
if window_dilation is None:
|
||||
window_dilation = (1,) * len(window_dimensions)
|
||||
monoid_reducer = _get_monoid_window_reducer(computation, init_value)
|
||||
if monoid_reducer:
|
||||
return monoid_reducer(operand, window_dimensions, window_strides, padding)
|
||||
return monoid_reducer(operand, window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation)
|
||||
else:
|
||||
jaxpr, consts = _reduction_jaxpr(computation, _abstractify(init_value))
|
||||
return reduce_window_p.bind(
|
||||
operand, init_value, jaxpr=jaxpr, consts=consts,
|
||||
window_dimensions=tuple(window_dimensions),
|
||||
window_strides=tuple(window_strides), padding=padding)
|
||||
window_strides=tuple(window_strides), padding=padding,
|
||||
base_dilation=tuple(base_dilation),
|
||||
window_dilation=tuple(window_dilation))
|
||||
|
||||
def _get_monoid_window_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]:
|
||||
aval = core.get_aval(x)
|
||||
@ -1122,46 +1131,82 @@ def _get_monoid_window_reducer(monoid_op: Callable, x: Array) -> Optional[Callab
|
||||
|
||||
def _reduce_window_sum(operand: Array, window_dimensions: Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]]) -> Array:
|
||||
padding: Sequence[Tuple[int, int]],
|
||||
base_dilation: Optional[Sequence[int]] = None,
|
||||
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
||||
if base_dilation is None:
|
||||
base_dilation = (1,) * len(window_dimensions)
|
||||
if window_dilation is None:
|
||||
window_dilation = (1,) * len(window_dimensions)
|
||||
return reduce_window_sum_p.bind(
|
||||
operand, window_dimensions=tuple(window_dimensions),
|
||||
window_strides=tuple(window_strides), padding=tuple(padding))
|
||||
window_strides=tuple(window_strides), padding=tuple(padding),
|
||||
base_dilation=tuple(base_dilation),
|
||||
window_dilation=tuple(window_dilation))
|
||||
|
||||
def _reduce_window_prod(operand: Array, window_dimensions: Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]]) -> Array:
|
||||
padding: Sequence[Tuple[int, int]],
|
||||
base_dilation: Optional[Sequence[int]] = None,
|
||||
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
||||
init_value = _const(operand, 1)
|
||||
jaxpr, consts = _reduction_jaxpr(mul, _abstractify(init_value))
|
||||
if base_dilation is None:
|
||||
base_dilation = (1,) * len(window_dimensions)
|
||||
if window_dilation is None:
|
||||
window_dilation = (1,) * len(window_dimensions)
|
||||
return reduce_window_p.bind(
|
||||
operand, init_value, jaxpr=jaxpr, consts=consts,
|
||||
window_dimensions=tuple(window_dimensions),
|
||||
window_strides=tuple(window_strides), padding=tuple(padding))
|
||||
window_strides=tuple(window_strides), padding=tuple(padding),
|
||||
base_dilation=tuple(base_dilation),
|
||||
window_dilation=tuple(window_dilation))
|
||||
|
||||
def _reduce_window_max(operand: Array, window_dimensions: Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]]) -> Array:
|
||||
padding: Sequence[Tuple[int, int]],
|
||||
base_dilation: Optional[Sequence[int]] = None,
|
||||
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
||||
if base_dilation is None:
|
||||
base_dilation = (1,) * len(window_dimensions)
|
||||
if window_dilation is None:
|
||||
window_dilation = (1,) * len(window_dimensions)
|
||||
return reduce_window_max_p.bind(
|
||||
operand, window_dimensions=tuple(window_dimensions),
|
||||
window_strides=tuple(window_strides), padding=tuple(padding))
|
||||
window_strides=tuple(window_strides), padding=tuple(padding),
|
||||
base_dilation=tuple(base_dilation),
|
||||
window_dilation=tuple(window_dilation))
|
||||
|
||||
def _reduce_window_min(operand: Array, window_dimensions: Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]]) -> Array:
|
||||
padding: Sequence[Tuple[int, int]],
|
||||
base_dilation: Optional[Sequence[int]] = None,
|
||||
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
||||
if base_dilation is None:
|
||||
base_dilation = (1,) * len(window_dimensions)
|
||||
if window_dilation is None:
|
||||
window_dilation = (1,) * len(window_dimensions)
|
||||
return reduce_window_min_p.bind(
|
||||
operand, window_dimensions=tuple(window_dimensions),
|
||||
window_strides=tuple(window_strides), padding=tuple(padding))
|
||||
window_strides=tuple(window_strides), padding=tuple(padding),
|
||||
base_dilation=tuple(base_dilation),
|
||||
window_dilation=tuple(window_dilation))
|
||||
|
||||
def _select_and_scatter(operand: Array, select: Callable,
|
||||
window_dimensions: Shape, window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]], source: Array,
|
||||
init_value: Array, scatter: Callable) -> Array:
|
||||
init_value: Array, scatter: Callable,
|
||||
base_dilation: Sequence[int],
|
||||
window_dilation: Sequence[int]) -> Array:
|
||||
select_jaxpr, select_consts = _reduction_jaxpr(select, _abstractify(init_value))
|
||||
scatter_jaxpr, scatter_consts = _reduction_jaxpr(scatter, _abstractify(init_value))
|
||||
return select_and_scatter_p.bind(
|
||||
operand, source, init_value, select_jaxpr=select_jaxpr,
|
||||
select_consts=select_consts, scatter_jaxpr=scatter_jaxpr,
|
||||
scatter_consts=scatter_consts, window_dimensions=tuple(window_dimensions),
|
||||
window_strides=tuple(window_strides), padding=tuple(padding))
|
||||
window_strides=tuple(window_strides), padding=tuple(padding),
|
||||
base_dilation=tuple(base_dilation),
|
||||
window_dilation=tuple(window_dilation))
|
||||
|
||||
def _select_and_scatter_add(source: Array, operand: Array,
|
||||
select_prim: core.Primitive,
|
||||
@ -1177,11 +1222,15 @@ def _select_and_gather_add(tangents: Array, operand: Array,
|
||||
select_prim: core.Primitive,
|
||||
window_dimensions: Shape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[Tuple[int, int]]) -> Array:
|
||||
padding: Sequence[Tuple[int, int]],
|
||||
base_dilation: Sequence[int],
|
||||
window_dilation: Sequence[int]) -> Array:
|
||||
return select_and_gather_add_p.bind(
|
||||
tangents, operand, select_prim=select_prim,
|
||||
window_dimensions=tuple(window_dimensions),
|
||||
window_strides=tuple(window_strides), padding=tuple(padding))
|
||||
window_strides=tuple(window_strides), padding=tuple(padding),
|
||||
base_dilation=tuple(base_dilation),
|
||||
window_dilation=tuple(window_dilation))
|
||||
|
||||
def cumsum(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative sum along `axis`."""
|
||||
@ -4490,38 +4539,43 @@ reduce_and_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bo
|
||||
batching.defreducer(reduce_and_p)
|
||||
|
||||
def _reduce_window_shape_rule(operand, init_value, *, jaxpr, consts,
|
||||
window_dimensions, window_strides, padding):
|
||||
window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation):
|
||||
if operand.dtype != init_value.dtype:
|
||||
msg = ("reduce_window got inconsistent dtypes for operand and init_value: "
|
||||
" got operand dtype {} and init_value dtype {}.")
|
||||
raise TypeError(msg.format(operand.dtype, init_value.dtype))
|
||||
return _common_reduce_window_shape_rule(operand, window_dimensions,
|
||||
window_strides, padding)
|
||||
return _common_reduce_window_shape_rule(
|
||||
operand, window_dimensions, window_strides, padding, base_dilation,
|
||||
window_dilation)
|
||||
|
||||
def _reduce_window_translation_rule(c, operand, init_value, *, jaxpr, consts,
|
||||
window_dimensions, window_strides, padding):
|
||||
window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation):
|
||||
xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
|
||||
return xops.ReduceWindowWithGeneralPadding(
|
||||
operand, init_value, xla_computation, window_dimensions,
|
||||
window_strides, (), (), padding)
|
||||
window_strides, base_dilation, window_dilation, padding)
|
||||
|
||||
def _generic_reduce_window_batch_rule(
|
||||
batched_args, batch_dims, *, jaxpr, consts, window_dimensions,
|
||||
window_strides, padding):
|
||||
window_strides, padding, base_dilation, window_dilation):
|
||||
operand, init = batched_args
|
||||
bdim, init_bdim = batch_dims
|
||||
if init_bdim is not None:
|
||||
raise NotImplementedError("reduce_window batching is not implemented for "
|
||||
"initial values")
|
||||
|
||||
def reduce_window(x, window_dimensions, window_strides, padding):
|
||||
def reduce_window(x, window_dimensions, window_strides, padding, base_dilation,
|
||||
window_dilation):
|
||||
return reduce_window_p.bind(
|
||||
x, init, jaxpr=jaxpr, consts=consts, window_dimensions=window_dimensions,
|
||||
window_strides=window_strides, padding=padding)
|
||||
return _reduce_window_batch_rule(reduce_window, (operand,), (bdim,),
|
||||
window_dimensions=window_dimensions,
|
||||
window_strides=window_strides,
|
||||
padding=padding)
|
||||
window_strides=window_strides, padding=padding, base_dilation=base_dilation,
|
||||
window_dilation=window_dilation)
|
||||
return _reduce_window_batch_rule(
|
||||
reduce_window, (operand,), (bdim,), window_dimensions=window_dimensions,
|
||||
window_strides=window_strides, padding=padding, base_dilation=base_dilation,
|
||||
window_dilation=window_dilation)
|
||||
|
||||
|
||||
reduce_window_p = standard_primitive(
|
||||
@ -4531,40 +4585,46 @@ batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
|
||||
|
||||
|
||||
def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides,
|
||||
padding):
|
||||
padding, base_dilation, window_dilation):
|
||||
if not dtypes.issubdtype(operand.dtype, np.number):
|
||||
msg = "operand to reduce_window_sum must have a number dtype, got {}"
|
||||
raise TypeError(msg.format(np.dtype(operand.dtype).name))
|
||||
return _common_reduce_window_shape_rule(operand, window_dimensions,
|
||||
window_strides, padding)
|
||||
window_strides, padding, base_dilation,
|
||||
window_dilation)
|
||||
|
||||
def _reduce_window_sum_translation_rule(c, operand, *, window_dimensions,
|
||||
window_strides, padding):
|
||||
window_strides, padding, base_dilation,
|
||||
window_dilation):
|
||||
dtype = c.get_shape(operand).numpy_dtype()
|
||||
scalar = ShapedArray((), dtype)
|
||||
return xops.ReduceWindowWithGeneralPadding(
|
||||
operand, xb.constant(c, np.array(0, dtype)),
|
||||
xla.primitive_subcomputation(add_p, scalar, scalar), window_dimensions,
|
||||
window_strides, (), (), padding)
|
||||
window_strides, base_dilation, window_dilation, padding)
|
||||
|
||||
def _reduce_window_sum_transpose_rule(cotangent, operand, *, window_dimensions,
|
||||
window_strides, padding):
|
||||
window_strides, padding, base_dilation,
|
||||
window_dilation):
|
||||
assert ad.is_undefined_primal(operand)
|
||||
input_shape = operand.aval.shape
|
||||
ones = [1] * len(input_shape)
|
||||
pads = _conv_general_vjp_lhs_padding(
|
||||
input_shape, window_dimensions, window_strides, cotangent.shape, padding,
|
||||
ones, ones)
|
||||
base_dilation, window_dilation)
|
||||
ones = [1] * len(input_shape)
|
||||
padding_config = [(lo, hi, stride - 1)
|
||||
for (lo, hi), stride in zip(pads, window_strides)]
|
||||
pad_cotangent = pad(cotangent, _zero(cotangent), padding_config)
|
||||
result = _reduce_window_sum(pad_cotangent, window_dimensions, ones,
|
||||
[(0, 0)] * len(input_shape))
|
||||
assert result.shape == input_shape
|
||||
result = _reduce_window_sum(pad_cotangent, window_dimensions, base_dilation,
|
||||
[(0, 0)] * len(input_shape),
|
||||
base_dilation=ones,
|
||||
window_dilation=window_dilation)
|
||||
assert result.shape == input_shape, (result.shape, input_shape)
|
||||
return [result]
|
||||
|
||||
def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *,
|
||||
window_dimensions, window_strides, padding):
|
||||
window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation):
|
||||
operand, = batched_args
|
||||
bdim, = bdims
|
||||
|
||||
@ -4573,7 +4633,11 @@ def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *,
|
||||
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
|
||||
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]
|
||||
padding = padding[:bdim] + ((0, 0),) + padding[bdim:]
|
||||
operand = reduce_window(operand, window_dimensions, window_strides, padding)
|
||||
base_dilation = base_dilation[:bdim] + (1,) + base_dilation[bdim:]
|
||||
window_dilation = window_dilation[:bdim] + (1,) + window_dilation[bdim:]
|
||||
|
||||
operand = reduce_window(operand, window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation)
|
||||
return operand, bdim
|
||||
|
||||
reduce_window_sum_p = standard_primitive(
|
||||
@ -4584,26 +4648,32 @@ batching.primitive_batchers[reduce_window_sum_p] = partial(
|
||||
_reduce_window_batch_rule, _reduce_window_sum)
|
||||
|
||||
def _reduce_window_chooser_translation_rule(
|
||||
prim, identity, c, operand, *, window_dimensions, window_strides, padding):
|
||||
prim, identity, c, operand, *, window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation):
|
||||
dtype = c.get_shape(operand).numpy_dtype()
|
||||
scalar = ShapedArray((), dtype)
|
||||
return xops.ReduceWindowWithGeneralPadding(
|
||||
operand, xb.constant(c, identity(dtype)),
|
||||
xla.primitive_subcomputation(prim, scalar, scalar), window_dimensions,
|
||||
window_strides, (), (), padding)
|
||||
window_strides, base_dilation, window_dilation, padding)
|
||||
|
||||
def _reduce_window_chooser_jvp_rule(prim, g, operand, *, window_dimensions,
|
||||
window_strides, padding):
|
||||
window_strides, padding, base_dilation,
|
||||
window_dilation):
|
||||
assert prim is max_p or prim is min_p
|
||||
select_prim = ge_p if prim is max_p else le_p
|
||||
return _select_and_gather_add(g, operand, select_prim, window_dimensions,
|
||||
window_strides, padding)
|
||||
window_strides, padding, base_dilation,
|
||||
window_dilation)
|
||||
|
||||
|
||||
def _common_reduce_window_shape_rule(operand, window_dimensions,
|
||||
window_strides, padding):
|
||||
window_strides, padding, base_dilation,
|
||||
window_dilation):
|
||||
_check_shapelike("reduce_window", "window_dimensions", window_dimensions)
|
||||
_check_shapelike("reduce_window", "window_strides", window_strides)
|
||||
_check_shapelike("reduce_window", "base_dilation", base_dilation)
|
||||
_check_shapelike("reduce_window", "window_dilation", window_dilation)
|
||||
if operand.ndim != len(window_dimensions):
|
||||
msg = ("reduce_window got the wrong number of window_dimensions for "
|
||||
"operand: got operand shape {} with window_dimensions {}.")
|
||||
@ -4612,12 +4682,27 @@ def _common_reduce_window_shape_rule(operand, window_dimensions,
|
||||
msg = ("reduce_window got inconsistent window_strides and "
|
||||
"window_dimensions: got window_strides {} and window_dimensions {}.")
|
||||
raise TypeError(msg.format(window_strides, window_dimensions))
|
||||
if len(base_dilation) != len(window_dimensions):
|
||||
msg = ("reduce_window got inconsistent base_dilation and "
|
||||
"window_dimensions: got base_dilation {} and window_dimensions {}.")
|
||||
raise TypeError(msg.format(base_dilation, window_dimensions))
|
||||
if len(window_dilation) != len(window_dimensions):
|
||||
msg = ("reduce_window got inconsistent window_dilation and "
|
||||
"window_dimensions: got window_dilation {} and window_dimensions "
|
||||
"{}.")
|
||||
raise TypeError(msg.format(window_dilation, window_dimensions))
|
||||
|
||||
return reduce_window_shape_tuple(operand.shape, window_dimensions,
|
||||
window_strides, padding)
|
||||
window_strides, padding, base_dilation,
|
||||
window_dilation)
|
||||
|
||||
def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
|
||||
padding):
|
||||
padding, base_dilation=None,
|
||||
window_dilation=None):
|
||||
if base_dilation is not None:
|
||||
operand_shape = _dilate_shape(operand_shape, base_dilation)
|
||||
if window_dilation is not None:
|
||||
window_dimensions = _dilate_shape(window_dimensions, window_dilation)
|
||||
operand_padded = np.add(operand_shape, np.add(*zip(*padding)))
|
||||
t = np.floor_divide(
|
||||
np.subtract(operand_padded, window_dimensions), window_strides) + 1
|
||||
@ -4708,8 +4793,9 @@ def _select_and_scatter_add_transpose(
|
||||
t, source, operand, *, select_prim, window_dimensions, window_strides,
|
||||
padding):
|
||||
assert ad.is_undefined_primal(source) and not ad.is_undefined_primal(operand)
|
||||
ones = (1,) * len(window_dimensions)
|
||||
source_t = _select_and_gather_add(t, operand, select_prim, window_dimensions,
|
||||
window_strides, padding)
|
||||
window_strides, padding, ones, ones)
|
||||
return [source_t, None]
|
||||
|
||||
def _select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
|
||||
@ -4753,13 +4839,14 @@ batching.primitive_batchers[select_and_scatter_add_p] = \
|
||||
|
||||
def _select_and_gather_add_shape_rule(
|
||||
tangents, operand, *, select_prim, window_dimensions, window_strides,
|
||||
padding):
|
||||
padding, base_dilation, window_dilation):
|
||||
if tangents.shape != operand.shape:
|
||||
msg = ("select_and_gather_add tangents and operand shapes must match, "
|
||||
"got {} and {}.")
|
||||
raise TypeError(msg.format(tangents.shape, operand.shape))
|
||||
return _common_reduce_window_shape_rule(operand, window_dimensions,
|
||||
window_strides, padding)
|
||||
return _common_reduce_window_shape_rule(
|
||||
operand, window_dimensions, window_strides, padding, base_dilation,
|
||||
window_dilation)
|
||||
|
||||
|
||||
_UINT_DTYPES = {
|
||||
@ -4776,7 +4863,7 @@ _INT_DTYPES = {
|
||||
|
||||
def _select_and_gather_add_translation(
|
||||
c, tangents, operand, *, select_prim, window_dimensions, window_strides,
|
||||
padding, max_bits=64):
|
||||
padding, base_dilation, window_dilation, max_bits=64):
|
||||
shape = c.get_shape(operand)
|
||||
dtype = shape.numpy_dtype()
|
||||
etype = shape.xla_element_type()
|
||||
@ -4867,37 +4954,52 @@ def _select_and_gather_add_translation(
|
||||
init = -np.inf if select_prim is ge_p else np.inf
|
||||
out = xops.ReduceWindowWithGeneralPadding(
|
||||
pack(operand, tangents), pack(const(c, dtype, init), const(c, dtype, 0)),
|
||||
reducer(), window_dimensions, window_strides, (), (), padding)
|
||||
reducer(), window_dimensions, window_strides, base_dilation,
|
||||
window_dilation, padding)
|
||||
return snd(out)
|
||||
|
||||
def _select_and_gather_add_jvp(
|
||||
primals, tangents, *, select_prim, window_dimensions, window_strides,
|
||||
padding):
|
||||
padding, base_dilation, window_dilation):
|
||||
source, operand = primals
|
||||
g_source, g_operand = tangents
|
||||
val_out = _select_and_gather_add(
|
||||
source, operand, select_prim, window_dimensions, window_strides,
|
||||
padding)
|
||||
padding, base_dilation, window_dilation)
|
||||
del g_operand
|
||||
if type(g_source) is ad_util.Zero:
|
||||
tangent_out = ad_util.Zero.from_value(val_out)
|
||||
else:
|
||||
tangent_out = _select_and_gather_add(
|
||||
g_source, operand, select_prim, window_dimensions,
|
||||
window_strides, padding)
|
||||
window_strides, padding, base_dilation, window_dilation)
|
||||
return val_out, tangent_out
|
||||
|
||||
def _select_and_gather_add_transpose(
|
||||
t, tangents, operand, *, select_prim, window_dimensions, window_strides,
|
||||
padding):
|
||||
padding, base_dilation, window_dilation):
|
||||
assert select_prim in (le_p, ge_p)
|
||||
assert ad.is_undefined_primal(tangents) and not ad.is_undefined_primal(operand)
|
||||
if any(d != 1 for d in window_dilation):
|
||||
msg = ("VJP not implemented for select_and_gather (MaxPool) with window "
|
||||
"dilation, got window_dilation={}.")
|
||||
raise NotImplementedError(msg.format(window_dilation))
|
||||
has_base_dilation = any(d != 1 for d in base_dilation)
|
||||
if has_base_dilation:
|
||||
select_identity = (_get_max_identity if select_prim is ge_p
|
||||
else _get_min_identity)
|
||||
operand = pad(operand, select_identity(operand.dtype),
|
||||
tuple((0, 0, d - 1) for d in base_dilation))
|
||||
result = _select_and_scatter_add(t, operand, select_prim, window_dimensions,
|
||||
window_strides, padding)
|
||||
if has_base_dilation:
|
||||
result = slice(operand, (0,) * len(operand.shape), operand.shape,
|
||||
base_dilation)
|
||||
return [result, None]
|
||||
|
||||
def _select_and_gather_add_batching_rule(
|
||||
batched_args, batch_dims, *, select_prim, window_dimensions, window_strides,
|
||||
padding):
|
||||
padding, base_dilation, window_dilation):
|
||||
t, x = batched_args
|
||||
t_bdim, x_bdim = batch_dims
|
||||
size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims)
|
||||
@ -4907,8 +5009,11 @@ def _select_and_gather_add_batching_rule(
|
||||
window_dimensions = (1,) + window_dimensions
|
||||
window_strides = (1,) + window_strides
|
||||
padding = ((0, 0),) + padding
|
||||
base_dilation = (1,) + base_dilation
|
||||
window_dilation = (1,) + window_dilation
|
||||
out = _select_and_gather_add(t, x, select_prim, window_dimensions,
|
||||
window_strides, padding)
|
||||
window_strides, padding, base_dilation,
|
||||
window_dilation)
|
||||
return (out, 0)
|
||||
|
||||
|
||||
|
@ -287,13 +287,16 @@ def reduce(operand, init_value, computation, dimensions): # pylint: disable=red
|
||||
return reducer(operand, tuple(dimensions)).astype(np.asarray(operand).dtype)
|
||||
|
||||
def reduce_window(operand, init_value, computation, window_dimensions,
|
||||
window_strides, padding):
|
||||
window_strides, padding, base_dilation):
|
||||
op, dims, strides = operand, window_dimensions, window_strides
|
||||
if isinstance(padding, str):
|
||||
pads = padtype_to_pads(op.shape, dims, strides, padding)
|
||||
else:
|
||||
pads = padding
|
||||
view = _conv_view(op.reshape((1, 1) + op.shape), (1, 1) + dims, strides, pads,
|
||||
op = op.reshape((1, 1) + op.shape)
|
||||
if base_dilation:
|
||||
op = _dilate(op, base_dilation, init_value)
|
||||
view = _conv_view(op, (1, 1) + dims, strides, pads,
|
||||
pad_value=init_value)[0]
|
||||
view = view.reshape(view.shape[1:1+len(dims)] + (-1,))
|
||||
reducer = _make_reducer(computation, init_value)
|
||||
@ -364,13 +367,13 @@ def _pad(arr, pads, pad_value):
|
||||
for (lo, hi), dim in zip(pads, np.shape(arr)))
|
||||
return out[slices]
|
||||
|
||||
def _dilate(operand, factors):
|
||||
def _dilate(operand, factors, fill_value=0):
|
||||
# this logic is like lax.pad, but with two leading dimensions, no edge
|
||||
# padding, and factors are at least 1 (interior padding is at least 0)
|
||||
outspace = np.add(operand.shape[2:],
|
||||
np.multiply(np.subtract(factors, 1),
|
||||
np.subtract(operand.shape[2:], 1)))
|
||||
out = np.zeros(operand.shape[:2] + tuple(outspace), operand.dtype)
|
||||
out = np.full(operand.shape[:2] + tuple(outspace), fill_value, operand.dtype)
|
||||
lhs_slices = tuple(_slice(None, None, step) for step in factors)
|
||||
out[(_slice(None),) * 2 + lhs_slices] = operand
|
||||
return out
|
||||
|
@ -181,8 +181,16 @@ def inner_prod(xs, ys):
|
||||
return tree_reduce(np.add, tree_multimap(contract, xs, ys))
|
||||
|
||||
|
||||
def _safe_subtract(x, y, *, dtype):
|
||||
"""Subtraction that with `inf - inf == 0` semantics."""
|
||||
with np.errstate(invalid='ignore'):
|
||||
return np.where(np.equal(x, y), np.array(0, dtype),
|
||||
np.subtract(x, y, dtype=dtype))
|
||||
|
||||
add = partial(tree_multimap, lambda x, y: np.add(x, y, dtype=_dtype(x)))
|
||||
sub = partial(tree_multimap, lambda x, y: np.subtract(x, y, dtype=_dtype(x)))
|
||||
safe_sub = partial(tree_multimap,
|
||||
lambda x, y: _safe_subtract(x, y, dtype=_dtype(x)))
|
||||
conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x)))
|
||||
|
||||
def scalar_mul(xs, a):
|
||||
@ -203,7 +211,7 @@ def numerical_jvp(f, primals, tangents, eps=EPS):
|
||||
delta = scalar_mul(tangents, eps)
|
||||
f_pos = f(*add(primals, delta))
|
||||
f_neg = f(*sub(primals, delta))
|
||||
return scalar_mul(sub(f_pos, f_neg), 0.5 / eps)
|
||||
return scalar_mul(safe_sub(f_pos, f_neg), 0.5 / eps)
|
||||
|
||||
|
||||
def _merge_tolerance(tol, default):
|
||||
|
@ -671,51 +671,62 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_dtype={}_padding={}"
|
||||
.format(op.__name__, np.dtype(dtype).name, padding),
|
||||
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
|
||||
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
|
||||
"_basedilation={}_windowdilation={}")
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
|
||||
strides, padding, base_dilation, window_dilation),
|
||||
"op": op, "init_val": init_val, "dtype": dtype, "shape": shape,
|
||||
"dims": dims, "strides": strides, "padding": padding,
|
||||
"base_dilation": base_dilation, "window_dilation": window_dilation,
|
||||
"rng_factory": rng_factory}
|
||||
for init_val, op, dtypes, rng_factory in [
|
||||
(0, lax.add, grad_float_dtypes, jtu.rand_small),
|
||||
(-np.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int),
|
||||
(np.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int),
|
||||
]
|
||||
for dtype in dtypes
|
||||
for padding in ["VALID", "SAME"]))
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
for shape, dims, strides, padding, base_dilation, window_dilation in (
|
||||
itertools.chain(
|
||||
itertools.product(
|
||||
[(4, 6)],
|
||||
[(2, 1), (1, 2)],
|
||||
[(1, 1), (2, 1), (1, 2)],
|
||||
# TODO(b/161704903): explicit paddings segfault on CPU.
|
||||
["VALID", "SAME"], #, [(0, 3), (1, 2)]],
|
||||
[(1, 1)] + ([(2, 3)] if op is lax.add else []),
|
||||
[(1, 1)] + ([(1, 2)] if op is lax.add else [])),
|
||||
itertools.product(
|
||||
[(3, 2, 4, 6)],
|
||||
[(1, 1, 2, 1), (2, 1, 2, 1)],
|
||||
[(1, 2, 2, 1), (1, 1, 1, 1)],
|
||||
# TODO(b/161704903): explicit paddings segfault on CPU.
|
||||
["VALID", "SAME"], # [(0, 1), (1, 0), (2, 3), (0, 2)]],
|
||||
[(1, 1, 1, 1)] + ([(2, 1, 3, 2)] if op is lax.add else []),
|
||||
[(1, 1, 1, 1)] + ([(1, 2, 2, 1)] if op is lax.add else []))))
|
||||
for dtype in dtypes))
|
||||
@jtu.ignore_warning(category=UserWarning,
|
||||
message="Using reduced precision for gradient.*")
|
||||
def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory):
|
||||
def testReduceWindowGrad(
|
||||
self, op, init_val, dtype, shape, dims, strides,
|
||||
padding, base_dilation, window_dilation, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
init_val = np.asarray(init_val, dtype=dtype)
|
||||
|
||||
gradient_order = 3
|
||||
# We need this conditional and the corresponding loop logic to be in the
|
||||
# test method, rather than at the parameterized test level, because it
|
||||
# depends on FLAGS for the device under test.
|
||||
# TODO(b/31565929): enable when fixed.
|
||||
if jtu.device_under_test() == "tpu" and op is not lax.add:
|
||||
all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))]
|
||||
if len(shape) != 4 or dims != (1, 1, 2, 1):
|
||||
raise SkipTest("Only R4 SelectAndScatter implemented on TPU")
|
||||
|
||||
# TODO(b/73062247): need variadic reduce-window for better precision.
|
||||
gradient_order = 1
|
||||
else:
|
||||
all_configs = itertools.chain(
|
||||
itertools.product(
|
||||
[(4, 6)], # shapes
|
||||
[(2, 1), (1, 2)], # window_dimensions
|
||||
[(1, 1), (2, 1), (1, 2)] # strides
|
||||
),
|
||||
itertools.product(
|
||||
[(3, 2, 4, 6)], # shapes
|
||||
[(1, 1, 2, 1), (2, 1, 2, 1)], # window_dimensions
|
||||
[(1, 2, 2, 1), (1, 1, 1, 1)]), # strides
|
||||
)
|
||||
gradient_order = 3
|
||||
|
||||
def fun(operand):
|
||||
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
|
||||
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
|
||||
base_dilation, window_dilation)
|
||||
|
||||
for shape, dims, strides in all_configs:
|
||||
operand = rng(shape, dtype)
|
||||
if op is lax.add:
|
||||
eps = 1.
|
||||
|
@ -1302,56 +1302,60 @@ class LaxTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_dtype={}"
|
||||
.format(op.__name__, np.dtype(dtype).name,),
|
||||
"op": op, "init_val": init_val, "dtype": dtype,
|
||||
"rng_factory": rng_factory}
|
||||
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
|
||||
"_basedilation={}_windowdilation={}")
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype),
|
||||
dims, strides, padding, base_dilation, window_dilation),
|
||||
"op": op, "init_val": init_val, "dtype": dtype, "shape": shape,
|
||||
"dims": dims, "strides": strides, "padding": padding,
|
||||
"base_dilation": base_dilation, "window_dilation": window_dilation}
|
||||
for init_val, op, dtypes in [
|
||||
(0, lax.add, [np.float32]),
|
||||
(-np.inf, lax.max, [np.float32]),
|
||||
(np.inf, lax.min, [np.float32]),
|
||||
]
|
||||
for dtype in dtypes
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def testReduceWindow(self, op, init_val, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
init_val = np.asarray(init_val, dtype=dtype)
|
||||
|
||||
all_configs = list(itertools.chain(
|
||||
for shape, dims, strides, padding, base_dilation, window_dilation in (
|
||||
itertools.chain(
|
||||
itertools.product(
|
||||
[(4, 6)],
|
||||
[(2, 1), (1, 2)],
|
||||
[(1, 1), (2, 1), (1, 2)],
|
||||
["VALID", "SAME", [(0, 3), (1, 2)]]),
|
||||
["VALID", "SAME", [(0, 3), (1, 2)]],
|
||||
[(1, 1), (2, 3)],
|
||||
[(1, 1), (1, 2)]),
|
||||
itertools.product(
|
||||
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
|
||||
[(1, 2, 2, 1), (1, 1, 1, 1)],
|
||||
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]])))
|
||||
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
|
||||
[(1, 1, 1, 1), (2, 1, 3, 2)],
|
||||
[(1, 1, 1, 1), (1, 2, 2, 1)])))
|
||||
for dtype in dtypes))
|
||||
def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding,
|
||||
base_dilation, window_dilation):
|
||||
rng = jtu.rand_small(self.rng())
|
||||
init_val = np.asarray(init_val, dtype=dtype)
|
||||
|
||||
def fun(operand, init_val):
|
||||
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
|
||||
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
|
||||
base_dilation, window_dilation)
|
||||
|
||||
def reference_fun(operand, init_val):
|
||||
return lax_reference.reduce_window(operand, init_val, op, dims, strides,
|
||||
padding)
|
||||
padding, base_dilation)
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
for shape, dims, strides, padding in all_configs:
|
||||
args_maker = lambda: [rng(shape, dtype), init_val]
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
if all(d == 1 for d in window_dilation):
|
||||
self._CheckAgainstNumpy(fun, reference_fun, args_maker)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
# we separately test the version that uses a concrete init_val because it
|
||||
# can hit different code paths
|
||||
def fun(operand):
|
||||
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
|
||||
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
|
||||
base_dilation, window_dilation)
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
for shape, dims, strides, padding in all_configs:
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_shape={}_axis={}"
|
||||
|
@ -500,35 +500,43 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_dtype={}_padding={}"
|
||||
.format(op.__name__, np.dtype(dtype).name, padding),
|
||||
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
|
||||
"rng_factory": rng_factory}
|
||||
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
|
||||
"_basedilation={}_windowdilation={}")
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype),
|
||||
dims, strides, padding, base_dilation, window_dilation),
|
||||
"op": op, "init_val": init_val, "dtype": dtype, "shape": shape,
|
||||
"dims": dims, "strides": strides, "padding": padding,
|
||||
"base_dilation": base_dilation, "window_dilation": window_dilation}
|
||||
for init_val, op, dtypes in [
|
||||
(0, lax.add, [np.float32]),
|
||||
(-np.inf, lax.max, [np.float32]),
|
||||
(np.inf, lax.min, [np.float32]),
|
||||
]
|
||||
for dtype in dtypes
|
||||
for padding in ["VALID", "SAME"]
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def testReduceWindow(self, op, init_val, dtype, padding, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
init_val = np.asarray(init_val, dtype=dtype)
|
||||
|
||||
all_configs = itertools.chain(
|
||||
for shape, dims, strides, padding, base_dilation, window_dilation in (
|
||||
itertools.chain(
|
||||
itertools.product(
|
||||
[(4, 6)],
|
||||
[(2, 1), (1, 2)],
|
||||
[(1, 1), (2, 1), (1, 2)]),
|
||||
[(1, 1), (2, 1), (1, 2)],
|
||||
["VALID", "SAME", [(0, 3), (1, 2)]],
|
||||
[(1, 1), (2, 3)],
|
||||
[(1, 1), (1, 2)]),
|
||||
itertools.product(
|
||||
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
|
||||
[(1, 2, 2, 1), (1, 1, 1, 1)]))
|
||||
[(1, 2, 2, 1), (1, 1, 1, 1)],
|
||||
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
|
||||
[(1, 1, 1, 1), (2, 1, 3, 2)],
|
||||
[(1, 1, 1, 1), (1, 2, 2, 1)])))
|
||||
for dtype in dtypes))
|
||||
def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding,
|
||||
base_dilation, window_dilation):
|
||||
rng = jtu.rand_small(self.rng())
|
||||
init_val = np.asarray(init_val, dtype=dtype)
|
||||
|
||||
def fun(operand):
|
||||
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
|
||||
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
|
||||
base_dilation, window_dilation)
|
||||
|
||||
for shape, dims, strides in all_configs:
|
||||
for bdims in all_bdims(shape):
|
||||
self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
@ -578,8 +586,9 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
|
||||
def fun(operand, tangents):
|
||||
pads = lax.padtype_to_pads(operand.shape, dims, strides, padding)
|
||||
ones = (1,) * len(operand.shape)
|
||||
return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims,
|
||||
strides, pads)
|
||||
strides, pads, ones, ones)
|
||||
|
||||
for shape, dims, strides in all_configs:
|
||||
for bdims in all_bdims(shape, shape):
|
||||
|
Loading…
x
Reference in New Issue
Block a user