[jax2tf] Support inequality and min/max for booleans.

For inequalities we add casts to int8. For min/max we rewrite
to logical operations and/or.
This commit is contained in:
George Necula 2021-06-10 12:42:40 +02:00
parent 5e3be94d8c
commit dd8ab85121
8 changed files with 82 additions and 193 deletions

View File

@ -11,6 +11,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.15 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...master).
* New features:
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans
({jax-issue}`#6956`).
* Breaking changes:
* Support for NumPy 1.16 has been dropped, per the

View File

@ -2907,6 +2907,12 @@ def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
return operand.shape
def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
if dtypes.issubdtype(old_dtype, np.bool_) or dtypes.issubdtype(old_dtype, np.complexfloating):
if old_dtype != new_dtype:
raise TypeError(f"`bitcast_convert_type` for operand type ({old_dtype}) cannot have different destination type ({new_dtype})")
if np.dtype(old_dtype).itemsize != np.dtype(new_dtype).itemsize:
raise TypeError(f"`bitcast_convert_type` for operand type ({old_dtype}) must have destination type ({new_dtype}) of same size.")
return new_dtype
def _bitcast_convert_type_translation_rule(c, operand, *, new_dtype):

View File

@ -1,10 +1,10 @@
# Primitives with limited JAX support
*Last generated on: 2021-06-04* (YYYY-MM-DD)
*Last generated on: 2021-06-12* (YYYY-MM-DD)
## Supported data types for primitives
We use a set of 2570 test harnesses to test
We use a set of 2604 test harnesses to test
the implementation of 121 numeric JAX primitives.
We consider a JAX primitive supported for a particular data
type if it is supported on at least one device type.
@ -77,7 +77,7 @@ be updated.
| digamma | 4 | floating | bool, complex, integer |
| div | 20 | inexact, integer | bool |
| dot_general | 245 | all | |
| dynamic_slice | 32 | all | |
| dynamic_slice | 64 | all | |
| dynamic_update_slice | 21 | all | |
| eig | 72 | inexact | bool, integer |
| eigh | 36 | inexact | bool, integer |
@ -134,10 +134,10 @@ be updated.
| rev | 19 | all | |
| round | 7 | floating | bool, complex, integer |
| rsqrt | 6 | inexact | bool, integer |
| scatter_add | 14 | inexact, integer | bool |
| scatter_add | 15 | all | |
| scatter_max | 15 | all | |
| scatter_min | 19 | all | |
| scatter_mul | 14 | inexact, integer | bool |
| scatter_mul | 15 | all | |
| select | 16 | all | |
| select_and_gather_add | 15 | floating | bool, complex, integer |
| select_and_scatter_add | 27 | bool, floating, integer | complex |
@ -197,8 +197,10 @@ and search for "limitation".
|eigh|unimplemented|bfloat16, float16|cpu, gpu|
|lu|unimplemented|bfloat16, float16|cpu, gpu, tpu|
|qr|unimplemented|bfloat16, float16|cpu, gpu|
|scatter_add|unimplemented|bool|cpu, gpu, tpu|
|scatter_max|unimplemented|complex64|tpu|
|scatter_min|unimplemented|complex64|tpu|
|scatter_mul|unimplemented|bool|cpu, gpu, tpu|
|select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu|
|svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu|
|svd|unimplemented|bfloat16, float16|cpu, gpu|

View File

@ -1,6 +1,6 @@
# Primitives with limited support for jax2tf
*Last generated on (YYYY-MM-DD): 2021-06-04*
*Last generated on (YYYY-MM-DD): 2021-06-12*
This document summarizes known limitations of the jax2tf conversion.
There are several kinds of limitations.
@ -33,7 +33,7 @@ The errors apply only for certain devices and compilation modes ("eager",
On TPU only the "compiled" mode is relevant.
Our priority is to ensure same coverage and numerical behavior with JAX
in the "compiled" mode, **when using XLA to compile the converted program**.
in the "compiled" mode, i.e., **when using XLA to compile the converted program**.
We are pretty close to that goal.
This table only shows errors for cases that are working in JAX (see [separate
@ -60,20 +60,15 @@ More detailed information can be found in the
| --- | --- | --- | --- | --- |
| bessel_i0e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| bessel_i1e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| bitcast_convert_type | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph |
| cholesky | TF error: function not compilable | complex | cpu, gpu | compiled |
| cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph |
| clamp | TF test skipped: Not implemented in JAX: unimplemented | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| clamp | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | int16, int32, int8 | gpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=f64 not implemented | bfloat16, float16, float32 | tpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=i64 not implemented | int16, int32, int8 | tpu | compiled, eager, graph |
| conv_general_dilated | TF error: jax2tf BUG: batch_group_count > 1 not yet converted | all | cpu, gpu, tpu | compiled, eager, graph |
| cummax | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| cummin | TF error: op not defined for dtype | uint64 | cpu, gpu | eager |
| cummin | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| digamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| div | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| dot_general | TF error: Numeric comparision disabled: Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598) | bfloat16, complex64, float16, float32 | gpu | compiled |
@ -92,42 +87,31 @@ More detailed information can be found in the
| erf_inv | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
| erfc | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| fft | TF error: TF function not compileable | complex128, float64 | cpu, gpu | compiled |
| ge | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| gt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| igamma | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
| igammac | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
| integer_pow | TF error: op not defined for dtype | int16, int8, unsigned | cpu, gpu | graph |
| le | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| lgamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| lt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| lu | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| lu | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| neg | TF error: op not defined for dtype | unsigned | cpu, gpu, tpu | compiled, eager, graph |
| nextafter | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| qr | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
| qr | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
| reduce_max | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
| reduce_min | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
| reduce_window_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| reduce_window_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| regularized_incomplete_beta | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| rem | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| round | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| rsqrt | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| scatter_add | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_add | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| scatter_max | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| scatter_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_min | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| scatter_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_mul | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_mul | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_mul | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph |
| select_and_scatter_add | TF test skipped: Not implemented in JAX: works only for 2 or more inactive dimensions | all | tpu | compiled, eager, graph |
| sign | TF error: sign not defined for unsigned integers | unsigned | cpu, gpu, tpu | compiled, eager, graph |
| sort | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| svd | TF test skipped: Not implemented in JAX: complex not implemented. Works in JAX for CPU and GPU with custom kernels | complex | tpu | compiled, eager, graph |
| svd | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
| svd | TF error: function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint` | complex | cpu, gpu | compiled |

View File

@ -33,7 +33,7 @@ The errors apply only for certain devices and compilation modes ("eager",
On TPU only the "compiled" mode is relevant.
Our priority is to ensure same coverage and numerical behavior with JAX
in the "compiled" mode, **when using XLA to compile the converted program**.
in the "compiled" mode, i.e., **when using XLA to compile the converted program**.
We are pretty close to that goal.
This table only shows errors for cases that are working in JAX (see [separate

View File

@ -1162,6 +1162,8 @@ def _minmax(x: TfVal, y: TfVal, *, is_min: bool,
partial(lax._minmax_complex_lowering,
lax_cmp_pick_x=lax.lt if is_min else lax.gt),
multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval)
elif x.dtype.as_numpy_dtype == np.bool_:
return (tf.math.logical_and if is_min else tf.math.logical_or)(x, y)
else:
return (tf.math.minimum if is_min else tf.math.maximum)(x, y)
@ -1287,19 +1289,37 @@ def _not(x):
tf_impl[lax.not_p] = _not
def bool_to_int8(f, argnums):
"""Computes bool valued functions using int8."""
def bool_to_int8(f, argnums: Sequence[int]):
"""Computes functions with some bool args and bool results using int8.
This is needed because some TF ops do not work for bool args, e.g.,
inequalities, min/max.
Args:
f: a TF callable to wrap. It will be called with non-boolean arguments.
argnums: the positional arguments that may be booleans.
Returns: a TF callable that can take a mix of boolean positional arguments
(in the positions specified by `argnums`) and some non-boolean positional
arguments. If there are no boolean arguments, just calls `f`. Otherwise,
casts the boolean arguments to `int8`, calls `f`, then casts the result to
`bool`.
"""
argnums = tf.nest.flatten(argnums)
def wrapper(*args, **kwargs):
if not any(args[i].dtype == tf.bool for i in argnums):
def wrapper(*args: TfVal, **kwargs):
argnum_types = {args[i].dtype for i in argnums}
if tf.bool not in argnum_types:
return f(*args, **kwargs)
else:
# All argnums should be boolean
assert len(argnum_types) == 1, argnum_types
args_cast = [(tf.cast(a, tf.int8) if i in argnums else a)
for i, a in enumerate(args)]
if "_in_avals" in kwargs:
def cast_aval(aval):
assert aval.dtype == np.bool_
return core.ShapedArray(aval.shape, np.int8)
_in_avals_cast = [
@ -1321,10 +1341,11 @@ tf_impl[lax.xor_p] = bool_to_int8(tf.bitwise.bitwise_xor, argnums=(0, 1))
tf_impl[lax.eq_p] = tf.math.equal
tf_impl[lax.ne_p] = tf.math.not_equal
tf_impl[lax.ge_p] = tf.math.greater_equal
tf_impl[lax.gt_p] = tf.math.greater
tf_impl[lax.le_p] = tf.math.less_equal
tf_impl[lax.lt_p] = tf.math.less
tf_impl[lax.ge_p] = bool_to_int8(tf.math.greater_equal, argnums=(0, 1))
tf_impl[lax.gt_p] = bool_to_int8(tf.math.greater, argnums=(0, 1))
tf_impl[lax.le_p] = bool_to_int8(tf.math.less_equal, argnums=(0, 1))
tf_impl[lax.lt_p] = bool_to_int8(tf.math.less, argnums=(0, 1))
tf_impl[lax_linalg.cholesky_p] = tf.linalg.cholesky
@ -1346,6 +1367,8 @@ tf_impl[lax.convert_element_type_p] = _convert_element_type
def _bitcast_convert_type(operand, new_dtype):
if operand.dtype == new_dtype:
return operand
return tf.bitcast(operand, _to_tf_dtype(new_dtype))
@ -1767,13 +1790,13 @@ tf_impl[lax.transpose_p] = _transpose
axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes)
tf_impl[lax.reduce_sum_p] = (
bool_to_int8(axes_to_axis(tf.reduce_sum), argnums=0))
bool_to_int8(axes_to_axis(tf.reduce_sum), argnums=[0]))
tf_impl[lax.reduce_prod_p] = (
bool_to_int8(axes_to_axis(tf.reduce_prod), argnums=0))
bool_to_int8(axes_to_axis(tf.reduce_prod), argnums=[0]))
tf_impl[lax.reduce_max_p] = (
bool_to_int8(axes_to_axis(tf.reduce_max), argnums=0))
bool_to_int8(axes_to_axis(tf.reduce_max), argnums=[0]))
tf_impl[lax.reduce_min_p] = (
bool_to_int8(axes_to_axis(tf.reduce_min), argnums=0))
bool_to_int8(axes_to_axis(tf.reduce_min), argnums=[0]))
tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any)
tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all)
@ -2178,7 +2201,7 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers):
return proto
@partial(bool_to_int8, argnums=0)
@partial(bool_to_int8, argnums=[0])
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
indices_are_sorted, unique_indices,
_in_avals, _out_aval):
@ -2446,25 +2469,6 @@ def _sort(*operands: TfVal, dimension: int, is_stable: bool,
operands[0].shape
), f"Invalid {dimension} for ndim {len(operands[0].shape)}"
# The comparator is a 2N-argument TF function, with arguments [2k] and [2k +1]
# corresponding to two scalars from operand[k].
def lexicographic_comparator_old(*tf_args: TfVal) -> TfVal:
assert len(tf_args) == 2 * len(operands)
# We build a comparison:
# arg[0] < arg[1] or (arg[0] == arg[1] and (arg[2] < arg[3] or ...))
# all the way to arg[2 * num_keys - 2] < arg[2 * num_keys - 1]
inside_comparison = None
for key_idx in range(num_keys - 1, -1, -1):
a = tf_args[2 * key_idx]
b = tf_args[2 * key_idx + 1]
a_lt_b = tf.math.less(a, b)
if inside_comparison is None:
inside_comparison = a_lt_b
else:
inside_comparison = tf.math.logical_or(
a_lt_b, tf.math.logical_and(tf.math.equal(a, b), inside_comparison))
return inside_comparison
comparator_spec: List[tf.TensorSpec] = []
comparator_jax_in_avals: List[core.AbstractValue] = []
for op in operands:

View File

@ -109,13 +109,13 @@ class Jax2TfLimitation(primitive_harness.Limitation):
group_method = getattr(cls, harness.group_name, None)
if harness.group_name in cls.harness_groups_no_limitations:
assert group_method is None, (
f"Harness group {harness.group_name} is both in "
f"Harness group '{harness.group_name}' is both in "
f"'harness_groups_no_limitations' and has a custom "
f"Jax2TfLimitation.classmethod defined (see module docstring)")
return []
else:
assert group_method is not None, (
f"Harness group {harness.group_name} must be either part of "
f"Harness group '{harness.group_name}' must be either part of "
f"'harness_groups_no_limitations' or must have a custom "
f"Jax2TfLimitation.classmethod defined (see module docstring)")
limitations = group_method(harness)
@ -124,16 +124,19 @@ class Jax2TfLimitation(primitive_harness.Limitation):
# We keep here the explicit set of groups for which we don't have limitations
harness_groups_no_limitations = {
"abs", "and", "argmin", "argmax", "atan2", "broadcast",
"broadcast_in_dim", "ceil",
"concatenate", "cos", "cosh", "complex", "conj",
"device_put", "dynamic_slice",
"dynamic_update_slice", "exp", "eq", "floor", "log", "gather", "imag",
"iota", "is_finite", "ne", "not", "or", "pad", "random_split",
"reduce_and", "reduce_prod", "reduce_or", "reduce_sum", "real", "reshape",
"rev",
"select", "shift_left", "shift_right_logical", "shift_right_arithmetic",
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient",
"abs", "add", "add_any", "and", "argmin", "argmax", "atan2",
"bitcast_convert_type", "broadcast", "broadcast_in_dim", "ceil", "clamp",
"concatenate", "cos", "cosh", "complex", "conj", "convert_element_type",
"cummax", "cummin", "device_put", "dynamic_slice",
"dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt", "imag",
"iota", "is_finite", "le", "lt", "log", "mul", "ne", "not", "or", "pad",
"population_count", "random_split",
"reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
"reduce_window_add", "reduce_window_mul", "reduce_window_min", "reduce_window_max",
"real", "reshape", "rev", "scatter_max", "scatter_min",
"select", "select_and_scatter_add",
"shift_left", "shift_right_logical", "shift_right_arithmetic",
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "sub",
"tie_in", "transpose", "xor", "zeros_like"
}
@ -173,15 +176,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
cls.helper_get_trig_custom_limitation(np.cosh)
]
@classmethod
def add(cls, harness: primitive_harness.Harness):
return []
@classmethod
# Also called add_jaxvals
def add_any(cls, harness: primitive_harness.Harness):
return []
@classmethod
def asin(cls, harness: primitive_harness.Harness):
return [
@ -228,10 +222,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
def bessel_i1e(cls, harness: primitive_harness.Harness):
return cls.bessel_i0e(harness)
@classmethod
def bitcast_convert_type(cls, harness: primitive_harness.Harness):
return [missing_tf_kernel(dtypes=[np.bool_])]
@classmethod
def cholesky(cls, harness: primitive_harness.Harness):
@ -280,16 +270,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
modes=("eager", "graph", "compiled"))
]
@classmethod
def clamp(cls, harness: primitive_harness.Harness):
return [
missing_tf_kernel(dtypes=[np.complex64, np.complex128]),
]
@classmethod
def convert_element_type(cls, harness: primitive_harness.Harness):
return []
@classmethod
def conv_general_dilated(cls, harness: primitive_harness.Harness):
return [
@ -307,22 +287,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
tol=5e-3)
]
@classmethod
def cummax(cls, harness):
return [
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
]
@classmethod
def cummin(cls, harness):
return [
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
# TODO: we get jax2tf AssertionError
missing_tf_kernel(dtypes=[np.uint64],
devices=("cpu", "gpu"),
modes=("eager",)),
]
@classmethod
def cumprod(cls, harness):
return [
@ -421,9 +385,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
@classmethod
def dot_general(cls, harness: primitive_harness.Harness):
return [
missing_tf_kernel(dtypes=[
np.bool_,
],),
missing_tf_kernel(dtypes=[np.bool_],),
# TODO(b/189287598)
Jax2TfLimitation(
"Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598)",
@ -583,14 +545,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
modes=("eager", "graph", "compiled"))
]
@classmethod
def ge(cls, harness: primitive_harness.Harness):
return [missing_tf_kernel(dtypes=[np.bool_])]
@classmethod
def gt(cls, harness: primitive_harness.Harness):
return cls.ge(harness)
@classmethod
def erf(cls, harness: primitive_harness.Harness):
return [
@ -799,14 +753,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
def pow(cls, harness: primitive_harness.Harness):
return cls._pow_test_util(harness)
@classmethod
def le(cls, harness: primitive_harness.Harness):
return cls.ge(harness)
@classmethod
def lt(cls, harness: primitive_harness.Harness):
return cls.ge(harness)
@classmethod
def lgamma(cls, harness: primitive_harness.Harness):
return [
@ -881,8 +827,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
return [
missing_tf_kernel(
dtypes=[np.bool_]),
custom_numeric(
custom_assert=custom_assert,
description=(
@ -901,7 +845,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
return [
missing_tf_kernel(dtypes=[np.bool_]),
custom_numeric(
custom_assert=custom_assert,
description=(
@ -911,10 +854,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
modes=("eager", "graph", "compiled"))
]
@classmethod
def mul(cls, harness: primitive_harness.Harness):
return []
@classmethod
def neg(cls, harness: primitive_harness.Harness):
return [
@ -925,10 +864,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
def nextafter(cls, harness: primitive_harness.Harness):
return [missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16])]
@classmethod
def population_count(cls, harness: primitive_harness.Harness):
return []
@classmethod
def qr(cls, harness: primitive_harness.Harness):
# See https://github.com/google/jax/pull/3775#issuecomment-659407824;
@ -960,39 +895,14 @@ class Jax2TfLimitation(primitive_harness.Limitation):
@classmethod
def reduce_max(cls, harness: primitive_harness.Harness):
return [
missing_tf_kernel(dtypes=[np.complex64, np.complex128]),
]
# Unlike reduce_window_max, we use a native TF op: tf.reduce_max, which
# does not work for complex
return [missing_tf_kernel(dtypes=[np.complex64, np.complex128])]
@classmethod
def reduce_min(cls, harness: primitive_harness.Harness):
return cls.reduce_max(harness)
@classmethod
def reduce_window_add(cls, harness):
assert "add" == harness.params["computation"].__name__
return []
@classmethod
def reduce_window_max(cls, harness):
assert "max" == harness.params["computation"].__name__
return [
missing_tf_kernel(dtypes=[np.bool_]),
]
@classmethod
def reduce_window_min(cls, harness):
assert "min" == harness.params["computation"].__name__
return [
missing_tf_kernel(dtypes=[np.bool_]),
]
@classmethod
def reduce_window_mul(cls, harness):
assert "mul" == harness.params["computation"].__name__
return []
@classmethod
def regularized_incomplete_beta(cls, harness: primitive_harness.Harness):
return [
@ -1034,29 +944,15 @@ class Jax2TfLimitation(primitive_harness.Limitation):
@classmethod
def scatter_add(cls, harness):
return [
missing_tf_kernel(dtypes=[np.bool_]),
missing_tf_kernel(
dtypes=[np.complex64],
devices="tpu",
)
]
@classmethod
def scatter_max(cls, harness):
return [
missing_tf_kernel(dtypes=[np.bool_]),
]
@classmethod
def scatter_min(cls, harness):
return [
missing_tf_kernel(dtypes=[np.bool_]),
]
@classmethod
def scatter_mul(cls, harness):
return [
missing_tf_kernel(dtypes=[np.bool_],),
missing_tf_kernel(
dtypes=[np.complex64],
devices="tpu",
@ -1078,10 +974,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
devices=("cpu", "gpu"))
]
@classmethod
def select_and_scatter_add(cls, harness):
return []
@classmethod
def sign(cls, harness: primitive_harness.Harness):
return [
@ -1101,13 +993,8 @@ class Jax2TfLimitation(primitive_harness.Limitation):
not harness.params["is_stable"]),
expect_tf_error=False,
skip_comparison=True),
missing_tf_kernel(dtypes=[np.bool_],),
]
@classmethod
def sub(cls, harness):
return []
@classmethod
def svd(cls, harness: primitive_harness.Harness):
# TODO: slow test

View File

@ -1202,7 +1202,11 @@ def _make_scatter_harness(name,
"unimplemented",
devices="tpu",
dtypes=np.complex64,
enabled=(f_lax in [lax.scatter_max, lax.scatter_min]))
enabled=(f_lax in [lax.scatter_max, lax.scatter_min])),
Limitation(
"unimplemented",
dtypes=np.bool_,
enabled=(f_lax in [lax.scatter_add, lax.scatter_mul])),
],
f_lax=f_lax,
shape=shape,
@ -1219,8 +1223,8 @@ for dtype in jtu.dtypes.all:
for f_lax in [
lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min
]:
if f_lax in [lax.scatter_add, lax.scatter_mul] and dtype == np.bool_:
continue
#if f_lax in [lax.scatter_add, lax.scatter_mul] and dtype == np.bool_:
# continue
_make_scatter_harness("dtypes", dtype=dtype, f_lax=f_lax)
# Validate f_lax/update_jaxpr