132 Commits

Author SHA1 Message Date
George Necula
235eb8c2b4 Copybara import of the project:
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00
George Necula
e2d546638c [jax2tf] Re-organized the tests for shape polymorphism
Added primitive harnesses and rewrote many existing tests in terms
of those.

Fixed the shape polymorphism for jnp.where.
2021-04-15 13:27:33 +03:00
George Necula
5750ec074a Fix scatter 2021-04-08 10:42:38 +03:00
George Necula
6176ac1cb6 [jax2tf] Fix bug in dot_general.
The case when there were batch dimensions but RHS has only
one inner dimmension was handled incorrectly. Add test also.
2021-04-05 15:57:17 +03:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
George Necula
469e1f05e1 [jax2tf] Updated the limitations
Update the documentation to reflect the improvements in TF support for dtypes
2021-03-03 12:38:00 +01:00
Jake VanderPlas
41b7a0f770 Re-land #4850 weak types change 2021-02-09 09:07:52 -08:00
George Necula
f105517ea2 Fixed mypy type errors for numpy 1.20
Revert also previous changes that pinned numpy to 1.19.

One of the changes in numpy 1.20 is to add more type annotations.
However, this sometimes make mypy give errors. A common example is
numpy.take, which with the new type annotation does not appear to
mypy as indexable.

Another change is that np.int and np.bool are deprecated. One
should use np.bool_ or np.int_, or the built-ins bool and int.
2021-02-05 10:40:47 +02:00
George Necula
ec2301a9ce Update limitations docs 2021-02-02 10:31:04 +02:00
Peter Hawkins
046961ad4e Disable bfloat16 in jax2tf eigh test.
We don't have any bfloat16 implementations, so we shouldn't be testing it.
2021-02-01 21:51:46 -05:00
George Necula
c2a32be2a2 [jax2tf] Update the limitations due to improvements in TF
TF now handles a few more ops for unsigned types.
Also igamma and igammac support for f16 anf bf16 were added to
JAX, but not yet to TF, hence the new limitations in TF.
2021-01-29 10:18:31 +01:00
Peter Hawkins
929a684a39 Small cleanups to dependency structure.
PiperOrigin-RevId: 352853244
2021-01-20 12:43:28 -08:00
George Necula
6bf634921e Copybara import of the project:
--
781492e0120ec915f9fdc83479884908f59d113d by George Necula <gcnecula@gmail.com>:

[jax2tf] Update limitations

Some bugs were fixed on the TF-side, and we can remove
some limitations.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5449 from gnecula:jax2tf_limit 781492e0120ec915f9fdc83479884908f59d113d
PiperOrigin-RevId: 352381535
2021-01-18 03:38:22 -08:00
George Necula
67b5af97f7 Copybara import of the project:
--
9be685946252edc67c2c28261b100b9aee68614a by George Necula <gcnecula@gmail.com>:

Change the translation rule for lax.nextafter_p to ensure
broadcasting during translation.

Previously, this was the only binary arithmetic primitive that
did not have broadcasting during translation. Trying to use it
with non-equal shapes resulted in the error:

```
 RuntimeError: Internal: RET_CHECK failure
(external/org_tensorflow/tensorflow/compiler/xla/client/xla_builder.cc:748)
non_scalar_shape.value().dimensions() == shape->dimensions() Unimplemented
implicit broadcast.:
       This is a bug in JAX's shape-checking rules; please report it!
```

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5448 from gnecula:nextafter 9be685946252edc67c2c28261b100b9aee68614a
PiperOrigin-RevId: 352367039
2021-01-18 01:59:55 -08:00
George Necula
ca27c83c0c [jax2tf] Cleanup the way jax2tf limitations are specified
Now each limitation is very explicit about the modes it applies
to, with the default being all modes. There is no more special
casing for TPUs.
2021-01-15 12:10:41 +02:00
George Necula
8275da2433 [jax2tf] Adjusting test tolerances
PiperOrigin-RevId: 351846168
2021-01-14 11:54:27 -08:00
George Necula
274d970644 [jax2tf] Finish the conversion of lax.sort
We were blocked until now due to limited support for the XlaSort op.
Also removed the patching of the XlaPad output shape; now XlaPad has
shape inference.
2021-01-07 16:26:45 +02:00
George Necula
1fb9cd34b2 [jax2tf] Moved text from the documentation into the limitations for custom checks
* cleaned up the handling of tolerances and custom asserts.
* Removed the harnesss field from the limitations.
* Moved the definitions of the Jax2TfLimitation into its own
  file, so it can be reused.
2020-12-31 14:56:31 +02:00
George Necula
cf6da862bc [jax2tf] Refinements for the auto-generated limitations documentation
* fixed the summarization of "inexact" types in the generated docs
* added support for 'skip_tf_run' to the limitations
* cleaned up the logging for the jax2tf tests
2020-12-29 07:04:34 +02:00
George Necula
1f098d4328 Fixes suggested by @bchetioui 2020-12-28 13:41:52 +02:00
George Necula
e76ce041f2 Added the tolerances and custom asserts as limitations 2020-12-28 13:13:19 +02:00
George Necula
a5fbc2865d Major refactoring of the jax2tf test harnesses.
See the PR description for details.
2020-12-27 17:56:54 +02:00
Matthew Johnson
3dee321fb8 rollback of #4850 2020-12-23 11:01:58 -08:00
Benjamin Chetioui
c09a73abda [jax2tf] Systematize broadcasting tests for binary elementwise harnesses.
Also add broadcasting tests to min and max, and splits the logic
for add and mul.
2020-12-17 15:52:54 +01:00
Benjamin Chetioui
08f2c7652b [jax2tf] Clean up and expand binary elementwise harnesses. 2020-12-16 17:24:32 +01:00
jax authors
43f603bae6 Merge pull request #5187 from bchetioui:remake_elementwise_harnesses
PiperOrigin-RevId: 347816538
2020-12-16 07:15:45 -08:00
Benjamin Chetioui
4f9000a78f [jax2tf] Minor fixes in round harness. 2020-12-15 13:27:48 +01:00
Benjamin Chetioui
7d4835786a [jax2tf] Expand coverage for unary elementwise ops. 2020-12-14 09:35:19 +01:00
Jake VanderPlas
c63097bc90 Add weak_type argument to convert_element_type_p 2020-12-10 11:10:21 -08:00
Benjamin Chetioui
99ad7195fb [jax2tf] Added testing for the conversion of stop_gradient. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
cdba4a73fb [jax2tf] Added testing for the conversion of tie_in. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
e6d12772ea [jax2tf] Added testing for the conversion of add_jaxvals_p.
Note that tests skip types that are not generally compatible with
jax2tf, e.g. core.Unit, which can be passed to this primitive.
2020-12-05 10:31:35 +01:00
Benjamin Chetioui
7a4448392e [jax2tf] Added testing for the conversion of bitcast_convert_type. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
d62fe001e2 [jax2tf] Added testing for the conversion of sub. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
099ff26351 [jax2tf] Added testing for the conversion of device_put. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
fdcb1bbde5 [jax2tf] Added testing for the conversion of rev. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
5bf5f0bde4 [jax2tf] Added testing for the conversion of reshape. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
809c54e8f8 [jax2tf] Added testing for the conversion of pow. 2020-12-05 10:31:35 +01:00
Benjamin Chetioui
92821fe176 [jax2tf] Added testing for the conversion of integer_pow. 2020-12-05 10:29:57 +01:00
Matthew Johnson
25b3d843b4 fix jax2tf 2020-12-04 16:40:59 -08:00
Benjamin Chetioui
236a6492e4 [jax2tf] Added testing for the conversion of convert_element_type. 2020-12-04 15:27:02 +01:00
jax authors
1135244550 Merge pull request #4959 from bchetioui:issue4952
PiperOrigin-RevId: 345110593
2020-12-01 15:17:39 -08:00
Benjamin Chetioui
f93ca369a8 [jax2tf] Fix and test conversion of round.
Fixes google/jax#4952.
2020-12-01 17:28:57 +01:00
Benjamin Chetioui
e4d5be27f6 [jax2tf] Fix and test conversion of iota 2020-11-20 18:51:30 +01:00
Benjamin Chetioui
e65ed8b647 [jax2tf] Added testing for real_p/imag_p 2020-11-20 18:51:30 +01:00
Benjamin Chetioui
559e507734 [jax2tf] Add testing for the conversion of complex_p. 2020-11-20 18:51:30 +01:00
Benjamin Chetioui
f7480d65f1 [jax2tf] Added tests for the conversion of reducers. 2020-11-20 18:51:20 +01:00
Benjamin Chetioui
7c471b6b01 [jax2tf] Added testing code for div and rem. 2020-11-19 18:42:50 +01:00
jax authors
050b8795a7 Merge pull request #4926 from bchetioui:test_transpose
PiperOrigin-RevId: 343157729
2020-11-18 14:32:20 -08:00
jax authors
0673047658 Merge pull request #4927 from bchetioui:test_zeros_like
PiperOrigin-RevId: 342948171
2020-11-17 14:35:01 -08:00