[jax2tf] Add the JAX-not-implemented to the jax2tf limitations doc

This commit is contained in:
George Necula 2021-01-29 12:48:56 +01:00
parent 9bd5ad5542
commit 3c89de6eed
5 changed files with 39 additions and 9 deletions

View File

@ -199,10 +199,8 @@ and search for "limitation".
|reduce_window_max|unimplemented in XLA|complex64|tpu|
|reduce_window_min|unimplemented in XLA|complex64|tpu|
|reduce_window_mul|unimplemented in XLA|complex64|tpu|
|scatter_add|unimplemented|complex64|tpu|
|scatter_max|unimplemented|complex64|tpu|
|scatter_min|unimplemented|complex64|tpu|
|scatter_mul|unimplemented|complex64|tpu|
|select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu|
|svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu|
|svd|unimplemented|bfloat16, float16|cpu, gpu|

View File

@ -67,27 +67,37 @@ More detailed information can be found in the
| bessel_i0e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| bessel_i1e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| bitcast_convert_type | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph |
| cholesky | TF error: function not compilable | complex | cpu, gpu | compiled |
| cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph |
| clamp | TF error: op not defined for dtype | int8, uint16, uint32, uint64 | cpu, gpu, tpu | compiled, eager, graph |
| conv_general_dilated | TF error: jax2tf BUG: batch_group_count > 1 not yet converted | all | cpu, gpu, tpu | compiled, eager, graph |
| cosh | TF error: op not defined for dtype | float16 | cpu, gpu | eager, graph |
| cummax | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| cummax | TF error: op not defined for dtype | complex128 | cpu, gpu | compiled, eager, graph |
| cummax | TF error: op not defined for dtype | complex64 | cpu, gpu, tpu | compiled, eager, graph |
| cummax | TF error: op not defined for dtype | complex64 | cpu, gpu | compiled, eager, graph |
| cummin | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| cummin | TF error: op not defined for dtype | complex128, uint64 | cpu, gpu | compiled, eager, graph |
| cummin | TF error: op not defined for dtype | complex64, int8, uint16, uint32 | cpu, gpu, tpu | compiled, eager, graph |
| cumprod | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| cumsum | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| digamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| div | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| div | TF error: op not defined for dtype | int16, int8, unsigned | cpu, gpu, tpu | compiled, eager, graph |
| dot_general | TF error: op not defined for dtype | int64 | cpu, gpu | compiled |
| dot_general | TF error: op not defined for dtype | bool, int16, int8, unsigned | cpu, gpu, tpu | compiled, eager, graph |
| eig | TF test skipped: Not implemented in JAX: only supported on CPU in JAX | all | gpu, tpu | compiled, eager, graph |
| eig | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu | compiled, eager, graph |
| eig | TF error: TF Conversion of eig is not implemented when both compute_left_eigenvectors and compute_right_eigenvectors are set to True | all | cpu, gpu, tpu | compiled, eager, graph |
| eig | TF error: function not compilable | all | cpu | compiled |
| eigh | TF test skipped: Not implemented in JAX: complex eigh not supported | complex | tpu | compiled, eager, graph |
| eigh | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu | compiled, eager, graph |
| eigh | TF test skipped: Not implemented in JAX: unimplemented | float16 | gpu | compiled, eager, graph |
| eigh | TF error: function not compilable | complex | cpu, gpu, tpu | compiled |
| erf | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| erf_inv | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
| erfc | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| fft | TF test skipped: Not implemented in JAX: only 1D FFT is currently supported b/140351181. | all | tpu | compiled, eager, graph |
| fft | TF error: TF function not compileable | complex128, float64 | cpu, gpu | compiled |
| ge | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| gt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
@ -97,6 +107,7 @@ More detailed information can be found in the
| le | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| lgamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| lt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| lu | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| lu | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| max | TF error: op not defined for dtype | complex128 | cpu, gpu | compiled, eager, graph |
| max | TF error: op not defined for dtype | int8, uint16, uint32, uint64 | cpu, gpu | eager, graph |
@ -106,15 +117,19 @@ More detailed information can be found in the
| neg | TF error: op not defined for dtype | unsigned | cpu, gpu, tpu | compiled, eager, graph |
| nextafter | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| population_count | TF error: op not defined for dtype | uint32, uint64 | cpu, gpu | eager, graph |
| qr | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
| qr | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
| reduce_max | TF error: op not defined for dtype | complex128 | cpu, gpu, tpu | compiled, eager, graph |
| reduce_max | TF error: op not defined for dtype | complex64 | cpu, gpu, tpu | compiled, eager, graph |
| reduce_min | TF error: op not defined for dtype | complex128 | cpu, gpu, tpu | compiled, eager, graph |
| reduce_min | TF error: op not defined for dtype | complex64 | cpu, gpu, tpu | compiled, eager, graph |
| reduce_window_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| reduce_window_max | TF test skipped: Not implemented in JAX: unimplemented in XLA | complex64 | tpu | compiled, eager, graph |
| reduce_window_max | TF error: op not defined for dtype | complex128 | cpu, gpu | compiled, eager, graph |
| reduce_window_max | TF error: op not defined for dtype | bool, complex64 | cpu, gpu, tpu | compiled, eager, graph |
| reduce_window_min | TF test skipped: Not implemented in JAX: unimplemented in XLA | complex64 | tpu | compiled, eager, graph |
| reduce_window_min | TF error: op not defined for dtype | bool, complex, int8, uint16, uint32, uint64 | cpu, gpu, tpu | compiled, eager, graph |
| reduce_window_mul | TF test skipped: Not implemented in JAX: unimplemented in XLA | complex64 | tpu | compiled, eager, graph |
| regularized_incomplete_beta | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| rem | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| rem | TF error: op not defined for dtype | float16 | cpu, gpu | eager, graph |
@ -124,18 +139,25 @@ More detailed information can be found in the
| rsqrt | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| scatter_add | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| scatter_max | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| scatter_max | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| scatter_min | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| scatter_min | TF error: op not defined for dtype | bool, complex, int8, uint16, uint32, uint64 | cpu, gpu, tpu | compiled, eager, graph |
| scatter_mul | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_mul | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| select_and_gather_add | TF error: This JAX primitives is not not exposed directly in the JAX API but arises from JVP of `lax.reduce_window` for reducers `lax.max` or `lax.min`. It also arises from second-order VJP of the same. Implemented using XlaReduceWindow | float32 | tpu | compiled, eager, graph |
| select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph |
| select_and_scatter_add | TF test skipped: Not implemented in JAX: works only for 2 or more inactive dimensions | all | tpu | compiled, eager, graph |
| sign | TF error: op not defined for dtype | int16, int8, unsigned | cpu, gpu, tpu | compiled, eager, graph |
| sinh | TF error: op not defined for dtype | float16 | cpu, gpu | eager, graph |
| sort | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| svd | TF test skipped: Not implemented in JAX: complex not implemented. Works in JAX for CPU and GPU with custom kernels | complex | tpu | compiled, eager, graph |
| svd | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
| svd | TF error: function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint` | complex | cpu, gpu | compiled |
| svd | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
| tie_in | TF test skipped: Not implemented in JAX: requires omnistaging to be disabled | all | cpu, gpu, tpu | compiled, eager, graph |
| top_k | TF error: op not defined for dtype | int64, uint64 | cpu, gpu | compiled |
| triangular_solve | TF test skipped: Not implemented in JAX: unimplemented | float16 | gpu | compiled, eager, graph |
| triangular_solve | TF error: op not defined for dtype | bfloat16 | cpu, gpu, tpu | compiled, eager, graph |
| triangular_solve | TF error: op not defined for dtype | float16 | cpu, gpu | eager, graph |

View File

@ -344,7 +344,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
dtypes=[np.complex128],
devices=("cpu", "gpu"),
),
missing_tf_kernel(dtypes=[np.complex64]),
missing_tf_kernel(dtypes=[np.complex64], devices=("cpu", "gpu")),
custom_numeric(dtypes=np.float16, tol=0.1),
custom_numeric(dtypes=dtypes.bfloat16, tol=0.5)
]
@ -1054,12 +1054,11 @@ class Jax2TfLimitation(primitive_harness.Limitation):
@classmethod
def scatter_add(cls, harness):
return [
missing_tf_kernel(dtypes=[np.bool_],),
missing_tf_kernel(dtypes=[np.bool_]),
missing_tf_kernel(
dtypes=[np.complex64],
devices="tpu",
),
]
)]
@classmethod
def scatter_max(cls, harness):

View File

@ -99,7 +99,8 @@ class JaxPrimitiveTest(jtu.JaxTestCase):
for h in harnesses:
harness_groups[h.group_name].append(h)
for l in h.jax_unimplemented:
unique_limitations[hash(unique_hash(h, l))] = (h, l)
if l.enabled:
unique_limitations[hash(unique_hash(h, l))] = (h, l)
primitive_coverage_table = ["""
| Primitive | Total test harnesses | dtypes supported on at least one device | dtypes NOT tested on any device |

View File

@ -145,7 +145,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
harnesses = [
h for h in primitive_harness.all_harnesses
if h.filter(h, include_jax_unimpl=False)
if h.filter(h, include_jax_unimpl=True)
]
print(f"Found {len(harnesses)} test harnesses that work in JAX")
@ -154,6 +154,16 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
tuple([np.dtype(d).name for d in l.dtypes]), l.modes)
unique_limitations: Dict[Any, Tuple[primitive_harness.Harness, Jax2TfLimitation]] = {}
for h in harnesses:
for l in h.jax_unimplemented:
if l.enabled:
# Fake a Jax2TFLimitation from the Limitation
tfl = Jax2TfLimitation(description="Not implemented in JAX: " + l.description,
devices = l.devices,
dtypes = l.dtypes,
expect_tf_error = False,
skip_tf_run = True)
unique_limitations[hash(unique_hash(h, tfl))] = (h, tfl)
for h in harnesses:
for l in Jax2TfLimitation.limitations_for_harness(h):
unique_limitations[hash(unique_hash(h, l))] = (h, l)