mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Add the JAX-not-implemented to the jax2tf limitations doc
This commit is contained in:
parent
9bd5ad5542
commit
3c89de6eed
@ -199,10 +199,8 @@ and search for "limitation".
|
||||
|reduce_window_max|unimplemented in XLA|complex64|tpu|
|
||||
|reduce_window_min|unimplemented in XLA|complex64|tpu|
|
||||
|reduce_window_mul|unimplemented in XLA|complex64|tpu|
|
||||
|scatter_add|unimplemented|complex64|tpu|
|
||||
|scatter_max|unimplemented|complex64|tpu|
|
||||
|scatter_min|unimplemented|complex64|tpu|
|
||||
|scatter_mul|unimplemented|complex64|tpu|
|
||||
|select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu|
|
||||
|svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu|
|
||||
|svd|unimplemented|bfloat16, float16|cpu, gpu|
|
||||
|
@ -67,27 +67,37 @@ More detailed information can be found in the
|
||||
| bessel_i0e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| bessel_i1e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| bitcast_convert_type | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph |
|
||||
| cholesky | TF error: function not compilable | complex | cpu, gpu | compiled |
|
||||
| cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph |
|
||||
| clamp | TF error: op not defined for dtype | int8, uint16, uint32, uint64 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| conv_general_dilated | TF error: jax2tf BUG: batch_group_count > 1 not yet converted | all | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| cosh | TF error: op not defined for dtype | float16 | cpu, gpu | eager, graph |
|
||||
| cummax | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
||||
| cummax | TF error: op not defined for dtype | complex128 | cpu, gpu | compiled, eager, graph |
|
||||
| cummax | TF error: op not defined for dtype | complex64 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| cummax | TF error: op not defined for dtype | complex64 | cpu, gpu | compiled, eager, graph |
|
||||
| cummin | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
||||
| cummin | TF error: op not defined for dtype | complex128, uint64 | cpu, gpu | compiled, eager, graph |
|
||||
| cummin | TF error: op not defined for dtype | complex64, int8, uint16, uint32 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| cumprod | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
||||
| cumsum | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||
| digamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| div | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| div | TF error: op not defined for dtype | int16, int8, unsigned | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| dot_general | TF error: op not defined for dtype | int64 | cpu, gpu | compiled |
|
||||
| dot_general | TF error: op not defined for dtype | bool, int16, int8, unsigned | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| eig | TF test skipped: Not implemented in JAX: only supported on CPU in JAX | all | gpu, tpu | compiled, eager, graph |
|
||||
| eig | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu | compiled, eager, graph |
|
||||
| eig | TF error: TF Conversion of eig is not implemented when both compute_left_eigenvectors and compute_right_eigenvectors are set to True | all | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| eig | TF error: function not compilable | all | cpu | compiled |
|
||||
| eigh | TF test skipped: Not implemented in JAX: complex eigh not supported | complex | tpu | compiled, eager, graph |
|
||||
| eigh | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu | compiled, eager, graph |
|
||||
| eigh | TF test skipped: Not implemented in JAX: unimplemented | float16 | gpu | compiled, eager, graph |
|
||||
| eigh | TF error: function not compilable | complex | cpu, gpu, tpu | compiled |
|
||||
| erf | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| erf_inv | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
|
||||
| erfc | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| fft | TF test skipped: Not implemented in JAX: only 1D FFT is currently supported b/140351181. | all | tpu | compiled, eager, graph |
|
||||
| fft | TF error: TF function not compileable | complex128, float64 | cpu, gpu | compiled |
|
||||
| ge | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| gt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
@ -97,6 +107,7 @@ More detailed information can be found in the
|
||||
| le | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| lgamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| lt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| lu | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| lu | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||
| max | TF error: op not defined for dtype | complex128 | cpu, gpu | compiled, eager, graph |
|
||||
| max | TF error: op not defined for dtype | int8, uint16, uint32, uint64 | cpu, gpu | eager, graph |
|
||||
@ -106,15 +117,19 @@ More detailed information can be found in the
|
||||
| neg | TF error: op not defined for dtype | unsigned | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| nextafter | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| population_count | TF error: op not defined for dtype | uint32, uint64 | cpu, gpu | eager, graph |
|
||||
| qr | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
|
||||
| qr | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
|
||||
| reduce_max | TF error: op not defined for dtype | complex128 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_max | TF error: op not defined for dtype | complex64 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_min | TF error: op not defined for dtype | complex128 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_min | TF error: op not defined for dtype | complex64 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_window_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||
| reduce_window_max | TF test skipped: Not implemented in JAX: unimplemented in XLA | complex64 | tpu | compiled, eager, graph |
|
||||
| reduce_window_max | TF error: op not defined for dtype | complex128 | cpu, gpu | compiled, eager, graph |
|
||||
| reduce_window_max | TF error: op not defined for dtype | bool, complex64 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_window_min | TF test skipped: Not implemented in JAX: unimplemented in XLA | complex64 | tpu | compiled, eager, graph |
|
||||
| reduce_window_min | TF error: op not defined for dtype | bool, complex, int8, uint16, uint32, uint64 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_window_mul | TF test skipped: Not implemented in JAX: unimplemented in XLA | complex64 | tpu | compiled, eager, graph |
|
||||
| regularized_incomplete_beta | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| rem | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| rem | TF error: op not defined for dtype | float16 | cpu, gpu | eager, graph |
|
||||
@ -124,18 +139,25 @@ More detailed information can be found in the
|
||||
| rsqrt | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
| scatter_add | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||
| scatter_max | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
||||
| scatter_max | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_min | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
||||
| scatter_min | TF error: op not defined for dtype | bool, complex, int8, uint16, uint32, uint64 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_mul | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_mul | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||
| select_and_gather_add | TF error: This JAX primitives is not not exposed directly in the JAX API but arises from JVP of `lax.reduce_window` for reducers `lax.max` or `lax.min`. It also arises from second-order VJP of the same. Implemented using XlaReduceWindow | float32 | tpu | compiled, eager, graph |
|
||||
| select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph |
|
||||
| select_and_scatter_add | TF test skipped: Not implemented in JAX: works only for 2 or more inactive dimensions | all | tpu | compiled, eager, graph |
|
||||
| sign | TF error: op not defined for dtype | int16, int8, unsigned | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| sinh | TF error: op not defined for dtype | float16 | cpu, gpu | eager, graph |
|
||||
| sort | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| svd | TF test skipped: Not implemented in JAX: complex not implemented. Works in JAX for CPU and GPU with custom kernels | complex | tpu | compiled, eager, graph |
|
||||
| svd | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
|
||||
| svd | TF error: function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint` | complex | cpu, gpu | compiled |
|
||||
| svd | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
|
||||
| tie_in | TF test skipped: Not implemented in JAX: requires omnistaging to be disabled | all | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| top_k | TF error: op not defined for dtype | int64, uint64 | cpu, gpu | compiled |
|
||||
| triangular_solve | TF test skipped: Not implemented in JAX: unimplemented | float16 | gpu | compiled, eager, graph |
|
||||
| triangular_solve | TF error: op not defined for dtype | bfloat16 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| triangular_solve | TF error: op not defined for dtype | float16 | cpu, gpu | eager, graph |
|
||||
|
||||
|
@ -344,7 +344,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
dtypes=[np.complex128],
|
||||
devices=("cpu", "gpu"),
|
||||
),
|
||||
missing_tf_kernel(dtypes=[np.complex64]),
|
||||
missing_tf_kernel(dtypes=[np.complex64], devices=("cpu", "gpu")),
|
||||
custom_numeric(dtypes=np.float16, tol=0.1),
|
||||
custom_numeric(dtypes=dtypes.bfloat16, tol=0.5)
|
||||
]
|
||||
@ -1054,12 +1054,11 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
@classmethod
|
||||
def scatter_add(cls, harness):
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_],),
|
||||
missing_tf_kernel(dtypes=[np.bool_]),
|
||||
missing_tf_kernel(
|
||||
dtypes=[np.complex64],
|
||||
devices="tpu",
|
||||
),
|
||||
]
|
||||
)]
|
||||
|
||||
@classmethod
|
||||
def scatter_max(cls, harness):
|
||||
|
@ -99,7 +99,8 @@ class JaxPrimitiveTest(jtu.JaxTestCase):
|
||||
for h in harnesses:
|
||||
harness_groups[h.group_name].append(h)
|
||||
for l in h.jax_unimplemented:
|
||||
unique_limitations[hash(unique_hash(h, l))] = (h, l)
|
||||
if l.enabled:
|
||||
unique_limitations[hash(unique_hash(h, l))] = (h, l)
|
||||
|
||||
primitive_coverage_table = ["""
|
||||
| Primitive | Total test harnesses | dtypes supported on at least one device | dtypes NOT tested on any device |
|
||||
|
@ -145,7 +145,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
harnesses = [
|
||||
h for h in primitive_harness.all_harnesses
|
||||
if h.filter(h, include_jax_unimpl=False)
|
||||
if h.filter(h, include_jax_unimpl=True)
|
||||
]
|
||||
print(f"Found {len(harnesses)} test harnesses that work in JAX")
|
||||
|
||||
@ -154,6 +154,16 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
tuple([np.dtype(d).name for d in l.dtypes]), l.modes)
|
||||
|
||||
unique_limitations: Dict[Any, Tuple[primitive_harness.Harness, Jax2TfLimitation]] = {}
|
||||
for h in harnesses:
|
||||
for l in h.jax_unimplemented:
|
||||
if l.enabled:
|
||||
# Fake a Jax2TFLimitation from the Limitation
|
||||
tfl = Jax2TfLimitation(description="Not implemented in JAX: " + l.description,
|
||||
devices = l.devices,
|
||||
dtypes = l.dtypes,
|
||||
expect_tf_error = False,
|
||||
skip_tf_run = True)
|
||||
unique_limitations[hash(unique_hash(h, tfl))] = (h, tfl)
|
||||
for h in harnesses:
|
||||
for l in Jax2TfLimitation.limitations_for_harness(h):
|
||||
unique_limitations[hash(unique_hash(h, l))] = (h, l)
|
||||
|
Loading…
x
Reference in New Issue
Block a user