mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
DOC: Fix rendering of jax2tf docs
This commit is contained in:
parent
e1646bf2b8
commit
6386057cf0
@ -1,6 +1,6 @@
|
||||
# Primitives with limited JAX support
|
||||
|
||||
*Last generated on: 2021-03-03* (YYYY-MM-DD)
|
||||
*Last generated on: 2021-03-23* (YYYY-MM-DD)
|
||||
|
||||
## Supported data types for primitives
|
||||
|
||||
@ -39,7 +39,7 @@ be updated.
|
||||
|
||||
|
||||
| Primitive | Total test harnesses | dtypes supported on at least one device | dtypes NOT tested on any device |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| --- | --- | --- | --- |
|
||||
| abs | 10 | inexact, signed | bool, unsigned |
|
||||
| acos | 6 | inexact | bool, integer |
|
||||
| acosh | 6 | inexact | bool, integer |
|
||||
@ -183,7 +183,7 @@ and search for "limitation".
|
||||
|
||||
|
||||
| Affected primitive | Description of limitation | Affected dtypes | Affected devices |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| --- | --- | --- | --- |
|
||||
|cholesky|unimplemented|float16|cpu, gpu|
|
||||
|cummax|unimplemented|complex64|tpu|
|
||||
|cummin|unimplemented|complex64|tpu|
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Primitives with limited support for jax2tf
|
||||
|
||||
*Last generated on (YYYY-MM-DD): 2021-03-03*
|
||||
*Last generated on (YYYY-MM-DD): 2021-03-23*
|
||||
|
||||
This document summarizes known limitations of the jax2tf conversion.
|
||||
There are several kinds of limitations.
|
||||
@ -53,7 +53,7 @@ More detailed information can be found in the
|
||||
|
||||
|
||||
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
|
||||
| --- | --- | --- | --- | --- | ---|
|
||||
| --- | --- | --- | --- | --- |
|
||||
| acos | TF error: op not defined for dtype | complex128 | cpu, gpu | eager, graph |
|
||||
| acos | TF error: op not defined for dtype | bfloat16, complex64, float16 | cpu, gpu | eager, graph |
|
||||
| acosh | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
|
||||
@ -72,6 +72,7 @@ More detailed information can be found in the
|
||||
| cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph |
|
||||
| clamp | TF error: op not defined for dtype | int8, uint16, uint32, uint64 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| conv_general_dilated | TF error: jax2tf BUG: batch_group_count > 1 not yet converted | all | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| conv_general_dilated | TF error: op not defined for dtype | complex | gpu | compiled, eager, graph |
|
||||
| cosh | TF error: op not defined for dtype | float16 | cpu, gpu | eager, graph |
|
||||
| cummax | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
||||
| cummax | TF error: op not defined for dtype | complex128 | cpu, gpu | compiled, eager, graph |
|
||||
@ -167,7 +168,7 @@ with jax2tf. The following table lists that cases when this does not quite hold:
|
||||
|
||||
|
||||
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
|
||||
| --- | --- | --- | --- | --- | ---|
|
||||
| --- | --- | --- | --- | --- |
|
||||
| acos | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
|
||||
| acosh | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
|
||||
| asin | May return different but still correct results | complex | cpu, gpu, tpu | eager, graph |
|
||||
|
@ -103,7 +103,7 @@ class JaxPrimitiveTest(jtu.JaxTestCase):
|
||||
|
||||
primitive_coverage_table = ["""
|
||||
| Primitive | Total test harnesses | dtypes supported on at least one device | dtypes NOT tested on any device |
|
||||
| --- | --- | --- | --- | --- |"""]
|
||||
| --- | --- | --- | --- |"""]
|
||||
all_dtypes = set(jtu.dtypes.all)
|
||||
|
||||
for group_name in sorted(harness_groups.keys()):
|
||||
@ -120,7 +120,7 @@ class JaxPrimitiveTest(jtu.JaxTestCase):
|
||||
print(f"Found {len(unique_limitations)} unique limitations")
|
||||
primitive_unimpl_table = ["""
|
||||
| Affected primitive | Description of limitation | Affected dtypes | Affected devices |
|
||||
| --- | --- | --- | --- | --- |"""]
|
||||
| --- | --- | --- | --- |"""]
|
||||
for h, l in sorted(
|
||||
unique_limitations.values(), key=lambda pair: unique_hash(*pair)):
|
||||
devices = ", ".join(l.devices)
|
||||
|
@ -173,7 +173,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
tf_error_table = [
|
||||
"""
|
||||
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
|
||||
| --- | --- | --- | --- | --- | ---|"""
|
||||
| --- | --- | --- | --- | --- |"""
|
||||
]
|
||||
tf_numerical_discrepancies_table = list(tf_error_table) # a copy
|
||||
for h, l in sorted(
|
||||
|
Loading…
x
Reference in New Issue
Block a user