mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[jax2tf] Fix conversion for argmin/argmax; add conversion for reduce
The previous conversion for argmin/argmax simply used tf.argmin and tf.argmax. Those ops behave differently than JAX when the inputs contain NaN and Inf. Added a few test cases in primitive_harness to expose the failures. In order to implement an accurate conversion of argmin/argmax, we need to use the XLA Reduce op. Also tightened the shape checks for lax.argmin and lax.argmax, to ensure they are not used with an empty reduced dimension. E.g., if the axis=-1, previously we got an internal error: ``` RuntimeError: Invalid argument: Reducing out-of-bounds dimension -1 in shape f32[2,0,3].: This is a bug in JAX's shape-checking rules; please report it! ``` PiperOrigin-RevId: 384182794
This commit is contained in:
parent
a27047889d
commit
0beef34d25
@ -11,6 +11,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
## jax 0.2.18 (unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...main).
|
||||
|
||||
* Bug fixes:
|
||||
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
|
||||
not used with invalid `axis` value, or with an empty reduction dimension.
|
||||
({jax-issue}`#7196`)
|
||||
|
||||
## jaxlib 0.1.69 (unreleased)
|
||||
|
||||
## jax 0.2.17 (July 9 2021)
|
||||
|
@ -5488,6 +5488,11 @@ _masking_defreducer(reduce_min_p,
|
||||
|
||||
def _argminmax_shape_rule(operand, *, axes, index_dtype):
|
||||
axis, = axes
|
||||
if not (0 <= axis < len(operand.shape)):
|
||||
raise ValueError(f"Invalid axis {axis} for operand shape {operand.shape}")
|
||||
if not core.greater_equal_dim(operand.shape[axis], 1):
|
||||
raise ValueError("argmin and argmax require non-empty reduced dimension. "
|
||||
f"operand.shape={operand.shape} axis={axis}")
|
||||
return tuple(np.delete(operand.shape, axis))
|
||||
|
||||
def _argminmax_dtype_rule(operand, *, axes, index_dtype):
|
||||
@ -5496,34 +5501,29 @@ def _argminmax_dtype_rule(operand, *, axes, index_dtype):
|
||||
.format(np.dtype(index_dtype).name))
|
||||
return index_dtype
|
||||
|
||||
def _argminmax_translation_rule(value_comparator, identity,
|
||||
c, operand, *, axes, index_dtype):
|
||||
def _compute_argminmax(value_comparator, get_identity,
|
||||
operand, *, index_dtype, axes):
|
||||
# value_comparator is either lax.lt (for argmin) or lax.gt
|
||||
# get_identity(operand.dtype) is inf for argmin or -inf for argmax
|
||||
axis, = axes
|
||||
shape = c.get_shape(operand)
|
||||
dtype = shape.numpy_dtype()
|
||||
|
||||
subc = xb.make_computation_builder("argminmax_comparator")
|
||||
value_shape = xc.Shape.array_shape(shape.xla_element_type(), ())
|
||||
index_shape = xc.Shape.array_shape(index_dtype, ())
|
||||
x_value = xb.parameter(subc, 0, value_shape)
|
||||
x_index = xb.parameter(subc, 1, index_shape)
|
||||
y_value = xb.parameter(subc, 2, value_shape)
|
||||
y_index = xb.parameter(subc, 3, index_shape)
|
||||
which_value = xops.Or(value_comparator(x_value, y_value),
|
||||
xops.Ne(x_value, x_value))
|
||||
which_index = xops.Or(which_value, xops.And(xops.Eq(x_value, y_value),
|
||||
xops.Lt(x_index, y_index)))
|
||||
xops.Tuple(subc, [xops.Select(which_value, x_value, y_value),
|
||||
xops.Select(which_index, x_index, y_index)])
|
||||
comparator = subc.build()
|
||||
|
||||
iota_shape = xc.Shape.array_shape(index_dtype, shape.dimensions())
|
||||
iota = xc.ops.Iota(c, iota_shape, axis)
|
||||
out = xops.Reduce(
|
||||
c, [operand, iota],
|
||||
[xb.constant(c, identity(dtype)),
|
||||
xb.constant(c, np.array(0, index_dtype))], comparator, [axis])
|
||||
return xops.GetTupleElement(out, 1)
|
||||
indices = broadcasted_iota(index_dtype, np.shape(operand), axis)
|
||||
def reducer_fn(op_val_index, acc_val_index):
|
||||
op_val, op_index = op_val_index
|
||||
acc_val, acc_index = acc_val_index
|
||||
# Pick op_val if Lt (for argmin) or if NaN
|
||||
pick_op_val = bitwise_or(value_comparator(op_val, acc_val),
|
||||
ne(op_val, op_val))
|
||||
# If x and y are not NaN and x = y, then pick the first
|
||||
pick_op_index = bitwise_or(pick_op_val,
|
||||
bitwise_and(eq(op_val, acc_val),
|
||||
lt(op_index, acc_index)))
|
||||
return (select(pick_op_val, op_val, acc_val),
|
||||
select(pick_op_index, op_index, acc_index))
|
||||
res = reduce([operand, indices],
|
||||
[get_identity(operand.dtype), np.array(0, index_dtype)],
|
||||
reducer_fn,
|
||||
axes)
|
||||
return res[1]
|
||||
|
||||
def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype):
|
||||
axis, = axes
|
||||
@ -5534,10 +5534,13 @@ def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype):
|
||||
mask_idxs = select(eq(a, maxvals) | ne(a, a), idxs, maxval)
|
||||
return _reduce_min(mask_idxs, (axis,))
|
||||
|
||||
_argmin_translation_rule = partial(_argminmax_translation_rule, xops.Lt,
|
||||
_get_min_identity)
|
||||
_argmax_translation_rule = partial(_argminmax_translation_rule, xops.Gt,
|
||||
_get_max_identity)
|
||||
_argmin_translation_rule = xla.lower_fun(
|
||||
partial(_compute_argminmax, lt, _get_min_identity),
|
||||
multiple_results=False)
|
||||
|
||||
_argmax_translation_rule = xla.lower_fun(
|
||||
partial(_compute_argminmax, gt, _get_max_identity),
|
||||
multiple_results=False)
|
||||
|
||||
argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
||||
'argmin', _argmin_translation_rule,
|
||||
|
@ -786,6 +786,7 @@ We use the following XLA TF ops:
|
||||
* `XlaReduceWindow` (wraps XLA ReduceWindow operator). These are used
|
||||
for `lax.reduce_window_sum_p`, `lax.reduce_window_min_p`,
|
||||
`lax.reduce_window_max_p`, and `lax.reduce_window_p`.
|
||||
* `XlaVariadicReduceV2` (for `lax.reduce`, `lax.argmin`, `lax.argmax`).
|
||||
* `XlaVariadicSort` (wraps XLA Sort operator).
|
||||
|
||||
### Different performance characteristics
|
||||
|
@ -36,6 +36,9 @@ Our priority is to ensure same coverage and numerical behavior with JAX
|
||||
in the "compiled" mode, i.e., **when using XLA to compile the converted program**.
|
||||
We are pretty close to that goal.
|
||||
|
||||
The converter has a mode in which it attempts to avoid special XLA TF ops
|
||||
(`enable_xla=False`). In this mode, some primitives have additional limitations.
|
||||
|
||||
This table only shows errors for cases that are working in JAX (see [separate
|
||||
list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) )
|
||||
|
||||
@ -126,27 +129,29 @@ with jax2tf. The following table lists that cases when this does not quite hold:
|
||||
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| acosh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
|
||||
| argmax | Numeric comparison disabled: different results when the input contains NaN and enable_xla=False | inexact | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| argmin | Numeric comparison disabled: different results when the input contains NaN and enable_xla=False | inexact | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| asin | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
|
||||
| asinh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
|
||||
| atan | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
|
||||
| atanh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
|
||||
| cholesky | May return different values in the strictly upper triangular part of the result. This does not matter for correctness, because this part of the matrix is not considered in the result. | all | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| custom_linear_solve | Numeric comparision disabled: TODO: large numerical discrepancy | float32 | tpu | compiled, eager, graph |
|
||||
| custom_linear_solve | Numeric comparison disabled: TODO: large numerical discrepancy | float32 | tpu | compiled, eager, graph |
|
||||
| digamma | May return different results at singularity points 0 and -1.JAX returns nan and TF returns inf | bfloat16 | cpu, gpu, tpu | eager, graph |
|
||||
| eig | May return the eigenvalues and eigenvectors in a potentially different order. The eigenvectors may also be different, but equally valid. | all | cpu, gpu, tpu | eager, graph |
|
||||
| eigh | May return the eigenvalues and eigenvectors in a potentially different order. The eigenvectors may also be different, but equally valid. | all | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| eigh | Numeric comparision disabled: TODO: numeric discrepancies | float16 | tpu | compiled, eager, graph |
|
||||
| eigh | Numeric comparison disabled: TODO: numeric discrepancies | float16 | tpu | compiled, eager, graph |
|
||||
| erf_inv | May return different results at undefined points (< -1 or > 1): JAX returns `NaN` and TF returns `+inf` or `-inf`. | float32, float64 | cpu, gpu, tpu | eager, graph |
|
||||
| igamma | May return different results at undefined points (both arguments 0). JAX returns `NaN` and TF returns 0 or JAX returns 1 and TF returns `NaN` | all | cpu, gpu, tpu | eager, graph |
|
||||
| igammac | May return different results at undefined points (both arguments less or equal 0). JAX returns `NaN` and TF returns 0 or JAX returns 1 and TF returns `NaN` | all | cpu, gpu | eager, graph |
|
||||
| integer_pow | Numeric comparision disabled: Different overflow behavior for large exponents. | bfloat16, complex, float16, float32, signed | cpu, gpu, tpu | eager, graph |
|
||||
| integer_pow | Numeric comparision disabled: Different overflow behavior. | bfloat16, float16 | tpu | eager, graph |
|
||||
| integer_pow | Numeric comparison disabled: Different overflow behavior for large exponents. | bfloat16, complex, float16, float32, signed | cpu, gpu, tpu | eager, graph |
|
||||
| integer_pow | Numeric comparison disabled: Different overflow behavior. | bfloat16, float16 | tpu | eager, graph |
|
||||
| integer_pow | custom numeric comparison | complex | cpu, gpu, tpu | eager, graph |
|
||||
| lu | May return different, but also correct, results when the decomposition is not unique | all | cpu, gpu | compiled, eager, graph |
|
||||
| max | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| min | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| pow | custom numeric comparison | complex | cpu, gpu, tpu | eager, graph |
|
||||
| sort | Numeric comparision disabled: TODO: TF non-stable multiple-array sort | all | gpu | compiled, eager, graph |
|
||||
| sort | Numeric comparison disabled: TODO: TF non-stable multiple-array sort | all | gpu | compiled, eager, graph |
|
||||
| svd | custom numeric comparison when compute_uv | all | cpu, gpu | compiled, eager, graph |
|
||||
| top_k | Produces different results when the array contains `inf` and `NaN` (they are sorted differently in TF vs. XLA). | floating | cpu, gpu, tpu | eager, graph |
|
||||
|
||||
|
@ -36,6 +36,9 @@ Our priority is to ensure same coverage and numerical behavior with JAX
|
||||
in the "compiled" mode, i.e., **when using XLA to compile the converted program**.
|
||||
We are pretty close to that goal.
|
||||
|
||||
The converter has a mode in which it attempts to avoid special XLA TF ops
|
||||
(`enable_xla=False`). In this mode, some primitives have additional limitations.
|
||||
|
||||
This table only shows errors for cases that are working in JAX (see [separate
|
||||
list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) )
|
||||
|
||||
|
@ -967,7 +967,6 @@ for unexpected in xla.call_translations: # Call primitives are inlined
|
||||
|
||||
# Primitives that are not yet implemented must be explicitly declared here.
|
||||
tf_not_yet_impl = [
|
||||
"reduce",
|
||||
"rng_uniform",
|
||||
"clz",
|
||||
"igamma_grad_a",
|
||||
@ -1843,18 +1842,42 @@ tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any)
|
||||
tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all)
|
||||
|
||||
|
||||
def _argminmax(fn, operand, axes, index_dtype):
|
||||
def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int],
|
||||
index_dtype: DType,
|
||||
_in_avals: Sequence[core.AbstractValue],
|
||||
_out_aval: core.AbstractValue):
|
||||
if _thread_local_state.enable_xla:
|
||||
# Follow the JAX implementation, using a XlaReduce with a custom comparator
|
||||
if is_min:
|
||||
extra_name_stack = "argmin"
|
||||
value_comparator = lax.lt
|
||||
get_identity = lax._get_min_identity
|
||||
else:
|
||||
extra_name_stack = "argmax"
|
||||
value_comparator = lax.gt
|
||||
get_identity = lax._get_max_identity
|
||||
|
||||
res = _convert_jax_impl(
|
||||
partial(lax._compute_argminmax, value_comparator, get_identity),
|
||||
multiple_results=False, extra_name_stack=extra_name_stack)(
|
||||
operand, index_dtype=index_dtype, axes=axes,
|
||||
_in_avals=_in_avals, _out_aval=_out_aval)
|
||||
return res
|
||||
|
||||
# The following is known to diverge from JAX behavior for NaN.
|
||||
axis, = axes
|
||||
output_type = tf.int32
|
||||
if dtypes.iinfo(index_dtype).bits > 32:
|
||||
output_type = tf.int64
|
||||
# TODO(phawkins): handle axes larger than 2^31.
|
||||
fn = tf.math.argmin if is_min else tf.math.argmax
|
||||
result = fn(operand, axis=axis, output_type=output_type)
|
||||
return tf.cast(result, _to_tf_dtype(index_dtype))
|
||||
|
||||
|
||||
tf_impl[lax.argmin_p] = partial(_argminmax, tf.math.argmin)
|
||||
tf_impl[lax.argmax_p] = partial(_argminmax, tf.math.argmax)
|
||||
tf_impl_with_avals[lax.argmin_p] = partial(_argminmax, True)
|
||||
tf_impl_with_avals[lax.argmax_p] = partial(_argminmax, False)
|
||||
|
||||
|
||||
_add_fn = tf.function(_add, autograph=False)
|
||||
_ge_fn = tf.function(tf.math.greater_equal, autograph=False)
|
||||
@ -2156,18 +2179,56 @@ tf_impl_with_avals[lax.reduce_window_max_p] = (
|
||||
tf_impl_with_avals[lax.reduce_window_p] = _reduce_window
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _reduce(*operands: TfVal,
|
||||
computation: Callable,
|
||||
jaxpr: core.Jaxpr,
|
||||
consts: Sequence[Any],
|
||||
dimensions: Sequence[int],
|
||||
_in_avals: Sequence[core.AbstractValue],
|
||||
_out_aval: core.AbstractValue) -> Sequence[TfVal]:
|
||||
|
||||
if not _thread_local_state.enable_xla:
|
||||
raise _xla_disabled_error("reduce")
|
||||
del computation
|
||||
assert not consts
|
||||
assert len(operands) % 2 == 0
|
||||
# operands: op1, op2, ..., init_val1, init_val2, ...
|
||||
# reducer takes op1[i], op2[i], ..., init_val1, init_val2, ...
|
||||
nr_operands = len(operands) // 2
|
||||
init_vals = operands[nr_operands:]
|
||||
operands = operands[0:nr_operands]
|
||||
|
||||
reducer_arg_spec = tuple([tf.TensorSpec((), op.dtype) for op in init_vals] * 2)
|
||||
|
||||
def reducer_computation(*args: TfVal) -> TfVal:
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
res = _interpret_jaxpr(closed_jaxpr, *args, extra_name_stack=None)
|
||||
return res
|
||||
|
||||
xla_reducer_computation = (
|
||||
tf.function(reducer_computation,
|
||||
autograph=False).get_concrete_function(*reducer_arg_spec))
|
||||
|
||||
out = tfxla.variadic_reduce_v2(operands, init_vals,
|
||||
dimensions_to_reduce=dimensions,
|
||||
reducer=xla_reducer_computation)
|
||||
return out
|
||||
|
||||
tf_impl_with_avals[lax.reduce_p] = _reduce
|
||||
|
||||
|
||||
# We use lax_control_flow._cumred_tpu_translation_rule to convert cummax,
|
||||
# cummin, cumsum and cumprod. This is efficient on TPU, but the complexity is
|
||||
# O(n^2) on other backends. This may be implemented using associative_scan
|
||||
# instead to favor different backends.
|
||||
tf_impl_with_avals[lax_control_flow.cummin_p] = _convert_jax_impl(
|
||||
partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
lax._reduce_window_min),
|
||||
lax._reduce_window_min),
|
||||
multiple_results=False,
|
||||
extra_name_stack="cummin")
|
||||
tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl(
|
||||
partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
lax._reduce_window_max),
|
||||
lax._reduce_window_max),
|
||||
multiple_results=False,
|
||||
extra_name_stack="cummin")
|
||||
# TODO(bchetioui): cumsum and cumprod can be converted using pure TF ops for
|
||||
@ -2177,12 +2238,12 @@ tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl(
|
||||
# tests will crash.
|
||||
tf_impl_with_avals[lax_control_flow.cumsum_p] = _convert_jax_impl(
|
||||
partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
lax._reduce_window_sum),
|
||||
lax._reduce_window_sum),
|
||||
multiple_results=False,
|
||||
extra_name_stack="cumsum")
|
||||
tf_impl_with_avals[lax_control_flow.cumprod_p] = _convert_jax_impl(
|
||||
partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
lax._reduce_window_prod),
|
||||
lax._reduce_window_prod),
|
||||
multiple_results=False,
|
||||
extra_name_stack="cumprod")
|
||||
|
||||
|
@ -11,17 +11,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""See primitives_test docstring for how the Jax2TfLimitations are used"""
|
||||
"""See primitives_test docstring for how the Jax2TfLimitations are used."""
|
||||
|
||||
import itertools
|
||||
import numpy as np
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax import test_util as jtu
|
||||
from jax._src import dtypes
|
||||
from jax.experimental.jax2tf.tests import primitive_harness
|
||||
import numpy as np
|
||||
|
||||
DType = Any
|
||||
|
||||
@ -83,7 +83,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
def get_max_tolerance_limitation(
|
||||
self, limitations: Sequence["Jax2TfLimitation"]
|
||||
) -> Optional["Jax2TfLimitation"]:
|
||||
"""Pick the tolerance limitation that establishes the maximum tolerance"""
|
||||
"""Pick the tolerance limitation that establishes the maximum tolerance."""
|
||||
# TODO: it would be best if the limitations with tolerance are mutually exclusive
|
||||
# and we don't have to compute the maximum
|
||||
# TODO: we made this an instance method only so that we don't have to import
|
||||
@ -124,15 +124,17 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
|
||||
# We keep here the explicit set of groups for which we don't have limitations
|
||||
harness_groups_no_limitations = {
|
||||
"abs", "add", "add_any", "and", "argmin", "argmax", "atan2",
|
||||
"abs", "add", "add_any", "and", "atan2",
|
||||
"bitcast_convert_type", "broadcast", "broadcast_in_dim", "ceil", "clamp",
|
||||
"concatenate", "cos", "cosh", "complex", "conj", "convert_element_type",
|
||||
"cummax", "cummin", "device_put", "dynamic_slice",
|
||||
"dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt", "imag",
|
||||
"dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt",
|
||||
"imag",
|
||||
"iota", "is_finite", "le", "lt", "log", "mul", "ne", "neg", "not",
|
||||
"or", "pad", "population_count", "random_split",
|
||||
"or", "pad", "population_count", "random_split", "reduce",
|
||||
"reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
|
||||
"reduce_window_add", "reduce_window_mul", "reduce_window_min", "reduce_window_max",
|
||||
"reduce_window_add", "reduce_window_mul", "reduce_window_min",
|
||||
"reduce_window_max",
|
||||
"real", "reshape", "rev", "rsqrt", "scatter_max", "scatter_min",
|
||||
"select", "select_and_scatter_add",
|
||||
"shift_left", "shift_right_logical", "shift_right_arithmetic", "sign",
|
||||
@ -176,6 +178,23 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
cls.helper_get_trig_custom_limitation(np.cosh)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def argmax(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
Jax2TfLimitation(
|
||||
"different results when the input contains NaN and enable_xla=False",
|
||||
dtypes=jtu.dtypes.all_inexact,
|
||||
devices=("cpu", "gpu", "tpu"),
|
||||
modes=("eager", "graph", "compiled"),
|
||||
expect_tf_error=False,
|
||||
skip_comparison=True,
|
||||
enabled=("nan_" in harness.name and not harness.params["enable_xla"])),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def argmin(cls, harness: primitive_harness.Harness):
|
||||
return cls.argmax(harness)
|
||||
|
||||
@classmethod
|
||||
def asin(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
|
@ -477,7 +477,7 @@ for dtype in jtu.dtypes.all_floating:
|
||||
for rounding_method in [
|
||||
lax.RoundingMethod.AWAY_FROM_ZERO, lax.RoundingMethod.TO_NEAREST_EVEN
|
||||
]:
|
||||
operand = np.array([[0.5, 1.5, 2.5], [-0.5, -1.5, -2.5]], dtype=np.float32)
|
||||
operand = np.array([[0.5, 1.2, 1.5, 1.7, 2.5], [-0.5, -1.2, -1.5, -1.7, -2.5]], dtype=np.float32)
|
||||
_make_round_harness(
|
||||
"rounding_methods", operand=operand, rounding_method=rounding_method)
|
||||
|
||||
@ -793,19 +793,22 @@ def _make_argminmax_harness(prim,
|
||||
dtype=jnp.float32,
|
||||
axes=(0,),
|
||||
index_dtype=np.int32,
|
||||
arr=None):
|
||||
arr=None,
|
||||
works_without_xla=True):
|
||||
arr = arr if arr is not None else RandArg(shape, dtype)
|
||||
dtype, shape = arr.dtype, arr.shape
|
||||
index_dtype = dtypes.canonicalize_dtype(index_dtype)
|
||||
define(
|
||||
prim,
|
||||
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_axes={axes}_indexdtype={index_dtype}",
|
||||
lambda arg: prim.bind(arg, axes=axes, index_dtype=index_dtype), [arr],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
axes=axes,
|
||||
index_dtype=index_dtype,
|
||||
prim=prim)
|
||||
for enable_xla in [True, False]:
|
||||
define(
|
||||
prim,
|
||||
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_axes={axes}_indexdtype={index_dtype}_enablexla={enable_xla}",
|
||||
lambda arg: prim.bind(arg, axes=axes, index_dtype=index_dtype), [arr],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
axes=axes,
|
||||
index_dtype=index_dtype,
|
||||
prim=prim,
|
||||
enable_xla=enable_xla)
|
||||
|
||||
|
||||
for prim in [lax.argmin_p, lax.argmax_p]:
|
||||
@ -820,6 +823,22 @@ for prim in [lax.argmin_p, lax.argmax_p]:
|
||||
for index_dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned:
|
||||
_make_argminmax_harness(prim, "index_dtype", index_dtype=index_dtype)
|
||||
|
||||
# Some special cases, with equal elements and NaN
|
||||
for name, operand in [
|
||||
("nan_0", np.array([np.nan, np.nan, 2., -2., -np.nan, -np.nan], np.float32)),
|
||||
("nan_1", np.array([np.nan, -np.nan, 2., -2.], np.float32)),
|
||||
("inf_0", np.array([2., np.inf, np.inf, -2.], np.float32)),
|
||||
("inf_1", np.array([2., np.inf, -np.inf, -2.], np.float32)),
|
||||
("inf_2", np.array([2., -np.inf, np.inf, -2.], np.float32)),
|
||||
("inf_3", np.array([2., -np.inf, -np.inf, -2.], np.float32)),
|
||||
("nan_inf_0", np.array([2., np.nan, np.inf, -2.], np.float32)),
|
||||
("nan_inf_1", np.array([2., np.nan, -np.inf, -2.], np.float32)),
|
||||
("equal", np.array([2., 2., 2.], np.int32)),
|
||||
("singleton", np.array([1.], np.float32)),
|
||||
]:
|
||||
_make_argminmax_harness(prim, f"special_{name}", shape=operand.shape,
|
||||
arr=operand)
|
||||
|
||||
# TODO(bchetioui): the below documents a limitation of argmin and argmax when a
|
||||
# dimension of the input is too large. However, it is not categorizable as it
|
||||
# seems that the converter fails before reaching the actual primitive call. This
|
||||
@ -2201,6 +2220,59 @@ for base_dilation, window_dilation in [
|
||||
_make_select_and_gather_add_harness(
|
||||
"dilations", base_dilation=base_dilation, window_dilation=window_dilation)
|
||||
|
||||
def _make_reduce_harness(name, *,
|
||||
shape=(4, 6), # The shape of all operands
|
||||
nr_operands=1, # How many operands
|
||||
computation=lax.add, # Takes Tuple(op1, [op2,]) and Tuple(init_val1, [init_val2]). Returns Tuple(out_val1, [out_val2]).
|
||||
dimensions: Sequence[int] = (0,),
|
||||
init_value=0, # The init value for first operand
|
||||
dtype=np.float32): # The dtype of first operand
|
||||
def reducer(*args):
|
||||
init_val = np.array(init_value, dtype=dtype)
|
||||
init_values = [init_val]
|
||||
if nr_operands == 2:
|
||||
init_values.append(np.int32(0.))
|
||||
return lax.reduce(args[0:nr_operands], tuple(init_values),
|
||||
computation, dimensions)
|
||||
define(
|
||||
lax.reduce_p,
|
||||
f"gen_{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_initvalue={init_value}_nr_operands={nr_operands}_dimensions={dimensions}".replace(" ", ""),
|
||||
reducer,
|
||||
[
|
||||
RandArg(shape, dtype),
|
||||
# Second operand (optional, always i32). We cannot mix multiple float
|
||||
# types in XLA.
|
||||
RandArg(shape, np.int32),
|
||||
],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
init_value=init_value,
|
||||
computation=computation,
|
||||
dimensions=dimensions)
|
||||
|
||||
for dtype in jtu.dtypes.all:
|
||||
for name, nr_operands, computation, init_value in [
|
||||
("add_scalar", 1,
|
||||
lambda ops, inits: (lax.add(ops[0], inits[0]),), 3),
|
||||
# Compute the max (starting with 3) and the min (from 0), in parallel
|
||||
("max_min", 2,
|
||||
lambda ops, inits: (lax.max(ops[0], inits[0]),
|
||||
lax.min(ops[1], inits[1])), 3),
|
||||
]:
|
||||
if not (dtype == np.bool_ and name == "add_scalar"):
|
||||
_make_reduce_harness(name, nr_operands=nr_operands,
|
||||
computation=computation, init_value=init_value,
|
||||
dtype=dtype)
|
||||
# Test the dimensions, but only for int32 (to keep the # of tests small)
|
||||
if dtype == np.int32:
|
||||
_make_reduce_harness(name, nr_operands=nr_operands,
|
||||
computation=computation, init_value=init_value,
|
||||
dimensions=(1,),
|
||||
dtype=dtype)
|
||||
_make_reduce_harness(name, nr_operands=nr_operands,
|
||||
computation=computation, init_value=init_value,
|
||||
dimensions=(0, 1),
|
||||
dtype=dtype)
|
||||
|
||||
def _make_reduce_window_harness(name,
|
||||
*,
|
||||
|
@ -182,7 +182,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
modes = ", ".join(sorted(l.modes))
|
||||
description = l.description
|
||||
if l.skip_comparison:
|
||||
description = "Numeric comparision disabled: " + description
|
||||
description = "Numeric comparison disabled: " + description
|
||||
if l.expect_tf_error:
|
||||
description = "TF error: " + description
|
||||
if l.skip_tf_run:
|
||||
|
@ -1116,6 +1116,20 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
|
||||
for enable_xla in [False, True]:
|
||||
_POLY_SHAPE_TEST_HARNESSES.extend([
|
||||
# Reduce the poly dimension
|
||||
_make_harness("argmax", f"0_enable_xla={enable_xla}",
|
||||
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
|
||||
[RandArg((3, 4, 5), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_xla=enable_xla),
|
||||
|
||||
# Reduce the non-poly dimension
|
||||
_make_harness("argmax", f"1_enable_xla={enable_xla}",
|
||||
lambda op: lax.argmax(op, axis=1, index_dtype=np.int32),
|
||||
[RandArg((3, 4, 5), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_xla=enable_xla),
|
||||
|
||||
_make_harness("dynamic_slice", f"enable_xla={enable_xla}",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
|
||||
|
@ -2576,6 +2576,24 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
"index_dtype must be an integer type"):
|
||||
jax_fn(np.ones((2, 2)), axis=0, index_dtype=index_dtype)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_fn={}".format(jax_fn.__name__),
|
||||
"jax_fn": jax_fn}
|
||||
for jax_fn in [lax.argmin, lax.argmax]))
|
||||
def testArgMinMaxEmptyError(self, jax_fn):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"require non-empty reduced dimension"):
|
||||
jax_fn(np.ones((0, 2)), axis=0, index_dtype=np.int32)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_fn={}".format(jax_fn.__name__),
|
||||
"jax_fn": jax_fn}
|
||||
for jax_fn in [lax.argmin, lax.argmax]))
|
||||
def testArgMinMaxInvalidAxisError(self, jax_fn):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Invalid axis -1 for operand"):
|
||||
jax_fn(np.ones((2, 3)), axis=-1, index_dtype=np.int32)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_fn={}_weaktype={}".format(jax_fn.__name__, weak_type),
|
||||
"jax_fn": jax_fn, "weak_type": weak_type}
|
||||
|
Loading…
x
Reference in New Issue
Block a user