From 0beef34d2523013c03fec6b97d98dbe3b3b69f87 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 12 Jul 2021 01:11:17 -0700 Subject: [PATCH] [jax2tf] Fix conversion for argmin/argmax; add conversion for reduce The previous conversion for argmin/argmax simply used tf.argmin and tf.argmax. Those ops behave differently than JAX when the inputs contain NaN and Inf. Added a few test cases in primitive_harness to expose the failures. In order to implement an accurate conversion of argmin/argmax, we need to use the XLA Reduce op. Also tightened the shape checks for lax.argmin and lax.argmax, to ensure they are not used with an empty reduced dimension. E.g., if the axis=-1, previously we got an internal error: ``` RuntimeError: Invalid argument: Reducing out-of-bounds dimension -1 in shape f32[2,0,3].: This is a bug in JAX's shape-checking rules; please report it! ``` PiperOrigin-RevId: 384182794 --- CHANGELOG.md | 5 + jax/_src/lax/lax.py | 65 +++++++------ jax/experimental/jax2tf/README.md | 1 + .../g3doc/primitives_with_limited_support.md | 15 ++- ...rimitives_with_limited_support.md.template | 3 + jax/experimental/jax2tf/jax2tf.py | 77 +++++++++++++-- .../jax2tf/tests/jax2tf_limitations.py | 37 ++++++-- .../jax2tf/tests/primitive_harness.py | 94 ++++++++++++++++--- .../jax2tf/tests/primitives_test.py | 2 +- .../jax2tf/tests/shape_poly_test.py | 14 +++ tests/lax_test.py | 18 ++++ 11 files changed, 266 insertions(+), 65 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d16866c95..36359c9ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.2.18 (unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...main). +* Bug fixes: + * Tightened the checks for lax.argmin and lax.argmax to ensure they are + not used with invalid `axis` value, or with an empty reduction dimension. + ({jax-issue}`#7196`) + ## jaxlib 0.1.69 (unreleased) ## jax 0.2.17 (July 9 2021) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 97f29904e..8b750f655 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5488,6 +5488,11 @@ _masking_defreducer(reduce_min_p, def _argminmax_shape_rule(operand, *, axes, index_dtype): axis, = axes + if not (0 <= axis < len(operand.shape)): + raise ValueError(f"Invalid axis {axis} for operand shape {operand.shape}") + if not core.greater_equal_dim(operand.shape[axis], 1): + raise ValueError("argmin and argmax require non-empty reduced dimension. " + f"operand.shape={operand.shape} axis={axis}") return tuple(np.delete(operand.shape, axis)) def _argminmax_dtype_rule(operand, *, axes, index_dtype): @@ -5496,34 +5501,29 @@ def _argminmax_dtype_rule(operand, *, axes, index_dtype): .format(np.dtype(index_dtype).name)) return index_dtype -def _argminmax_translation_rule(value_comparator, identity, - c, operand, *, axes, index_dtype): +def _compute_argminmax(value_comparator, get_identity, + operand, *, index_dtype, axes): + # value_comparator is either lax.lt (for argmin) or lax.gt + # get_identity(operand.dtype) is inf for argmin or -inf for argmax axis, = axes - shape = c.get_shape(operand) - dtype = shape.numpy_dtype() - - subc = xb.make_computation_builder("argminmax_comparator") - value_shape = xc.Shape.array_shape(shape.xla_element_type(), ()) - index_shape = xc.Shape.array_shape(index_dtype, ()) - x_value = xb.parameter(subc, 0, value_shape) - x_index = xb.parameter(subc, 1, index_shape) - y_value = xb.parameter(subc, 2, value_shape) - y_index = xb.parameter(subc, 3, index_shape) - which_value = xops.Or(value_comparator(x_value, y_value), - xops.Ne(x_value, x_value)) - which_index = xops.Or(which_value, xops.And(xops.Eq(x_value, y_value), - xops.Lt(x_index, y_index))) - xops.Tuple(subc, [xops.Select(which_value, x_value, y_value), - xops.Select(which_index, x_index, y_index)]) - comparator = subc.build() - - iota_shape = xc.Shape.array_shape(index_dtype, shape.dimensions()) - iota = xc.ops.Iota(c, iota_shape, axis) - out = xops.Reduce( - c, [operand, iota], - [xb.constant(c, identity(dtype)), - xb.constant(c, np.array(0, index_dtype))], comparator, [axis]) - return xops.GetTupleElement(out, 1) + indices = broadcasted_iota(index_dtype, np.shape(operand), axis) + def reducer_fn(op_val_index, acc_val_index): + op_val, op_index = op_val_index + acc_val, acc_index = acc_val_index + # Pick op_val if Lt (for argmin) or if NaN + pick_op_val = bitwise_or(value_comparator(op_val, acc_val), + ne(op_val, op_val)) + # If x and y are not NaN and x = y, then pick the first + pick_op_index = bitwise_or(pick_op_val, + bitwise_and(eq(op_val, acc_val), + lt(op_index, acc_index))) + return (select(pick_op_val, op_val, acc_val), + select(pick_op_index, op_index, acc_index)) + res = reduce([operand, indices], + [get_identity(operand.dtype), np.array(0, index_dtype)], + reducer_fn, + axes) + return res[1] def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype): axis, = axes @@ -5534,10 +5534,13 @@ def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype): mask_idxs = select(eq(a, maxvals) | ne(a, a), idxs, maxval) return _reduce_min(mask_idxs, (axis,)) -_argmin_translation_rule = partial(_argminmax_translation_rule, xops.Lt, - _get_min_identity) -_argmax_translation_rule = partial(_argminmax_translation_rule, xops.Gt, - _get_max_identity) +_argmin_translation_rule = xla.lower_fun( + partial(_compute_argminmax, lt, _get_min_identity), + multiple_results=False) + +_argmax_translation_rule = xla.lower_fun( + partial(_compute_argminmax, gt, _get_max_identity), + multiple_results=False) argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmin', _argmin_translation_rule, diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 481f891ae..aa4f7de8c 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -786,6 +786,7 @@ We use the following XLA TF ops: * `XlaReduceWindow` (wraps XLA ReduceWindow operator). These are used for `lax.reduce_window_sum_p`, `lax.reduce_window_min_p`, `lax.reduce_window_max_p`, and `lax.reduce_window_p`. + * `XlaVariadicReduceV2` (for `lax.reduce`, `lax.argmin`, `lax.argmax`). * `XlaVariadicSort` (wraps XLA Sort operator). ### Different performance characteristics diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md index 2f9c491ad..80b700bd2 100644 --- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md +++ b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md @@ -36,6 +36,9 @@ Our priority is to ensure same coverage and numerical behavior with JAX in the "compiled" mode, i.e., **when using XLA to compile the converted program**. We are pretty close to that goal. +The converter has a mode in which it attempts to avoid special XLA TF ops +(`enable_xla=False`). In this mode, some primitives have additional limitations. + This table only shows errors for cases that are working in JAX (see [separate list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) @@ -126,27 +129,29 @@ with jax2tf. The following table lists that cases when this does not quite hold: | Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes | | --- | --- | --- | --- | --- | | acosh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph | +| argmax | Numeric comparison disabled: different results when the input contains NaN and enable_xla=False | inexact | cpu, gpu, tpu | compiled, eager, graph | +| argmin | Numeric comparison disabled: different results when the input contains NaN and enable_xla=False | inexact | cpu, gpu, tpu | compiled, eager, graph | | asin | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph | | asinh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph | | atan | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph | | atanh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph | | cholesky | May return different values in the strictly upper triangular part of the result. This does not matter for correctness, because this part of the matrix is not considered in the result. | all | cpu, gpu, tpu | compiled, eager, graph | -| custom_linear_solve | Numeric comparision disabled: TODO: large numerical discrepancy | float32 | tpu | compiled, eager, graph | +| custom_linear_solve | Numeric comparison disabled: TODO: large numerical discrepancy | float32 | tpu | compiled, eager, graph | | digamma | May return different results at singularity points 0 and -1.JAX returns nan and TF returns inf | bfloat16 | cpu, gpu, tpu | eager, graph | | eig | May return the eigenvalues and eigenvectors in a potentially different order. The eigenvectors may also be different, but equally valid. | all | cpu, gpu, tpu | eager, graph | | eigh | May return the eigenvalues and eigenvectors in a potentially different order. The eigenvectors may also be different, but equally valid. | all | cpu, gpu, tpu | compiled, eager, graph | -| eigh | Numeric comparision disabled: TODO: numeric discrepancies | float16 | tpu | compiled, eager, graph | +| eigh | Numeric comparison disabled: TODO: numeric discrepancies | float16 | tpu | compiled, eager, graph | | erf_inv | May return different results at undefined points (< -1 or > 1): JAX returns `NaN` and TF returns `+inf` or `-inf`. | float32, float64 | cpu, gpu, tpu | eager, graph | | igamma | May return different results at undefined points (both arguments 0). JAX returns `NaN` and TF returns 0 or JAX returns 1 and TF returns `NaN` | all | cpu, gpu, tpu | eager, graph | | igammac | May return different results at undefined points (both arguments less or equal 0). JAX returns `NaN` and TF returns 0 or JAX returns 1 and TF returns `NaN` | all | cpu, gpu | eager, graph | -| integer_pow | Numeric comparision disabled: Different overflow behavior for large exponents. | bfloat16, complex, float16, float32, signed | cpu, gpu, tpu | eager, graph | -| integer_pow | Numeric comparision disabled: Different overflow behavior. | bfloat16, float16 | tpu | eager, graph | +| integer_pow | Numeric comparison disabled: Different overflow behavior for large exponents. | bfloat16, complex, float16, float32, signed | cpu, gpu, tpu | eager, graph | +| integer_pow | Numeric comparison disabled: Different overflow behavior. | bfloat16, float16 | tpu | eager, graph | | integer_pow | custom numeric comparison | complex | cpu, gpu, tpu | eager, graph | | lu | May return different, but also correct, results when the decomposition is not unique | all | cpu, gpu | compiled, eager, graph | | max | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph | | min | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph | | pow | custom numeric comparison | complex | cpu, gpu, tpu | eager, graph | -| sort | Numeric comparision disabled: TODO: TF non-stable multiple-array sort | all | gpu | compiled, eager, graph | +| sort | Numeric comparison disabled: TODO: TF non-stable multiple-array sort | all | gpu | compiled, eager, graph | | svd | custom numeric comparison when compute_uv | all | cpu, gpu | compiled, eager, graph | | top_k | Produces different results when the array contains `inf` and `NaN` (they are sorted differently in TF vs. XLA). | floating | cpu, gpu, tpu | eager, graph | diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template index ca4251930..bf5dc41d8 100644 --- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template +++ b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template @@ -36,6 +36,9 @@ Our priority is to ensure same coverage and numerical behavior with JAX in the "compiled" mode, i.e., **when using XLA to compile the converted program**. We are pretty close to that goal. +The converter has a mode in which it attempts to avoid special XLA TF ops +(`enable_xla=False`). In this mode, some primitives have additional limitations. + This table only shows errors for cases that are working in JAX (see [separate list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 8844665c2..d7cd201ab 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -967,7 +967,6 @@ for unexpected in xla.call_translations: # Call primitives are inlined # Primitives that are not yet implemented must be explicitly declared here. tf_not_yet_impl = [ - "reduce", "rng_uniform", "clz", "igamma_grad_a", @@ -1843,18 +1842,42 @@ tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any) tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all) -def _argminmax(fn, operand, axes, index_dtype): +def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int], + index_dtype: DType, + _in_avals: Sequence[core.AbstractValue], + _out_aval: core.AbstractValue): + if _thread_local_state.enable_xla: + # Follow the JAX implementation, using a XlaReduce with a custom comparator + if is_min: + extra_name_stack = "argmin" + value_comparator = lax.lt + get_identity = lax._get_min_identity + else: + extra_name_stack = "argmax" + value_comparator = lax.gt + get_identity = lax._get_max_identity + + res = _convert_jax_impl( + partial(lax._compute_argminmax, value_comparator, get_identity), + multiple_results=False, extra_name_stack=extra_name_stack)( + operand, index_dtype=index_dtype, axes=axes, + _in_avals=_in_avals, _out_aval=_out_aval) + return res + + # The following is known to diverge from JAX behavior for NaN. axis, = axes output_type = tf.int32 if dtypes.iinfo(index_dtype).bits > 32: output_type = tf.int64 # TODO(phawkins): handle axes larger than 2^31. + fn = tf.math.argmin if is_min else tf.math.argmax result = fn(operand, axis=axis, output_type=output_type) return tf.cast(result, _to_tf_dtype(index_dtype)) -tf_impl[lax.argmin_p] = partial(_argminmax, tf.math.argmin) -tf_impl[lax.argmax_p] = partial(_argminmax, tf.math.argmax) +tf_impl_with_avals[lax.argmin_p] = partial(_argminmax, True) +tf_impl_with_avals[lax.argmax_p] = partial(_argminmax, False) + _add_fn = tf.function(_add, autograph=False) _ge_fn = tf.function(tf.math.greater_equal, autograph=False) @@ -2156,18 +2179,56 @@ tf_impl_with_avals[lax.reduce_window_max_p] = ( tf_impl_with_avals[lax.reduce_window_p] = _reduce_window # pylint: enable=protected-access +def _reduce(*operands: TfVal, + computation: Callable, + jaxpr: core.Jaxpr, + consts: Sequence[Any], + dimensions: Sequence[int], + _in_avals: Sequence[core.AbstractValue], + _out_aval: core.AbstractValue) -> Sequence[TfVal]: + + if not _thread_local_state.enable_xla: + raise _xla_disabled_error("reduce") + del computation + assert not consts + assert len(operands) % 2 == 0 + # operands: op1, op2, ..., init_val1, init_val2, ... + # reducer takes op1[i], op2[i], ..., init_val1, init_val2, ... + nr_operands = len(operands) // 2 + init_vals = operands[nr_operands:] + operands = operands[0:nr_operands] + + reducer_arg_spec = tuple([tf.TensorSpec((), op.dtype) for op in init_vals] * 2) + + def reducer_computation(*args: TfVal) -> TfVal: + closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) + res = _interpret_jaxpr(closed_jaxpr, *args, extra_name_stack=None) + return res + + xla_reducer_computation = ( + tf.function(reducer_computation, + autograph=False).get_concrete_function(*reducer_arg_spec)) + + out = tfxla.variadic_reduce_v2(operands, init_vals, + dimensions_to_reduce=dimensions, + reducer=xla_reducer_computation) + return out + +tf_impl_with_avals[lax.reduce_p] = _reduce + + # We use lax_control_flow._cumred_tpu_translation_rule to convert cummax, # cummin, cumsum and cumprod. This is efficient on TPU, but the complexity is # O(n^2) on other backends. This may be implemented using associative_scan # instead to favor different backends. tf_impl_with_avals[lax_control_flow.cummin_p] = _convert_jax_impl( partial(lax_control_flow._cumred_tpu_translation_rule, - lax._reduce_window_min), + lax._reduce_window_min), multiple_results=False, extra_name_stack="cummin") tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl( partial(lax_control_flow._cumred_tpu_translation_rule, - lax._reduce_window_max), + lax._reduce_window_max), multiple_results=False, extra_name_stack="cummin") # TODO(bchetioui): cumsum and cumprod can be converted using pure TF ops for @@ -2177,12 +2238,12 @@ tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl( # tests will crash. tf_impl_with_avals[lax_control_flow.cumsum_p] = _convert_jax_impl( partial(lax_control_flow._cumred_tpu_translation_rule, - lax._reduce_window_sum), + lax._reduce_window_sum), multiple_results=False, extra_name_stack="cumsum") tf_impl_with_avals[lax_control_flow.cumprod_p] = _convert_jax_impl( partial(lax_control_flow._cumred_tpu_translation_rule, - lax._reduce_window_prod), + lax._reduce_window_prod), multiple_results=False, extra_name_stack="cumprod") diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index b7f2de3ee..964e5b40c 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -11,17 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""See primitives_test docstring for how the Jax2TfLimitations are used""" +"""See primitives_test docstring for how the Jax2TfLimitations are used.""" import itertools -import numpy as np from typing import Any, Callable, Optional, Sequence, Union -from jax._src import dtypes from jax import lax from jax import numpy as jnp - +from jax import test_util as jtu +from jax._src import dtypes from jax.experimental.jax2tf.tests import primitive_harness +import numpy as np DType = Any @@ -83,7 +83,7 @@ class Jax2TfLimitation(primitive_harness.Limitation): def get_max_tolerance_limitation( self, limitations: Sequence["Jax2TfLimitation"] ) -> Optional["Jax2TfLimitation"]: - """Pick the tolerance limitation that establishes the maximum tolerance""" + """Pick the tolerance limitation that establishes the maximum tolerance.""" # TODO: it would be best if the limitations with tolerance are mutually exclusive # and we don't have to compute the maximum # TODO: we made this an instance method only so that we don't have to import @@ -124,15 +124,17 @@ class Jax2TfLimitation(primitive_harness.Limitation): # We keep here the explicit set of groups for which we don't have limitations harness_groups_no_limitations = { - "abs", "add", "add_any", "and", "argmin", "argmax", "atan2", + "abs", "add", "add_any", "and", "atan2", "bitcast_convert_type", "broadcast", "broadcast_in_dim", "ceil", "clamp", "concatenate", "cos", "cosh", "complex", "conj", "convert_element_type", "cummax", "cummin", "device_put", "dynamic_slice", - "dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt", "imag", + "dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt", + "imag", "iota", "is_finite", "le", "lt", "log", "mul", "ne", "neg", "not", - "or", "pad", "population_count", "random_split", + "or", "pad", "population_count", "random_split", "reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum", - "reduce_window_add", "reduce_window_mul", "reduce_window_min", "reduce_window_max", + "reduce_window_add", "reduce_window_mul", "reduce_window_min", + "reduce_window_max", "real", "reshape", "rev", "rsqrt", "scatter_max", "scatter_min", "select", "select_and_scatter_add", "shift_left", "shift_right_logical", "shift_right_arithmetic", "sign", @@ -176,6 +178,23 @@ class Jax2TfLimitation(primitive_harness.Limitation): cls.helper_get_trig_custom_limitation(np.cosh) ] + @classmethod + def argmax(cls, harness: primitive_harness.Harness): + return [ + Jax2TfLimitation( + "different results when the input contains NaN and enable_xla=False", + dtypes=jtu.dtypes.all_inexact, + devices=("cpu", "gpu", "tpu"), + modes=("eager", "graph", "compiled"), + expect_tf_error=False, + skip_comparison=True, + enabled=("nan_" in harness.name and not harness.params["enable_xla"])), + ] + + @classmethod + def argmin(cls, harness: primitive_harness.Harness): + return cls.argmax(harness) + @classmethod def asin(cls, harness: primitive_harness.Harness): return [ diff --git a/jax/experimental/jax2tf/tests/primitive_harness.py b/jax/experimental/jax2tf/tests/primitive_harness.py index f308ebc24..9cd0826e2 100644 --- a/jax/experimental/jax2tf/tests/primitive_harness.py +++ b/jax/experimental/jax2tf/tests/primitive_harness.py @@ -477,7 +477,7 @@ for dtype in jtu.dtypes.all_floating: for rounding_method in [ lax.RoundingMethod.AWAY_FROM_ZERO, lax.RoundingMethod.TO_NEAREST_EVEN ]: - operand = np.array([[0.5, 1.5, 2.5], [-0.5, -1.5, -2.5]], dtype=np.float32) + operand = np.array([[0.5, 1.2, 1.5, 1.7, 2.5], [-0.5, -1.2, -1.5, -1.7, -2.5]], dtype=np.float32) _make_round_harness( "rounding_methods", operand=operand, rounding_method=rounding_method) @@ -793,19 +793,22 @@ def _make_argminmax_harness(prim, dtype=jnp.float32, axes=(0,), index_dtype=np.int32, - arr=None): + arr=None, + works_without_xla=True): arr = arr if arr is not None else RandArg(shape, dtype) dtype, shape = arr.dtype, arr.shape index_dtype = dtypes.canonicalize_dtype(index_dtype) - define( - prim, - f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_axes={axes}_indexdtype={index_dtype}", - lambda arg: prim.bind(arg, axes=axes, index_dtype=index_dtype), [arr], - shape=shape, - dtype=dtype, - axes=axes, - index_dtype=index_dtype, - prim=prim) + for enable_xla in [True, False]: + define( + prim, + f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_axes={axes}_indexdtype={index_dtype}_enablexla={enable_xla}", + lambda arg: prim.bind(arg, axes=axes, index_dtype=index_dtype), [arr], + shape=shape, + dtype=dtype, + axes=axes, + index_dtype=index_dtype, + prim=prim, + enable_xla=enable_xla) for prim in [lax.argmin_p, lax.argmax_p]: @@ -820,6 +823,22 @@ for prim in [lax.argmin_p, lax.argmax_p]: for index_dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned: _make_argminmax_harness(prim, "index_dtype", index_dtype=index_dtype) + # Some special cases, with equal elements and NaN + for name, operand in [ + ("nan_0", np.array([np.nan, np.nan, 2., -2., -np.nan, -np.nan], np.float32)), + ("nan_1", np.array([np.nan, -np.nan, 2., -2.], np.float32)), + ("inf_0", np.array([2., np.inf, np.inf, -2.], np.float32)), + ("inf_1", np.array([2., np.inf, -np.inf, -2.], np.float32)), + ("inf_2", np.array([2., -np.inf, np.inf, -2.], np.float32)), + ("inf_3", np.array([2., -np.inf, -np.inf, -2.], np.float32)), + ("nan_inf_0", np.array([2., np.nan, np.inf, -2.], np.float32)), + ("nan_inf_1", np.array([2., np.nan, -np.inf, -2.], np.float32)), + ("equal", np.array([2., 2., 2.], np.int32)), + ("singleton", np.array([1.], np.float32)), + ]: + _make_argminmax_harness(prim, f"special_{name}", shape=operand.shape, + arr=operand) + # TODO(bchetioui): the below documents a limitation of argmin and argmax when a # dimension of the input is too large. However, it is not categorizable as it # seems that the converter fails before reaching the actual primitive call. This @@ -2201,6 +2220,59 @@ for base_dilation, window_dilation in [ _make_select_and_gather_add_harness( "dilations", base_dilation=base_dilation, window_dilation=window_dilation) +def _make_reduce_harness(name, *, + shape=(4, 6), # The shape of all operands + nr_operands=1, # How many operands + computation=lax.add, # Takes Tuple(op1, [op2,]) and Tuple(init_val1, [init_val2]). Returns Tuple(out_val1, [out_val2]). + dimensions: Sequence[int] = (0,), + init_value=0, # The init value for first operand + dtype=np.float32): # The dtype of first operand + def reducer(*args): + init_val = np.array(init_value, dtype=dtype) + init_values = [init_val] + if nr_operands == 2: + init_values.append(np.int32(0.)) + return lax.reduce(args[0:nr_operands], tuple(init_values), + computation, dimensions) + define( + lax.reduce_p, + f"gen_{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_initvalue={init_value}_nr_operands={nr_operands}_dimensions={dimensions}".replace(" ", ""), + reducer, + [ + RandArg(shape, dtype), + # Second operand (optional, always i32). We cannot mix multiple float + # types in XLA. + RandArg(shape, np.int32), + ], + shape=shape, + dtype=dtype, + init_value=init_value, + computation=computation, + dimensions=dimensions) + +for dtype in jtu.dtypes.all: + for name, nr_operands, computation, init_value in [ + ("add_scalar", 1, + lambda ops, inits: (lax.add(ops[0], inits[0]),), 3), + # Compute the max (starting with 3) and the min (from 0), in parallel + ("max_min", 2, + lambda ops, inits: (lax.max(ops[0], inits[0]), + lax.min(ops[1], inits[1])), 3), + ]: + if not (dtype == np.bool_ and name == "add_scalar"): + _make_reduce_harness(name, nr_operands=nr_operands, + computation=computation, init_value=init_value, + dtype=dtype) + # Test the dimensions, but only for int32 (to keep the # of tests small) + if dtype == np.int32: + _make_reduce_harness(name, nr_operands=nr_operands, + computation=computation, init_value=init_value, + dimensions=(1,), + dtype=dtype) + _make_reduce_harness(name, nr_operands=nr_operands, + computation=computation, init_value=init_value, + dimensions=(0, 1), + dtype=dtype) def _make_reduce_window_harness(name, *, diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 711033bbb..2a31db582 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -182,7 +182,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): modes = ", ".join(sorted(l.modes)) description = l.description if l.skip_comparison: - description = "Numeric comparision disabled: " + description + description = "Numeric comparison disabled: " + description if l.expect_tf_error: description = "TF error: " + description if l.skip_tf_run: diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 5648a8f8e..8cc3b8c95 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -1116,6 +1116,20 @@ _POLY_SHAPE_TEST_HARNESSES = [ for enable_xla in [False, True]: _POLY_SHAPE_TEST_HARNESSES.extend([ + # Reduce the poly dimension + _make_harness("argmax", f"0_enable_xla={enable_xla}", + lambda op: lax.argmax(op, axis=0, index_dtype=np.int32), + [RandArg((3, 4, 5), _f32)], + poly_axes=[0], + enable_xla=enable_xla), + + # Reduce the non-poly dimension + _make_harness("argmax", f"1_enable_xla={enable_xla}", + lambda op: lax.argmax(op, axis=1, index_dtype=np.int32), + [RandArg((3, 4, 5), _f32)], + poly_axes=[0], + enable_xla=enable_xla), + _make_harness("dynamic_slice", f"enable_xla={enable_xla}", # x:shape: (b, 4) lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)), diff --git a/tests/lax_test.py b/tests/lax_test.py index a254e6590..58414f041 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2576,6 +2576,24 @@ class LazyConstantTest(jtu.JaxTestCase): "index_dtype must be an integer type"): jax_fn(np.ones((2, 2)), axis=0, index_dtype=index_dtype) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_fn={}".format(jax_fn.__name__), + "jax_fn": jax_fn} + for jax_fn in [lax.argmin, lax.argmax])) + def testArgMinMaxEmptyError(self, jax_fn): + with self.assertRaisesRegex(ValueError, + "require non-empty reduced dimension"): + jax_fn(np.ones((0, 2)), axis=0, index_dtype=np.int32) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_fn={}".format(jax_fn.__name__), + "jax_fn": jax_fn} + for jax_fn in [lax.argmin, lax.argmax])) + def testArgMinMaxInvalidAxisError(self, jax_fn): + with self.assertRaisesRegex(ValueError, + "Invalid axis -1 for operand"): + jax_fn(np.ones((2, 3)), axis=-1, index_dtype=np.int32) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_fn={}_weaktype={}".format(jax_fn.__name__, weak_type), "jax_fn": jax_fn, "weak_type": weak_type}