George Necula
a27109d1bd
[jax2tf] Added documentation explaining how to handle undefined TF ops
...
Added a test case showing how to mix compileable and non-compileable code.
2021-05-19 08:23:33 +03:00
Peter Hawkins
86887a21e2
Add Mac M1 support to build.py.
2021-05-18 21:53:31 -04:00
Luke Pfister
bcd5deb269
Prevent nans in scale_and_translate
...
fixes #6780
2021-05-18 15:36:45 -06:00
jax authors
683289c4ad
Merge pull request #6764 from hawkinsp:argmax
...
PiperOrigin-RevId: 374437439
2021-05-18 09:31:53 -07:00
Peter Hawkins
6cc440b79d
Fix handling of NaNs in GPU argmax translation rule.
2021-05-18 11:35:54 -04:00
George Necula
bf63107046
[jax2tf] Add support for preferred_element_type for convolutions.
...
PiperOrigin-RevId: 374347868
2021-05-17 22:34:14 -07:00
jax authors
6a95a8cf50
Merge pull request #6771 from skye:version
...
PiperOrigin-RevId: 374321498
2021-05-17 18:16:26 -07:00
Skye Wanderman-Milne
63dbb99a66
Update README, etc. for jaxlib 0.1.67 release
2021-05-17 17:48:46 -07:00
jax authors
58a5ef7912
Merge pull request #6766 from jakevdp:py39
...
PiperOrigin-RevId: 374276022
2021-05-17 14:06:11 -07:00
jax authors
d39261497c
Merge pull request #6748 from hawkinsp:complexconv
...
PiperOrigin-RevId: 374259292
2021-05-17 12:51:35 -07:00
jax authors
bea18073cf
Merge pull request #6767 from skye:workspace
...
PiperOrigin-RevId: 374256702
jaxlib-v0.1.67
2021-05-17 12:40:25 -07:00
Skye Wanderman-Milne
c5e199a42d
Update WORKSPACE for jaxlib 0.1.67
2021-05-17 12:15:54 -07:00
Jake VanderPlas
ed66dd4d57
CI: test Python 3.9 with numpy dispatch test
2021-05-17 10:10:52 -07:00
George Necula
149f7fa6c4
Tighten the numerical tolerances for jax2tf.
...
We want to focus on tolerances for the TF compiled mode. So, we
make the `custom_numeric` apply by default only to `eager` and `graph` modes.
Getting this done required iterating through tests and we used this
opportunity to tighten the tolerances and apply
them more narrowly, e.g., only on CPU or GPU. The tolerances are still
somewhat loose, mainly for linear algebra primitives.
PiperOrigin-RevId: 374215224
2021-05-17 09:36:47 -07:00
George Necula
afe5ec3df4
Improve accuracy of the jax2tf convolution conversion.
...
Part of the discrepancies were due to JAX using a workaround for
missing complex convolutions on CPU/GPU, while jax2tf was not using
it. We apply the same lowering as JAX, on all platforms.
This allows us to remove custom numeric tolerances and enables complex
convolutions on GPU.
PiperOrigin-RevId: 374199441
2021-05-17 08:18:51 -07:00
jax authors
e049d89657
Merge pull request #6757 from LenaMartens:changelist/373786534
...
PiperOrigin-RevId: 374183021
2021-05-17 06:30:39 -07:00
jax authors
6a28351acd
Merge pull request #6760 from gnecula:tf_limitations
...
PiperOrigin-RevId: 374167386
2021-05-17 04:27:17 -07:00
George Necula
a08cdb30ff
[jax2tf] Update the limitations for unsupported primitives
...
Also update the documentation.
2021-05-17 10:01:13 +03:00
jax authors
25cc3ece66
Merge pull request #6739 from jakevdp:sparse-op-jvp
...
PiperOrigin-RevId: 373870776
2021-05-14 14:50:20 -07:00
Jake VanderPlas
2bcae23e20
fix flakes & test tolerances
2021-05-14 14:18:12 -07:00
Roy Frostig
25ff2f4a94
propagate symbolic zeros in sparse op JVPs
2021-05-14 13:44:31 -07:00
Jake VanderPlas
926de5a2bc
[sparse] add JVP & transpose rules for coo primitives
2021-05-14 13:43:53 -07:00
jax authors
fd6069c450
Merge pull request #6627 from GJBoth:sparse_ops_jvp
...
PiperOrigin-RevId: 373852297
2021-05-14 13:16:31 -07:00
Gert-Jan
81903e894b
Add JVP rules COO sparse ops.
...
Updated coo_matvec jvp rule.
Make flake8 happy.
2021-05-14 18:22:28 +00:00
Lena Martens
73e9302fc3
Fix jsp.linalg.lu
translation rule to pass backend arg to lower_fun
.
...
If it doesn't, trying to run `lu` with a custom CPU backend when a GPU is
present results in a `Unable to resolve runtime symbol:
`cuda_lu_pivots_to_permutation'` fatal error.
2021-05-14 17:37:09 +01:00
Peter Hawkins
8b5c640608
rollback #6722
...
PiperOrigin-RevId: 373651549
2021-05-13 13:53:58 -07:00
Peter Hawkins
2f7ef94562
Support complex numbers in jnp.convolve and jnp.correlate.
2021-05-13 09:09:46 -04:00
jax authors
b42e9e3789
Merge pull request #6722 from gnecula:tf_enable_xla_test
...
PiperOrigin-RevId: 373562914
2021-05-13 06:02:16 -07:00
George Necula
e7568c7ae6
Add additional message to the error when we cannot convert
2021-05-13 09:10:24 +03:00
jax authors
c97cd0a526
Merge pull request #6742 from hawkinsp:scalartype
...
PiperOrigin-RevId: 373499629
2021-05-12 19:13:12 -07:00
jax authors
43a47b6374
Merge pull request #6683 from bdamoc:update_custom_deriv_doc
...
PiperOrigin-RevId: 373488172
2021-05-12 17:40:28 -07:00
Peter Hawkins
6d2344d5b8
Change jnp scalar types to consider numpy scalars as instances.
2021-05-12 20:31:49 -04:00
jax authors
db79701732
Merge pull request #6741 from hawkinsp:ascan2
...
PiperOrigin-RevId: 373422349
2021-05-12 12:10:34 -07:00
Peter Hawkins
724e24d10a
Add test from https://github.com/google/jax/pull/5165 , with a couple of small improvements..
2021-05-12 14:58:29 -04:00
jax authors
8f71d20a8b
Merge pull request #6738 from hawkinsp:ascan
...
PiperOrigin-RevId: 373417478
2021-05-12 11:47:03 -07:00
jax authors
d1eea328d7
Merge pull request #6731 from hawkinsp:issue5728
...
PiperOrigin-RevId: 373411561
2021-05-12 11:20:05 -07:00
Peter Hawkins
44c98ad4e8
Improve JVP rule for scatters with non-overlapping indices.
...
If the scattered values don't overlap, we don't need complicated masking logic to work out which of the two overlapping values "win".
2021-05-12 14:16:35 -04:00
jax authors
4f2ec864a3
Merge pull request #6736 from hawkinsp:cuda
...
PiperOrigin-RevId: 373405918
2021-05-12 10:56:45 -07:00
Peter Hawkins
1350d21881
Add regression test for #5728 .
...
This issue appears to have been fixed by jaxlib 0.1.66.
2021-05-12 13:45:16 -04:00
jax authors
ad7c6b5d66
Merge pull request #6646 from hawkinsp:changelog
...
PiperOrigin-RevId: 373392909
2021-05-12 10:00:47 -07:00
Bogdan Damoc
db39a67cca
Update documentation for custom_jvp handling of nondiff_argnums as arguments of _fwd and _bwd rules.
2021-05-12 16:44:52 +01:00
George Necula
f4fa7c7ad0
[jax2tf] Remove dot_general limitation due to XLA fixing crashing bug
...
PiperOrigin-RevId: 373375341
2021-05-12 08:32:11 -07:00
jax authors
c63e8b5913
Merge pull request #6737 from hawkinsp:nwarn
...
PiperOrigin-RevId: 373367454
2021-05-12 07:41:10 -07:00
Peter Hawkins
ecaeb94655
Make associative_scan work for boolean arguments.
2021-05-12 10:28:55 -04:00
Peter Hawkins
f83e309fe7
Update changelog.
2021-05-12 09:46:17 -04:00
Peter Hawkins
97d4827ad4
Suppress numpy overflow warning from gamma grad test.
2021-05-12 09:35:36 -04:00
Peter Hawkins
f88e45295f
Update installation instructions for CUDA wheels.
2021-05-12 09:24:31 -04:00
George Necula
ba5e11f86f
[jax2tf] Improve the conversion of integer_pow for better numerical accuracy.
...
Previously we simply converted integer_pow to tf.math.pow. JAX instead uses
a series of multiplications. We now use the same lowering strategy as JAX, so
that we have the same numerical result.
Also improved the error messages for assertion failures.
PiperOrigin-RevId: 373351147
2021-05-12 05:45:39 -07:00
George Necula
235eb8c2b4
Copybara import of the project:
...
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:
[jax2tf] Change the conversion of dot_general to use XLA op.
Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.
Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00
jax authors
aa74314c1a
Merge pull request #6732 from hawkinsp:gmres2
...
PiperOrigin-RevId: 373282000
2021-05-11 19:37:04 -07:00