--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:
[jax2tf] Change the conversion of dot_general to use XLA op.
Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.
Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
Revert also previous changes that pinned numpy to 1.19.
One of the changes in numpy 1.20 is to add more type annotations.
However, this sometimes make mypy give errors. A common example is
numpy.take, which with the new type annotation does not appear to
mypy as indexable.
Another change is that np.int and np.bool are deprecated. One
should use np.bool_ or np.int_, or the built-ins bool and int.
TF now handles a few more ops for unsigned types.
Also igamma and igammac support for f16 anf bf16 were added to
JAX, but not yet to TF, hence the new limitations in TF.
--
781492e0120ec915f9fdc83479884908f59d113d by George Necula <gcnecula@gmail.com>:
[jax2tf] Update limitations
Some bugs were fixed on the TF-side, and we can remove
some limitations.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5449 from gnecula:jax2tf_limit 781492e0120ec915f9fdc83479884908f59d113d
PiperOrigin-RevId: 352381535
--
9be685946252edc67c2c28261b100b9aee68614a by George Necula <gcnecula@gmail.com>:
Change the translation rule for lax.nextafter_p to ensure
broadcasting during translation.
Previously, this was the only binary arithmetic primitive that
did not have broadcasting during translation. Trying to use it
with non-equal shapes resulted in the error:
```
RuntimeError: Internal: RET_CHECK failure
(external/org_tensorflow/tensorflow/compiler/xla/client/xla_builder.cc:748)
non_scalar_shape.value().dimensions() == shape->dimensions() Unimplemented
implicit broadcast.:
This is a bug in JAX's shape-checking rules; please report it!
```
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5448 from gnecula:nextafter 9be685946252edc67c2c28261b100b9aee68614a
PiperOrigin-RevId: 352367039
We were blocked until now due to limited support for the XlaSort op.
Also removed the patching of the XlaPad output shape; now XlaPad has
shape inference.
* cleaned up the handling of tolerances and custom asserts.
* Removed the harnesss field from the limitations.
* Moved the definitions of the Jax2TfLimitation into its own
file, so it can be reused.
* fixed the summarization of "inexact" types in the generated docs
* added support for 'skip_tf_run' to the limitations
* cleaned up the logging for the jax2tf tests