rocm_jax/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md
George Necula 235eb8c2b4 Copybara import of the project:
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00

9.0 KiB

Primitives with limited JAX support

Last generated on: 2021-05-12 (YYYY-MM-DD)

Supported data types for primitives

We use a set of 2418 test harnesses to test the implementation of 121 numeric JAX primitives. We consider a JAX primitive supported for a particular data type if it is supported on at least one device type. The following table shows the dtypes at which primitives are supported in JAX and at which dtypes they are not supported on any device. (In reality, this shows for each primitive what dtypes are not covered by the current test harnesses on any device.)

The set of supported dtypes include 64-bit types (float64, int64, uint64, complex128) only if the flag --jax_enable_x64 or the JAX_ENABLE_X64 environment variable are set.

We use below the following abbreviations for sets of dtypes:

  • signed = int8, int16, int32, int64
  • unsigned = uint8, uint16, uint32, uint64
  • integer = signed, unsigned
  • floating = float16, bfloat16, float32, float64
  • complex = complex64, complex128
  • inexact = floating, complex
  • all = integer, inexact, bool

See the comment in primitive_harness.py for a description of test harnesses and their definitions.

Note that automated tests will fail if new limitations appear, but they won't when limitations are fixed. If you see a limitation that you think it does not exist anymore, please ask for this file to be updated.

Primitive Total test harnesses dtypes supported on at least one device dtypes NOT tested on any device
abs 10 inexact, signed bool, unsigned
acos 6 inexact bool, integer
acosh 6 inexact bool, integer
add 16 inexact, integer bool
add_any 14 inexact, integer bool
and 11 bool, integer inexact
argmax 22 bool, floating, integer complex
argmin 22 bool, floating, integer complex
asin 6 inexact bool, integer
asinh 6 inexact bool, integer
atan 6 inexact bool, integer
atan2 6 floating bool, complex, integer
atanh 6 inexact bool, integer
bessel_i0e 4 floating bool, complex, integer
bessel_i1e 4 floating bool, complex, integer
bitcast_convert_type 41 all
broadcast 17 all
broadcast_in_dim 19 all
ceil 4 floating bool, complex, integer
cholesky 30 inexact bool, integer
clamp 17 floating, integer bool, complex
complex 4 float32, float64 bfloat16, bool, complex, float16, integer
concatenate 17 all
conj 5 complex, float32, float64 bfloat16, bool, float16, integer
conv_general_dilated 58 inexact bool, integer
convert_element_type 201 all
cos 6 inexact bool, integer
cosh 6 inexact bool, integer
cummax 17 inexact, integer bool
cummin 17 inexact, integer bool
cumprod 17 inexact, integer bool
cumsum 17 inexact, integer bool
custom_linear_solve 4 float32, float64 bfloat16, bool, complex, float16, integer
device_put 16 all
digamma 4 floating bool, complex, integer
div 20 inexact, integer bool
dot_general 245 all
dynamic_slice 32 all
dynamic_update_slice 21 all
eig 72 inexact bool, integer
eigh 36 inexact bool, integer
eq 17 all
erf 4 floating bool, complex, integer
erf_inv 4 floating bool, complex, integer
erfc 4 floating bool, complex, integer
exp 6 inexact bool, integer
expm1 6 inexact bool, integer
fft 20 complex, float32, float64 bfloat16, bool, float16, integer
floor 4 floating bool, complex, integer
gather 37 all
ge 15 bool, floating, integer complex
gt 15 bool, floating, integer complex
igamma 6 floating bool, complex, integer
igammac 6 floating bool, complex, integer
imag 2 complex bool, floating, integer
integer_pow 34 inexact, integer bool
iota 16 inexact, integer bool
is_finite 4 floating bool, complex, integer
le 15 bool, floating, integer complex
lgamma 4 floating bool, complex, integer
log 6 inexact bool, integer
log1p 6 inexact bool, integer
lt 15 bool, floating, integer complex
lu 18 inexact bool, integer
max 29 all
min 29 all
mul 16 inexact, integer bool
ne 17 all
neg 14 inexact, integer bool
nextafter 6 floating bool, complex, integer
or 11 bool, integer inexact
pad 90 all
population_count 8 integer bool, inexact
pow 10 inexact bool, integer
qr 60 inexact bool, integer
random_gamma 4 float32, float64 bfloat16, bool, complex, float16, integer
random_split 5 uint32 all
real 2 complex bool, floating, integer
reduce_and 1 bool inexact, integer
reduce_max 15 all
reduce_min 15 all
reduce_or 1 bool inexact, integer
reduce_prod 14 inexact, integer bool
reduce_sum 14 inexact, integer bool
reduce_window_add 33 inexact, integer bool
reduce_window_max 37 all
reduce_window_min 15 all
reduce_window_mul 42 inexact, integer bool
regularized_incomplete_beta 4 floating bool, complex, integer
rem 18 floating, integer bool, complex
reshape 19 all
rev 19 all
round 7 floating bool, complex, integer
rsqrt 6 inexact bool, integer
scatter_add 14 inexact, integer bool
scatter_max 15 all
scatter_min 19 all
scatter_mul 14 inexact, integer bool
select 16 all
select_and_gather_add 15 floating bool, complex, integer
select_and_scatter_add 27 bool, floating, integer complex
shift_left 10 integer bool, inexact
shift_right_arithmetic 10 integer bool, inexact
shift_right_logical 10 integer bool, inexact
sign 14 inexact, integer bool
sin 6 inexact bool, integer
sinh 6 inexact bool, integer
slice 24 all
sort 29 all
sqrt 6 inexact bool, integer
squeeze 23 all
stop_gradient 15 all
sub 16 inexact, integer bool
svd 120 inexact bool, integer
tan 6 inexact bool, integer
tanh 6 inexact bool, integer
top_k 15 bool, floating, integer complex
transpose 17 all
triangular_solve 26 inexact bool, integer
xor 11 bool, integer inexact
zeros_like 15 all

Partially implemented data types for primitives

In some cases, a primitive is supported in JAX at a given data type but it may be missing implementations for some of the devices. For example, the eigen decomposition (lax.eig) is implemented in JAX using custom kernels only on CPU and GPU. There is no TPU implementation. In other cases, there are either bugs or not-yet-implemented cases in the XLA compiler for different devices.

The following table shows which of the supported data types are partially implemented for each primitive. This table already excludes the unsupported data types (previous table).

In order to see the actual errors for all entries look through the logs of the test_jax_implemented from jax_primitives_coverage_test.py and search for "limitation".

Affected primitive Description of limitation Affected dtypes Affected devices
cholesky unimplemented float16 cpu, gpu
cummax unimplemented complex64 tpu
cummin unimplemented complex64 tpu
cumprod unimplemented complex64 tpu
dot_general preferred_element_type=c128 not implemented complex64 tpu
dot_general preferred_element_type=f64 crashes (b/187884887) bfloat16, float16, float32 tpu
dot_general preferred_element_type=i64 not implemented int16, int32, int8 tpu
eig only supported on CPU in JAX all tpu, gpu
eig unimplemented bfloat16, float16 cpu
eigh complex eigh not supported complex tpu
eigh unimplemented bfloat16, float16 cpu, gpu
lu unimplemented bfloat16, float16 cpu, gpu, tpu
qr unimplemented bfloat16, float16 cpu, gpu
reduce_window_max unimplemented in XLA complex64 tpu
reduce_window_min unimplemented in XLA complex64 tpu
reduce_window_mul unimplemented in XLA complex64 tpu
scatter_max unimplemented complex64 tpu
scatter_min unimplemented complex64 tpu
select_and_scatter_add works only for 2 or more inactive dimensions all tpu
svd complex not implemented. Works in JAX for CPU and GPU with custom kernels complex tpu
svd unimplemented bfloat16, float16 cpu, gpu
triangular_solve unimplemented float16 gpu

Table generation

To regenerate this table run (on a CPU machine):

  JAX_OUTPUT_LIMITATIONS_DOC=1 JAX_ENABLE_X64=1 python jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py JaxPrimitiveTest.test_generate_primitives_coverage_doc