mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[jax2tf] Support inequality and min/max for booleans.
For inequalities we add casts to int8. For min/max we rewrite to logical operations and/or.
This commit is contained in:
parent
5e3be94d8c
commit
dd8ab85121
@ -11,6 +11,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
## jax 0.2.15 (unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...master).
|
||||
* New features:
|
||||
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans
|
||||
({jax-issue}`#6956`).
|
||||
|
||||
* Breaking changes:
|
||||
* Support for NumPy 1.16 has been dropped, per the
|
||||
|
@ -2907,6 +2907,12 @@ def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
|
||||
return operand.shape
|
||||
|
||||
def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):
|
||||
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
|
||||
if dtypes.issubdtype(old_dtype, np.bool_) or dtypes.issubdtype(old_dtype, np.complexfloating):
|
||||
if old_dtype != new_dtype:
|
||||
raise TypeError(f"`bitcast_convert_type` for operand type ({old_dtype}) cannot have different destination type ({new_dtype})")
|
||||
if np.dtype(old_dtype).itemsize != np.dtype(new_dtype).itemsize:
|
||||
raise TypeError(f"`bitcast_convert_type` for operand type ({old_dtype}) must have destination type ({new_dtype}) of same size.")
|
||||
return new_dtype
|
||||
|
||||
def _bitcast_convert_type_translation_rule(c, operand, *, new_dtype):
|
||||
|
@ -1,10 +1,10 @@
|
||||
# Primitives with limited JAX support
|
||||
|
||||
*Last generated on: 2021-06-04* (YYYY-MM-DD)
|
||||
*Last generated on: 2021-06-12* (YYYY-MM-DD)
|
||||
|
||||
## Supported data types for primitives
|
||||
|
||||
We use a set of 2570 test harnesses to test
|
||||
We use a set of 2604 test harnesses to test
|
||||
the implementation of 121 numeric JAX primitives.
|
||||
We consider a JAX primitive supported for a particular data
|
||||
type if it is supported on at least one device type.
|
||||
@ -77,7 +77,7 @@ be updated.
|
||||
| digamma | 4 | floating | bool, complex, integer |
|
||||
| div | 20 | inexact, integer | bool |
|
||||
| dot_general | 245 | all | |
|
||||
| dynamic_slice | 32 | all | |
|
||||
| dynamic_slice | 64 | all | |
|
||||
| dynamic_update_slice | 21 | all | |
|
||||
| eig | 72 | inexact | bool, integer |
|
||||
| eigh | 36 | inexact | bool, integer |
|
||||
@ -134,10 +134,10 @@ be updated.
|
||||
| rev | 19 | all | |
|
||||
| round | 7 | floating | bool, complex, integer |
|
||||
| rsqrt | 6 | inexact | bool, integer |
|
||||
| scatter_add | 14 | inexact, integer | bool |
|
||||
| scatter_add | 15 | all | |
|
||||
| scatter_max | 15 | all | |
|
||||
| scatter_min | 19 | all | |
|
||||
| scatter_mul | 14 | inexact, integer | bool |
|
||||
| scatter_mul | 15 | all | |
|
||||
| select | 16 | all | |
|
||||
| select_and_gather_add | 15 | floating | bool, complex, integer |
|
||||
| select_and_scatter_add | 27 | bool, floating, integer | complex |
|
||||
@ -197,8 +197,10 @@ and search for "limitation".
|
||||
|eigh|unimplemented|bfloat16, float16|cpu, gpu|
|
||||
|lu|unimplemented|bfloat16, float16|cpu, gpu, tpu|
|
||||
|qr|unimplemented|bfloat16, float16|cpu, gpu|
|
||||
|scatter_add|unimplemented|bool|cpu, gpu, tpu|
|
||||
|scatter_max|unimplemented|complex64|tpu|
|
||||
|scatter_min|unimplemented|complex64|tpu|
|
||||
|scatter_mul|unimplemented|bool|cpu, gpu, tpu|
|
||||
|select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu|
|
||||
|svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu|
|
||||
|svd|unimplemented|bfloat16, float16|cpu, gpu|
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Primitives with limited support for jax2tf
|
||||
|
||||
*Last generated on (YYYY-MM-DD): 2021-06-04*
|
||||
*Last generated on (YYYY-MM-DD): 2021-06-12*
|
||||
|
||||
This document summarizes known limitations of the jax2tf conversion.
|
||||
There are several kinds of limitations.
|
||||
@ -33,7 +33,7 @@ The errors apply only for certain devices and compilation modes ("eager",
|
||||
On TPU only the "compiled" mode is relevant.
|
||||
|
||||
Our priority is to ensure same coverage and numerical behavior with JAX
|
||||
in the "compiled" mode, **when using XLA to compile the converted program**.
|
||||
in the "compiled" mode, i.e., **when using XLA to compile the converted program**.
|
||||
We are pretty close to that goal.
|
||||
|
||||
This table only shows errors for cases that are working in JAX (see [separate
|
||||
@ -60,20 +60,15 @@ More detailed information can be found in the
|
||||
| --- | --- | --- | --- | --- |
|
||||
| bessel_i0e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| bessel_i1e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| bitcast_convert_type | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph |
|
||||
| cholesky | TF error: function not compilable | complex | cpu, gpu | compiled |
|
||||
| cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph |
|
||||
| clamp | TF test skipped: Not implemented in JAX: unimplemented | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| clamp | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | int16, int32, int8 | gpu | compiled, eager, graph |
|
||||
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph |
|
||||
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=f64 not implemented | bfloat16, float16, float32 | tpu | compiled, eager, graph |
|
||||
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=i64 not implemented | int16, int32, int8 | tpu | compiled, eager, graph |
|
||||
| conv_general_dilated | TF error: jax2tf BUG: batch_group_count > 1 not yet converted | all | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| cummax | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| cummin | TF error: op not defined for dtype | uint64 | cpu, gpu | eager |
|
||||
| cummin | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| digamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| div | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| dot_general | TF error: Numeric comparision disabled: Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598) | bfloat16, complex64, float16, float32 | gpu | compiled |
|
||||
@ -92,42 +87,31 @@ More detailed information can be found in the
|
||||
| erf_inv | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
|
||||
| erfc | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| fft | TF error: TF function not compileable | complex128, float64 | cpu, gpu | compiled |
|
||||
| ge | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| gt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| igamma | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
|
||||
| igammac | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
|
||||
| integer_pow | TF error: op not defined for dtype | int16, int8, unsigned | cpu, gpu | graph |
|
||||
| le | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| lgamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| lt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| lu | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| lu | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||
| max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| neg | TF error: op not defined for dtype | unsigned | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| nextafter | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| qr | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
|
||||
| qr | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
|
||||
| reduce_max | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_min | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_window_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_window_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| regularized_incomplete_beta | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| rem | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| round | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| rsqrt | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| scatter_add | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_add | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||
| scatter_max | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
||||
| scatter_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_min | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
||||
| scatter_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_mul | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_mul | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_mul | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||
| select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph |
|
||||
| select_and_scatter_add | TF test skipped: Not implemented in JAX: works only for 2 or more inactive dimensions | all | tpu | compiled, eager, graph |
|
||||
| sign | TF error: sign not defined for unsigned integers | unsigned | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| sort | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| svd | TF test skipped: Not implemented in JAX: complex not implemented. Works in JAX for CPU and GPU with custom kernels | complex | tpu | compiled, eager, graph |
|
||||
| svd | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
|
||||
| svd | TF error: function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint` | complex | cpu, gpu | compiled |
|
||||
|
@ -33,7 +33,7 @@ The errors apply only for certain devices and compilation modes ("eager",
|
||||
On TPU only the "compiled" mode is relevant.
|
||||
|
||||
Our priority is to ensure same coverage and numerical behavior with JAX
|
||||
in the "compiled" mode, **when using XLA to compile the converted program**.
|
||||
in the "compiled" mode, i.e., **when using XLA to compile the converted program**.
|
||||
We are pretty close to that goal.
|
||||
|
||||
This table only shows errors for cases that are working in JAX (see [separate
|
||||
|
@ -1162,6 +1162,8 @@ def _minmax(x: TfVal, y: TfVal, *, is_min: bool,
|
||||
partial(lax._minmax_complex_lowering,
|
||||
lax_cmp_pick_x=lax.lt if is_min else lax.gt),
|
||||
multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval)
|
||||
elif x.dtype.as_numpy_dtype == np.bool_:
|
||||
return (tf.math.logical_and if is_min else tf.math.logical_or)(x, y)
|
||||
else:
|
||||
return (tf.math.minimum if is_min else tf.math.maximum)(x, y)
|
||||
|
||||
@ -1287,19 +1289,37 @@ def _not(x):
|
||||
tf_impl[lax.not_p] = _not
|
||||
|
||||
|
||||
def bool_to_int8(f, argnums):
|
||||
"""Computes bool valued functions using int8."""
|
||||
def bool_to_int8(f, argnums: Sequence[int]):
|
||||
"""Computes functions with some bool args and bool results using int8.
|
||||
|
||||
This is needed because some TF ops do not work for bool args, e.g.,
|
||||
inequalities, min/max.
|
||||
|
||||
Args:
|
||||
f: a TF callable to wrap. It will be called with non-boolean arguments.
|
||||
argnums: the positional arguments that may be booleans.
|
||||
|
||||
Returns: a TF callable that can take a mix of boolean positional arguments
|
||||
(in the positions specified by `argnums`) and some non-boolean positional
|
||||
arguments. If there are no boolean arguments, just calls `f`. Otherwise,
|
||||
casts the boolean arguments to `int8`, calls `f`, then casts the result to
|
||||
`bool`.
|
||||
"""
|
||||
argnums = tf.nest.flatten(argnums)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if not any(args[i].dtype == tf.bool for i in argnums):
|
||||
def wrapper(*args: TfVal, **kwargs):
|
||||
argnum_types = {args[i].dtype for i in argnums}
|
||||
if tf.bool not in argnum_types:
|
||||
return f(*args, **kwargs)
|
||||
else:
|
||||
# All argnums should be boolean
|
||||
assert len(argnum_types) == 1, argnum_types
|
||||
args_cast = [(tf.cast(a, tf.int8) if i in argnums else a)
|
||||
for i, a in enumerate(args)]
|
||||
if "_in_avals" in kwargs:
|
||||
|
||||
def cast_aval(aval):
|
||||
assert aval.dtype == np.bool_
|
||||
return core.ShapedArray(aval.shape, np.int8)
|
||||
|
||||
_in_avals_cast = [
|
||||
@ -1321,10 +1341,11 @@ tf_impl[lax.xor_p] = bool_to_int8(tf.bitwise.bitwise_xor, argnums=(0, 1))
|
||||
|
||||
tf_impl[lax.eq_p] = tf.math.equal
|
||||
tf_impl[lax.ne_p] = tf.math.not_equal
|
||||
tf_impl[lax.ge_p] = tf.math.greater_equal
|
||||
tf_impl[lax.gt_p] = tf.math.greater
|
||||
tf_impl[lax.le_p] = tf.math.less_equal
|
||||
tf_impl[lax.lt_p] = tf.math.less
|
||||
|
||||
tf_impl[lax.ge_p] = bool_to_int8(tf.math.greater_equal, argnums=(0, 1))
|
||||
tf_impl[lax.gt_p] = bool_to_int8(tf.math.greater, argnums=(0, 1))
|
||||
tf_impl[lax.le_p] = bool_to_int8(tf.math.less_equal, argnums=(0, 1))
|
||||
tf_impl[lax.lt_p] = bool_to_int8(tf.math.less, argnums=(0, 1))
|
||||
|
||||
tf_impl[lax_linalg.cholesky_p] = tf.linalg.cholesky
|
||||
|
||||
@ -1346,6 +1367,8 @@ tf_impl[lax.convert_element_type_p] = _convert_element_type
|
||||
|
||||
|
||||
def _bitcast_convert_type(operand, new_dtype):
|
||||
if operand.dtype == new_dtype:
|
||||
return operand
|
||||
return tf.bitcast(operand, _to_tf_dtype(new_dtype))
|
||||
|
||||
|
||||
@ -1767,13 +1790,13 @@ tf_impl[lax.transpose_p] = _transpose
|
||||
axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes)
|
||||
|
||||
tf_impl[lax.reduce_sum_p] = (
|
||||
bool_to_int8(axes_to_axis(tf.reduce_sum), argnums=0))
|
||||
bool_to_int8(axes_to_axis(tf.reduce_sum), argnums=[0]))
|
||||
tf_impl[lax.reduce_prod_p] = (
|
||||
bool_to_int8(axes_to_axis(tf.reduce_prod), argnums=0))
|
||||
bool_to_int8(axes_to_axis(tf.reduce_prod), argnums=[0]))
|
||||
tf_impl[lax.reduce_max_p] = (
|
||||
bool_to_int8(axes_to_axis(tf.reduce_max), argnums=0))
|
||||
bool_to_int8(axes_to_axis(tf.reduce_max), argnums=[0]))
|
||||
tf_impl[lax.reduce_min_p] = (
|
||||
bool_to_int8(axes_to_axis(tf.reduce_min), argnums=0))
|
||||
bool_to_int8(axes_to_axis(tf.reduce_min), argnums=[0]))
|
||||
tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any)
|
||||
tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all)
|
||||
|
||||
@ -2178,7 +2201,7 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers):
|
||||
return proto
|
||||
|
||||
|
||||
@partial(bool_to_int8, argnums=0)
|
||||
@partial(bool_to_int8, argnums=[0])
|
||||
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
|
||||
indices_are_sorted, unique_indices,
|
||||
_in_avals, _out_aval):
|
||||
@ -2446,25 +2469,6 @@ def _sort(*operands: TfVal, dimension: int, is_stable: bool,
|
||||
operands[0].shape
|
||||
), f"Invalid {dimension} for ndim {len(operands[0].shape)}"
|
||||
|
||||
# The comparator is a 2N-argument TF function, with arguments [2k] and [2k +1]
|
||||
# corresponding to two scalars from operand[k].
|
||||
def lexicographic_comparator_old(*tf_args: TfVal) -> TfVal:
|
||||
assert len(tf_args) == 2 * len(operands)
|
||||
# We build a comparison:
|
||||
# arg[0] < arg[1] or (arg[0] == arg[1] and (arg[2] < arg[3] or ...))
|
||||
# all the way to arg[2 * num_keys - 2] < arg[2 * num_keys - 1]
|
||||
inside_comparison = None
|
||||
for key_idx in range(num_keys - 1, -1, -1):
|
||||
a = tf_args[2 * key_idx]
|
||||
b = tf_args[2 * key_idx + 1]
|
||||
a_lt_b = tf.math.less(a, b)
|
||||
if inside_comparison is None:
|
||||
inside_comparison = a_lt_b
|
||||
else:
|
||||
inside_comparison = tf.math.logical_or(
|
||||
a_lt_b, tf.math.logical_and(tf.math.equal(a, b), inside_comparison))
|
||||
return inside_comparison
|
||||
|
||||
comparator_spec: List[tf.TensorSpec] = []
|
||||
comparator_jax_in_avals: List[core.AbstractValue] = []
|
||||
for op in operands:
|
||||
|
@ -109,13 +109,13 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
group_method = getattr(cls, harness.group_name, None)
|
||||
if harness.group_name in cls.harness_groups_no_limitations:
|
||||
assert group_method is None, (
|
||||
f"Harness group {harness.group_name} is both in "
|
||||
f"Harness group '{harness.group_name}' is both in "
|
||||
f"'harness_groups_no_limitations' and has a custom "
|
||||
f"Jax2TfLimitation.classmethod defined (see module docstring)")
|
||||
return []
|
||||
else:
|
||||
assert group_method is not None, (
|
||||
f"Harness group {harness.group_name} must be either part of "
|
||||
f"Harness group '{harness.group_name}' must be either part of "
|
||||
f"'harness_groups_no_limitations' or must have a custom "
|
||||
f"Jax2TfLimitation.classmethod defined (see module docstring)")
|
||||
limitations = group_method(harness)
|
||||
@ -124,16 +124,19 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
|
||||
# We keep here the explicit set of groups for which we don't have limitations
|
||||
harness_groups_no_limitations = {
|
||||
"abs", "and", "argmin", "argmax", "atan2", "broadcast",
|
||||
"broadcast_in_dim", "ceil",
|
||||
"concatenate", "cos", "cosh", "complex", "conj",
|
||||
"device_put", "dynamic_slice",
|
||||
"dynamic_update_slice", "exp", "eq", "floor", "log", "gather", "imag",
|
||||
"iota", "is_finite", "ne", "not", "or", "pad", "random_split",
|
||||
"reduce_and", "reduce_prod", "reduce_or", "reduce_sum", "real", "reshape",
|
||||
"rev",
|
||||
"select", "shift_left", "shift_right_logical", "shift_right_arithmetic",
|
||||
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient",
|
||||
"abs", "add", "add_any", "and", "argmin", "argmax", "atan2",
|
||||
"bitcast_convert_type", "broadcast", "broadcast_in_dim", "ceil", "clamp",
|
||||
"concatenate", "cos", "cosh", "complex", "conj", "convert_element_type",
|
||||
"cummax", "cummin", "device_put", "dynamic_slice",
|
||||
"dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt", "imag",
|
||||
"iota", "is_finite", "le", "lt", "log", "mul", "ne", "not", "or", "pad",
|
||||
"population_count", "random_split",
|
||||
"reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
|
||||
"reduce_window_add", "reduce_window_mul", "reduce_window_min", "reduce_window_max",
|
||||
"real", "reshape", "rev", "scatter_max", "scatter_min",
|
||||
"select", "select_and_scatter_add",
|
||||
"shift_left", "shift_right_logical", "shift_right_arithmetic",
|
||||
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "sub",
|
||||
"tie_in", "transpose", "xor", "zeros_like"
|
||||
}
|
||||
|
||||
@ -173,15 +176,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
cls.helper_get_trig_custom_limitation(np.cosh)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def add(cls, harness: primitive_harness.Harness):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
# Also called add_jaxvals
|
||||
def add_any(cls, harness: primitive_harness.Harness):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def asin(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
@ -228,10 +222,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
def bessel_i1e(cls, harness: primitive_harness.Harness):
|
||||
return cls.bessel_i0e(harness)
|
||||
|
||||
@classmethod
|
||||
def bitcast_convert_type(cls, harness: primitive_harness.Harness):
|
||||
return [missing_tf_kernel(dtypes=[np.bool_])]
|
||||
|
||||
@classmethod
|
||||
def cholesky(cls, harness: primitive_harness.Harness):
|
||||
|
||||
@ -280,16 +270,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
modes=("eager", "graph", "compiled"))
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def clamp(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.complex64, np.complex128]),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def convert_element_type(cls, harness: primitive_harness.Harness):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def conv_general_dilated(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
@ -307,22 +287,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
tol=5e-3)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def cummax(cls, harness):
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def cummin(cls, harness):
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
|
||||
# TODO: we get jax2tf AssertionError
|
||||
missing_tf_kernel(dtypes=[np.uint64],
|
||||
devices=("cpu", "gpu"),
|
||||
modes=("eager",)),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def cumprod(cls, harness):
|
||||
return [
|
||||
@ -421,9 +385,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
@classmethod
|
||||
def dot_general(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[
|
||||
np.bool_,
|
||||
],),
|
||||
missing_tf_kernel(dtypes=[np.bool_],),
|
||||
# TODO(b/189287598)
|
||||
Jax2TfLimitation(
|
||||
"Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598)",
|
||||
@ -583,14 +545,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
modes=("eager", "graph", "compiled"))
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def ge(cls, harness: primitive_harness.Harness):
|
||||
return [missing_tf_kernel(dtypes=[np.bool_])]
|
||||
|
||||
@classmethod
|
||||
def gt(cls, harness: primitive_harness.Harness):
|
||||
return cls.ge(harness)
|
||||
|
||||
@classmethod
|
||||
def erf(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
@ -799,14 +753,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
def pow(cls, harness: primitive_harness.Harness):
|
||||
return cls._pow_test_util(harness)
|
||||
|
||||
@classmethod
|
||||
def le(cls, harness: primitive_harness.Harness):
|
||||
return cls.ge(harness)
|
||||
|
||||
@classmethod
|
||||
def lt(cls, harness: primitive_harness.Harness):
|
||||
return cls.ge(harness)
|
||||
|
||||
@classmethod
|
||||
def lgamma(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
@ -881,8 +827,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
|
||||
|
||||
return [
|
||||
missing_tf_kernel(
|
||||
dtypes=[np.bool_]),
|
||||
custom_numeric(
|
||||
custom_assert=custom_assert,
|
||||
description=(
|
||||
@ -901,7 +845,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
|
||||
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_]),
|
||||
custom_numeric(
|
||||
custom_assert=custom_assert,
|
||||
description=(
|
||||
@ -911,10 +854,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
modes=("eager", "graph", "compiled"))
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def mul(cls, harness: primitive_harness.Harness):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def neg(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
@ -925,10 +864,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
def nextafter(cls, harness: primitive_harness.Harness):
|
||||
return [missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16])]
|
||||
|
||||
@classmethod
|
||||
def population_count(cls, harness: primitive_harness.Harness):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def qr(cls, harness: primitive_harness.Harness):
|
||||
# See https://github.com/google/jax/pull/3775#issuecomment-659407824;
|
||||
@ -960,39 +895,14 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
|
||||
@classmethod
|
||||
def reduce_max(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.complex64, np.complex128]),
|
||||
]
|
||||
# Unlike reduce_window_max, we use a native TF op: tf.reduce_max, which
|
||||
# does not work for complex
|
||||
return [missing_tf_kernel(dtypes=[np.complex64, np.complex128])]
|
||||
|
||||
@classmethod
|
||||
def reduce_min(cls, harness: primitive_harness.Harness):
|
||||
return cls.reduce_max(harness)
|
||||
|
||||
|
||||
@classmethod
|
||||
def reduce_window_add(cls, harness):
|
||||
assert "add" == harness.params["computation"].__name__
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def reduce_window_max(cls, harness):
|
||||
assert "max" == harness.params["computation"].__name__
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_]),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def reduce_window_min(cls, harness):
|
||||
assert "min" == harness.params["computation"].__name__
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_]),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def reduce_window_mul(cls, harness):
|
||||
assert "mul" == harness.params["computation"].__name__
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def regularized_incomplete_beta(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
@ -1034,29 +944,15 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
@classmethod
|
||||
def scatter_add(cls, harness):
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_]),
|
||||
missing_tf_kernel(
|
||||
dtypes=[np.complex64],
|
||||
devices="tpu",
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def scatter_max(cls, harness):
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_]),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def scatter_min(cls, harness):
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_]),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def scatter_mul(cls, harness):
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_],),
|
||||
missing_tf_kernel(
|
||||
dtypes=[np.complex64],
|
||||
devices="tpu",
|
||||
@ -1078,10 +974,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
devices=("cpu", "gpu"))
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def select_and_scatter_add(cls, harness):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def sign(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
@ -1101,13 +993,8 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
not harness.params["is_stable"]),
|
||||
expect_tf_error=False,
|
||||
skip_comparison=True),
|
||||
missing_tf_kernel(dtypes=[np.bool_],),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def sub(cls, harness):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def svd(cls, harness: primitive_harness.Harness):
|
||||
# TODO: slow test
|
||||
|
@ -1202,7 +1202,11 @@ def _make_scatter_harness(name,
|
||||
"unimplemented",
|
||||
devices="tpu",
|
||||
dtypes=np.complex64,
|
||||
enabled=(f_lax in [lax.scatter_max, lax.scatter_min]))
|
||||
enabled=(f_lax in [lax.scatter_max, lax.scatter_min])),
|
||||
Limitation(
|
||||
"unimplemented",
|
||||
dtypes=np.bool_,
|
||||
enabled=(f_lax in [lax.scatter_add, lax.scatter_mul])),
|
||||
],
|
||||
f_lax=f_lax,
|
||||
shape=shape,
|
||||
@ -1219,8 +1223,8 @@ for dtype in jtu.dtypes.all:
|
||||
for f_lax in [
|
||||
lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min
|
||||
]:
|
||||
if f_lax in [lax.scatter_add, lax.scatter_mul] and dtype == np.bool_:
|
||||
continue
|
||||
#if f_lax in [lax.scatter_add, lax.scatter_mul] and dtype == np.bool_:
|
||||
# continue
|
||||
_make_scatter_harness("dtypes", dtype=dtype, f_lax=f_lax)
|
||||
|
||||
# Validate f_lax/update_jaxpr
|
||||
|
Loading…
x
Reference in New Issue
Block a user