145 Commits

Author SHA1 Message Date
George Necula
dd8ab85121 [jax2tf] Support inequality and min/max for booleans.
For inequalities we add casts to int8. For min/max we rewrite
to logical operations and/or.
2021-06-12 21:08:37 +03:00
Peter Hawkins
e9611eb090 Move jax.ad_util to jax._src.ad_util.
Expose ad_util.stop_gradient_p as jax.lax.stop_gradient_p. stop_gradient() is already under the external lax namespace.

PiperOrigin-RevId: 378011152
2021-06-07 14:51:34 -07:00
Marc van Zee
9fed620119 Adds support for lax.dynamic_slice_p when XLA is disabled.
PiperOrigin-RevId: 377909682
2021-06-07 07:22:48 -07:00
George Necula
d243258b86 [jax2tf] Implement inequalities and friends for complex numbers.
This requires re-using JAX's lowering rule for comparisons of
complex numbers to use lexicographic comparison.
2021-06-04 17:56:44 +03:00
George Necula
74638a4553 [jax2tf] Improve conversion of sign and abs, to account for TF limitations
PiperOrigin-RevId: 375274010
2021-05-22 11:03:01 -07:00
George Necula
1f83460b72 [jax2tf] Improve conversion of pad when enable_xla=True, add tests.
If enable_xla we should directly use the XLA conversion, without
trying to see whether the primitive can actually be converted
to non XLA TF ops. This is happening elsewhere, but it was not
happening for pad.

Also discovered that we were never turning enable_xla=False in
tests, even when we had special harnesses for it.

This is another attempt for the rolled-back #6722.

PiperOrigin-RevId: 374597824
2021-05-19 01:49:01 -07:00
George Necula
bf63107046 [jax2tf] Add support for preferred_element_type for convolutions.
PiperOrigin-RevId: 374347868
2021-05-17 22:34:14 -07:00
George Necula
a08cdb30ff [jax2tf] Update the limitations for unsupported primitives
Also update the documentation.
2021-05-17 10:01:13 +03:00
Peter Hawkins
8b5c640608 rollback #6722
PiperOrigin-RevId: 373651549
2021-05-13 13:53:58 -07:00
jax authors
b42e9e3789 Merge pull request #6722 from gnecula:tf_enable_xla_test
PiperOrigin-RevId: 373562914
2021-05-13 06:02:16 -07:00
George Necula
f4fa7c7ad0 [jax2tf] Remove dot_general limitation due to XLA fixing crashing bug
PiperOrigin-RevId: 373375341
2021-05-12 08:32:11 -07:00
George Necula
ba5e11f86f [jax2tf] Improve the conversion of integer_pow for better numerical accuracy.
Previously we simply converted integer_pow to tf.math.pow. JAX instead uses
a series of multiplications. We now use the same lowering strategy as JAX, so
that we have the same numerical result.

Also improved the error messages for assertion failures.

PiperOrigin-RevId: 373351147
2021-05-12 05:45:39 -07:00
George Necula
235eb8c2b4 Copybara import of the project:
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00
George Necula
281ae7c12d [jax2tf] Improve conversion of pad when enable_xla=True, add tests.
If enable_xla we should directly use the XLA conversion, without
trying to see whether the primitive can actually be converted
to non XLA TF ops. This is happening elsewhere, but it was not
happening for pad.

Also discovered that we were never turning enable_xla=False in
tests, even when we had special harnesses for it.
2021-05-11 12:17:55 +03:00
George Necula
e2d546638c [jax2tf] Re-organized the tests for shape polymorphism
Added primitive harnesses and rewrote many existing tests in terms
of those.

Fixed the shape polymorphism for jnp.where.
2021-04-15 13:27:33 +03:00
George Necula
5750ec074a Fix scatter 2021-04-08 10:42:38 +03:00
George Necula
6176ac1cb6 [jax2tf] Fix bug in dot_general.
The case when there were batch dimensions but RHS has only
one inner dimmension was handled incorrectly. Add test also.
2021-04-05 15:57:17 +03:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
George Necula
469e1f05e1 [jax2tf] Updated the limitations
Update the documentation to reflect the improvements in TF support for dtypes
2021-03-03 12:38:00 +01:00
Jake VanderPlas
41b7a0f770 Re-land #4850 weak types change 2021-02-09 09:07:52 -08:00
George Necula
f105517ea2 Fixed mypy type errors for numpy 1.20
Revert also previous changes that pinned numpy to 1.19.

One of the changes in numpy 1.20 is to add more type annotations.
However, this sometimes make mypy give errors. A common example is
numpy.take, which with the new type annotation does not appear to
mypy as indexable.

Another change is that np.int and np.bool are deprecated. One
should use np.bool_ or np.int_, or the built-ins bool and int.
2021-02-05 10:40:47 +02:00
George Necula
ec2301a9ce Update limitations docs 2021-02-02 10:31:04 +02:00
Peter Hawkins
046961ad4e Disable bfloat16 in jax2tf eigh test.
We don't have any bfloat16 implementations, so we shouldn't be testing it.
2021-02-01 21:51:46 -05:00
George Necula
c2a32be2a2 [jax2tf] Update the limitations due to improvements in TF
TF now handles a few more ops for unsigned types.
Also igamma and igammac support for f16 anf bf16 were added to
JAX, but not yet to TF, hence the new limitations in TF.
2021-01-29 10:18:31 +01:00
Peter Hawkins
929a684a39 Small cleanups to dependency structure.
PiperOrigin-RevId: 352853244
2021-01-20 12:43:28 -08:00
George Necula
6bf634921e Copybara import of the project:
--
781492e0120ec915f9fdc83479884908f59d113d by George Necula <gcnecula@gmail.com>:

[jax2tf] Update limitations

Some bugs were fixed on the TF-side, and we can remove
some limitations.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5449 from gnecula:jax2tf_limit 781492e0120ec915f9fdc83479884908f59d113d
PiperOrigin-RevId: 352381535
2021-01-18 03:38:22 -08:00
George Necula
67b5af97f7 Copybara import of the project:
--
9be685946252edc67c2c28261b100b9aee68614a by George Necula <gcnecula@gmail.com>:

Change the translation rule for lax.nextafter_p to ensure
broadcasting during translation.

Previously, this was the only binary arithmetic primitive that
did not have broadcasting during translation. Trying to use it
with non-equal shapes resulted in the error:

```
 RuntimeError: Internal: RET_CHECK failure
(external/org_tensorflow/tensorflow/compiler/xla/client/xla_builder.cc:748)
non_scalar_shape.value().dimensions() == shape->dimensions() Unimplemented
implicit broadcast.:
       This is a bug in JAX's shape-checking rules; please report it!
```

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5448 from gnecula:nextafter 9be685946252edc67c2c28261b100b9aee68614a
PiperOrigin-RevId: 352367039
2021-01-18 01:59:55 -08:00
George Necula
ca27c83c0c [jax2tf] Cleanup the way jax2tf limitations are specified
Now each limitation is very explicit about the modes it applies
to, with the default being all modes. There is no more special
casing for TPUs.
2021-01-15 12:10:41 +02:00
George Necula
8275da2433 [jax2tf] Adjusting test tolerances
PiperOrigin-RevId: 351846168
2021-01-14 11:54:27 -08:00
George Necula
274d970644 [jax2tf] Finish the conversion of lax.sort
We were blocked until now due to limited support for the XlaSort op.
Also removed the patching of the XlaPad output shape; now XlaPad has
shape inference.
2021-01-07 16:26:45 +02:00
George Necula
1fb9cd34b2 [jax2tf] Moved text from the documentation into the limitations for custom checks
* cleaned up the handling of tolerances and custom asserts.
* Removed the harnesss field from the limitations.
* Moved the definitions of the Jax2TfLimitation into its own
  file, so it can be reused.
2020-12-31 14:56:31 +02:00
George Necula
cf6da862bc [jax2tf] Refinements for the auto-generated limitations documentation
* fixed the summarization of "inexact" types in the generated docs
* added support for 'skip_tf_run' to the limitations
* cleaned up the logging for the jax2tf tests
2020-12-29 07:04:34 +02:00
George Necula
1f098d4328 Fixes suggested by @bchetioui 2020-12-28 13:41:52 +02:00
George Necula
e76ce041f2 Added the tolerances and custom asserts as limitations 2020-12-28 13:13:19 +02:00
George Necula
a5fbc2865d Major refactoring of the jax2tf test harnesses.
See the PR description for details.
2020-12-27 17:56:54 +02:00
Matthew Johnson
3dee321fb8 rollback of #4850 2020-12-23 11:01:58 -08:00
Benjamin Chetioui
c09a73abda [jax2tf] Systematize broadcasting tests for binary elementwise harnesses.
Also add broadcasting tests to min and max, and splits the logic
for add and mul.
2020-12-17 15:52:54 +01:00
Benjamin Chetioui
08f2c7652b [jax2tf] Clean up and expand binary elementwise harnesses. 2020-12-16 17:24:32 +01:00
jax authors
43f603bae6 Merge pull request #5187 from bchetioui:remake_elementwise_harnesses
PiperOrigin-RevId: 347816538
2020-12-16 07:15:45 -08:00
Benjamin Chetioui
4f9000a78f [jax2tf] Minor fixes in round harness. 2020-12-15 13:27:48 +01:00
Benjamin Chetioui
7d4835786a [jax2tf] Expand coverage for unary elementwise ops. 2020-12-14 09:35:19 +01:00
Jake VanderPlas
c63097bc90 Add weak_type argument to convert_element_type_p 2020-12-10 11:10:21 -08:00
Benjamin Chetioui
99ad7195fb [jax2tf] Added testing for the conversion of stop_gradient. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
cdba4a73fb [jax2tf] Added testing for the conversion of tie_in. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
e6d12772ea [jax2tf] Added testing for the conversion of add_jaxvals_p.
Note that tests skip types that are not generally compatible with
jax2tf, e.g. core.Unit, which can be passed to this primitive.
2020-12-05 10:31:35 +01:00
Benjamin Chetioui
7a4448392e [jax2tf] Added testing for the conversion of bitcast_convert_type. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
d62fe001e2 [jax2tf] Added testing for the conversion of sub. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
099ff26351 [jax2tf] Added testing for the conversion of device_put. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
fdcb1bbde5 [jax2tf] Added testing for the conversion of rev. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
5bf5f0bde4 [jax2tf] Added testing for the conversion of reshape. 2020-12-05 10:31:35 +01:00