mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[jax2tf] Support inequality and min/max for booleans.
For inequalities we add casts to int8. For min/max we rewrite to logical operations and/or.
This commit is contained in:
parent
5e3be94d8c
commit
dd8ab85121
@ -11,6 +11,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
## jax 0.2.15 (unreleased)
|
## jax 0.2.15 (unreleased)
|
||||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...master).
|
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...master).
|
||||||
* New features:
|
* New features:
|
||||||
|
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans
|
||||||
|
({jax-issue}`#6956`).
|
||||||
|
|
||||||
* Breaking changes:
|
* Breaking changes:
|
||||||
* Support for NumPy 1.16 has been dropped, per the
|
* Support for NumPy 1.16 has been dropped, per the
|
||||||
|
@ -2907,6 +2907,12 @@ def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
|
|||||||
return operand.shape
|
return operand.shape
|
||||||
|
|
||||||
def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):
|
def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):
|
||||||
|
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
|
||||||
|
if dtypes.issubdtype(old_dtype, np.bool_) or dtypes.issubdtype(old_dtype, np.complexfloating):
|
||||||
|
if old_dtype != new_dtype:
|
||||||
|
raise TypeError(f"`bitcast_convert_type` for operand type ({old_dtype}) cannot have different destination type ({new_dtype})")
|
||||||
|
if np.dtype(old_dtype).itemsize != np.dtype(new_dtype).itemsize:
|
||||||
|
raise TypeError(f"`bitcast_convert_type` for operand type ({old_dtype}) must have destination type ({new_dtype}) of same size.")
|
||||||
return new_dtype
|
return new_dtype
|
||||||
|
|
||||||
def _bitcast_convert_type_translation_rule(c, operand, *, new_dtype):
|
def _bitcast_convert_type_translation_rule(c, operand, *, new_dtype):
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
# Primitives with limited JAX support
|
# Primitives with limited JAX support
|
||||||
|
|
||||||
*Last generated on: 2021-06-04* (YYYY-MM-DD)
|
*Last generated on: 2021-06-12* (YYYY-MM-DD)
|
||||||
|
|
||||||
## Supported data types for primitives
|
## Supported data types for primitives
|
||||||
|
|
||||||
We use a set of 2570 test harnesses to test
|
We use a set of 2604 test harnesses to test
|
||||||
the implementation of 121 numeric JAX primitives.
|
the implementation of 121 numeric JAX primitives.
|
||||||
We consider a JAX primitive supported for a particular data
|
We consider a JAX primitive supported for a particular data
|
||||||
type if it is supported on at least one device type.
|
type if it is supported on at least one device type.
|
||||||
@ -77,7 +77,7 @@ be updated.
|
|||||||
| digamma | 4 | floating | bool, complex, integer |
|
| digamma | 4 | floating | bool, complex, integer |
|
||||||
| div | 20 | inexact, integer | bool |
|
| div | 20 | inexact, integer | bool |
|
||||||
| dot_general | 245 | all | |
|
| dot_general | 245 | all | |
|
||||||
| dynamic_slice | 32 | all | |
|
| dynamic_slice | 64 | all | |
|
||||||
| dynamic_update_slice | 21 | all | |
|
| dynamic_update_slice | 21 | all | |
|
||||||
| eig | 72 | inexact | bool, integer |
|
| eig | 72 | inexact | bool, integer |
|
||||||
| eigh | 36 | inexact | bool, integer |
|
| eigh | 36 | inexact | bool, integer |
|
||||||
@ -134,10 +134,10 @@ be updated.
|
|||||||
| rev | 19 | all | |
|
| rev | 19 | all | |
|
||||||
| round | 7 | floating | bool, complex, integer |
|
| round | 7 | floating | bool, complex, integer |
|
||||||
| rsqrt | 6 | inexact | bool, integer |
|
| rsqrt | 6 | inexact | bool, integer |
|
||||||
| scatter_add | 14 | inexact, integer | bool |
|
| scatter_add | 15 | all | |
|
||||||
| scatter_max | 15 | all | |
|
| scatter_max | 15 | all | |
|
||||||
| scatter_min | 19 | all | |
|
| scatter_min | 19 | all | |
|
||||||
| scatter_mul | 14 | inexact, integer | bool |
|
| scatter_mul | 15 | all | |
|
||||||
| select | 16 | all | |
|
| select | 16 | all | |
|
||||||
| select_and_gather_add | 15 | floating | bool, complex, integer |
|
| select_and_gather_add | 15 | floating | bool, complex, integer |
|
||||||
| select_and_scatter_add | 27 | bool, floating, integer | complex |
|
| select_and_scatter_add | 27 | bool, floating, integer | complex |
|
||||||
@ -197,8 +197,10 @@ and search for "limitation".
|
|||||||
|eigh|unimplemented|bfloat16, float16|cpu, gpu|
|
|eigh|unimplemented|bfloat16, float16|cpu, gpu|
|
||||||
|lu|unimplemented|bfloat16, float16|cpu, gpu, tpu|
|
|lu|unimplemented|bfloat16, float16|cpu, gpu, tpu|
|
||||||
|qr|unimplemented|bfloat16, float16|cpu, gpu|
|
|qr|unimplemented|bfloat16, float16|cpu, gpu|
|
||||||
|
|scatter_add|unimplemented|bool|cpu, gpu, tpu|
|
||||||
|scatter_max|unimplemented|complex64|tpu|
|
|scatter_max|unimplemented|complex64|tpu|
|
||||||
|scatter_min|unimplemented|complex64|tpu|
|
|scatter_min|unimplemented|complex64|tpu|
|
||||||
|
|scatter_mul|unimplemented|bool|cpu, gpu, tpu|
|
||||||
|select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu|
|
|select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu|
|
||||||
|svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu|
|
|svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu|
|
||||||
|svd|unimplemented|bfloat16, float16|cpu, gpu|
|
|svd|unimplemented|bfloat16, float16|cpu, gpu|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Primitives with limited support for jax2tf
|
# Primitives with limited support for jax2tf
|
||||||
|
|
||||||
*Last generated on (YYYY-MM-DD): 2021-06-04*
|
*Last generated on (YYYY-MM-DD): 2021-06-12*
|
||||||
|
|
||||||
This document summarizes known limitations of the jax2tf conversion.
|
This document summarizes known limitations of the jax2tf conversion.
|
||||||
There are several kinds of limitations.
|
There are several kinds of limitations.
|
||||||
@ -33,7 +33,7 @@ The errors apply only for certain devices and compilation modes ("eager",
|
|||||||
On TPU only the "compiled" mode is relevant.
|
On TPU only the "compiled" mode is relevant.
|
||||||
|
|
||||||
Our priority is to ensure same coverage and numerical behavior with JAX
|
Our priority is to ensure same coverage and numerical behavior with JAX
|
||||||
in the "compiled" mode, **when using XLA to compile the converted program**.
|
in the "compiled" mode, i.e., **when using XLA to compile the converted program**.
|
||||||
We are pretty close to that goal.
|
We are pretty close to that goal.
|
||||||
|
|
||||||
This table only shows errors for cases that are working in JAX (see [separate
|
This table only shows errors for cases that are working in JAX (see [separate
|
||||||
@ -60,20 +60,15 @@ More detailed information can be found in the
|
|||||||
| --- | --- | --- | --- | --- |
|
| --- | --- | --- | --- | --- |
|
||||||
| bessel_i0e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
| bessel_i0e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||||
| bessel_i1e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
| bessel_i1e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||||
| bitcast_convert_type | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph |
|
| cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph |
|
||||||
| cholesky | TF error: function not compilable | complex | cpu, gpu | compiled |
|
| cholesky | TF error: function not compilable | complex | cpu, gpu | compiled |
|
||||||
| cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph |
|
| cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph |
|
||||||
| clamp | TF test skipped: Not implemented in JAX: unimplemented | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
| clamp | TF test skipped: Not implemented in JAX: unimplemented | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||||
| clamp | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | int16, int32, int8 | gpu | compiled, eager, graph |
|
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | int16, int32, int8 | gpu | compiled, eager, graph |
|
||||||
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph |
|
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph |
|
||||||
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=f64 not implemented | bfloat16, float16, float32 | tpu | compiled, eager, graph |
|
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=f64 not implemented | bfloat16, float16, float32 | tpu | compiled, eager, graph |
|
||||||
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=i64 not implemented | int16, int32, int8 | tpu | compiled, eager, graph |
|
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=i64 not implemented | int16, int32, int8 | tpu | compiled, eager, graph |
|
||||||
| conv_general_dilated | TF error: jax2tf BUG: batch_group_count > 1 not yet converted | all | cpu, gpu, tpu | compiled, eager, graph |
|
| conv_general_dilated | TF error: jax2tf BUG: batch_group_count > 1 not yet converted | all | cpu, gpu, tpu | compiled, eager, graph |
|
||||||
| cummax | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| cummin | TF error: op not defined for dtype | uint64 | cpu, gpu | eager |
|
|
||||||
| cummin | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| digamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
| digamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||||
| div | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
|
| div | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
|
||||||
| dot_general | TF error: Numeric comparision disabled: Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598) | bfloat16, complex64, float16, float32 | gpu | compiled |
|
| dot_general | TF error: Numeric comparision disabled: Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598) | bfloat16, complex64, float16, float32 | gpu | compiled |
|
||||||
@ -92,42 +87,31 @@ More detailed information can be found in the
|
|||||||
| erf_inv | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
|
| erf_inv | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
|
||||||
| erfc | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
| erfc | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||||
| fft | TF error: TF function not compileable | complex128, float64 | cpu, gpu | compiled |
|
| fft | TF error: TF function not compileable | complex128, float64 | cpu, gpu | compiled |
|
||||||
| ge | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| gt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| igamma | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
|
| igamma | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
|
||||||
| igammac | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
|
| igammac | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
|
||||||
| integer_pow | TF error: op not defined for dtype | int16, int8, unsigned | cpu, gpu | graph |
|
| integer_pow | TF error: op not defined for dtype | int16, int8, unsigned | cpu, gpu | graph |
|
||||||
| le | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| lgamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
| lgamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||||
| lt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| lu | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
| lu | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
||||||
| lu | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
| lu | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||||
| max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| neg | TF error: op not defined for dtype | unsigned | cpu, gpu, tpu | compiled, eager, graph |
|
| neg | TF error: op not defined for dtype | unsigned | cpu, gpu, tpu | compiled, eager, graph |
|
||||||
| nextafter | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
| nextafter | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
||||||
| qr | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
|
| qr | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
|
||||||
| qr | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
|
| qr | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
|
||||||
| reduce_max | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
|
| reduce_max | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||||
| reduce_min | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
|
| reduce_min | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||||
| reduce_window_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| reduce_window_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| regularized_incomplete_beta | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
| regularized_incomplete_beta | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
||||||
| rem | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
|
| rem | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
|
||||||
| round | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
| round | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||||
| rsqrt | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
| rsqrt | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||||
| scatter_add | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
| scatter_add | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||||
| scatter_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
| scatter_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||||
| scatter_max | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
| scatter_max | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
||||||
| scatter_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| scatter_min | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
| scatter_min | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
||||||
| scatter_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
| scatter_mul | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||||
| scatter_mul | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| scatter_mul | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
| scatter_mul | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||||
| select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph |
|
| select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph |
|
||||||
| select_and_scatter_add | TF test skipped: Not implemented in JAX: works only for 2 or more inactive dimensions | all | tpu | compiled, eager, graph |
|
| select_and_scatter_add | TF test skipped: Not implemented in JAX: works only for 2 or more inactive dimensions | all | tpu | compiled, eager, graph |
|
||||||
| sign | TF error: sign not defined for unsigned integers | unsigned | cpu, gpu, tpu | compiled, eager, graph |
|
| sign | TF error: sign not defined for unsigned integers | unsigned | cpu, gpu, tpu | compiled, eager, graph |
|
||||||
| sort | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
|
||||||
| svd | TF test skipped: Not implemented in JAX: complex not implemented. Works in JAX for CPU and GPU with custom kernels | complex | tpu | compiled, eager, graph |
|
| svd | TF test skipped: Not implemented in JAX: complex not implemented. Works in JAX for CPU and GPU with custom kernels | complex | tpu | compiled, eager, graph |
|
||||||
| svd | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
|
| svd | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
|
||||||
| svd | TF error: function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint` | complex | cpu, gpu | compiled |
|
| svd | TF error: function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint` | complex | cpu, gpu | compiled |
|
||||||
|
@ -33,7 +33,7 @@ The errors apply only for certain devices and compilation modes ("eager",
|
|||||||
On TPU only the "compiled" mode is relevant.
|
On TPU only the "compiled" mode is relevant.
|
||||||
|
|
||||||
Our priority is to ensure same coverage and numerical behavior with JAX
|
Our priority is to ensure same coverage and numerical behavior with JAX
|
||||||
in the "compiled" mode, **when using XLA to compile the converted program**.
|
in the "compiled" mode, i.e., **when using XLA to compile the converted program**.
|
||||||
We are pretty close to that goal.
|
We are pretty close to that goal.
|
||||||
|
|
||||||
This table only shows errors for cases that are working in JAX (see [separate
|
This table only shows errors for cases that are working in JAX (see [separate
|
||||||
|
@ -1162,6 +1162,8 @@ def _minmax(x: TfVal, y: TfVal, *, is_min: bool,
|
|||||||
partial(lax._minmax_complex_lowering,
|
partial(lax._minmax_complex_lowering,
|
||||||
lax_cmp_pick_x=lax.lt if is_min else lax.gt),
|
lax_cmp_pick_x=lax.lt if is_min else lax.gt),
|
||||||
multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval)
|
multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval)
|
||||||
|
elif x.dtype.as_numpy_dtype == np.bool_:
|
||||||
|
return (tf.math.logical_and if is_min else tf.math.logical_or)(x, y)
|
||||||
else:
|
else:
|
||||||
return (tf.math.minimum if is_min else tf.math.maximum)(x, y)
|
return (tf.math.minimum if is_min else tf.math.maximum)(x, y)
|
||||||
|
|
||||||
@ -1287,19 +1289,37 @@ def _not(x):
|
|||||||
tf_impl[lax.not_p] = _not
|
tf_impl[lax.not_p] = _not
|
||||||
|
|
||||||
|
|
||||||
def bool_to_int8(f, argnums):
|
def bool_to_int8(f, argnums: Sequence[int]):
|
||||||
"""Computes bool valued functions using int8."""
|
"""Computes functions with some bool args and bool results using int8.
|
||||||
|
|
||||||
|
This is needed because some TF ops do not work for bool args, e.g.,
|
||||||
|
inequalities, min/max.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f: a TF callable to wrap. It will be called with non-boolean arguments.
|
||||||
|
argnums: the positional arguments that may be booleans.
|
||||||
|
|
||||||
|
Returns: a TF callable that can take a mix of boolean positional arguments
|
||||||
|
(in the positions specified by `argnums`) and some non-boolean positional
|
||||||
|
arguments. If there are no boolean arguments, just calls `f`. Otherwise,
|
||||||
|
casts the boolean arguments to `int8`, calls `f`, then casts the result to
|
||||||
|
`bool`.
|
||||||
|
"""
|
||||||
argnums = tf.nest.flatten(argnums)
|
argnums = tf.nest.flatten(argnums)
|
||||||
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args: TfVal, **kwargs):
|
||||||
if not any(args[i].dtype == tf.bool for i in argnums):
|
argnum_types = {args[i].dtype for i in argnums}
|
||||||
|
if tf.bool not in argnum_types:
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
# All argnums should be boolean
|
||||||
|
assert len(argnum_types) == 1, argnum_types
|
||||||
args_cast = [(tf.cast(a, tf.int8) if i in argnums else a)
|
args_cast = [(tf.cast(a, tf.int8) if i in argnums else a)
|
||||||
for i, a in enumerate(args)]
|
for i, a in enumerate(args)]
|
||||||
if "_in_avals" in kwargs:
|
if "_in_avals" in kwargs:
|
||||||
|
|
||||||
def cast_aval(aval):
|
def cast_aval(aval):
|
||||||
|
assert aval.dtype == np.bool_
|
||||||
return core.ShapedArray(aval.shape, np.int8)
|
return core.ShapedArray(aval.shape, np.int8)
|
||||||
|
|
||||||
_in_avals_cast = [
|
_in_avals_cast = [
|
||||||
@ -1321,10 +1341,11 @@ tf_impl[lax.xor_p] = bool_to_int8(tf.bitwise.bitwise_xor, argnums=(0, 1))
|
|||||||
|
|
||||||
tf_impl[lax.eq_p] = tf.math.equal
|
tf_impl[lax.eq_p] = tf.math.equal
|
||||||
tf_impl[lax.ne_p] = tf.math.not_equal
|
tf_impl[lax.ne_p] = tf.math.not_equal
|
||||||
tf_impl[lax.ge_p] = tf.math.greater_equal
|
|
||||||
tf_impl[lax.gt_p] = tf.math.greater
|
tf_impl[lax.ge_p] = bool_to_int8(tf.math.greater_equal, argnums=(0, 1))
|
||||||
tf_impl[lax.le_p] = tf.math.less_equal
|
tf_impl[lax.gt_p] = bool_to_int8(tf.math.greater, argnums=(0, 1))
|
||||||
tf_impl[lax.lt_p] = tf.math.less
|
tf_impl[lax.le_p] = bool_to_int8(tf.math.less_equal, argnums=(0, 1))
|
||||||
|
tf_impl[lax.lt_p] = bool_to_int8(tf.math.less, argnums=(0, 1))
|
||||||
|
|
||||||
tf_impl[lax_linalg.cholesky_p] = tf.linalg.cholesky
|
tf_impl[lax_linalg.cholesky_p] = tf.linalg.cholesky
|
||||||
|
|
||||||
@ -1346,6 +1367,8 @@ tf_impl[lax.convert_element_type_p] = _convert_element_type
|
|||||||
|
|
||||||
|
|
||||||
def _bitcast_convert_type(operand, new_dtype):
|
def _bitcast_convert_type(operand, new_dtype):
|
||||||
|
if operand.dtype == new_dtype:
|
||||||
|
return operand
|
||||||
return tf.bitcast(operand, _to_tf_dtype(new_dtype))
|
return tf.bitcast(operand, _to_tf_dtype(new_dtype))
|
||||||
|
|
||||||
|
|
||||||
@ -1767,13 +1790,13 @@ tf_impl[lax.transpose_p] = _transpose
|
|||||||
axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes)
|
axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes)
|
||||||
|
|
||||||
tf_impl[lax.reduce_sum_p] = (
|
tf_impl[lax.reduce_sum_p] = (
|
||||||
bool_to_int8(axes_to_axis(tf.reduce_sum), argnums=0))
|
bool_to_int8(axes_to_axis(tf.reduce_sum), argnums=[0]))
|
||||||
tf_impl[lax.reduce_prod_p] = (
|
tf_impl[lax.reduce_prod_p] = (
|
||||||
bool_to_int8(axes_to_axis(tf.reduce_prod), argnums=0))
|
bool_to_int8(axes_to_axis(tf.reduce_prod), argnums=[0]))
|
||||||
tf_impl[lax.reduce_max_p] = (
|
tf_impl[lax.reduce_max_p] = (
|
||||||
bool_to_int8(axes_to_axis(tf.reduce_max), argnums=0))
|
bool_to_int8(axes_to_axis(tf.reduce_max), argnums=[0]))
|
||||||
tf_impl[lax.reduce_min_p] = (
|
tf_impl[lax.reduce_min_p] = (
|
||||||
bool_to_int8(axes_to_axis(tf.reduce_min), argnums=0))
|
bool_to_int8(axes_to_axis(tf.reduce_min), argnums=[0]))
|
||||||
tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any)
|
tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any)
|
||||||
tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all)
|
tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all)
|
||||||
|
|
||||||
@ -2178,7 +2201,7 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers):
|
|||||||
return proto
|
return proto
|
||||||
|
|
||||||
|
|
||||||
@partial(bool_to_int8, argnums=0)
|
@partial(bool_to_int8, argnums=[0])
|
||||||
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
|
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
|
||||||
indices_are_sorted, unique_indices,
|
indices_are_sorted, unique_indices,
|
||||||
_in_avals, _out_aval):
|
_in_avals, _out_aval):
|
||||||
@ -2446,25 +2469,6 @@ def _sort(*operands: TfVal, dimension: int, is_stable: bool,
|
|||||||
operands[0].shape
|
operands[0].shape
|
||||||
), f"Invalid {dimension} for ndim {len(operands[0].shape)}"
|
), f"Invalid {dimension} for ndim {len(operands[0].shape)}"
|
||||||
|
|
||||||
# The comparator is a 2N-argument TF function, with arguments [2k] and [2k +1]
|
|
||||||
# corresponding to two scalars from operand[k].
|
|
||||||
def lexicographic_comparator_old(*tf_args: TfVal) -> TfVal:
|
|
||||||
assert len(tf_args) == 2 * len(operands)
|
|
||||||
# We build a comparison:
|
|
||||||
# arg[0] < arg[1] or (arg[0] == arg[1] and (arg[2] < arg[3] or ...))
|
|
||||||
# all the way to arg[2 * num_keys - 2] < arg[2 * num_keys - 1]
|
|
||||||
inside_comparison = None
|
|
||||||
for key_idx in range(num_keys - 1, -1, -1):
|
|
||||||
a = tf_args[2 * key_idx]
|
|
||||||
b = tf_args[2 * key_idx + 1]
|
|
||||||
a_lt_b = tf.math.less(a, b)
|
|
||||||
if inside_comparison is None:
|
|
||||||
inside_comparison = a_lt_b
|
|
||||||
else:
|
|
||||||
inside_comparison = tf.math.logical_or(
|
|
||||||
a_lt_b, tf.math.logical_and(tf.math.equal(a, b), inside_comparison))
|
|
||||||
return inside_comparison
|
|
||||||
|
|
||||||
comparator_spec: List[tf.TensorSpec] = []
|
comparator_spec: List[tf.TensorSpec] = []
|
||||||
comparator_jax_in_avals: List[core.AbstractValue] = []
|
comparator_jax_in_avals: List[core.AbstractValue] = []
|
||||||
for op in operands:
|
for op in operands:
|
||||||
|
@ -109,13 +109,13 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
group_method = getattr(cls, harness.group_name, None)
|
group_method = getattr(cls, harness.group_name, None)
|
||||||
if harness.group_name in cls.harness_groups_no_limitations:
|
if harness.group_name in cls.harness_groups_no_limitations:
|
||||||
assert group_method is None, (
|
assert group_method is None, (
|
||||||
f"Harness group {harness.group_name} is both in "
|
f"Harness group '{harness.group_name}' is both in "
|
||||||
f"'harness_groups_no_limitations' and has a custom "
|
f"'harness_groups_no_limitations' and has a custom "
|
||||||
f"Jax2TfLimitation.classmethod defined (see module docstring)")
|
f"Jax2TfLimitation.classmethod defined (see module docstring)")
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
assert group_method is not None, (
|
assert group_method is not None, (
|
||||||
f"Harness group {harness.group_name} must be either part of "
|
f"Harness group '{harness.group_name}' must be either part of "
|
||||||
f"'harness_groups_no_limitations' or must have a custom "
|
f"'harness_groups_no_limitations' or must have a custom "
|
||||||
f"Jax2TfLimitation.classmethod defined (see module docstring)")
|
f"Jax2TfLimitation.classmethod defined (see module docstring)")
|
||||||
limitations = group_method(harness)
|
limitations = group_method(harness)
|
||||||
@ -124,16 +124,19 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
|
|
||||||
# We keep here the explicit set of groups for which we don't have limitations
|
# We keep here the explicit set of groups for which we don't have limitations
|
||||||
harness_groups_no_limitations = {
|
harness_groups_no_limitations = {
|
||||||
"abs", "and", "argmin", "argmax", "atan2", "broadcast",
|
"abs", "add", "add_any", "and", "argmin", "argmax", "atan2",
|
||||||
"broadcast_in_dim", "ceil",
|
"bitcast_convert_type", "broadcast", "broadcast_in_dim", "ceil", "clamp",
|
||||||
"concatenate", "cos", "cosh", "complex", "conj",
|
"concatenate", "cos", "cosh", "complex", "conj", "convert_element_type",
|
||||||
"device_put", "dynamic_slice",
|
"cummax", "cummin", "device_put", "dynamic_slice",
|
||||||
"dynamic_update_slice", "exp", "eq", "floor", "log", "gather", "imag",
|
"dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt", "imag",
|
||||||
"iota", "is_finite", "ne", "not", "or", "pad", "random_split",
|
"iota", "is_finite", "le", "lt", "log", "mul", "ne", "not", "or", "pad",
|
||||||
"reduce_and", "reduce_prod", "reduce_or", "reduce_sum", "real", "reshape",
|
"population_count", "random_split",
|
||||||
"rev",
|
"reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
|
||||||
"select", "shift_left", "shift_right_logical", "shift_right_arithmetic",
|
"reduce_window_add", "reduce_window_mul", "reduce_window_min", "reduce_window_max",
|
||||||
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient",
|
"real", "reshape", "rev", "scatter_max", "scatter_min",
|
||||||
|
"select", "select_and_scatter_add",
|
||||||
|
"shift_left", "shift_right_logical", "shift_right_arithmetic",
|
||||||
|
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "sub",
|
||||||
"tie_in", "transpose", "xor", "zeros_like"
|
"tie_in", "transpose", "xor", "zeros_like"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -173,15 +176,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
cls.helper_get_trig_custom_limitation(np.cosh)
|
cls.helper_get_trig_custom_limitation(np.cosh)
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def add(cls, harness: primitive_harness.Harness):
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
# Also called add_jaxvals
|
|
||||||
def add_any(cls, harness: primitive_harness.Harness):
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def asin(cls, harness: primitive_harness.Harness):
|
def asin(cls, harness: primitive_harness.Harness):
|
||||||
return [
|
return [
|
||||||
@ -228,10 +222,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
def bessel_i1e(cls, harness: primitive_harness.Harness):
|
def bessel_i1e(cls, harness: primitive_harness.Harness):
|
||||||
return cls.bessel_i0e(harness)
|
return cls.bessel_i0e(harness)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def bitcast_convert_type(cls, harness: primitive_harness.Harness):
|
|
||||||
return [missing_tf_kernel(dtypes=[np.bool_])]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def cholesky(cls, harness: primitive_harness.Harness):
|
def cholesky(cls, harness: primitive_harness.Harness):
|
||||||
|
|
||||||
@ -280,16 +270,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
modes=("eager", "graph", "compiled"))
|
modes=("eager", "graph", "compiled"))
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def clamp(cls, harness: primitive_harness.Harness):
|
|
||||||
return [
|
|
||||||
missing_tf_kernel(dtypes=[np.complex64, np.complex128]),
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_element_type(cls, harness: primitive_harness.Harness):
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def conv_general_dilated(cls, harness: primitive_harness.Harness):
|
def conv_general_dilated(cls, harness: primitive_harness.Harness):
|
||||||
return [
|
return [
|
||||||
@ -307,22 +287,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
tol=5e-3)
|
tol=5e-3)
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def cummax(cls, harness):
|
|
||||||
return [
|
|
||||||
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def cummin(cls, harness):
|
|
||||||
return [
|
|
||||||
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
|
|
||||||
# TODO: we get jax2tf AssertionError
|
|
||||||
missing_tf_kernel(dtypes=[np.uint64],
|
|
||||||
devices=("cpu", "gpu"),
|
|
||||||
modes=("eager",)),
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def cumprod(cls, harness):
|
def cumprod(cls, harness):
|
||||||
return [
|
return [
|
||||||
@ -421,9 +385,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def dot_general(cls, harness: primitive_harness.Harness):
|
def dot_general(cls, harness: primitive_harness.Harness):
|
||||||
return [
|
return [
|
||||||
missing_tf_kernel(dtypes=[
|
missing_tf_kernel(dtypes=[np.bool_],),
|
||||||
np.bool_,
|
|
||||||
],),
|
|
||||||
# TODO(b/189287598)
|
# TODO(b/189287598)
|
||||||
Jax2TfLimitation(
|
Jax2TfLimitation(
|
||||||
"Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598)",
|
"Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598)",
|
||||||
@ -583,14 +545,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
modes=("eager", "graph", "compiled"))
|
modes=("eager", "graph", "compiled"))
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def ge(cls, harness: primitive_harness.Harness):
|
|
||||||
return [missing_tf_kernel(dtypes=[np.bool_])]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def gt(cls, harness: primitive_harness.Harness):
|
|
||||||
return cls.ge(harness)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def erf(cls, harness: primitive_harness.Harness):
|
def erf(cls, harness: primitive_harness.Harness):
|
||||||
return [
|
return [
|
||||||
@ -799,14 +753,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
def pow(cls, harness: primitive_harness.Harness):
|
def pow(cls, harness: primitive_harness.Harness):
|
||||||
return cls._pow_test_util(harness)
|
return cls._pow_test_util(harness)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def le(cls, harness: primitive_harness.Harness):
|
|
||||||
return cls.ge(harness)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def lt(cls, harness: primitive_harness.Harness):
|
|
||||||
return cls.ge(harness)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def lgamma(cls, harness: primitive_harness.Harness):
|
def lgamma(cls, harness: primitive_harness.Harness):
|
||||||
return [
|
return [
|
||||||
@ -881,8 +827,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
|
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
missing_tf_kernel(
|
|
||||||
dtypes=[np.bool_]),
|
|
||||||
custom_numeric(
|
custom_numeric(
|
||||||
custom_assert=custom_assert,
|
custom_assert=custom_assert,
|
||||||
description=(
|
description=(
|
||||||
@ -901,7 +845,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
|
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
missing_tf_kernel(dtypes=[np.bool_]),
|
|
||||||
custom_numeric(
|
custom_numeric(
|
||||||
custom_assert=custom_assert,
|
custom_assert=custom_assert,
|
||||||
description=(
|
description=(
|
||||||
@ -911,10 +854,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
modes=("eager", "graph", "compiled"))
|
modes=("eager", "graph", "compiled"))
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def mul(cls, harness: primitive_harness.Harness):
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def neg(cls, harness: primitive_harness.Harness):
|
def neg(cls, harness: primitive_harness.Harness):
|
||||||
return [
|
return [
|
||||||
@ -925,10 +864,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
def nextafter(cls, harness: primitive_harness.Harness):
|
def nextafter(cls, harness: primitive_harness.Harness):
|
||||||
return [missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16])]
|
return [missing_tf_kernel(dtypes=[np.float16, dtypes.bfloat16])]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def population_count(cls, harness: primitive_harness.Harness):
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def qr(cls, harness: primitive_harness.Harness):
|
def qr(cls, harness: primitive_harness.Harness):
|
||||||
# See https://github.com/google/jax/pull/3775#issuecomment-659407824;
|
# See https://github.com/google/jax/pull/3775#issuecomment-659407824;
|
||||||
@ -960,39 +895,14 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def reduce_max(cls, harness: primitive_harness.Harness):
|
def reduce_max(cls, harness: primitive_harness.Harness):
|
||||||
return [
|
# Unlike reduce_window_max, we use a native TF op: tf.reduce_max, which
|
||||||
missing_tf_kernel(dtypes=[np.complex64, np.complex128]),
|
# does not work for complex
|
||||||
]
|
return [missing_tf_kernel(dtypes=[np.complex64, np.complex128])]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def reduce_min(cls, harness: primitive_harness.Harness):
|
def reduce_min(cls, harness: primitive_harness.Harness):
|
||||||
return cls.reduce_max(harness)
|
return cls.reduce_max(harness)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def reduce_window_add(cls, harness):
|
|
||||||
assert "add" == harness.params["computation"].__name__
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def reduce_window_max(cls, harness):
|
|
||||||
assert "max" == harness.params["computation"].__name__
|
|
||||||
return [
|
|
||||||
missing_tf_kernel(dtypes=[np.bool_]),
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def reduce_window_min(cls, harness):
|
|
||||||
assert "min" == harness.params["computation"].__name__
|
|
||||||
return [
|
|
||||||
missing_tf_kernel(dtypes=[np.bool_]),
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def reduce_window_mul(cls, harness):
|
|
||||||
assert "mul" == harness.params["computation"].__name__
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def regularized_incomplete_beta(cls, harness: primitive_harness.Harness):
|
def regularized_incomplete_beta(cls, harness: primitive_harness.Harness):
|
||||||
return [
|
return [
|
||||||
@ -1034,29 +944,15 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def scatter_add(cls, harness):
|
def scatter_add(cls, harness):
|
||||||
return [
|
return [
|
||||||
missing_tf_kernel(dtypes=[np.bool_]),
|
|
||||||
missing_tf_kernel(
|
missing_tf_kernel(
|
||||||
dtypes=[np.complex64],
|
dtypes=[np.complex64],
|
||||||
devices="tpu",
|
devices="tpu",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def scatter_max(cls, harness):
|
|
||||||
return [
|
|
||||||
missing_tf_kernel(dtypes=[np.bool_]),
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def scatter_min(cls, harness):
|
|
||||||
return [
|
|
||||||
missing_tf_kernel(dtypes=[np.bool_]),
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def scatter_mul(cls, harness):
|
def scatter_mul(cls, harness):
|
||||||
return [
|
return [
|
||||||
missing_tf_kernel(dtypes=[np.bool_],),
|
|
||||||
missing_tf_kernel(
|
missing_tf_kernel(
|
||||||
dtypes=[np.complex64],
|
dtypes=[np.complex64],
|
||||||
devices="tpu",
|
devices="tpu",
|
||||||
@ -1078,10 +974,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
devices=("cpu", "gpu"))
|
devices=("cpu", "gpu"))
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def select_and_scatter_add(cls, harness):
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sign(cls, harness: primitive_harness.Harness):
|
def sign(cls, harness: primitive_harness.Harness):
|
||||||
return [
|
return [
|
||||||
@ -1101,13 +993,8 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
|||||||
not harness.params["is_stable"]),
|
not harness.params["is_stable"]),
|
||||||
expect_tf_error=False,
|
expect_tf_error=False,
|
||||||
skip_comparison=True),
|
skip_comparison=True),
|
||||||
missing_tf_kernel(dtypes=[np.bool_],),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def sub(cls, harness):
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def svd(cls, harness: primitive_harness.Harness):
|
def svd(cls, harness: primitive_harness.Harness):
|
||||||
# TODO: slow test
|
# TODO: slow test
|
||||||
|
@ -1202,7 +1202,11 @@ def _make_scatter_harness(name,
|
|||||||
"unimplemented",
|
"unimplemented",
|
||||||
devices="tpu",
|
devices="tpu",
|
||||||
dtypes=np.complex64,
|
dtypes=np.complex64,
|
||||||
enabled=(f_lax in [lax.scatter_max, lax.scatter_min]))
|
enabled=(f_lax in [lax.scatter_max, lax.scatter_min])),
|
||||||
|
Limitation(
|
||||||
|
"unimplemented",
|
||||||
|
dtypes=np.bool_,
|
||||||
|
enabled=(f_lax in [lax.scatter_add, lax.scatter_mul])),
|
||||||
],
|
],
|
||||||
f_lax=f_lax,
|
f_lax=f_lax,
|
||||||
shape=shape,
|
shape=shape,
|
||||||
@ -1219,8 +1223,8 @@ for dtype in jtu.dtypes.all:
|
|||||||
for f_lax in [
|
for f_lax in [
|
||||||
lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min
|
lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min
|
||||||
]:
|
]:
|
||||||
if f_lax in [lax.scatter_add, lax.scatter_mul] and dtype == np.bool_:
|
#if f_lax in [lax.scatter_add, lax.scatter_mul] and dtype == np.bool_:
|
||||||
continue
|
# continue
|
||||||
_make_scatter_harness("dtypes", dtype=dtype, f_lax=f_lax)
|
_make_scatter_harness("dtypes", dtype=dtype, f_lax=f_lax)
|
||||||
|
|
||||||
# Validate f_lax/update_jaxpr
|
# Validate f_lax/update_jaxpr
|
||||||
|
Loading…
x
Reference in New Issue
Block a user