If enable_xla we should directly use the XLA conversion, without
trying to see whether the primitive can actually be converted
to non XLA TF ops. This is happening elsewhere, but it was not
happening for pad.
Also discovered that we were never turning enable_xla=False in
tests, even when we had special harnesses for it.
This is another attempt for the rolled-back #6722.
PiperOrigin-RevId: 374597824
Previously we simply converted integer_pow to tf.math.pow. JAX instead uses
a series of multiplications. We now use the same lowering strategy as JAX, so
that we have the same numerical result.
Also improved the error messages for assertion failures.
PiperOrigin-RevId: 373351147
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:
[jax2tf] Change the conversion of dot_general to use XLA op.
Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.
Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
If enable_xla we should directly use the XLA conversion, without
trying to see whether the primitive can actually be converted
to non XLA TF ops. This is happening elsewhere, but it was not
happening for pad.
Also discovered that we were never turning enable_xla=False in
tests, even when we had special harnesses for it.
Revert also previous changes that pinned numpy to 1.19.
One of the changes in numpy 1.20 is to add more type annotations.
However, this sometimes make mypy give errors. A common example is
numpy.take, which with the new type annotation does not appear to
mypy as indexable.
Another change is that np.int and np.bool are deprecated. One
should use np.bool_ or np.int_, or the built-ins bool and int.
TF now handles a few more ops for unsigned types.
Also igamma and igammac support for f16 anf bf16 were added to
JAX, but not yet to TF, hence the new limitations in TF.
--
781492e0120ec915f9fdc83479884908f59d113d by George Necula <gcnecula@gmail.com>:
[jax2tf] Update limitations
Some bugs were fixed on the TF-side, and we can remove
some limitations.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5449 from gnecula:jax2tf_limit 781492e0120ec915f9fdc83479884908f59d113d
PiperOrigin-RevId: 352381535
--
9be685946252edc67c2c28261b100b9aee68614a by George Necula <gcnecula@gmail.com>:
Change the translation rule for lax.nextafter_p to ensure
broadcasting during translation.
Previously, this was the only binary arithmetic primitive that
did not have broadcasting during translation. Trying to use it
with non-equal shapes resulted in the error:
```
RuntimeError: Internal: RET_CHECK failure
(external/org_tensorflow/tensorflow/compiler/xla/client/xla_builder.cc:748)
non_scalar_shape.value().dimensions() == shape->dimensions() Unimplemented
implicit broadcast.:
This is a bug in JAX's shape-checking rules; please report it!
```
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5448 from gnecula:nextafter 9be685946252edc67c2c28261b100b9aee68614a
PiperOrigin-RevId: 352367039
We were blocked until now due to limited support for the XlaSort op.
Also removed the patching of the XlaPad output shape; now XlaPad has
shape inference.
* cleaned up the handling of tolerances and custom asserts.
* Removed the harnesss field from the limitations.
* Moved the definitions of the Jax2TfLimitation into its own
file, so it can be reused.
* fixed the summarization of "inexact" types in the generated docs
* added support for 'skip_tf_run' to the limitations
* cleaned up the logging for the jax2tf tests