[jax2tf] Support inequality and min/max for booleans.

For inequalities we add casts to int8. For min/max we rewrite
to logical operations and/or.
This commit is contained in:
George Necula 2021-06-10 12:42:40 +02:00
parent 5e3be94d8c
commit dd8ab85121
8 changed files with 82 additions and 193 deletions

View File

@ -11,6 +11,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.15 (unreleased) ## jax 0.2.15 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...master). * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...master).
* New features: * New features:
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans
({jax-issue}`#6956`).
* Breaking changes: * Breaking changes:
* Support for NumPy 1.16 has been dropped, per the * Support for NumPy 1.16 has been dropped, per the

View File

@ -2907,6 +2907,12 @@ def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
return operand.shape return operand.shape
def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
if dtypes.issubdtype(old_dtype, np.bool_) or dtypes.issubdtype(old_dtype, np.complexfloating):
if old_dtype != new_dtype:
raise TypeError(f"`bitcast_convert_type` for operand type ({old_dtype}) cannot have different destination type ({new_dtype})")
if np.dtype(old_dtype).itemsize != np.dtype(new_dtype).itemsize:
raise TypeError(f"`bitcast_convert_type` for operand type ({old_dtype}) must have destination type ({new_dtype}) of same size.")
return new_dtype return new_dtype
def _bitcast_convert_type_translation_rule(c, operand, *, new_dtype): def _bitcast_convert_type_translation_rule(c, operand, *, new_dtype):

View File

@ -1,10 +1,10 @@
# Primitives with limited JAX support # Primitives with limited JAX support
*Last generated on: 2021-06-04* (YYYY-MM-DD) *Last generated on: 2021-06-12* (YYYY-MM-DD)
## Supported data types for primitives ## Supported data types for primitives
We use a set of 2570 test harnesses to test We use a set of 2604 test harnesses to test
the implementation of 121 numeric JAX primitives. the implementation of 121 numeric JAX primitives.
We consider a JAX primitive supported for a particular data We consider a JAX primitive supported for a particular data
type if it is supported on at least one device type. type if it is supported on at least one device type.
@ -77,7 +77,7 @@ be updated.
| digamma | 4 | floating | bool, complex, integer | | digamma | 4 | floating | bool, complex, integer |
| div | 20 | inexact, integer | bool | | div | 20 | inexact, integer | bool |
| dot_general | 245 | all | | | dot_general | 245 | all | |
| dynamic_slice | 32 | all | | | dynamic_slice | 64 | all | |
| dynamic_update_slice | 21 | all | | | dynamic_update_slice | 21 | all | |
| eig | 72 | inexact | bool, integer | | eig | 72 | inexact | bool, integer |
| eigh | 36 | inexact | bool, integer | | eigh | 36 | inexact | bool, integer |
@ -134,10 +134,10 @@ be updated.
| rev | 19 | all | | | rev | 19 | all | |
| round | 7 | floating | bool, complex, integer | | round | 7 | floating | bool, complex, integer |
| rsqrt | 6 | inexact | bool, integer | | rsqrt | 6 | inexact | bool, integer |
| scatter_add | 14 | inexact, integer | bool | | scatter_add | 15 | all | |
| scatter_max | 15 | all | | | scatter_max | 15 | all | |
| scatter_min | 19 | all | | | scatter_min | 19 | all | |
| scatter_mul | 14 | inexact, integer | bool | | scatter_mul | 15 | all | |
| select | 16 | all | | | select | 16 | all | |
| select_and_gather_add | 15 | floating | bool, complex, integer | | select_and_gather_add | 15 | floating | bool, complex, integer |
| select_and_scatter_add | 27 | bool, floating, integer | complex | | select_and_scatter_add | 27 | bool, floating, integer | complex |
@ -197,8 +197,10 @@ and search for "limitation".
|eigh|unimplemented|bfloat16, float16|cpu, gpu| |eigh|unimplemented|bfloat16, float16|cpu, gpu|
|lu|unimplemented|bfloat16, float16|cpu, gpu, tpu| |lu|unimplemented|bfloat16, float16|cpu, gpu, tpu|
|qr|unimplemented|bfloat16, float16|cpu, gpu| |qr|unimplemented|bfloat16, float16|cpu, gpu|
|scatter_add|unimplemented|bool|cpu, gpu, tpu|
|scatter_max|unimplemented|complex64|tpu| |scatter_max|unimplemented|complex64|tpu|
|scatter_min|unimplemented|complex64|tpu| |scatter_min|unimplemented|complex64|tpu|
|scatter_mul|unimplemented|bool|cpu, gpu, tpu|
|select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu| |select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu|
|svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu| |svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu|
|svd|unimplemented|bfloat16, float16|cpu, gpu| |svd|unimplemented|bfloat16, float16|cpu, gpu|

View File

@ -1,6 +1,6 @@
# Primitives with limited support for jax2tf # Primitives with limited support for jax2tf
*Last generated on (YYYY-MM-DD): 2021-06-04* *Last generated on (YYYY-MM-DD): 2021-06-12*
This document summarizes known limitations of the jax2tf conversion. This document summarizes known limitations of the jax2tf conversion.
There are several kinds of limitations. There are several kinds of limitations.
@ -33,7 +33,7 @@ The errors apply only for certain devices and compilation modes ("eager",
On TPU only the "compiled" mode is relevant. On TPU only the "compiled" mode is relevant.
Our priority is to ensure same coverage and numerical behavior with JAX Our priority is to ensure same coverage and numerical behavior with JAX
in the "compiled" mode, **when using XLA to compile the converted program**. in the "compiled" mode, i.e., **when using XLA to compile the converted program**.
We are pretty close to that goal. We are pretty close to that goal.
This table only shows errors for cases that are working in JAX (see [separate This table only shows errors for cases that are working in JAX (see [separate
@ -60,20 +60,15 @@ More detailed information can be found in the
| --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- |
| bessel_i0e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | | bessel_i0e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| bessel_i1e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | | bessel_i1e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| bitcast_convert_type | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph | | cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph |
| cholesky | TF error: function not compilable | complex | cpu, gpu | compiled | | cholesky | TF error: function not compilable | complex | cpu, gpu | compiled |
| cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph | | cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph |
| clamp | TF test skipped: Not implemented in JAX: unimplemented | bool, complex | cpu, gpu, tpu | compiled, eager, graph | | clamp | TF test skipped: Not implemented in JAX: unimplemented | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| clamp | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | int16, int32, int8 | gpu | compiled, eager, graph | | conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | int16, int32, int8 | gpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph | | conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=f64 not implemented | bfloat16, float16, float32 | tpu | compiled, eager, graph | | conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=f64 not implemented | bfloat16, float16, float32 | tpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=i64 not implemented | int16, int32, int8 | tpu | compiled, eager, graph | | conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=i64 not implemented | int16, int32, int8 | tpu | compiled, eager, graph |
| conv_general_dilated | TF error: jax2tf BUG: batch_group_count > 1 not yet converted | all | cpu, gpu, tpu | compiled, eager, graph | | conv_general_dilated | TF error: jax2tf BUG: batch_group_count > 1 not yet converted | all | cpu, gpu, tpu | compiled, eager, graph |
| cummax | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| cummin | TF error: op not defined for dtype | uint64 | cpu, gpu | eager |
| cummin | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| digamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | | digamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| div | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph | | div | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| dot_general | TF error: Numeric comparision disabled: Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598) | bfloat16, complex64, float16, float32 | gpu | compiled | | dot_general | TF error: Numeric comparision disabled: Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598) | bfloat16, complex64, float16, float32 | gpu | compiled |
@ -92,42 +87,31 @@ More detailed information can be found in the
| erf_inv | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph | | erf_inv | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
| erfc | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | | erfc | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| fft | TF error: TF function not compileable | complex128, float64 | cpu, gpu | compiled | | fft | TF error: TF function not compileable | complex128, float64 | cpu, gpu | compiled |
| ge | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| gt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| igamma | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph | | igamma | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
| igammac | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph | | igammac | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
| integer_pow | TF error: op not defined for dtype | int16, int8, unsigned | cpu, gpu | graph | | integer_pow | TF error: op not defined for dtype | int16, int8, unsigned | cpu, gpu | graph |
| le | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| lgamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | | lgamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| lt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| lu | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph | | lu | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| lu | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph | | lu | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| neg | TF error: op not defined for dtype | unsigned | cpu, gpu, tpu | compiled, eager, graph | | neg | TF error: op not defined for dtype | unsigned | cpu, gpu, tpu | compiled, eager, graph |
| nextafter | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph | | nextafter | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| qr | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph | | qr | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
| qr | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph | | qr | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
| reduce_max | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph | | reduce_max | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
| reduce_min | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph | | reduce_min | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
| reduce_window_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| reduce_window_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| regularized_incomplete_beta | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph | | regularized_incomplete_beta | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| rem | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph | | rem | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| round | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | | round | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| rsqrt | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | | rsqrt | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| scatter_add | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph | | scatter_add | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph | | scatter_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| scatter_max | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph | | scatter_max | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| scatter_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_min | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph | | scatter_min | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| scatter_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph | | scatter_mul | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_mul | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_mul | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph | | scatter_mul | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph | | select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph |
| select_and_scatter_add | TF test skipped: Not implemented in JAX: works only for 2 or more inactive dimensions | all | tpu | compiled, eager, graph | | select_and_scatter_add | TF test skipped: Not implemented in JAX: works only for 2 or more inactive dimensions | all | tpu | compiled, eager, graph |
| sign | TF error: sign not defined for unsigned integers | unsigned | cpu, gpu, tpu | compiled, eager, graph | | sign | TF error: sign not defined for unsigned integers | unsigned | cpu, gpu, tpu | compiled, eager, graph |
| sort | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| svd | TF test skipped: Not implemented in JAX: complex not implemented. Works in JAX for CPU and GPU with custom kernels | complex | tpu | compiled, eager, graph | | svd | TF test skipped: Not implemented in JAX: complex not implemented. Works in JAX for CPU and GPU with custom kernels | complex | tpu | compiled, eager, graph |
| svd | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph | | svd | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
| svd | TF error: function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint` | complex | cpu, gpu | compiled | | svd | TF error: function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint` | complex | cpu, gpu | compiled |

View File

@ -33,7 +33,7 @@ The errors apply only for certain devices and compilation modes ("eager",
On TPU only the "compiled" mode is relevant. On TPU only the "compiled" mode is relevant.
Our priority is to ensure same coverage and numerical behavior with JAX Our priority is to ensure same coverage and numerical behavior with JAX
in the "compiled" mode, **when using XLA to compile the converted program**. in the "compiled" mode, i.e., **when using XLA to compile the converted program**.
We are pretty close to that goal. We are pretty close to that goal.
This table only shows errors for cases that are working in JAX (see [separate This table only shows errors for cases that are working in JAX (see [separate

View File

@ -1162,6 +1162,8 @@ def _minmax(x: TfVal, y: TfVal, *, is_min: bool,
partial(lax._minmax_complex_lowering, partial(lax._minmax_complex_lowering,
lax_cmp_pick_x=lax.lt if is_min else lax.gt), lax_cmp_pick_x=lax.lt if is_min else lax.gt),
multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval) multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval)
elif x.dtype.as_numpy_dtype == np.bool_:
return (tf.math.logical_and if is_min else tf.math.logical_or)(x, y)
else: else:
return (tf.math.minimum if is_min else tf.math.maximum)(x, y) return (tf.math.minimum if is_min else tf.math.maximum)(x, y)
@ -1287,19 +1289,37 @@ def _not(x):
tf_impl[lax.not_p] = _not tf_impl[lax.not_p] = _not
def bool_to_int8(f, argnums): def bool_to_int8(f, argnums: Sequence[int]):
"""Computes bool valued functions using int8.""" """Computes functions with some bool args and bool results using int8.
This is needed because some TF ops do not work for bool args, e.g.,
inequalities, min/max.
Args:
f: a TF callable to wrap. It will be called with non-boolean arguments.
argnums: the positional arguments that may be booleans.
Returns: a TF callable that can take a mix of boolean positional arguments
(in the positions specified by `argnums`) and some non-boolean positional
arguments. If there are no boolean arguments, just calls `f`. Otherwise,
casts the boolean arguments to `int8`, calls `f`, then casts the result to
`bool`.
"""
argnums = tf.nest.flatten(argnums) argnums = tf.nest.flatten(argnums)
def wrapper(*args, **kwargs): def wrapper(*args: TfVal, **kwargs):
if not any(args[i].dtype == tf.bool for i in argnums): argnum_types = {args[i].dtype for i in argnums}
if tf.bool not in argnum_types:
return f(*args, **kwargs) return f(*args, **kwargs)
else: else:
# All argnums should be boolean
assert len(argnum_types) == 1, argnum_types
args_cast = [(tf.cast(a, tf.int8) if i in argnums else a) args_cast = [(tf.cast(a, tf.int8) if i in argnums else a)
for i, a in enumerate(args)] for i, a in enumerate(args)]
if "_in_avals" in kwargs: if "_in_avals" in kwargs:
def cast_aval(aval): def cast_aval(aval):
assert aval.dtype == np.bool_
return core.ShapedArray(aval.shape, np.int8) return core.ShapedArray(aval.shape, np.int8)
_in_avals_cast = [ _in_avals_cast = [
@ -1321,10 +1341,11 @@ tf_impl[lax.xor_p] = bool_to_int8(tf.bitwise.bitwise_xor, argnums=(0, 1))
tf_impl[lax.eq_p] = tf.math.equal tf_impl[lax.eq_p] = tf.math.equal
tf_impl[lax.ne_p] = tf.math.not_equal tf_impl[lax.ne_p] = tf.math.not_equal
tf_impl[lax.ge_p] = tf.math.greater_equal
tf_impl[lax.gt_p] = tf.math.greater tf_impl[lax.ge_p] = bool_to_int8(tf.math.greater_equal, argnums=(0, 1))
tf_impl[lax.le_p] = tf.math.less_equal tf_impl[lax.gt_p] = bool_to_int8(tf.math.greater, argnums=(0, 1))
tf_impl[lax.lt_p] = tf.math.less tf_impl[lax.le_p] = bool_to_int8(tf.math.less_equal, argnums=(0, 1))
tf_impl[lax.lt_p] = bool_to_int8(tf.math.less, argnums=(0, 1))
tf_impl[lax_linalg.cholesky_p] = tf.linalg.cholesky tf_impl[lax_linalg.cholesky_p] = tf.linalg.cholesky
@ -1346,6 +1367,8 @@ tf_impl[lax.convert_element_type_p] = _convert_element_type
def _bitcast_convert_type(operand, new_dtype): def _bitcast_convert_type(operand, new_dtype):
if operand.dtype == new_dtype:
return operand
return tf.bitcast(operand, _to_tf_dtype(new_dtype)) return tf.bitcast(operand, _to_tf_dtype(new_dtype))
@ -1767,13 +1790,13 @@ tf_impl[lax.transpose_p] = _transpose
axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes) axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes)
tf_impl[lax.reduce_sum_p] = ( tf_impl[lax.reduce_sum_p] = (
bool_to_int8(axes_to_axis(tf.reduce_sum), argnums=0)) bool_to_int8(axes_to_axis(tf.reduce_sum), argnums=[0]))
tf_impl[lax.reduce_prod_p] = ( tf_impl[lax.reduce_prod_p] = (
bool_to_int8(axes_to_axis(tf.reduce_prod), argnums=0)) bool_to_int8(axes_to_axis(tf.reduce_prod), argnums=[0]))
tf_impl[lax.reduce_max_p] = ( tf_impl[lax.reduce_max_p] = (
bool_to_int8(axes_to_axis(tf.reduce_max), argnums=0)) bool_to_int8(axes_to_axis(tf.reduce_max), argnums=[0]))
tf_impl[lax.reduce_min_p] = ( tf_impl[lax.reduce_min_p] = (
bool_to_int8(axes_to_axis(tf.reduce_min), argnums=0)) bool_to_int8(axes_to_axis(tf.reduce_min), argnums=[0]))
tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any) tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any)
tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all) tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all)
@ -2178,7 +2201,7 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers):
return proto return proto
@partial(bool_to_int8, argnums=0) @partial(bool_to_int8, argnums=[0])
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes, def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
indices_are_sorted, unique_indices, indices_are_sorted, unique_indices,
_in_avals, _out_aval): _in_avals, _out_aval):
@ -2446,25 +2469,6 @@ def _sort(*operands: TfVal, dimension: int, is_stable: bool,
operands[0].shape operands[0].shape
), f"Invalid {dimension} for ndim {len(operands[0].shape)}" ), f"Invalid {dimension} for ndim {len(operands[0].shape)}"
# The comparator is a 2N-argument TF function, with arguments [2k] and [2k +1]
# corresponding to two scalars from operand[k].
def lexicographic_comparator_old(*tf_args: TfVal) -> TfVal:
assert len(tf_args) == 2 * len(operands)
# We build a comparison:
# arg[0] < arg[1] or (arg[0] == arg[1] and (arg[2] < arg[3] or ...))
# all the way to arg[2 * num_keys - 2] < arg[2 * num_keys - 1]
inside_comparison = None
for key_idx in range(num_keys - 1, -1, -1):
a = tf_args[2 * key_idx]
b = tf_args[2 * key_idx + 1]
a_lt_b = tf.math.less(a, b)
if inside_comparison is None:
inside_comparison = a_lt_b
else:
inside_comparison = tf.math.logical_or(
a_lt_b, tf.math.logical_and(tf.math.equal(a, b), inside_comparison))
return inside_comparison
comparator_spec: List[tf.TensorSpec] = [] comparator_spec: List[tf.TensorSpec] = []
comparator_jax_in_avals: List[core.AbstractValue] = [] comparator_jax_in_avals: List[core.AbstractValue] = []
for op in operands: for op in operands:

View File

@ -109,13 +109,13 @@ class Jax2TfLimitation(primitive_harness.Limitation):
group_method = getattr(cls, harness.group_name, None) group_method = getattr(cls, harness.group_name, None)
if harness.group_name in cls.harness_groups_no_limitations: if harness.group_name in cls.harness_groups_no_limitations:
assert group_method is None, ( assert group_method is None, (
f"Harness group {harness.group_name} is both in " f"Harness group '{harness.group_name}' is both in "
f"'harness_groups_no_limitations' and has a custom " f"'harness_groups_no_limitations' and has a custom "
f"Jax2TfLimitation.classmethod defined (see module docstring)") f"Jax2TfLimitation.classmethod defined (see module docstring)")
return [] return []
else: else:
assert group_method is not None, ( assert group_method is not None, (
f"Harness group {harness.group_name} must be either part of " f"Harness group '{harness.group_name}' must be either part of "
f"'harness_groups_no_limitations' or must have a custom " f"'harness_groups_no_limitations' or must have a custom "
f"Jax2TfLimitation.classmethod defined (see module docstring)") f"Jax2TfLimitation.classmethod defined (see module docstring)")
limitations = group_method(harness) limitations = group_method(harness)
@ -124,16 +124,19 @@ class Jax2TfLimitation(primitive_harness.Limitation):
# We keep here the explicit set of groups for which we don't have limitations # We keep here the explicit set of groups for which we don't have limitations
harness_groups_no_limitations = { harness_groups_no_limitations = {
"abs", "and", "argmin", "argmax", "atan2", "broadcast", "abs", "add", "add_any", "and", "argmin", "argmax", "atan2",
"broadcast_in_dim", "ceil", "bitcast_convert_type", "broadcast", "broadcast_in_dim", "ceil", "clamp",
"concatenate", "cos", "cosh", "complex", "conj", "concatenate", "cos", "cosh", "complex", "conj", "convert_element_type",
"device_put", "dynamic_slice", "cummax", "cummin", "device_put", "dynamic_slice",
"dynamic_update_slice", "exp", "eq", "floor", "log", "gather", "imag", "dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt", "imag",
"iota", "is_finite", "ne", "not", "or", "pad", "random_split", "iota", "is_finite", "le", "lt", "log", "mul", "ne", "not", "or", "pad",
"reduce_and", "reduce_prod", "reduce_or", "reduce_sum", "real", "reshape", "population_count", "random_split",
"rev", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
"select", "shift_left", "shift_right_logical", "shift_right_arithmetic", "reduce_window_add", "reduce_window_mul", "reduce_window_min", "reduce_window_max",
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "real", "reshape", "rev", "scatter_max", "scatter_min",
"select", "select_and_scatter_add",
"shift_left", "shift_right_logical", "shift_right_arithmetic",
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "sub",
"tie_in", "transpose", "xor", "zeros_like" "tie_in", "transpose", "xor", "zeros_like"
} }
@ -173,15 +176,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
cls.helper_get_trig_custom_limitation(np.cosh) cls.helper_get_trig_custom_limitation(np.cosh)
] ]
@classmethod
def add(cls, harness: primitive_harness.Harness):
return []
@classmethod
# Also called add_jaxvals
def add_any(cls, harness: primitive_harness.Harness):
return []
@classmethod @classmethod
def asin(cls, harness: primitive_harness.Harness): def asin(cls, harness: primitive_harness.Harness):
return [ return [
@ -228,10 +222,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
def bessel_i1e(cls, harness: primitive_harness.Harness): def bessel_i1e(cls, harness: primitive_harness.Harness):
return cls.bessel_i0e(harness) return cls.bessel_i0e(harness)
@classmethod
def bitcast_convert_type(cls, harness: primitive_harness.Harness):
return [missing_tf_kernel(dtypes=[np.bool_])]
@classmethod @classmethod
def cholesky(cls, harness: primitive_harness.Harness): def cholesky(cls, harness: primitive_harness.Harness):
@ -280,16 +270,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
modes=("eager", "graph", "compiled")) modes=("eager", "graph", "compiled"))
] ]
@classmethod
def clamp(cls, harness: primitive_harness.Harness):
return [
missing_tf_kernel(dtypes=[np.complex64, np.complex128]),
]
@classmethod
def convert_element_type(cls, harness: primitive_harness.Harness):
return []
@classmethod @classmethod
def conv_general_dilated(cls, harness: primitive_harness.Harness): def conv_general_dilated(cls, harness: primitive_harness.Harness):
return [ return [
@ -307,22 +287,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
tol=5e-3) tol=5e-3)
] ]
@classmethod
def cummax(cls, harness):
return [
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
]
@classmethod
def cummin(cls, harness):
return [
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
# TODO: we get jax2tf AssertionError
missing_tf_kernel(dtypes=[np.uint64],
devices=("cpu", "gpu"),
modes=("eager",)),
]
@classmethod @classmethod
def cumprod(cls, harness): def cumprod(cls, harness):
return [ return [
@ -421,9 +385,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
@classmethod @classmethod
def dot_general(cls, harness: primitive_harness.Harness): def dot_general(cls, harness: primitive_harness.Harness):
return [ return [
missing_tf_kernel(dtypes=[ missing_tf_kernel(dtypes=[np.bool_],),
np.bool_,
],),
# TODO(b/189287598) # TODO(b/189287598)
Jax2TfLimitation( Jax2TfLimitation(
"Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598)", "Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598)",
@ -583,14 +545,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
modes=("eager", "graph", "compiled")) modes=("eager", "graph", "compiled"))
] ]
@classmethod
def ge(cls, harness: primitive_harness.Harness):
return [missing_tf_kernel(dtypes=[np.bool_])]
@classmethod
def gt(cls, harness: primitive_harness.Harness):
return cls.ge(harness)
@classmethod @classmethod
def erf(cls, harness: primitive_harness.Harness): def erf(cls, harness: primitive_harness.Harness):
return [ return [
@ -799,14 +753,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
def pow(cls, harness: primitive_harness.Harness): def pow(cls, harness: primitive_harness.Harness):
return cls._pow_test_util(harness) return cls._pow_test_util(harness)
@classmethod
def le(cls, harness: primitive_harness.Harness):
return cls.ge(harness)
@classmethod
def lt(cls, harness: primitive_harness.Harness):
return cls.ge(harness)
@classmethod @classmethod
def lgamma(cls, harness: primitive_harness.Harness): def lgamma(cls, harness: primitive_harness.Harness):
return [ return [
@ -881,8 +827,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg) tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
return [ return [
missing_tf_kernel(
dtypes=[np.bool_]),
custom_numeric( custom_numeric(
custom_assert=custom_assert, custom_assert=custom_assert,
description=( description=(
@ -901,7 +845,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg) tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
return [ return [
missing_tf_kernel(dtypes=[np.bool_]),
custom_numeric( custom_numeric(
custom_assert=custom_assert, custom_assert=custom_assert,
description=( description=(
@ -911,10 +854,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
modes=("eager", "graph", "compiled")) modes=("eager", "graph", "compiled"))
] ]
@classmethod
def mul(cls, harness: primitive_harness.Harness):
return []
@classmethod @classmethod
def neg(cls, harness: primitive_harness.Harness): def neg(cls, harness: primitive_harness.Harness):
return [ return [
@ -925,10 +864,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
def nextafter(cls, harness: primitive_harness.Harness): def nextafter(cls, harness: primitive_harness.Harness):
return [missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16])] return [missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16])]
@classmethod
def population_count(cls, harness: primitive_harness.Harness):
return []
@classmethod @classmethod
def qr(cls, harness: primitive_harness.Harness): def qr(cls, harness: primitive_harness.Harness):
# See https://github.com/google/jax/pull/3775#issuecomment-659407824; # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
@ -960,39 +895,14 @@ class Jax2TfLimitation(primitive_harness.Limitation):
@classmethod @classmethod
def reduce_max(cls, harness: primitive_harness.Harness): def reduce_max(cls, harness: primitive_harness.Harness):
return [ # Unlike reduce_window_max, we use a native TF op: tf.reduce_max, which
missing_tf_kernel(dtypes=[np.complex64, np.complex128]), # does not work for complex
] return [missing_tf_kernel(dtypes=[np.complex64, np.complex128])]
@classmethod @classmethod
def reduce_min(cls, harness: primitive_harness.Harness): def reduce_min(cls, harness: primitive_harness.Harness):
return cls.reduce_max(harness) return cls.reduce_max(harness)
@classmethod
def reduce_window_add(cls, harness):
assert "add" == harness.params["computation"].__name__
return []
@classmethod
def reduce_window_max(cls, harness):
assert "max" == harness.params["computation"].__name__
return [
missing_tf_kernel(dtypes=[np.bool_]),
]
@classmethod
def reduce_window_min(cls, harness):
assert "min" == harness.params["computation"].__name__
return [
missing_tf_kernel(dtypes=[np.bool_]),
]
@classmethod
def reduce_window_mul(cls, harness):
assert "mul" == harness.params["computation"].__name__
return []
@classmethod @classmethod
def regularized_incomplete_beta(cls, harness: primitive_harness.Harness): def regularized_incomplete_beta(cls, harness: primitive_harness.Harness):
return [ return [
@ -1034,29 +944,15 @@ class Jax2TfLimitation(primitive_harness.Limitation):
@classmethod @classmethod
def scatter_add(cls, harness): def scatter_add(cls, harness):
return [ return [
missing_tf_kernel(dtypes=[np.bool_]),
missing_tf_kernel( missing_tf_kernel(
dtypes=[np.complex64], dtypes=[np.complex64],
devices="tpu", devices="tpu",
) )
] ]
@classmethod
def scatter_max(cls, harness):
return [
missing_tf_kernel(dtypes=[np.bool_]),
]
@classmethod
def scatter_min(cls, harness):
return [
missing_tf_kernel(dtypes=[np.bool_]),
]
@classmethod @classmethod
def scatter_mul(cls, harness): def scatter_mul(cls, harness):
return [ return [
missing_tf_kernel(dtypes=[np.bool_],),
missing_tf_kernel( missing_tf_kernel(
dtypes=[np.complex64], dtypes=[np.complex64],
devices="tpu", devices="tpu",
@ -1078,10 +974,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
devices=("cpu", "gpu")) devices=("cpu", "gpu"))
] ]
@classmethod
def select_and_scatter_add(cls, harness):
return []
@classmethod @classmethod
def sign(cls, harness: primitive_harness.Harness): def sign(cls, harness: primitive_harness.Harness):
return [ return [
@ -1101,13 +993,8 @@ class Jax2TfLimitation(primitive_harness.Limitation):
not harness.params["is_stable"]), not harness.params["is_stable"]),
expect_tf_error=False, expect_tf_error=False,
skip_comparison=True), skip_comparison=True),
missing_tf_kernel(dtypes=[np.bool_],),
] ]
@classmethod
def sub(cls, harness):
return []
@classmethod @classmethod
def svd(cls, harness: primitive_harness.Harness): def svd(cls, harness: primitive_harness.Harness):
# TODO: slow test # TODO: slow test

View File

@ -1202,7 +1202,11 @@ def _make_scatter_harness(name,
"unimplemented", "unimplemented",
devices="tpu", devices="tpu",
dtypes=np.complex64, dtypes=np.complex64,
enabled=(f_lax in [lax.scatter_max, lax.scatter_min])) enabled=(f_lax in [lax.scatter_max, lax.scatter_min])),
Limitation(
"unimplemented",
dtypes=np.bool_,
enabled=(f_lax in [lax.scatter_add, lax.scatter_mul])),
], ],
f_lax=f_lax, f_lax=f_lax,
shape=shape, shape=shape,
@ -1219,8 +1223,8 @@ for dtype in jtu.dtypes.all:
for f_lax in [ for f_lax in [
lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min
]: ]:
if f_lax in [lax.scatter_add, lax.scatter_mul] and dtype == np.bool_: #if f_lax in [lax.scatter_add, lax.scatter_mul] and dtype == np.bool_:
continue # continue
_make_scatter_harness("dtypes", dtype=dtype, f_lax=f_lax) _make_scatter_harness("dtypes", dtype=dtype, f_lax=f_lax)
# Validate f_lax/update_jaxpr # Validate f_lax/update_jaxpr