Add support for base dilation and window dilation to reduce window op… (#3803)

This commit is contained in:
Peter Hawkins 2020-07-20 17:27:24 -04:00 committed by GitHub
parent ce14409025
commit a6e2d20b31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 300 additions and 153 deletions

View File

@ -781,11 +781,14 @@ tf_impl[lax.cumprod_p] = tf.math.cumprod
def _reduce_window_shape(jax_f, operand, window_dimensions,
window_strides, padding, input_shape=None):
window_strides, padding, base_dilation,
window_dilation, input_shape=None):
"""Shape inference function for reduce_window_{sum,min,max}."""
params = dict(window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding,
base_dilation=base_dilation,
window_dilation=window_dilation,
input_shape=input_shape)
try:
out, = _infer_shape_jax(jax_f, operand, **params)
@ -802,17 +805,20 @@ def _get_shape_from_tensor_or_array(x):
def _reduce_window(jax_f, reducer, init_val, operand, window_dimensions,
window_strides, padding, input_shape=None):
window_strides, padding, base_dilation, window_dilation,
input_shape=None):
"""TensorFlow implementation of reduce_window_{sum,min,max}."""
del input_shape
# TODO(tomhennigan): tf2xla should have a shape inference function.
out_shape = _reduce_window_shape(jax_f, operand, window_dimensions,
window_strides, padding)
window_strides, padding, base_dilation,
window_dilation)
a = tf.constant(0, operand.dtype)
reducer_fn = reducer.get_concrete_function(a, a)
out = tfxla.reduce_window(operand, tf.constant(init_val, operand.dtype),
reducer_fn, window_dimensions,
window_strides, padding=padding)
window_strides, base_dilations=base_dilation,
window_dilations=window_dilation, padding=padding)
out.set_shape(out_shape)
return out
# pylint: disable=protected-access

View File

@ -176,8 +176,9 @@ def _pooling_layer(reducer, init_val, rescaler=None):
def init_fun(rng, input_shape):
padding_vals = lax.padtype_to_pads(input_shape, window_shape,
strides, padding)
out_shape = lax.reduce_window_shape_tuple(input_shape, window_shape,
strides, padding_vals)
ones = (1,) * len(window_shape)
out_shape = lax.reduce_window_shape_tuple(
input_shape, window_shape, strides, padding_vals, ones, ones)
return out_shape, ()
def apply_fun(params, inputs, **kwargs):
out = lax.reduce_window(inputs, init_val, reducer, window_shape,

View File

@ -1089,7 +1089,9 @@ def _reduce_and(operand: Array, axes: Sequence[int]) -> Array:
def reduce_window(operand: Array, init_value: Array, computation: Callable,
window_dimensions: Shape, window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]]) -> Array:
padding: Union[str, Sequence[Tuple[int, int]]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
"""Wraps XLA's `ReduceWindowWithGeneralPadding
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
operator.
@ -1099,15 +1101,22 @@ def reduce_window(operand: Array, init_value: Array, computation: Callable,
window_strides, padding))
else:
padding = tuple(padding)
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
monoid_reducer = _get_monoid_window_reducer(computation, init_value)
if monoid_reducer:
return monoid_reducer(operand, window_dimensions, window_strides, padding)
return monoid_reducer(operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation)
else:
jaxpr, consts = _reduction_jaxpr(computation, _abstractify(init_value))
return reduce_window_p.bind(
operand, init_value, jaxpr=jaxpr, consts=consts,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=padding)
window_strides=tuple(window_strides), padding=padding,
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
def _get_monoid_window_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]:
aval = core.get_aval(x)
@ -1122,46 +1131,82 @@ def _get_monoid_window_reducer(monoid_op: Callable, x: Array) -> Optional[Callab
def _reduce_window_sum(operand: Array, window_dimensions: Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]]) -> Array:
padding: Sequence[Tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
return reduce_window_sum_p.bind(
operand, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding))
window_strides=tuple(window_strides), padding=tuple(padding),
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
def _reduce_window_prod(operand: Array, window_dimensions: Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]]) -> Array:
padding: Sequence[Tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
init_value = _const(operand, 1)
jaxpr, consts = _reduction_jaxpr(mul, _abstractify(init_value))
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
return reduce_window_p.bind(
operand, init_value, jaxpr=jaxpr, consts=consts,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding))
window_strides=tuple(window_strides), padding=tuple(padding),
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
def _reduce_window_max(operand: Array, window_dimensions: Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]]) -> Array:
padding: Sequence[Tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
return reduce_window_max_p.bind(
operand, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding))
window_strides=tuple(window_strides), padding=tuple(padding),
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
def _reduce_window_min(operand: Array, window_dimensions: Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]]) -> Array:
padding: Sequence[Tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
return reduce_window_min_p.bind(
operand, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding))
window_strides=tuple(window_strides), padding=tuple(padding),
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
def _select_and_scatter(operand: Array, select: Callable,
window_dimensions: Shape, window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]], source: Array,
init_value: Array, scatter: Callable) -> Array:
init_value: Array, scatter: Callable,
base_dilation: Sequence[int],
window_dilation: Sequence[int]) -> Array:
select_jaxpr, select_consts = _reduction_jaxpr(select, _abstractify(init_value))
scatter_jaxpr, scatter_consts = _reduction_jaxpr(scatter, _abstractify(init_value))
return select_and_scatter_p.bind(
operand, source, init_value, select_jaxpr=select_jaxpr,
select_consts=select_consts, scatter_jaxpr=scatter_jaxpr,
scatter_consts=scatter_consts, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding))
window_strides=tuple(window_strides), padding=tuple(padding),
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
def _select_and_scatter_add(source: Array, operand: Array,
select_prim: core.Primitive,
@ -1177,11 +1222,15 @@ def _select_and_gather_add(tangents: Array, operand: Array,
select_prim: core.Primitive,
window_dimensions: Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]]) -> Array:
padding: Sequence[Tuple[int, int]],
base_dilation: Sequence[int],
window_dilation: Sequence[int]) -> Array:
return select_and_gather_add_p.bind(
tangents, operand, select_prim=select_prim,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding))
window_strides=tuple(window_strides), padding=tuple(padding),
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
def cumsum(operand: Array, axis: int) -> Array:
"""Computes a cumulative sum along `axis`."""
@ -4490,38 +4539,43 @@ reduce_and_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bo
batching.defreducer(reduce_and_p)
def _reduce_window_shape_rule(operand, init_value, *, jaxpr, consts,
window_dimensions, window_strides, padding):
window_dimensions, window_strides, padding,
base_dilation, window_dilation):
if operand.dtype != init_value.dtype:
msg = ("reduce_window got inconsistent dtypes for operand and init_value: "
" got operand dtype {} and init_value dtype {}.")
raise TypeError(msg.format(operand.dtype, init_value.dtype))
return _common_reduce_window_shape_rule(operand, window_dimensions,
window_strides, padding)
return _common_reduce_window_shape_rule(
operand, window_dimensions, window_strides, padding, base_dilation,
window_dilation)
def _reduce_window_translation_rule(c, operand, init_value, *, jaxpr, consts,
window_dimensions, window_strides, padding):
window_dimensions, window_strides, padding,
base_dilation, window_dilation):
xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
return xops.ReduceWindowWithGeneralPadding(
operand, init_value, xla_computation, window_dimensions,
window_strides, (), (), padding)
window_strides, base_dilation, window_dilation, padding)
def _generic_reduce_window_batch_rule(
batched_args, batch_dims, *, jaxpr, consts, window_dimensions,
window_strides, padding):
window_strides, padding, base_dilation, window_dilation):
operand, init = batched_args
bdim, init_bdim = batch_dims
if init_bdim is not None:
raise NotImplementedError("reduce_window batching is not implemented for "
"initial values")
def reduce_window(x, window_dimensions, window_strides, padding):
def reduce_window(x, window_dimensions, window_strides, padding, base_dilation,
window_dilation):
return reduce_window_p.bind(
x, init, jaxpr=jaxpr, consts=consts, window_dimensions=window_dimensions,
window_strides=window_strides, padding=padding)
return _reduce_window_batch_rule(reduce_window, (operand,), (bdim,),
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding)
window_strides=window_strides, padding=padding, base_dilation=base_dilation,
window_dilation=window_dilation)
return _reduce_window_batch_rule(
reduce_window, (operand,), (bdim,), window_dimensions=window_dimensions,
window_strides=window_strides, padding=padding, base_dilation=base_dilation,
window_dilation=window_dilation)
reduce_window_p = standard_primitive(
@ -4531,40 +4585,46 @@ batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides,
padding):
padding, base_dilation, window_dilation):
if not dtypes.issubdtype(operand.dtype, np.number):
msg = "operand to reduce_window_sum must have a number dtype, got {}"
raise TypeError(msg.format(np.dtype(operand.dtype).name))
return _common_reduce_window_shape_rule(operand, window_dimensions,
window_strides, padding)
window_strides, padding, base_dilation,
window_dilation)
def _reduce_window_sum_translation_rule(c, operand, *, window_dimensions,
window_strides, padding):
window_strides, padding, base_dilation,
window_dilation):
dtype = c.get_shape(operand).numpy_dtype()
scalar = ShapedArray((), dtype)
return xops.ReduceWindowWithGeneralPadding(
operand, xb.constant(c, np.array(0, dtype)),
xla.primitive_subcomputation(add_p, scalar, scalar), window_dimensions,
window_strides, (), (), padding)
window_strides, base_dilation, window_dilation, padding)
def _reduce_window_sum_transpose_rule(cotangent, operand, *, window_dimensions,
window_strides, padding):
window_strides, padding, base_dilation,
window_dilation):
assert ad.is_undefined_primal(operand)
input_shape = operand.aval.shape
ones = [1] * len(input_shape)
pads = _conv_general_vjp_lhs_padding(
input_shape, window_dimensions, window_strides, cotangent.shape, padding,
ones, ones)
base_dilation, window_dilation)
ones = [1] * len(input_shape)
padding_config = [(lo, hi, stride - 1)
for (lo, hi), stride in zip(pads, window_strides)]
pad_cotangent = pad(cotangent, _zero(cotangent), padding_config)
result = _reduce_window_sum(pad_cotangent, window_dimensions, ones,
[(0, 0)] * len(input_shape))
assert result.shape == input_shape
result = _reduce_window_sum(pad_cotangent, window_dimensions, base_dilation,
[(0, 0)] * len(input_shape),
base_dilation=ones,
window_dilation=window_dilation)
assert result.shape == input_shape, (result.shape, input_shape)
return [result]
def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *,
window_dimensions, window_strides, padding):
window_dimensions, window_strides, padding,
base_dilation, window_dilation):
operand, = batched_args
bdim, = bdims
@ -4573,7 +4633,11 @@ def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *,
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]
padding = padding[:bdim] + ((0, 0),) + padding[bdim:]
operand = reduce_window(operand, window_dimensions, window_strides, padding)
base_dilation = base_dilation[:bdim] + (1,) + base_dilation[bdim:]
window_dilation = window_dilation[:bdim] + (1,) + window_dilation[bdim:]
operand = reduce_window(operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation)
return operand, bdim
reduce_window_sum_p = standard_primitive(
@ -4584,26 +4648,32 @@ batching.primitive_batchers[reduce_window_sum_p] = partial(
_reduce_window_batch_rule, _reduce_window_sum)
def _reduce_window_chooser_translation_rule(
prim, identity, c, operand, *, window_dimensions, window_strides, padding):
prim, identity, c, operand, *, window_dimensions, window_strides, padding,
base_dilation, window_dilation):
dtype = c.get_shape(operand).numpy_dtype()
scalar = ShapedArray((), dtype)
return xops.ReduceWindowWithGeneralPadding(
operand, xb.constant(c, identity(dtype)),
xla.primitive_subcomputation(prim, scalar, scalar), window_dimensions,
window_strides, (), (), padding)
window_strides, base_dilation, window_dilation, padding)
def _reduce_window_chooser_jvp_rule(prim, g, operand, *, window_dimensions,
window_strides, padding):
window_strides, padding, base_dilation,
window_dilation):
assert prim is max_p or prim is min_p
select_prim = ge_p if prim is max_p else le_p
return _select_and_gather_add(g, operand, select_prim, window_dimensions,
window_strides, padding)
window_strides, padding, base_dilation,
window_dilation)
def _common_reduce_window_shape_rule(operand, window_dimensions,
window_strides, padding):
window_strides, padding, base_dilation,
window_dilation):
_check_shapelike("reduce_window", "window_dimensions", window_dimensions)
_check_shapelike("reduce_window", "window_strides", window_strides)
_check_shapelike("reduce_window", "base_dilation", base_dilation)
_check_shapelike("reduce_window", "window_dilation", window_dilation)
if operand.ndim != len(window_dimensions):
msg = ("reduce_window got the wrong number of window_dimensions for "
"operand: got operand shape {} with window_dimensions {}.")
@ -4612,12 +4682,27 @@ def _common_reduce_window_shape_rule(operand, window_dimensions,
msg = ("reduce_window got inconsistent window_strides and "
"window_dimensions: got window_strides {} and window_dimensions {}.")
raise TypeError(msg.format(window_strides, window_dimensions))
if len(base_dilation) != len(window_dimensions):
msg = ("reduce_window got inconsistent base_dilation and "
"window_dimensions: got base_dilation {} and window_dimensions {}.")
raise TypeError(msg.format(base_dilation, window_dimensions))
if len(window_dilation) != len(window_dimensions):
msg = ("reduce_window got inconsistent window_dilation and "
"window_dimensions: got window_dilation {} and window_dimensions "
"{}.")
raise TypeError(msg.format(window_dilation, window_dimensions))
return reduce_window_shape_tuple(operand.shape, window_dimensions,
window_strides, padding)
window_strides, padding, base_dilation,
window_dilation)
def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
padding):
padding, base_dilation=None,
window_dilation=None):
if base_dilation is not None:
operand_shape = _dilate_shape(operand_shape, base_dilation)
if window_dilation is not None:
window_dimensions = _dilate_shape(window_dimensions, window_dilation)
operand_padded = np.add(operand_shape, np.add(*zip(*padding)))
t = np.floor_divide(
np.subtract(operand_padded, window_dimensions), window_strides) + 1
@ -4708,8 +4793,9 @@ def _select_and_scatter_add_transpose(
t, source, operand, *, select_prim, window_dimensions, window_strides,
padding):
assert ad.is_undefined_primal(source) and not ad.is_undefined_primal(operand)
ones = (1,) * len(window_dimensions)
source_t = _select_and_gather_add(t, operand, select_prim, window_dimensions,
window_strides, padding)
window_strides, padding, ones, ones)
return [source_t, None]
def _select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
@ -4753,13 +4839,14 @@ batching.primitive_batchers[select_and_scatter_add_p] = \
def _select_and_gather_add_shape_rule(
tangents, operand, *, select_prim, window_dimensions, window_strides,
padding):
padding, base_dilation, window_dilation):
if tangents.shape != operand.shape:
msg = ("select_and_gather_add tangents and operand shapes must match, "
"got {} and {}.")
raise TypeError(msg.format(tangents.shape, operand.shape))
return _common_reduce_window_shape_rule(operand, window_dimensions,
window_strides, padding)
return _common_reduce_window_shape_rule(
operand, window_dimensions, window_strides, padding, base_dilation,
window_dilation)
_UINT_DTYPES = {
@ -4776,7 +4863,7 @@ _INT_DTYPES = {
def _select_and_gather_add_translation(
c, tangents, operand, *, select_prim, window_dimensions, window_strides,
padding, max_bits=64):
padding, base_dilation, window_dilation, max_bits=64):
shape = c.get_shape(operand)
dtype = shape.numpy_dtype()
etype = shape.xla_element_type()
@ -4867,37 +4954,52 @@ def _select_and_gather_add_translation(
init = -np.inf if select_prim is ge_p else np.inf
out = xops.ReduceWindowWithGeneralPadding(
pack(operand, tangents), pack(const(c, dtype, init), const(c, dtype, 0)),
reducer(), window_dimensions, window_strides, (), (), padding)
reducer(), window_dimensions, window_strides, base_dilation,
window_dilation, padding)
return snd(out)
def _select_and_gather_add_jvp(
primals, tangents, *, select_prim, window_dimensions, window_strides,
padding):
padding, base_dilation, window_dilation):
source, operand = primals
g_source, g_operand = tangents
val_out = _select_and_gather_add(
source, operand, select_prim, window_dimensions, window_strides,
padding)
padding, base_dilation, window_dilation)
del g_operand
if type(g_source) is ad_util.Zero:
tangent_out = ad_util.Zero.from_value(val_out)
else:
tangent_out = _select_and_gather_add(
g_source, operand, select_prim, window_dimensions,
window_strides, padding)
window_strides, padding, base_dilation, window_dilation)
return val_out, tangent_out
def _select_and_gather_add_transpose(
t, tangents, operand, *, select_prim, window_dimensions, window_strides,
padding):
padding, base_dilation, window_dilation):
assert select_prim in (le_p, ge_p)
assert ad.is_undefined_primal(tangents) and not ad.is_undefined_primal(operand)
if any(d != 1 for d in window_dilation):
msg = ("VJP not implemented for select_and_gather (MaxPool) with window "
"dilation, got window_dilation={}.")
raise NotImplementedError(msg.format(window_dilation))
has_base_dilation = any(d != 1 for d in base_dilation)
if has_base_dilation:
select_identity = (_get_max_identity if select_prim is ge_p
else _get_min_identity)
operand = pad(operand, select_identity(operand.dtype),
tuple((0, 0, d - 1) for d in base_dilation))
result = _select_and_scatter_add(t, operand, select_prim, window_dimensions,
window_strides, padding)
if has_base_dilation:
result = slice(operand, (0,) * len(operand.shape), operand.shape,
base_dilation)
return [result, None]
def _select_and_gather_add_batching_rule(
batched_args, batch_dims, *, select_prim, window_dimensions, window_strides,
padding):
padding, base_dilation, window_dilation):
t, x = batched_args
t_bdim, x_bdim = batch_dims
size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims)
@ -4907,8 +5009,11 @@ def _select_and_gather_add_batching_rule(
window_dimensions = (1,) + window_dimensions
window_strides = (1,) + window_strides
padding = ((0, 0),) + padding
base_dilation = (1,) + base_dilation
window_dilation = (1,) + window_dilation
out = _select_and_gather_add(t, x, select_prim, window_dimensions,
window_strides, padding)
window_strides, padding, base_dilation,
window_dilation)
return (out, 0)

View File

@ -287,13 +287,16 @@ def reduce(operand, init_value, computation, dimensions): # pylint: disable=red
return reducer(operand, tuple(dimensions)).astype(np.asarray(operand).dtype)
def reduce_window(operand, init_value, computation, window_dimensions,
window_strides, padding):
window_strides, padding, base_dilation):
op, dims, strides = operand, window_dimensions, window_strides
if isinstance(padding, str):
pads = padtype_to_pads(op.shape, dims, strides, padding)
else:
pads = padding
view = _conv_view(op.reshape((1, 1) + op.shape), (1, 1) + dims, strides, pads,
op = op.reshape((1, 1) + op.shape)
if base_dilation:
op = _dilate(op, base_dilation, init_value)
view = _conv_view(op, (1, 1) + dims, strides, pads,
pad_value=init_value)[0]
view = view.reshape(view.shape[1:1+len(dims)] + (-1,))
reducer = _make_reducer(computation, init_value)
@ -364,13 +367,13 @@ def _pad(arr, pads, pad_value):
for (lo, hi), dim in zip(pads, np.shape(arr)))
return out[slices]
def _dilate(operand, factors):
def _dilate(operand, factors, fill_value=0):
# this logic is like lax.pad, but with two leading dimensions, no edge
# padding, and factors are at least 1 (interior padding is at least 0)
outspace = np.add(operand.shape[2:],
np.multiply(np.subtract(factors, 1),
np.subtract(operand.shape[2:], 1)))
out = np.zeros(operand.shape[:2] + tuple(outspace), operand.dtype)
out = np.full(operand.shape[:2] + tuple(outspace), fill_value, operand.dtype)
lhs_slices = tuple(_slice(None, None, step) for step in factors)
out[(_slice(None),) * 2 + lhs_slices] = operand
return out

View File

@ -181,8 +181,16 @@ def inner_prod(xs, ys):
return tree_reduce(np.add, tree_multimap(contract, xs, ys))
def _safe_subtract(x, y, *, dtype):
"""Subtraction that with `inf - inf == 0` semantics."""
with np.errstate(invalid='ignore'):
return np.where(np.equal(x, y), np.array(0, dtype),
np.subtract(x, y, dtype=dtype))
add = partial(tree_multimap, lambda x, y: np.add(x, y, dtype=_dtype(x)))
sub = partial(tree_multimap, lambda x, y: np.subtract(x, y, dtype=_dtype(x)))
safe_sub = partial(tree_multimap,
lambda x, y: _safe_subtract(x, y, dtype=_dtype(x)))
conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x)))
def scalar_mul(xs, a):
@ -203,7 +211,7 @@ def numerical_jvp(f, primals, tangents, eps=EPS):
delta = scalar_mul(tangents, eps)
f_pos = f(*add(primals, delta))
f_neg = f(*sub(primals, delta))
return scalar_mul(sub(f_pos, f_neg), 0.5 / eps)
return scalar_mul(safe_sub(f_pos, f_neg), 0.5 / eps)
def _merge_tolerance(tol, default):

View File

@ -671,51 +671,62 @@ class LaxAutodiffTest(jtu.JaxTestCase):
check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_dtype={}_padding={}"
.format(op.__name__, np.dtype(dtype).name, padding),
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
"_basedilation={}_windowdilation={}")
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
strides, padding, base_dilation, window_dilation),
"op": op, "init_val": init_val, "dtype": dtype, "shape": shape,
"dims": dims, "strides": strides, "padding": padding,
"base_dilation": base_dilation, "window_dilation": window_dilation,
"rng_factory": rng_factory}
for init_val, op, dtypes, rng_factory in [
(0, lax.add, grad_float_dtypes, jtu.rand_small),
(-np.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int),
(np.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int),
]
for dtype in dtypes
for padding in ["VALID", "SAME"]))
@jtu.skip_on_flag("jax_skip_slow_tests", True)
for shape, dims, strides, padding, base_dilation, window_dilation in (
itertools.chain(
itertools.product(
[(4, 6)],
[(2, 1), (1, 2)],
[(1, 1), (2, 1), (1, 2)],
# TODO(b/161704903): explicit paddings segfault on CPU.
["VALID", "SAME"], #, [(0, 3), (1, 2)]],
[(1, 1)] + ([(2, 3)] if op is lax.add else []),
[(1, 1)] + ([(1, 2)] if op is lax.add else [])),
itertools.product(
[(3, 2, 4, 6)],
[(1, 1, 2, 1), (2, 1, 2, 1)],
[(1, 2, 2, 1), (1, 1, 1, 1)],
# TODO(b/161704903): explicit paddings segfault on CPU.
["VALID", "SAME"], # [(0, 1), (1, 0), (2, 3), (0, 2)]],
[(1, 1, 1, 1)] + ([(2, 1, 3, 2)] if op is lax.add else []),
[(1, 1, 1, 1)] + ([(1, 2, 2, 1)] if op is lax.add else []))))
for dtype in dtypes))
@jtu.ignore_warning(category=UserWarning,
message="Using reduced precision for gradient.*")
def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory):
def testReduceWindowGrad(
self, op, init_val, dtype, shape, dims, strides,
padding, base_dilation, window_dilation, rng_factory):
rng = rng_factory(self.rng())
init_val = np.asarray(init_val, dtype=dtype)
gradient_order = 3
# We need this conditional and the corresponding loop logic to be in the
# test method, rather than at the parameterized test level, because it
# depends on FLAGS for the device under test.
# TODO(b/31565929): enable when fixed.
if jtu.device_under_test() == "tpu" and op is not lax.add:
all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))]
if len(shape) != 4 or dims != (1, 1, 2, 1):
raise SkipTest("Only R4 SelectAndScatter implemented on TPU")
# TODO(b/73062247): need variadic reduce-window for better precision.
gradient_order = 1
else:
all_configs = itertools.chain(
itertools.product(
[(4, 6)], # shapes
[(2, 1), (1, 2)], # window_dimensions
[(1, 1), (2, 1), (1, 2)] # strides
),
itertools.product(
[(3, 2, 4, 6)], # shapes
[(1, 1, 2, 1), (2, 1, 2, 1)], # window_dimensions
[(1, 2, 2, 1), (1, 1, 1, 1)]), # strides
)
gradient_order = 3
def fun(operand):
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
base_dilation, window_dilation)
for shape, dims, strides in all_configs:
operand = rng(shape, dtype)
if op is lax.add:
eps = 1.

View File

@ -1302,56 +1302,60 @@ class LaxTest(jtu.JaxTestCase):
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_dtype={}"
.format(op.__name__, np.dtype(dtype).name,),
"op": op, "init_val": init_val, "dtype": dtype,
"rng_factory": rng_factory}
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
"_basedilation={}_windowdilation={}")
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype),
dims, strides, padding, base_dilation, window_dilation),
"op": op, "init_val": init_val, "dtype": dtype, "shape": shape,
"dims": dims, "strides": strides, "padding": padding,
"base_dilation": base_dilation, "window_dilation": window_dilation}
for init_val, op, dtypes in [
(0, lax.add, [np.float32]),
(-np.inf, lax.max, [np.float32]),
(np.inf, lax.min, [np.float32]),
]
for dtype in dtypes
for rng_factory in [jtu.rand_small]))
def testReduceWindow(self, op, init_val, dtype, rng_factory):
rng = rng_factory(self.rng())
init_val = np.asarray(init_val, dtype=dtype)
all_configs = list(itertools.chain(
for shape, dims, strides, padding, base_dilation, window_dilation in (
itertools.chain(
itertools.product(
[(4, 6)],
[(2, 1), (1, 2)],
[(1, 1), (2, 1), (1, 2)],
["VALID", "SAME", [(0, 3), (1, 2)]]),
["VALID", "SAME", [(0, 3), (1, 2)]],
[(1, 1), (2, 3)],
[(1, 1), (1, 2)]),
itertools.product(
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
[(1, 2, 2, 1), (1, 1, 1, 1)],
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]])))
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
[(1, 1, 1, 1), (2, 1, 3, 2)],
[(1, 1, 1, 1), (1, 2, 2, 1)])))
for dtype in dtypes))
def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding,
base_dilation, window_dilation):
rng = jtu.rand_small(self.rng())
init_val = np.asarray(init_val, dtype=dtype)
def fun(operand, init_val):
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
base_dilation, window_dilation)
def reference_fun(operand, init_val):
return lax_reference.reduce_window(operand, init_val, op, dims, strides,
padding)
padding, base_dilation)
# pylint: disable=cell-var-from-loop
for shape, dims, strides, padding in all_configs:
args_maker = lambda: [rng(shape, dtype), init_val]
self._CompileAndCheck(fun, args_maker)
if all(d == 1 for d in window_dilation):
self._CheckAgainstNumpy(fun, reference_fun, args_maker)
# pylint: enable=cell-var-from-loop
# we separately test the version that uses a concrete init_val because it
# can hit different code paths
def fun(operand):
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
base_dilation, window_dilation)
# pylint: disable=cell-var-from-loop
for shape, dims, strides, padding in all_configs:
args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(fun, args_maker)
# pylint: enable=cell-var-from-loop
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_shape={}_axis={}"

View File

@ -500,35 +500,43 @@ class LaxVmapTest(jtu.JaxTestCase):
self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_dtype={}_padding={}"
.format(op.__name__, np.dtype(dtype).name, padding),
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
"rng_factory": rng_factory}
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
"_basedilation={}_windowdilation={}")
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype),
dims, strides, padding, base_dilation, window_dilation),
"op": op, "init_val": init_val, "dtype": dtype, "shape": shape,
"dims": dims, "strides": strides, "padding": padding,
"base_dilation": base_dilation, "window_dilation": window_dilation}
for init_val, op, dtypes in [
(0, lax.add, [np.float32]),
(-np.inf, lax.max, [np.float32]),
(np.inf, lax.min, [np.float32]),
]
for dtype in dtypes
for padding in ["VALID", "SAME"]
for rng_factory in [jtu.rand_small]))
def testReduceWindow(self, op, init_val, dtype, padding, rng_factory):
rng = rng_factory(self.rng())
init_val = np.asarray(init_val, dtype=dtype)
all_configs = itertools.chain(
for shape, dims, strides, padding, base_dilation, window_dilation in (
itertools.chain(
itertools.product(
[(4, 6)],
[(2, 1), (1, 2)],
[(1, 1), (2, 1), (1, 2)]),
[(1, 1), (2, 1), (1, 2)],
["VALID", "SAME", [(0, 3), (1, 2)]],
[(1, 1), (2, 3)],
[(1, 1), (1, 2)]),
itertools.product(
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
[(1, 2, 2, 1), (1, 1, 1, 1)]))
[(1, 2, 2, 1), (1, 1, 1, 1)],
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
[(1, 1, 1, 1), (2, 1, 3, 2)],
[(1, 1, 1, 1), (1, 2, 2, 1)])))
for dtype in dtypes))
def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding,
base_dilation, window_dilation):
rng = jtu.rand_small(self.rng())
init_val = np.asarray(init_val, dtype=dtype)
def fun(operand):
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
base_dilation, window_dilation)
for shape, dims, strides in all_configs:
for bdims in all_bdims(shape):
self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng)
@ -578,8 +586,9 @@ class LaxVmapTest(jtu.JaxTestCase):
def fun(operand, tangents):
pads = lax.padtype_to_pads(operand.shape, dims, strides, padding)
ones = (1,) * len(operand.shape)
return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims,
strides, pads)
strides, pads, ones, ones)
for shape, dims, strides in all_configs:
for bdims in all_bdims(shape, shape):