8184 Commits

Author SHA1 Message Date
George Necula
a27109d1bd [jax2tf] Added documentation explaining how to handle undefined TF ops
Added a test case showing how to mix compileable and non-compileable code.
2021-05-19 08:23:33 +03:00
Peter Hawkins
86887a21e2 Add Mac M1 support to build.py. 2021-05-18 21:53:31 -04:00
Luke Pfister
bcd5deb269 Prevent nans in scale_and_translate
fixes #6780
2021-05-18 15:36:45 -06:00
jax authors
683289c4ad Merge pull request #6764 from hawkinsp:argmax
PiperOrigin-RevId: 374437439
2021-05-18 09:31:53 -07:00
Peter Hawkins
6cc440b79d Fix handling of NaNs in GPU argmax translation rule. 2021-05-18 11:35:54 -04:00
George Necula
bf63107046 [jax2tf] Add support for preferred_element_type for convolutions.
PiperOrigin-RevId: 374347868
2021-05-17 22:34:14 -07:00
jax authors
6a95a8cf50 Merge pull request #6771 from skye:version
PiperOrigin-RevId: 374321498
2021-05-17 18:16:26 -07:00
Skye Wanderman-Milne
63dbb99a66 Update README, etc. for jaxlib 0.1.67 release 2021-05-17 17:48:46 -07:00
jax authors
58a5ef7912 Merge pull request #6766 from jakevdp:py39
PiperOrigin-RevId: 374276022
2021-05-17 14:06:11 -07:00
jax authors
d39261497c Merge pull request #6748 from hawkinsp:complexconv
PiperOrigin-RevId: 374259292
2021-05-17 12:51:35 -07:00
jax authors
bea18073cf Merge pull request #6767 from skye:workspace
PiperOrigin-RevId: 374256702
jaxlib-v0.1.67
2021-05-17 12:40:25 -07:00
Skye Wanderman-Milne
c5e199a42d Update WORKSPACE for jaxlib 0.1.67 2021-05-17 12:15:54 -07:00
Jake VanderPlas
ed66dd4d57 CI: test Python 3.9 with numpy dispatch test 2021-05-17 10:10:52 -07:00
George Necula
149f7fa6c4 Tighten the numerical tolerances for jax2tf.
We want to focus on tolerances for the TF compiled mode. So, we
make the `custom_numeric` apply by default only to `eager` and `graph` modes.

Getting this done required iterating through tests and we used this
opportunity to tighten the tolerances and apply
them more narrowly, e.g., only on CPU or GPU. The tolerances are still
somewhat loose, mainly for linear algebra primitives.

PiperOrigin-RevId: 374215224
2021-05-17 09:36:47 -07:00
George Necula
afe5ec3df4 Improve accuracy of the jax2tf convolution conversion.
Part of the discrepancies were due to JAX using a workaround for
missing complex convolutions on CPU/GPU, while jax2tf was not using
it. We apply the same lowering as JAX, on all platforms.

This allows us to remove custom numeric tolerances and enables complex
convolutions on GPU.

PiperOrigin-RevId: 374199441
2021-05-17 08:18:51 -07:00
jax authors
e049d89657 Merge pull request #6757 from LenaMartens:changelist/373786534
PiperOrigin-RevId: 374183021
2021-05-17 06:30:39 -07:00
jax authors
6a28351acd Merge pull request #6760 from gnecula:tf_limitations
PiperOrigin-RevId: 374167386
2021-05-17 04:27:17 -07:00
George Necula
a08cdb30ff [jax2tf] Update the limitations for unsupported primitives
Also update the documentation.
2021-05-17 10:01:13 +03:00
jax authors
25cc3ece66 Merge pull request #6739 from jakevdp:sparse-op-jvp
PiperOrigin-RevId: 373870776
2021-05-14 14:50:20 -07:00
Jake VanderPlas
2bcae23e20 fix flakes & test tolerances 2021-05-14 14:18:12 -07:00
Roy Frostig
25ff2f4a94 propagate symbolic zeros in sparse op JVPs 2021-05-14 13:44:31 -07:00
Jake VanderPlas
926de5a2bc [sparse] add JVP & transpose rules for coo primitives 2021-05-14 13:43:53 -07:00
jax authors
fd6069c450 Merge pull request #6627 from GJBoth:sparse_ops_jvp
PiperOrigin-RevId: 373852297
2021-05-14 13:16:31 -07:00
Gert-Jan
81903e894b Add JVP rules COO sparse ops.
Updated coo_matvec jvp rule.

Make flake8 happy.
2021-05-14 18:22:28 +00:00
Lena Martens
73e9302fc3 Fix jsp.linalg.lu translation rule to pass backend arg to lower_fun.
If it doesn't, trying to run `lu` with a custom CPU backend when a GPU is
present results in a `Unable to resolve runtime symbol:
`cuda_lu_pivots_to_permutation'` fatal error.
2021-05-14 17:37:09 +01:00
Peter Hawkins
8b5c640608 rollback #6722
PiperOrigin-RevId: 373651549
2021-05-13 13:53:58 -07:00
Peter Hawkins
2f7ef94562 Support complex numbers in jnp.convolve and jnp.correlate. 2021-05-13 09:09:46 -04:00
jax authors
b42e9e3789 Merge pull request #6722 from gnecula:tf_enable_xla_test
PiperOrigin-RevId: 373562914
2021-05-13 06:02:16 -07:00
George Necula
e7568c7ae6 Add additional message to the error when we cannot convert 2021-05-13 09:10:24 +03:00
jax authors
c97cd0a526 Merge pull request #6742 from hawkinsp:scalartype
PiperOrigin-RevId: 373499629
2021-05-12 19:13:12 -07:00
jax authors
43a47b6374 Merge pull request #6683 from bdamoc:update_custom_deriv_doc
PiperOrigin-RevId: 373488172
2021-05-12 17:40:28 -07:00
Peter Hawkins
6d2344d5b8 Change jnp scalar types to consider numpy scalars as instances. 2021-05-12 20:31:49 -04:00
jax authors
db79701732 Merge pull request #6741 from hawkinsp:ascan2
PiperOrigin-RevId: 373422349
2021-05-12 12:10:34 -07:00
Peter Hawkins
724e24d10a Add test from https://github.com/google/jax/pull/5165, with a couple of small improvements.. 2021-05-12 14:58:29 -04:00
jax authors
8f71d20a8b Merge pull request #6738 from hawkinsp:ascan
PiperOrigin-RevId: 373417478
2021-05-12 11:47:03 -07:00
jax authors
d1eea328d7 Merge pull request #6731 from hawkinsp:issue5728
PiperOrigin-RevId: 373411561
2021-05-12 11:20:05 -07:00
Peter Hawkins
44c98ad4e8 Improve JVP rule for scatters with non-overlapping indices.
If the scattered values don't overlap, we don't need complicated masking logic to work out which of the two overlapping values "win".
2021-05-12 14:16:35 -04:00
jax authors
4f2ec864a3 Merge pull request #6736 from hawkinsp:cuda
PiperOrigin-RevId: 373405918
2021-05-12 10:56:45 -07:00
Peter Hawkins
1350d21881 Add regression test for #5728.
This issue appears to have been fixed by jaxlib 0.1.66.
2021-05-12 13:45:16 -04:00
jax authors
ad7c6b5d66 Merge pull request #6646 from hawkinsp:changelog
PiperOrigin-RevId: 373392909
2021-05-12 10:00:47 -07:00
Bogdan Damoc
db39a67cca Update documentation for custom_jvp handling of nondiff_argnums as arguments of _fwd and _bwd rules. 2021-05-12 16:44:52 +01:00
George Necula
f4fa7c7ad0 [jax2tf] Remove dot_general limitation due to XLA fixing crashing bug
PiperOrigin-RevId: 373375341
2021-05-12 08:32:11 -07:00
jax authors
c63e8b5913 Merge pull request #6737 from hawkinsp:nwarn
PiperOrigin-RevId: 373367454
2021-05-12 07:41:10 -07:00
Peter Hawkins
ecaeb94655 Make associative_scan work for boolean arguments. 2021-05-12 10:28:55 -04:00
Peter Hawkins
f83e309fe7 Update changelog. 2021-05-12 09:46:17 -04:00
Peter Hawkins
97d4827ad4 Suppress numpy overflow warning from gamma grad test. 2021-05-12 09:35:36 -04:00
Peter Hawkins
f88e45295f Update installation instructions for CUDA wheels. 2021-05-12 09:24:31 -04:00
George Necula
ba5e11f86f [jax2tf] Improve the conversion of integer_pow for better numerical accuracy.
Previously we simply converted integer_pow to tf.math.pow. JAX instead uses
a series of multiplications. We now use the same lowering strategy as JAX, so
that we have the same numerical result.

Also improved the error messages for assertion failures.

PiperOrigin-RevId: 373351147
2021-05-12 05:45:39 -07:00
George Necula
235eb8c2b4 Copybara import of the project:
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00
jax authors
aa74314c1a Merge pull request #6732 from hawkinsp:gmres2
PiperOrigin-RevId: 373282000
2021-05-11 19:37:04 -07:00