DOC: Fix rendering of jax2tf docs

This commit is contained in:
Lukas Geiger 2021-03-23 00:55:11 +01:00
parent e1646bf2b8
commit 6386057cf0
4 changed files with 10 additions and 9 deletions

View File

@ -1,6 +1,6 @@
# Primitives with limited JAX support
*Last generated on: 2021-03-03* (YYYY-MM-DD)
*Last generated on: 2021-03-23* (YYYY-MM-DD)
## Supported data types for primitives
@ -39,7 +39,7 @@ be updated.
| Primitive | Total test harnesses | dtypes supported on at least one device | dtypes NOT tested on any device |
| --- | --- | --- | --- | --- |
| --- | --- | --- | --- |
| abs | 10 | inexact, signed | bool, unsigned |
| acos | 6 | inexact | bool, integer |
| acosh | 6 | inexact | bool, integer |
@ -183,7 +183,7 @@ and search for "limitation".
| Affected primitive | Description of limitation | Affected dtypes | Affected devices |
| --- | --- | --- | --- | --- |
| --- | --- | --- | --- |
|cholesky|unimplemented|float16|cpu, gpu|
|cummax|unimplemented|complex64|tpu|
|cummin|unimplemented|complex64|tpu|

View File

@ -1,6 +1,6 @@
# Primitives with limited support for jax2tf
*Last generated on (YYYY-MM-DD): 2021-03-03*
*Last generated on (YYYY-MM-DD): 2021-03-23*
This document summarizes known limitations of the jax2tf conversion.
There are several kinds of limitations.
@ -53,7 +53,7 @@ More detailed information can be found in the
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
| --- | --- | --- | --- | --- | ---|
| --- | --- | --- | --- | --- |
| acos | TF error: op not defined for dtype | complex128 | cpu, gpu | eager, graph |
| acos | TF error: op not defined for dtype | bfloat16, complex64, float16 | cpu, gpu | eager, graph |
| acosh | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
@ -72,6 +72,7 @@ More detailed information can be found in the
| cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph |
| clamp | TF error: op not defined for dtype | int8, uint16, uint32, uint64 | cpu, gpu, tpu | compiled, eager, graph |
| conv_general_dilated | TF error: jax2tf BUG: batch_group_count > 1 not yet converted | all | cpu, gpu, tpu | compiled, eager, graph |
| conv_general_dilated | TF error: op not defined for dtype | complex | gpu | compiled, eager, graph |
| cosh | TF error: op not defined for dtype | float16 | cpu, gpu | eager, graph |
| cummax | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| cummax | TF error: op not defined for dtype | complex128 | cpu, gpu | compiled, eager, graph |
@ -167,7 +168,7 @@ with jax2tf. The following table lists that cases when this does not quite hold:
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
| --- | --- | --- | --- | --- | ---|
| --- | --- | --- | --- | --- |
| acos | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
| acosh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
| asin | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |

View File

@ -103,7 +103,7 @@ class JaxPrimitiveTest(jtu.JaxTestCase):
primitive_coverage_table = ["""
| Primitive | Total test harnesses | dtypes supported on at least one device | dtypes NOT tested on any device |
| --- | --- | --- | --- | --- |"""]
| --- | --- | --- | --- |"""]
all_dtypes = set(jtu.dtypes.all)
for group_name in sorted(harness_groups.keys()):
@ -120,7 +120,7 @@ class JaxPrimitiveTest(jtu.JaxTestCase):
print(f"Found {len(unique_limitations)} unique limitations")
primitive_unimpl_table = ["""
| Affected primitive | Description of limitation | Affected dtypes | Affected devices |
| --- | --- | --- | --- | --- |"""]
| --- | --- | --- | --- |"""]
for h, l in sorted(
unique_limitations.values(), key=lambda pair: unique_hash(*pair)):
devices = ", ".join(l.devices)

View File

@ -173,7 +173,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
tf_error_table = [
"""
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
| --- | --- | --- | --- | --- | ---|"""
| --- | --- | --- | --- | --- |"""
]
tf_numerical_discrepancies_table = list(tf_error_table) # a copy
for h, l in sorted(