diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 658552070..fc81c809e 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -781,11 +781,14 @@ tf_impl[lax.cumprod_p] = tf.math.cumprod def _reduce_window_shape(jax_f, operand, window_dimensions, - window_strides, padding, input_shape=None): + window_strides, padding, base_dilation, + window_dilation, input_shape=None): """Shape inference function for reduce_window_{sum,min,max}.""" params = dict(window_dimensions=window_dimensions, window_strides=window_strides, padding=padding, + base_dilation=base_dilation, + window_dilation=window_dilation, input_shape=input_shape) try: out, = _infer_shape_jax(jax_f, operand, **params) @@ -802,17 +805,20 @@ def _get_shape_from_tensor_or_array(x): def _reduce_window(jax_f, reducer, init_val, operand, window_dimensions, - window_strides, padding, input_shape=None): + window_strides, padding, base_dilation, window_dilation, + input_shape=None): """TensorFlow implementation of reduce_window_{sum,min,max}.""" del input_shape # TODO(tomhennigan): tf2xla should have a shape inference function. out_shape = _reduce_window_shape(jax_f, operand, window_dimensions, - window_strides, padding) + window_strides, padding, base_dilation, + window_dilation) a = tf.constant(0, operand.dtype) reducer_fn = reducer.get_concrete_function(a, a) out = tfxla.reduce_window(operand, tf.constant(init_val, operand.dtype), reducer_fn, window_dimensions, - window_strides, padding=padding) + window_strides, base_dilations=base_dilation, + window_dilations=window_dilation, padding=padding) out.set_shape(out_shape) return out # pylint: disable=protected-access diff --git a/jax/experimental/stax.py b/jax/experimental/stax.py index b4b853ecb..46e0315a2 100644 --- a/jax/experimental/stax.py +++ b/jax/experimental/stax.py @@ -176,8 +176,9 @@ def _pooling_layer(reducer, init_val, rescaler=None): def init_fun(rng, input_shape): padding_vals = lax.padtype_to_pads(input_shape, window_shape, strides, padding) - out_shape = lax.reduce_window_shape_tuple(input_shape, window_shape, - strides, padding_vals) + ones = (1,) * len(window_shape) + out_shape = lax.reduce_window_shape_tuple( + input_shape, window_shape, strides, padding_vals, ones, ones) return out_shape, () def apply_fun(params, inputs, **kwargs): out = lax.reduce_window(inputs, init_val, reducer, window_shape, diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 5c265f201..dce39b341 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -1089,7 +1089,9 @@ def _reduce_and(operand: Array, axes: Sequence[int]) -> Array: def reduce_window(operand: Array, init_value: Array, computation: Callable, window_dimensions: Shape, window_strides: Sequence[int], - padding: Union[str, Sequence[Tuple[int, int]]]) -> Array: + padding: Union[str, Sequence[Tuple[int, int]]], + base_dilation: Optional[Sequence[int]] = None, + window_dilation: Optional[Sequence[int]] = None) -> Array: """Wraps XLA's `ReduceWindowWithGeneralPadding `_ operator. @@ -1099,15 +1101,22 @@ def reduce_window(operand: Array, init_value: Array, computation: Callable, window_strides, padding)) else: padding = tuple(padding) + if base_dilation is None: + base_dilation = (1,) * len(window_dimensions) + if window_dilation is None: + window_dilation = (1,) * len(window_dimensions) monoid_reducer = _get_monoid_window_reducer(computation, init_value) if monoid_reducer: - return monoid_reducer(operand, window_dimensions, window_strides, padding) + return monoid_reducer(operand, window_dimensions, window_strides, padding, + base_dilation, window_dilation) else: jaxpr, consts = _reduction_jaxpr(computation, _abstractify(init_value)) return reduce_window_p.bind( operand, init_value, jaxpr=jaxpr, consts=consts, window_dimensions=tuple(window_dimensions), - window_strides=tuple(window_strides), padding=padding) + window_strides=tuple(window_strides), padding=padding, + base_dilation=tuple(base_dilation), + window_dilation=tuple(window_dilation)) def _get_monoid_window_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]: aval = core.get_aval(x) @@ -1122,46 +1131,82 @@ def _get_monoid_window_reducer(monoid_op: Callable, x: Array) -> Optional[Callab def _reduce_window_sum(operand: Array, window_dimensions: Shape, window_strides: Sequence[int], - padding: Sequence[Tuple[int, int]]) -> Array: + padding: Sequence[Tuple[int, int]], + base_dilation: Optional[Sequence[int]] = None, + window_dilation: Optional[Sequence[int]] = None) -> Array: + if base_dilation is None: + base_dilation = (1,) * len(window_dimensions) + if window_dilation is None: + window_dilation = (1,) * len(window_dimensions) return reduce_window_sum_p.bind( operand, window_dimensions=tuple(window_dimensions), - window_strides=tuple(window_strides), padding=tuple(padding)) + window_strides=tuple(window_strides), padding=tuple(padding), + base_dilation=tuple(base_dilation), + window_dilation=tuple(window_dilation)) def _reduce_window_prod(operand: Array, window_dimensions: Shape, window_strides: Sequence[int], - padding: Sequence[Tuple[int, int]]) -> Array: + padding: Sequence[Tuple[int, int]], + base_dilation: Optional[Sequence[int]] = None, + window_dilation: Optional[Sequence[int]] = None) -> Array: init_value = _const(operand, 1) jaxpr, consts = _reduction_jaxpr(mul, _abstractify(init_value)) + if base_dilation is None: + base_dilation = (1,) * len(window_dimensions) + if window_dilation is None: + window_dilation = (1,) * len(window_dimensions) return reduce_window_p.bind( operand, init_value, jaxpr=jaxpr, consts=consts, window_dimensions=tuple(window_dimensions), - window_strides=tuple(window_strides), padding=tuple(padding)) + window_strides=tuple(window_strides), padding=tuple(padding), + base_dilation=tuple(base_dilation), + window_dilation=tuple(window_dilation)) def _reduce_window_max(operand: Array, window_dimensions: Shape, window_strides: Sequence[int], - padding: Sequence[Tuple[int, int]]) -> Array: + padding: Sequence[Tuple[int, int]], + base_dilation: Optional[Sequence[int]] = None, + window_dilation: Optional[Sequence[int]] = None) -> Array: + if base_dilation is None: + base_dilation = (1,) * len(window_dimensions) + if window_dilation is None: + window_dilation = (1,) * len(window_dimensions) return reduce_window_max_p.bind( operand, window_dimensions=tuple(window_dimensions), - window_strides=tuple(window_strides), padding=tuple(padding)) + window_strides=tuple(window_strides), padding=tuple(padding), + base_dilation=tuple(base_dilation), + window_dilation=tuple(window_dilation)) def _reduce_window_min(operand: Array, window_dimensions: Shape, window_strides: Sequence[int], - padding: Sequence[Tuple[int, int]]) -> Array: + padding: Sequence[Tuple[int, int]], + base_dilation: Optional[Sequence[int]] = None, + window_dilation: Optional[Sequence[int]] = None) -> Array: + if base_dilation is None: + base_dilation = (1,) * len(window_dimensions) + if window_dilation is None: + window_dilation = (1,) * len(window_dimensions) return reduce_window_min_p.bind( operand, window_dimensions=tuple(window_dimensions), - window_strides=tuple(window_strides), padding=tuple(padding)) + window_strides=tuple(window_strides), padding=tuple(padding), + base_dilation=tuple(base_dilation), + window_dilation=tuple(window_dilation)) def _select_and_scatter(operand: Array, select: Callable, window_dimensions: Shape, window_strides: Sequence[int], padding: Sequence[Tuple[int, int]], source: Array, - init_value: Array, scatter: Callable) -> Array: + init_value: Array, scatter: Callable, + base_dilation: Sequence[int], + window_dilation: Sequence[int]) -> Array: select_jaxpr, select_consts = _reduction_jaxpr(select, _abstractify(init_value)) scatter_jaxpr, scatter_consts = _reduction_jaxpr(scatter, _abstractify(init_value)) return select_and_scatter_p.bind( operand, source, init_value, select_jaxpr=select_jaxpr, select_consts=select_consts, scatter_jaxpr=scatter_jaxpr, scatter_consts=scatter_consts, window_dimensions=tuple(window_dimensions), - window_strides=tuple(window_strides), padding=tuple(padding)) + window_strides=tuple(window_strides), padding=tuple(padding), + base_dilation=tuple(base_dilation), + window_dilation=tuple(window_dilation)) def _select_and_scatter_add(source: Array, operand: Array, select_prim: core.Primitive, @@ -1177,11 +1222,15 @@ def _select_and_gather_add(tangents: Array, operand: Array, select_prim: core.Primitive, window_dimensions: Shape, window_strides: Sequence[int], - padding: Sequence[Tuple[int, int]]) -> Array: + padding: Sequence[Tuple[int, int]], + base_dilation: Sequence[int], + window_dilation: Sequence[int]) -> Array: return select_and_gather_add_p.bind( tangents, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), - window_strides=tuple(window_strides), padding=tuple(padding)) + window_strides=tuple(window_strides), padding=tuple(padding), + base_dilation=tuple(base_dilation), + window_dilation=tuple(window_dilation)) def cumsum(operand: Array, axis: int) -> Array: """Computes a cumulative sum along `axis`.""" @@ -4490,38 +4539,43 @@ reduce_and_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bo batching.defreducer(reduce_and_p) def _reduce_window_shape_rule(operand, init_value, *, jaxpr, consts, - window_dimensions, window_strides, padding): + window_dimensions, window_strides, padding, + base_dilation, window_dilation): if operand.dtype != init_value.dtype: msg = ("reduce_window got inconsistent dtypes for operand and init_value: " " got operand dtype {} and init_value dtype {}.") raise TypeError(msg.format(operand.dtype, init_value.dtype)) - return _common_reduce_window_shape_rule(operand, window_dimensions, - window_strides, padding) + return _common_reduce_window_shape_rule( + operand, window_dimensions, window_strides, padding, base_dilation, + window_dilation) def _reduce_window_translation_rule(c, operand, init_value, *, jaxpr, consts, - window_dimensions, window_strides, padding): + window_dimensions, window_strides, padding, + base_dilation, window_dilation): xla_computation = _reduction_computation(c, jaxpr, consts, init_value) return xops.ReduceWindowWithGeneralPadding( operand, init_value, xla_computation, window_dimensions, - window_strides, (), (), padding) + window_strides, base_dilation, window_dilation, padding) def _generic_reduce_window_batch_rule( batched_args, batch_dims, *, jaxpr, consts, window_dimensions, - window_strides, padding): + window_strides, padding, base_dilation, window_dilation): operand, init = batched_args bdim, init_bdim = batch_dims if init_bdim is not None: raise NotImplementedError("reduce_window batching is not implemented for " "initial values") - def reduce_window(x, window_dimensions, window_strides, padding): + def reduce_window(x, window_dimensions, window_strides, padding, base_dilation, + window_dilation): return reduce_window_p.bind( x, init, jaxpr=jaxpr, consts=consts, window_dimensions=window_dimensions, - window_strides=window_strides, padding=padding) - return _reduce_window_batch_rule(reduce_window, (operand,), (bdim,), - window_dimensions=window_dimensions, - window_strides=window_strides, - padding=padding) + window_strides=window_strides, padding=padding, base_dilation=base_dilation, + window_dilation=window_dilation) + return _reduce_window_batch_rule( + reduce_window, (operand,), (bdim,), window_dimensions=window_dimensions, + window_strides=window_strides, padding=padding, base_dilation=base_dilation, + window_dilation=window_dilation) reduce_window_p = standard_primitive( @@ -4531,40 +4585,46 @@ batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides, - padding): + padding, base_dilation, window_dilation): if not dtypes.issubdtype(operand.dtype, np.number): msg = "operand to reduce_window_sum must have a number dtype, got {}" raise TypeError(msg.format(np.dtype(operand.dtype).name)) return _common_reduce_window_shape_rule(operand, window_dimensions, - window_strides, padding) + window_strides, padding, base_dilation, + window_dilation) def _reduce_window_sum_translation_rule(c, operand, *, window_dimensions, - window_strides, padding): + window_strides, padding, base_dilation, + window_dilation): dtype = c.get_shape(operand).numpy_dtype() scalar = ShapedArray((), dtype) return xops.ReduceWindowWithGeneralPadding( operand, xb.constant(c, np.array(0, dtype)), xla.primitive_subcomputation(add_p, scalar, scalar), window_dimensions, - window_strides, (), (), padding) + window_strides, base_dilation, window_dilation, padding) def _reduce_window_sum_transpose_rule(cotangent, operand, *, window_dimensions, - window_strides, padding): + window_strides, padding, base_dilation, + window_dilation): assert ad.is_undefined_primal(operand) input_shape = operand.aval.shape - ones = [1] * len(input_shape) pads = _conv_general_vjp_lhs_padding( input_shape, window_dimensions, window_strides, cotangent.shape, padding, - ones, ones) + base_dilation, window_dilation) + ones = [1] * len(input_shape) padding_config = [(lo, hi, stride - 1) for (lo, hi), stride in zip(pads, window_strides)] pad_cotangent = pad(cotangent, _zero(cotangent), padding_config) - result = _reduce_window_sum(pad_cotangent, window_dimensions, ones, - [(0, 0)] * len(input_shape)) - assert result.shape == input_shape + result = _reduce_window_sum(pad_cotangent, window_dimensions, base_dilation, + [(0, 0)] * len(input_shape), + base_dilation=ones, + window_dilation=window_dilation) + assert result.shape == input_shape, (result.shape, input_shape) return [result] def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *, - window_dimensions, window_strides, padding): + window_dimensions, window_strides, padding, + base_dilation, window_dilation): operand, = batched_args bdim, = bdims @@ -4573,7 +4633,11 @@ def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *, window_dimensions[:bdim] + (1,) + window_dimensions[bdim:] window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:] padding = padding[:bdim] + ((0, 0),) + padding[bdim:] - operand = reduce_window(operand, window_dimensions, window_strides, padding) + base_dilation = base_dilation[:bdim] + (1,) + base_dilation[bdim:] + window_dilation = window_dilation[:bdim] + (1,) + window_dilation[bdim:] + + operand = reduce_window(operand, window_dimensions, window_strides, padding, + base_dilation, window_dilation) return operand, bdim reduce_window_sum_p = standard_primitive( @@ -4584,26 +4648,32 @@ batching.primitive_batchers[reduce_window_sum_p] = partial( _reduce_window_batch_rule, _reduce_window_sum) def _reduce_window_chooser_translation_rule( - prim, identity, c, operand, *, window_dimensions, window_strides, padding): + prim, identity, c, operand, *, window_dimensions, window_strides, padding, + base_dilation, window_dilation): dtype = c.get_shape(operand).numpy_dtype() scalar = ShapedArray((), dtype) return xops.ReduceWindowWithGeneralPadding( operand, xb.constant(c, identity(dtype)), xla.primitive_subcomputation(prim, scalar, scalar), window_dimensions, - window_strides, (), (), padding) + window_strides, base_dilation, window_dilation, padding) def _reduce_window_chooser_jvp_rule(prim, g, operand, *, window_dimensions, - window_strides, padding): + window_strides, padding, base_dilation, + window_dilation): assert prim is max_p or prim is min_p select_prim = ge_p if prim is max_p else le_p return _select_and_gather_add(g, operand, select_prim, window_dimensions, - window_strides, padding) + window_strides, padding, base_dilation, + window_dilation) def _common_reduce_window_shape_rule(operand, window_dimensions, - window_strides, padding): + window_strides, padding, base_dilation, + window_dilation): _check_shapelike("reduce_window", "window_dimensions", window_dimensions) _check_shapelike("reduce_window", "window_strides", window_strides) + _check_shapelike("reduce_window", "base_dilation", base_dilation) + _check_shapelike("reduce_window", "window_dilation", window_dilation) if operand.ndim != len(window_dimensions): msg = ("reduce_window got the wrong number of window_dimensions for " "operand: got operand shape {} with window_dimensions {}.") @@ -4612,12 +4682,27 @@ def _common_reduce_window_shape_rule(operand, window_dimensions, msg = ("reduce_window got inconsistent window_strides and " "window_dimensions: got window_strides {} and window_dimensions {}.") raise TypeError(msg.format(window_strides, window_dimensions)) + if len(base_dilation) != len(window_dimensions): + msg = ("reduce_window got inconsistent base_dilation and " + "window_dimensions: got base_dilation {} and window_dimensions {}.") + raise TypeError(msg.format(base_dilation, window_dimensions)) + if len(window_dilation) != len(window_dimensions): + msg = ("reduce_window got inconsistent window_dilation and " + "window_dimensions: got window_dilation {} and window_dimensions " + "{}.") + raise TypeError(msg.format(window_dilation, window_dimensions)) return reduce_window_shape_tuple(operand.shape, window_dimensions, - window_strides, padding) + window_strides, padding, base_dilation, + window_dilation) def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, - padding): + padding, base_dilation=None, + window_dilation=None): + if base_dilation is not None: + operand_shape = _dilate_shape(operand_shape, base_dilation) + if window_dilation is not None: + window_dimensions = _dilate_shape(window_dimensions, window_dilation) operand_padded = np.add(operand_shape, np.add(*zip(*padding))) t = np.floor_divide( np.subtract(operand_padded, window_dimensions), window_strides) + 1 @@ -4708,8 +4793,9 @@ def _select_and_scatter_add_transpose( t, source, operand, *, select_prim, window_dimensions, window_strides, padding): assert ad.is_undefined_primal(source) and not ad.is_undefined_primal(operand) + ones = (1,) * len(window_dimensions) source_t = _select_and_gather_add(t, operand, select_prim, window_dimensions, - window_strides, padding) + window_strides, padding, ones, ones) return [source_t, None] def _select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs): @@ -4753,13 +4839,14 @@ batching.primitive_batchers[select_and_scatter_add_p] = \ def _select_and_gather_add_shape_rule( tangents, operand, *, select_prim, window_dimensions, window_strides, - padding): + padding, base_dilation, window_dilation): if tangents.shape != operand.shape: msg = ("select_and_gather_add tangents and operand shapes must match, " "got {} and {}.") raise TypeError(msg.format(tangents.shape, operand.shape)) - return _common_reduce_window_shape_rule(operand, window_dimensions, - window_strides, padding) + return _common_reduce_window_shape_rule( + operand, window_dimensions, window_strides, padding, base_dilation, + window_dilation) _UINT_DTYPES = { @@ -4776,7 +4863,7 @@ _INT_DTYPES = { def _select_and_gather_add_translation( c, tangents, operand, *, select_prim, window_dimensions, window_strides, - padding, max_bits=64): + padding, base_dilation, window_dilation, max_bits=64): shape = c.get_shape(operand) dtype = shape.numpy_dtype() etype = shape.xla_element_type() @@ -4867,37 +4954,52 @@ def _select_and_gather_add_translation( init = -np.inf if select_prim is ge_p else np.inf out = xops.ReduceWindowWithGeneralPadding( pack(operand, tangents), pack(const(c, dtype, init), const(c, dtype, 0)), - reducer(), window_dimensions, window_strides, (), (), padding) + reducer(), window_dimensions, window_strides, base_dilation, + window_dilation, padding) return snd(out) def _select_and_gather_add_jvp( primals, tangents, *, select_prim, window_dimensions, window_strides, - padding): + padding, base_dilation, window_dilation): source, operand = primals g_source, g_operand = tangents val_out = _select_and_gather_add( source, operand, select_prim, window_dimensions, window_strides, - padding) + padding, base_dilation, window_dilation) del g_operand if type(g_source) is ad_util.Zero: tangent_out = ad_util.Zero.from_value(val_out) else: tangent_out = _select_and_gather_add( g_source, operand, select_prim, window_dimensions, - window_strides, padding) + window_strides, padding, base_dilation, window_dilation) return val_out, tangent_out def _select_and_gather_add_transpose( t, tangents, operand, *, select_prim, window_dimensions, window_strides, - padding): + padding, base_dilation, window_dilation): + assert select_prim in (le_p, ge_p) assert ad.is_undefined_primal(tangents) and not ad.is_undefined_primal(operand) + if any(d != 1 for d in window_dilation): + msg = ("VJP not implemented for select_and_gather (MaxPool) with window " + "dilation, got window_dilation={}.") + raise NotImplementedError(msg.format(window_dilation)) + has_base_dilation = any(d != 1 for d in base_dilation) + if has_base_dilation: + select_identity = (_get_max_identity if select_prim is ge_p + else _get_min_identity) + operand = pad(operand, select_identity(operand.dtype), + tuple((0, 0, d - 1) for d in base_dilation)) result = _select_and_scatter_add(t, operand, select_prim, window_dimensions, window_strides, padding) + if has_base_dilation: + result = slice(operand, (0,) * len(operand.shape), operand.shape, + base_dilation) return [result, None] def _select_and_gather_add_batching_rule( batched_args, batch_dims, *, select_prim, window_dimensions, window_strides, - padding): + padding, base_dilation, window_dilation): t, x = batched_args t_bdim, x_bdim = batch_dims size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims) @@ -4907,8 +5009,11 @@ def _select_and_gather_add_batching_rule( window_dimensions = (1,) + window_dimensions window_strides = (1,) + window_strides padding = ((0, 0),) + padding + base_dilation = (1,) + base_dilation + window_dilation = (1,) + window_dilation out = _select_and_gather_add(t, x, select_prim, window_dimensions, - window_strides, padding) + window_strides, padding, base_dilation, + window_dilation) return (out, 0) diff --git a/jax/lax_reference.py b/jax/lax_reference.py index a877b24c6..c9616c28c 100644 --- a/jax/lax_reference.py +++ b/jax/lax_reference.py @@ -287,13 +287,16 @@ def reduce(operand, init_value, computation, dimensions): # pylint: disable=red return reducer(operand, tuple(dimensions)).astype(np.asarray(operand).dtype) def reduce_window(operand, init_value, computation, window_dimensions, - window_strides, padding): + window_strides, padding, base_dilation): op, dims, strides = operand, window_dimensions, window_strides if isinstance(padding, str): pads = padtype_to_pads(op.shape, dims, strides, padding) else: pads = padding - view = _conv_view(op.reshape((1, 1) + op.shape), (1, 1) + dims, strides, pads, + op = op.reshape((1, 1) + op.shape) + if base_dilation: + op = _dilate(op, base_dilation, init_value) + view = _conv_view(op, (1, 1) + dims, strides, pads, pad_value=init_value)[0] view = view.reshape(view.shape[1:1+len(dims)] + (-1,)) reducer = _make_reducer(computation, init_value) @@ -364,13 +367,13 @@ def _pad(arr, pads, pad_value): for (lo, hi), dim in zip(pads, np.shape(arr))) return out[slices] -def _dilate(operand, factors): +def _dilate(operand, factors, fill_value=0): # this logic is like lax.pad, but with two leading dimensions, no edge # padding, and factors are at least 1 (interior padding is at least 0) outspace = np.add(operand.shape[2:], np.multiply(np.subtract(factors, 1), np.subtract(operand.shape[2:], 1))) - out = np.zeros(operand.shape[:2] + tuple(outspace), operand.dtype) + out = np.full(operand.shape[:2] + tuple(outspace), fill_value, operand.dtype) lhs_slices = tuple(_slice(None, None, step) for step in factors) out[(_slice(None),) * 2 + lhs_slices] = operand return out diff --git a/jax/test_util.py b/jax/test_util.py index 99dea1dae..03a41bf88 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -181,8 +181,16 @@ def inner_prod(xs, ys): return tree_reduce(np.add, tree_multimap(contract, xs, ys)) +def _safe_subtract(x, y, *, dtype): + """Subtraction that with `inf - inf == 0` semantics.""" + with np.errstate(invalid='ignore'): + return np.where(np.equal(x, y), np.array(0, dtype), + np.subtract(x, y, dtype=dtype)) + add = partial(tree_multimap, lambda x, y: np.add(x, y, dtype=_dtype(x))) sub = partial(tree_multimap, lambda x, y: np.subtract(x, y, dtype=_dtype(x))) +safe_sub = partial(tree_multimap, + lambda x, y: _safe_subtract(x, y, dtype=_dtype(x))) conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x))) def scalar_mul(xs, a): @@ -203,7 +211,7 @@ def numerical_jvp(f, primals, tangents, eps=EPS): delta = scalar_mul(tangents, eps) f_pos = f(*add(primals, delta)) f_neg = f(*sub(primals, delta)) - return scalar_mul(sub(f_pos, f_neg), 0.5 / eps) + return scalar_mul(safe_sub(f_pos, f_neg), 0.5 / eps) def _merge_tolerance(tol, default): diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 68aa59da4..0b14a53db 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -671,63 +671,74 @@ class LaxAutodiffTest(jtu.JaxTestCase): check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_dtype={}_padding={}" - .format(op.__name__, np.dtype(dtype).name, padding), - "op": op, "init_val": init_val, "dtype": dtype, "padding": padding, + {"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}" + "_basedilation={}_windowdilation={}") + .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, + strides, padding, base_dilation, window_dilation), + "op": op, "init_val": init_val, "dtype": dtype, "shape": shape, + "dims": dims, "strides": strides, "padding": padding, + "base_dilation": base_dilation, "window_dilation": window_dilation, "rng_factory": rng_factory} for init_val, op, dtypes, rng_factory in [ (0, lax.add, grad_float_dtypes, jtu.rand_small), (-np.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int), (np.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int), ] - for dtype in dtypes - for padding in ["VALID", "SAME"])) - @jtu.skip_on_flag("jax_skip_slow_tests", True) + for shape, dims, strides, padding, base_dilation, window_dilation in ( + itertools.chain( + itertools.product( + [(4, 6)], + [(2, 1), (1, 2)], + [(1, 1), (2, 1), (1, 2)], + # TODO(b/161704903): explicit paddings segfault on CPU. + ["VALID", "SAME"], #, [(0, 3), (1, 2)]], + [(1, 1)] + ([(2, 3)] if op is lax.add else []), + [(1, 1)] + ([(1, 2)] if op is lax.add else [])), + itertools.product( + [(3, 2, 4, 6)], + [(1, 1, 2, 1), (2, 1, 2, 1)], + [(1, 2, 2, 1), (1, 1, 1, 1)], + # TODO(b/161704903): explicit paddings segfault on CPU. + ["VALID", "SAME"], # [(0, 1), (1, 0), (2, 3), (0, 2)]], + [(1, 1, 1, 1)] + ([(2, 1, 3, 2)] if op is lax.add else []), + [(1, 1, 1, 1)] + ([(1, 2, 2, 1)] if op is lax.add else [])))) + for dtype in dtypes)) @jtu.ignore_warning(category=UserWarning, message="Using reduced precision for gradient.*") - def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory): + def testReduceWindowGrad( + self, op, init_val, dtype, shape, dims, strides, + padding, base_dilation, window_dilation, rng_factory): rng = rng_factory(self.rng()) init_val = np.asarray(init_val, dtype=dtype) + gradient_order = 3 # We need this conditional and the corresponding loop logic to be in the # test method, rather than at the parameterized test level, because it # depends on FLAGS for the device under test. # TODO(b/31565929): enable when fixed. if jtu.device_under_test() == "tpu" and op is not lax.add: - all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))] + if len(shape) != 4 or dims != (1, 1, 2, 1): + raise SkipTest("Only R4 SelectAndScatter implemented on TPU") # TODO(b/73062247): need variadic reduce-window for better precision. gradient_order = 1 - else: - all_configs = itertools.chain( - itertools.product( - [(4, 6)], # shapes - [(2, 1), (1, 2)], # window_dimensions - [(1, 1), (2, 1), (1, 2)] # strides - ), - itertools.product( - [(3, 2, 4, 6)], # shapes - [(1, 1, 2, 1), (2, 1, 2, 1)], # window_dimensions - [(1, 2, 2, 1), (1, 1, 1, 1)]), # strides - ) - gradient_order = 3 def fun(operand): - return lax.reduce_window(operand, init_val, op, dims, strides, padding) + return lax.reduce_window(operand, init_val, op, dims, strides, padding, + base_dilation, window_dilation) - for shape, dims, strides in all_configs: - operand = rng(shape, dtype) - if op is lax.add: - eps = 1. - tol = None - else: - # this test can fail if there are duplicates in operand - self.assertEqual(np.unique(operand).size, operand.size, - msg="test requires operand elements to be unique.") - eps = 1e-2 - tol = {np.float16: 1e-1, np.float32: 6e-2, np.float64: 6e-2} - check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol, - eps) + operand = rng(shape, dtype) + if op is lax.add: + eps = 1. + tol = None + else: + # this test can fail if there are duplicates in operand + self.assertEqual(np.unique(operand).size, operand.size, + msg="test requires operand elements to be unique.") + eps = 1e-2 + tol = {np.float16: 1e-1, np.float32: 6e-2, np.float64: 6e-2} + check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol, + eps) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_axis={}" diff --git a/tests/lax_test.py b/tests/lax_test.py index f381e25eb..9c329be54 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1302,56 +1302,60 @@ class LaxTest(jtu.JaxTestCase): self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_dtype={}" - .format(op.__name__, np.dtype(dtype).name,), - "op": op, "init_val": init_val, "dtype": dtype, - "rng_factory": rng_factory} + {"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}" + "_basedilation={}_windowdilation={}") + .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), + dims, strides, padding, base_dilation, window_dilation), + "op": op, "init_val": init_val, "dtype": dtype, "shape": shape, + "dims": dims, "strides": strides, "padding": padding, + "base_dilation": base_dilation, "window_dilation": window_dilation} for init_val, op, dtypes in [ (0, lax.add, [np.float32]), (-np.inf, lax.max, [np.float32]), (np.inf, lax.min, [np.float32]), ] - for dtype in dtypes - for rng_factory in [jtu.rand_small])) - def testReduceWindow(self, op, init_val, dtype, rng_factory): - rng = rng_factory(self.rng()) - init_val = np.asarray(init_val, dtype=dtype) - - all_configs = list(itertools.chain( - itertools.product( + for shape, dims, strides, padding, base_dilation, window_dilation in ( + itertools.chain( + itertools.product( [(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)], - ["VALID", "SAME", [(0, 3), (1, 2)]]), - itertools.product( + ["VALID", "SAME", [(0, 3), (1, 2)]], + [(1, 1), (2, 3)], + [(1, 1), (1, 2)]), + itertools.product( [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)], - ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]]))) + ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], + [(1, 1, 1, 1), (2, 1, 3, 2)], + [(1, 1, 1, 1), (1, 2, 2, 1)]))) + for dtype in dtypes)) + def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding, + base_dilation, window_dilation): + rng = jtu.rand_small(self.rng()) + init_val = np.asarray(init_val, dtype=dtype) def fun(operand, init_val): - return lax.reduce_window(operand, init_val, op, dims, strides, padding) + return lax.reduce_window(operand, init_val, op, dims, strides, padding, + base_dilation, window_dilation) def reference_fun(operand, init_val): return lax_reference.reduce_window(operand, init_val, op, dims, strides, - padding) + padding, base_dilation) - # pylint: disable=cell-var-from-loop - for shape, dims, strides, padding in all_configs: - args_maker = lambda: [rng(shape, dtype), init_val] - self._CompileAndCheck(fun, args_maker) + args_maker = lambda: [rng(shape, dtype), init_val] + self._CompileAndCheck(fun, args_maker) + if all(d == 1 for d in window_dilation): self._CheckAgainstNumpy(fun, reference_fun, args_maker) - # pylint: enable=cell-var-from-loop # we separately test the version that uses a concrete init_val because it # can hit different code paths def fun(operand): - return lax.reduce_window(operand, init_val, op, dims, strides, padding) + return lax.reduce_window(operand, init_val, op, dims, strides, padding, + base_dilation, window_dilation) - # pylint: disable=cell-var-from-loop - for shape, dims, strides, padding in all_configs: - args_maker = lambda: [rng(shape, dtype)] - self._CompileAndCheck(fun, args_maker) - # pylint: enable=cell-var-from-loop + args_maker = lambda: [rng(shape, dtype)] + self._CompileAndCheck(fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_axis={}" diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 91eb0ce90..97b245f03 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -500,37 +500,45 @@ class LaxVmapTest(jtu.JaxTestCase): self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_dtype={}_padding={}" - .format(op.__name__, np.dtype(dtype).name, padding), - "op": op, "init_val": init_val, "dtype": dtype, "padding": padding, - "rng_factory": rng_factory} + {"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}" + "_basedilation={}_windowdilation={}") + .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), + dims, strides, padding, base_dilation, window_dilation), + "op": op, "init_val": init_val, "dtype": dtype, "shape": shape, + "dims": dims, "strides": strides, "padding": padding, + "base_dilation": base_dilation, "window_dilation": window_dilation} for init_val, op, dtypes in [ (0, lax.add, [np.float32]), (-np.inf, lax.max, [np.float32]), (np.inf, lax.min, [np.float32]), ] - for dtype in dtypes - for padding in ["VALID", "SAME"] - for rng_factory in [jtu.rand_small])) - def testReduceWindow(self, op, init_val, dtype, padding, rng_factory): - rng = rng_factory(self.rng()) - init_val = np.asarray(init_val, dtype=dtype) - - all_configs = itertools.chain( - itertools.product( + for shape, dims, strides, padding, base_dilation, window_dilation in ( + itertools.chain( + itertools.product( [(4, 6)], [(2, 1), (1, 2)], - [(1, 1), (2, 1), (1, 2)]), - itertools.product( + [(1, 1), (2, 1), (1, 2)], + ["VALID", "SAME", [(0, 3), (1, 2)]], + [(1, 1), (2, 3)], + [(1, 1), (1, 2)]), + itertools.product( [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], - [(1, 2, 2, 1), (1, 1, 1, 1)])) + [(1, 2, 2, 1), (1, 1, 1, 1)], + ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], + [(1, 1, 1, 1), (2, 1, 3, 2)], + [(1, 1, 1, 1), (1, 2, 2, 1)]))) + for dtype in dtypes)) + def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding, + base_dilation, window_dilation): + rng = jtu.rand_small(self.rng()) + init_val = np.asarray(init_val, dtype=dtype) def fun(operand): - return lax.reduce_window(operand, init_val, op, dims, strides, padding) + return lax.reduce_window(operand, init_val, op, dims, strides, padding, + base_dilation, window_dilation) - for shape, dims, strides in all_configs: - for bdims in all_bdims(shape): - self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng) + for bdims in all_bdims(shape): + self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_axis={}_bdims={}" @@ -578,8 +586,9 @@ class LaxVmapTest(jtu.JaxTestCase): def fun(operand, tangents): pads = lax.padtype_to_pads(operand.shape, dims, strides, padding) + ones = (1,) * len(operand.shape) return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims, - strides, pads) + strides, pads, ones, ones) for shape, dims, strides in all_configs: for bdims in all_bdims(shape, shape):