15 Commits

Author SHA1 Message Date
George Necula
235eb8c2b4 Copybara import of the project:
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00
George Necula
e0a87ef062 [jax2tf] Update limitations due to updates in TF 2021-03-29 16:54:11 +03:00
Lukas Geiger
6386057cf0 DOC: Fix rendering of jax2tf docs 2021-03-23 00:55:11 +01:00
George Necula
469e1f05e1 [jax2tf] Updated the limitations
Update the documentation to reflect the improvements in TF support for dtypes
2021-03-03 12:38:00 +01:00
George Necula
ec2301a9ce Update limitations docs 2021-02-02 10:31:04 +02:00
George Necula
3c89de6eed [jax2tf] Add the JAX-not-implemented to the jax2tf limitations doc 2021-01-29 14:20:38 +01:00
George Necula
c2a32be2a2 [jax2tf] Update the limitations due to improvements in TF
TF now handles a few more ops for unsigned types.
Also igamma and igammac support for f16 anf bf16 were added to
JAX, but not yet to TF, hence the new limitations in TF.
2021-01-29 10:18:31 +01:00
George Necula
67b5af97f7 Copybara import of the project:
--
9be685946252edc67c2c28261b100b9aee68614a by George Necula <gcnecula@gmail.com>:

Change the translation rule for lax.nextafter_p to ensure
broadcasting during translation.

Previously, this was the only binary arithmetic primitive that
did not have broadcasting during translation. Trying to use it
with non-equal shapes resulted in the error:

```
 RuntimeError: Internal: RET_CHECK failure
(external/org_tensorflow/tensorflow/compiler/xla/client/xla_builder.cc:748)
non_scalar_shape.value().dimensions() == shape->dimensions() Unimplemented
implicit broadcast.:
       This is a bug in JAX's shape-checking rules; please report it!
```

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5448 from gnecula:nextafter 9be685946252edc67c2c28261b100b9aee68614a
PiperOrigin-RevId: 352367039
2021-01-18 01:59:55 -08:00
George Necula
ca27c83c0c [jax2tf] Cleanup the way jax2tf limitations are specified
Now each limitation is very explicit about the modes it applies
to, with the default being all modes. There is no more special
casing for TPUs.
2021-01-15 12:10:41 +02:00
George Necula
1fb9cd34b2 [jax2tf] Moved text from the documentation into the limitations for custom checks
* cleaned up the handling of tolerances and custom asserts.
* Removed the harnesss field from the limitations.
* Moved the definitions of the Jax2TfLimitation into its own
  file, so it can be reused.
2020-12-31 14:56:31 +02:00
George Necula
cf6da862bc [jax2tf] Refinements for the auto-generated limitations documentation
* fixed the summarization of "inexact" types in the generated docs
* added support for 'skip_tf_run' to the limitations
* cleaned up the logging for the jax2tf tests
2020-12-29 07:04:34 +02:00
George Necula
e3af2c7798 Rename the test file, and enable the jax_primitives_coverage_test.
PiperOrigin-RevId: 349301827
2020-12-28 10:28:57 -08:00
George Necula
1f098d4328 Fixes suggested by @bchetioui 2020-12-28 13:41:52 +02:00
George Necula
e76ce041f2 Added the tolerances and custom asserts as limitations 2020-12-28 13:13:19 +02:00
George Necula
a5fbc2865d Major refactoring of the jax2tf test harnesses.
See the PR description for details.
2020-12-27 17:56:54 +02:00