[jax2tf] Fix conversion for argmin/argmax; add conversion for reduce

The previous conversion for argmin/argmax simply used tf.argmin and tf.argmax.
Those ops behave differently than JAX when the inputs contain NaN and Inf. Added
a few test cases in primitive_harness to expose the failures.

In order to implement an accurate conversion of argmin/argmax, we need to use the
XLA Reduce op.

Also tightened the shape checks for lax.argmin and lax.argmax, to ensure they are
not used with an empty reduced dimension. E.g., if the axis=-1, previously we got
an internal error:
```
RuntimeError: Invalid argument: Reducing out-of-bounds dimension -1 in shape f32[2,0,3].:
This is a bug in JAX's shape-checking rules; please report it!
```
PiperOrigin-RevId: 384182794
This commit is contained in:
George Necula 2021-07-12 01:11:17 -07:00 committed by jax authors
parent a27047889d
commit 0beef34d25
11 changed files with 266 additions and 65 deletions

View File

@ -11,6 +11,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.18 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...main).
* Bug fixes:
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with invalid `axis` value, or with an empty reduction dimension.
({jax-issue}`#7196`)
## jaxlib 0.1.69 (unreleased)
## jax 0.2.17 (July 9 2021)

View File

@ -5488,6 +5488,11 @@ _masking_defreducer(reduce_min_p,
def _argminmax_shape_rule(operand, *, axes, index_dtype):
axis, = axes
if not (0 <= axis < len(operand.shape)):
raise ValueError(f"Invalid axis {axis} for operand shape {operand.shape}")
if not core.greater_equal_dim(operand.shape[axis], 1):
raise ValueError("argmin and argmax require non-empty reduced dimension. "
f"operand.shape={operand.shape} axis={axis}")
return tuple(np.delete(operand.shape, axis))
def _argminmax_dtype_rule(operand, *, axes, index_dtype):
@ -5496,34 +5501,29 @@ def _argminmax_dtype_rule(operand, *, axes, index_dtype):
.format(np.dtype(index_dtype).name))
return index_dtype
def _argminmax_translation_rule(value_comparator, identity,
c, operand, *, axes, index_dtype):
def _compute_argminmax(value_comparator, get_identity,
operand, *, index_dtype, axes):
# value_comparator is either lax.lt (for argmin) or lax.gt
# get_identity(operand.dtype) is inf for argmin or -inf for argmax
axis, = axes
shape = c.get_shape(operand)
dtype = shape.numpy_dtype()
subc = xb.make_computation_builder("argminmax_comparator")
value_shape = xc.Shape.array_shape(shape.xla_element_type(), ())
index_shape = xc.Shape.array_shape(index_dtype, ())
x_value = xb.parameter(subc, 0, value_shape)
x_index = xb.parameter(subc, 1, index_shape)
y_value = xb.parameter(subc, 2, value_shape)
y_index = xb.parameter(subc, 3, index_shape)
which_value = xops.Or(value_comparator(x_value, y_value),
xops.Ne(x_value, x_value))
which_index = xops.Or(which_value, xops.And(xops.Eq(x_value, y_value),
xops.Lt(x_index, y_index)))
xops.Tuple(subc, [xops.Select(which_value, x_value, y_value),
xops.Select(which_index, x_index, y_index)])
comparator = subc.build()
iota_shape = xc.Shape.array_shape(index_dtype, shape.dimensions())
iota = xc.ops.Iota(c, iota_shape, axis)
out = xops.Reduce(
c, [operand, iota],
[xb.constant(c, identity(dtype)),
xb.constant(c, np.array(0, index_dtype))], comparator, [axis])
return xops.GetTupleElement(out, 1)
indices = broadcasted_iota(index_dtype, np.shape(operand), axis)
def reducer_fn(op_val_index, acc_val_index):
op_val, op_index = op_val_index
acc_val, acc_index = acc_val_index
# Pick op_val if Lt (for argmin) or if NaN
pick_op_val = bitwise_or(value_comparator(op_val, acc_val),
ne(op_val, op_val))
# If x and y are not NaN and x = y, then pick the first
pick_op_index = bitwise_or(pick_op_val,
bitwise_and(eq(op_val, acc_val),
lt(op_index, acc_index)))
return (select(pick_op_val, op_val, acc_val),
select(pick_op_index, op_index, acc_index))
res = reduce([operand, indices],
[get_identity(operand.dtype), np.array(0, index_dtype)],
reducer_fn,
axes)
return res[1]
def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype):
axis, = axes
@ -5534,10 +5534,13 @@ def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype):
mask_idxs = select(eq(a, maxvals) | ne(a, a), idxs, maxval)
return _reduce_min(mask_idxs, (axis,))
_argmin_translation_rule = partial(_argminmax_translation_rule, xops.Lt,
_get_min_identity)
_argmax_translation_rule = partial(_argminmax_translation_rule, xops.Gt,
_get_max_identity)
_argmin_translation_rule = xla.lower_fun(
partial(_compute_argminmax, lt, _get_min_identity),
multiple_results=False)
_argmax_translation_rule = xla.lower_fun(
partial(_compute_argminmax, gt, _get_max_identity),
multiple_results=False)
argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
'argmin', _argmin_translation_rule,

View File

@ -786,6 +786,7 @@ We use the following XLA TF ops:
* `XlaReduceWindow` (wraps XLA ReduceWindow operator). These are used
for `lax.reduce_window_sum_p`, `lax.reduce_window_min_p`,
`lax.reduce_window_max_p`, and `lax.reduce_window_p`.
* `XlaVariadicReduceV2` (for `lax.reduce`, `lax.argmin`, `lax.argmax`).
* `XlaVariadicSort` (wraps XLA Sort operator).
### Different performance characteristics

View File

@ -36,6 +36,9 @@ Our priority is to ensure same coverage and numerical behavior with JAX
in the "compiled" mode, i.e., **when using XLA to compile the converted program**.
We are pretty close to that goal.
The converter has a mode in which it attempts to avoid special XLA TF ops
(`enable_xla=False`). In this mode, some primitives have additional limitations.
This table only shows errors for cases that are working in JAX (see [separate
list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) )
@ -126,27 +129,29 @@ with jax2tf. The following table lists that cases when this does not quite hold:
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
| --- | --- | --- | --- | --- |
| acosh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
| argmax | Numeric comparison disabled: different results when the input contains NaN and enable_xla=False | inexact | cpu, gpu, tpu | compiled, eager, graph |
| argmin | Numeric comparison disabled: different results when the input contains NaN and enable_xla=False | inexact | cpu, gpu, tpu | compiled, eager, graph |
| asin | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
| asinh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
| atan | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
| atanh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
| cholesky | May return different values in the strictly upper triangular part of the result. This does not matter for correctness, because this part of the matrix is not considered in the result. | all | cpu, gpu, tpu | compiled, eager, graph |
| custom_linear_solve | Numeric comparision disabled: TODO: large numerical discrepancy | float32 | tpu | compiled, eager, graph |
| custom_linear_solve | Numeric comparison disabled: TODO: large numerical discrepancy | float32 | tpu | compiled, eager, graph |
| digamma | May return different results at singularity points 0 and -1.JAX returns nan and TF returns inf | bfloat16 | cpu, gpu, tpu | eager, graph |
| eig | May return the eigenvalues and eigenvectors in a potentially different order. The eigenvectors may also be different, but equally valid. | all | cpu, gpu, tpu | eager, graph |
| eigh | May return the eigenvalues and eigenvectors in a potentially different order. The eigenvectors may also be different, but equally valid. | all | cpu, gpu, tpu | compiled, eager, graph |
| eigh | Numeric comparision disabled: TODO: numeric discrepancies | float16 | tpu | compiled, eager, graph |
| eigh | Numeric comparison disabled: TODO: numeric discrepancies | float16 | tpu | compiled, eager, graph |
| erf_inv | May return different results at undefined points (< -1 or > 1): JAX returns `NaN` and TF returns `+inf` or `-inf`. | float32, float64 | cpu, gpu, tpu | eager, graph |
| igamma | May return different results at undefined points (both arguments 0). JAX returns `NaN` and TF returns 0 or JAX returns 1 and TF returns `NaN` | all | cpu, gpu, tpu | eager, graph |
| igammac | May return different results at undefined points (both arguments less or equal 0). JAX returns `NaN` and TF returns 0 or JAX returns 1 and TF returns `NaN` | all | cpu, gpu | eager, graph |
| integer_pow | Numeric comparision disabled: Different overflow behavior for large exponents. | bfloat16, complex, float16, float32, signed | cpu, gpu, tpu | eager, graph |
| integer_pow | Numeric comparision disabled: Different overflow behavior. | bfloat16, float16 | tpu | eager, graph |
| integer_pow | Numeric comparison disabled: Different overflow behavior for large exponents. | bfloat16, complex, float16, float32, signed | cpu, gpu, tpu | eager, graph |
| integer_pow | Numeric comparison disabled: Different overflow behavior. | bfloat16, float16 | tpu | eager, graph |
| integer_pow | custom numeric comparison | complex | cpu, gpu, tpu | eager, graph |
| lu | May return different, but also correct, results when the decomposition is not unique | all | cpu, gpu | compiled, eager, graph |
| max | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph |
| min | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph |
| pow | custom numeric comparison | complex | cpu, gpu, tpu | eager, graph |
| sort | Numeric comparision disabled: TODO: TF non-stable multiple-array sort | all | gpu | compiled, eager, graph |
| sort | Numeric comparison disabled: TODO: TF non-stable multiple-array sort | all | gpu | compiled, eager, graph |
| svd | custom numeric comparison when compute_uv | all | cpu, gpu | compiled, eager, graph |
| top_k | Produces different results when the array contains `inf` and `NaN` (they are sorted differently in TF vs. XLA). | floating | cpu, gpu, tpu | eager, graph |

View File

@ -36,6 +36,9 @@ Our priority is to ensure same coverage and numerical behavior with JAX
in the "compiled" mode, i.e., **when using XLA to compile the converted program**.
We are pretty close to that goal.
The converter has a mode in which it attempts to avoid special XLA TF ops
(`enable_xla=False`). In this mode, some primitives have additional limitations.
This table only shows errors for cases that are working in JAX (see [separate
list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) )

View File

@ -967,7 +967,6 @@ for unexpected in xla.call_translations: # Call primitives are inlined
# Primitives that are not yet implemented must be explicitly declared here.
tf_not_yet_impl = [
"reduce",
"rng_uniform",
"clz",
"igamma_grad_a",
@ -1843,18 +1842,42 @@ tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any)
tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all)
def _argminmax(fn, operand, axes, index_dtype):
def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int],
index_dtype: DType,
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
if _thread_local_state.enable_xla:
# Follow the JAX implementation, using a XlaReduce with a custom comparator
if is_min:
extra_name_stack = "argmin"
value_comparator = lax.lt
get_identity = lax._get_min_identity
else:
extra_name_stack = "argmax"
value_comparator = lax.gt
get_identity = lax._get_max_identity
res = _convert_jax_impl(
partial(lax._compute_argminmax, value_comparator, get_identity),
multiple_results=False, extra_name_stack=extra_name_stack)(
operand, index_dtype=index_dtype, axes=axes,
_in_avals=_in_avals, _out_aval=_out_aval)
return res
# The following is known to diverge from JAX behavior for NaN.
axis, = axes
output_type = tf.int32
if dtypes.iinfo(index_dtype).bits > 32:
output_type = tf.int64
# TODO(phawkins): handle axes larger than 2^31.
fn = tf.math.argmin if is_min else tf.math.argmax
result = fn(operand, axis=axis, output_type=output_type)
return tf.cast(result, _to_tf_dtype(index_dtype))
tf_impl[lax.argmin_p] = partial(_argminmax, tf.math.argmin)
tf_impl[lax.argmax_p] = partial(_argminmax, tf.math.argmax)
tf_impl_with_avals[lax.argmin_p] = partial(_argminmax, True)
tf_impl_with_avals[lax.argmax_p] = partial(_argminmax, False)
_add_fn = tf.function(_add, autograph=False)
_ge_fn = tf.function(tf.math.greater_equal, autograph=False)
@ -2156,18 +2179,56 @@ tf_impl_with_avals[lax.reduce_window_max_p] = (
tf_impl_with_avals[lax.reduce_window_p] = _reduce_window
# pylint: enable=protected-access
def _reduce(*operands: TfVal,
computation: Callable,
jaxpr: core.Jaxpr,
consts: Sequence[Any],
dimensions: Sequence[int],
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue) -> Sequence[TfVal]:
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("reduce")
del computation
assert not consts
assert len(operands) % 2 == 0
# operands: op1, op2, ..., init_val1, init_val2, ...
# reducer takes op1[i], op2[i], ..., init_val1, init_val2, ...
nr_operands = len(operands) // 2
init_vals = operands[nr_operands:]
operands = operands[0:nr_operands]
reducer_arg_spec = tuple([tf.TensorSpec((), op.dtype) for op in init_vals] * 2)
def reducer_computation(*args: TfVal) -> TfVal:
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
res = _interpret_jaxpr(closed_jaxpr, *args, extra_name_stack=None)
return res
xla_reducer_computation = (
tf.function(reducer_computation,
autograph=False).get_concrete_function(*reducer_arg_spec))
out = tfxla.variadic_reduce_v2(operands, init_vals,
dimensions_to_reduce=dimensions,
reducer=xla_reducer_computation)
return out
tf_impl_with_avals[lax.reduce_p] = _reduce
# We use lax_control_flow._cumred_tpu_translation_rule to convert cummax,
# cummin, cumsum and cumprod. This is efficient on TPU, but the complexity is
# O(n^2) on other backends. This may be implemented using associative_scan
# instead to favor different backends.
tf_impl_with_avals[lax_control_flow.cummin_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_min),
lax._reduce_window_min),
multiple_results=False,
extra_name_stack="cummin")
tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_max),
lax._reduce_window_max),
multiple_results=False,
extra_name_stack="cummin")
# TODO(bchetioui): cumsum and cumprod can be converted using pure TF ops for
@ -2177,12 +2238,12 @@ tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl(
# tests will crash.
tf_impl_with_avals[lax_control_flow.cumsum_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_sum),
lax._reduce_window_sum),
multiple_results=False,
extra_name_stack="cumsum")
tf_impl_with_avals[lax_control_flow.cumprod_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_prod),
lax._reduce_window_prod),
multiple_results=False,
extra_name_stack="cumprod")

View File

@ -11,17 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""See primitives_test docstring for how the Jax2TfLimitations are used"""
"""See primitives_test docstring for how the Jax2TfLimitations are used."""
import itertools
import numpy as np
from typing import Any, Callable, Optional, Sequence, Union
from jax._src import dtypes
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax._src import dtypes
from jax.experimental.jax2tf.tests import primitive_harness
import numpy as np
DType = Any
@ -83,7 +83,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
def get_max_tolerance_limitation(
self, limitations: Sequence["Jax2TfLimitation"]
) -> Optional["Jax2TfLimitation"]:
"""Pick the tolerance limitation that establishes the maximum tolerance"""
"""Pick the tolerance limitation that establishes the maximum tolerance."""
# TODO: it would be best if the limitations with tolerance are mutually exclusive
# and we don't have to compute the maximum
# TODO: we made this an instance method only so that we don't have to import
@ -124,15 +124,17 @@ class Jax2TfLimitation(primitive_harness.Limitation):
# We keep here the explicit set of groups for which we don't have limitations
harness_groups_no_limitations = {
"abs", "add", "add_any", "and", "argmin", "argmax", "atan2",
"abs", "add", "add_any", "and", "atan2",
"bitcast_convert_type", "broadcast", "broadcast_in_dim", "ceil", "clamp",
"concatenate", "cos", "cosh", "complex", "conj", "convert_element_type",
"cummax", "cummin", "device_put", "dynamic_slice",
"dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt", "imag",
"dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt",
"imag",
"iota", "is_finite", "le", "lt", "log", "mul", "ne", "neg", "not",
"or", "pad", "population_count", "random_split",
"or", "pad", "population_count", "random_split", "reduce",
"reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
"reduce_window_add", "reduce_window_mul", "reduce_window_min", "reduce_window_max",
"reduce_window_add", "reduce_window_mul", "reduce_window_min",
"reduce_window_max",
"real", "reshape", "rev", "rsqrt", "scatter_max", "scatter_min",
"select", "select_and_scatter_add",
"shift_left", "shift_right_logical", "shift_right_arithmetic", "sign",
@ -176,6 +178,23 @@ class Jax2TfLimitation(primitive_harness.Limitation):
cls.helper_get_trig_custom_limitation(np.cosh)
]
@classmethod
def argmax(cls, harness: primitive_harness.Harness):
return [
Jax2TfLimitation(
"different results when the input contains NaN and enable_xla=False",
dtypes=jtu.dtypes.all_inexact,
devices=("cpu", "gpu", "tpu"),
modes=("eager", "graph", "compiled"),
expect_tf_error=False,
skip_comparison=True,
enabled=("nan_" in harness.name and not harness.params["enable_xla"])),
]
@classmethod
def argmin(cls, harness: primitive_harness.Harness):
return cls.argmax(harness)
@classmethod
def asin(cls, harness: primitive_harness.Harness):
return [

View File

@ -477,7 +477,7 @@ for dtype in jtu.dtypes.all_floating:
for rounding_method in [
lax.RoundingMethod.AWAY_FROM_ZERO, lax.RoundingMethod.TO_NEAREST_EVEN
]:
operand = np.array([[0.5, 1.5, 2.5], [-0.5, -1.5, -2.5]], dtype=np.float32)
operand = np.array([[0.5, 1.2, 1.5, 1.7, 2.5], [-0.5, -1.2, -1.5, -1.7, -2.5]], dtype=np.float32)
_make_round_harness(
"rounding_methods", operand=operand, rounding_method=rounding_method)
@ -793,19 +793,22 @@ def _make_argminmax_harness(prim,
dtype=jnp.float32,
axes=(0,),
index_dtype=np.int32,
arr=None):
arr=None,
works_without_xla=True):
arr = arr if arr is not None else RandArg(shape, dtype)
dtype, shape = arr.dtype, arr.shape
index_dtype = dtypes.canonicalize_dtype(index_dtype)
define(
prim,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_axes={axes}_indexdtype={index_dtype}",
lambda arg: prim.bind(arg, axes=axes, index_dtype=index_dtype), [arr],
shape=shape,
dtype=dtype,
axes=axes,
index_dtype=index_dtype,
prim=prim)
for enable_xla in [True, False]:
define(
prim,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_axes={axes}_indexdtype={index_dtype}_enablexla={enable_xla}",
lambda arg: prim.bind(arg, axes=axes, index_dtype=index_dtype), [arr],
shape=shape,
dtype=dtype,
axes=axes,
index_dtype=index_dtype,
prim=prim,
enable_xla=enable_xla)
for prim in [lax.argmin_p, lax.argmax_p]:
@ -820,6 +823,22 @@ for prim in [lax.argmin_p, lax.argmax_p]:
for index_dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned:
_make_argminmax_harness(prim, "index_dtype", index_dtype=index_dtype)
# Some special cases, with equal elements and NaN
for name, operand in [
("nan_0", np.array([np.nan, np.nan, 2., -2., -np.nan, -np.nan], np.float32)),
("nan_1", np.array([np.nan, -np.nan, 2., -2.], np.float32)),
("inf_0", np.array([2., np.inf, np.inf, -2.], np.float32)),
("inf_1", np.array([2., np.inf, -np.inf, -2.], np.float32)),
("inf_2", np.array([2., -np.inf, np.inf, -2.], np.float32)),
("inf_3", np.array([2., -np.inf, -np.inf, -2.], np.float32)),
("nan_inf_0", np.array([2., np.nan, np.inf, -2.], np.float32)),
("nan_inf_1", np.array([2., np.nan, -np.inf, -2.], np.float32)),
("equal", np.array([2., 2., 2.], np.int32)),
("singleton", np.array([1.], np.float32)),
]:
_make_argminmax_harness(prim, f"special_{name}", shape=operand.shape,
arr=operand)
# TODO(bchetioui): the below documents a limitation of argmin and argmax when a
# dimension of the input is too large. However, it is not categorizable as it
# seems that the converter fails before reaching the actual primitive call. This
@ -2201,6 +2220,59 @@ for base_dilation, window_dilation in [
_make_select_and_gather_add_harness(
"dilations", base_dilation=base_dilation, window_dilation=window_dilation)
def _make_reduce_harness(name, *,
shape=(4, 6), # The shape of all operands
nr_operands=1, # How many operands
computation=lax.add, # Takes Tuple(op1, [op2,]) and Tuple(init_val1, [init_val2]). Returns Tuple(out_val1, [out_val2]).
dimensions: Sequence[int] = (0,),
init_value=0, # The init value for first operand
dtype=np.float32): # The dtype of first operand
def reducer(*args):
init_val = np.array(init_value, dtype=dtype)
init_values = [init_val]
if nr_operands == 2:
init_values.append(np.int32(0.))
return lax.reduce(args[0:nr_operands], tuple(init_values),
computation, dimensions)
define(
lax.reduce_p,
f"gen_{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_initvalue={init_value}_nr_operands={nr_operands}_dimensions={dimensions}".replace(" ", ""),
reducer,
[
RandArg(shape, dtype),
# Second operand (optional, always i32). We cannot mix multiple float
# types in XLA.
RandArg(shape, np.int32),
],
shape=shape,
dtype=dtype,
init_value=init_value,
computation=computation,
dimensions=dimensions)
for dtype in jtu.dtypes.all:
for name, nr_operands, computation, init_value in [
("add_scalar", 1,
lambda ops, inits: (lax.add(ops[0], inits[0]),), 3),
# Compute the max (starting with 3) and the min (from 0), in parallel
("max_min", 2,
lambda ops, inits: (lax.max(ops[0], inits[0]),
lax.min(ops[1], inits[1])), 3),
]:
if not (dtype == np.bool_ and name == "add_scalar"):
_make_reduce_harness(name, nr_operands=nr_operands,
computation=computation, init_value=init_value,
dtype=dtype)
# Test the dimensions, but only for int32 (to keep the # of tests small)
if dtype == np.int32:
_make_reduce_harness(name, nr_operands=nr_operands,
computation=computation, init_value=init_value,
dimensions=(1,),
dtype=dtype)
_make_reduce_harness(name, nr_operands=nr_operands,
computation=computation, init_value=init_value,
dimensions=(0, 1),
dtype=dtype)
def _make_reduce_window_harness(name,
*,

View File

@ -182,7 +182,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
modes = ", ".join(sorted(l.modes))
description = l.description
if l.skip_comparison:
description = "Numeric comparision disabled: " + description
description = "Numeric comparison disabled: " + description
if l.expect_tf_error:
description = "TF error: " + description
if l.skip_tf_run:

View File

@ -1116,6 +1116,20 @@ _POLY_SHAPE_TEST_HARNESSES = [
for enable_xla in [False, True]:
_POLY_SHAPE_TEST_HARNESSES.extend([
# Reduce the poly dimension
_make_harness("argmax", f"0_enable_xla={enable_xla}",
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
[RandArg((3, 4, 5), _f32)],
poly_axes=[0],
enable_xla=enable_xla),
# Reduce the non-poly dimension
_make_harness("argmax", f"1_enable_xla={enable_xla}",
lambda op: lax.argmax(op, axis=1, index_dtype=np.int32),
[RandArg((3, 4, 5), _f32)],
poly_axes=[0],
enable_xla=enable_xla),
_make_harness("dynamic_slice", f"enable_xla={enable_xla}",
# x:shape: (b, 4)
lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),

View File

@ -2576,6 +2576,24 @@ class LazyConstantTest(jtu.JaxTestCase):
"index_dtype must be an integer type"):
jax_fn(np.ones((2, 2)), axis=0, index_dtype=index_dtype)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_fn={}".format(jax_fn.__name__),
"jax_fn": jax_fn}
for jax_fn in [lax.argmin, lax.argmax]))
def testArgMinMaxEmptyError(self, jax_fn):
with self.assertRaisesRegex(ValueError,
"require non-empty reduced dimension"):
jax_fn(np.ones((0, 2)), axis=0, index_dtype=np.int32)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_fn={}".format(jax_fn.__name__),
"jax_fn": jax_fn}
for jax_fn in [lax.argmin, lax.argmax]))
def testArgMinMaxInvalidAxisError(self, jax_fn):
with self.assertRaisesRegex(ValueError,
"Invalid axis -1 for operand"):
jax_fn(np.ones((2, 3)), axis=-1, index_dtype=np.int32)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_fn={}_weaktype={}".format(jax_fn.__name__, weak_type),
"jax_fn": jax_fn, "weak_type": weak_type}