8184 Commits

Author SHA1 Message Date
jax authors
87abab713d Merge pull request #6785 from GregCT:changelist/373551581
PiperOrigin-RevId: 376922795
2021-06-01 14:47:15 -07:00
Luke Pfister
c28b472ae2 Add test to ensure gradient is finite 2021-06-01 15:28:25 -06:00
jax authors
edd203e305 Merge pull request #6726 from njunge94:auxiliary_solver_data
PiperOrigin-RevId: 376899659
2021-06-01 12:58:39 -07:00
Stephan Hoyer
d09839124b
Fix wrong method name 2021-06-01 12:42:55 -07:00
Peter Hawkins
7db0c56a22 [JAX] Change how JAX manages XLA platforms.
* Combine the concepts of "platform" and "backend". The main upshot of this is that the tpu_driver backend requires users to write `jit(..., backend="tpu_driver")` if mixing CPU and TPU execution, however I doubt users are writing that because it didn't work to mix CPU and tpu_driver before.
* Initialize all platforms at startup, rather than lazily initializing platforms on demand. This makes it easy to do things like "list the available platforms".
* Don't use two levels of caching. Cache backends only in xla_bridge.py, not xla_client.py.

PiperOrigin-RevId: 376883261
2021-06-01 11:44:31 -07:00
jax authors
87332d01c6 Merge pull request #6861 from ho-oto:patch-1
PiperOrigin-RevId: 376821439
2021-06-01 06:47:56 -07:00
jax authors
8c1e1464b8 Merge pull request #6864 from gnecula:tf_einsum
PiperOrigin-RevId: 376814238
2021-06-01 05:59:47 -07:00
jax authors
9d2491cc4f Merge pull request #6856 from lkhphuc:patch-1
PiperOrigin-RevId: 376813353
2021-06-01 05:52:41 -07:00
George Necula
973171bb6d [jax2tf] Add support for pjit. 2021-06-01 14:32:59 +03:00
jax authors
4fb9715943 Merge pull request #6855 from PhilipVinc:logsumexp
PiperOrigin-RevId: 376786399
2021-06-01 02:16:22 -07:00
George Necula
c07d54aab0 [jax2tf] Add shape polymorphism support for jnp.einsum.
The main problem was that jnp.einsum uses opt_einsum.contract_path
to parse the specification string and compute the order or the
contractions. This function wants to compute the sizes of operands
and intermediate results, and will fail if some dimensions are
polymorphic.

The (partial) solution here is to replace the operands with
jax.ShapeDtypeStruct with a fixed size for all dimension variables,
then call opt_einsum.contract_path and use that result if there
is only one contraction. We abort if there are multiple contractions.
This behavior is clearly sound. If there were multiple contractions,
perhaps their order would be different with different dimension sizes.
2021-05-31 19:06:15 +03:00
Filippo Vicentini
c0c8e0d0a3 make logsumexp work with complex numbers 2021-05-31 16:01:57 +02:00
ho-oto
160c3e9357
rename v to vh 2021-05-31 22:18:24 +09:00
jax authors
e0f285fd21 Merge pull request #6839 from jakevdp:reshape-doc
PiperOrigin-RevId: 376645137
2021-05-31 02:30:04 -07:00
Phúc Lê Khắc
f30a36dbbc Typo.
Update Common_Gotchas_in_JAX.md
2021-05-29 17:22:40 +01:00
jax authors
44b1791b7a Copybara import of the project:
--
8226dfc8a4974b4c8031ee267fa5327e778140ee by Nicholas Junge <nicholas.junge@web.de>:

Handle negative values for list-like sections in jnp.split

PiperOrigin-RevId: 376302305
2021-05-27 20:33:49 -07:00
Rebecca Chen
5065e1bb93 Add missing typing.Optional type annotations to function parameters.
PiperOrigin-RevId: 376300297
2021-05-27 20:10:23 -07:00
Jake VanderPlas
dded0e38b3 DOC: add notes to jax.numpy docstrings about returning copies rather than views 2021-05-27 18:05:45 -07:00
jax authors
44fcd71091 Merge pull request #6851 from njunge94:split-fix
PiperOrigin-RevId: 376286728
2021-05-27 17:55:04 -07:00
Nicholas Junge
8226dfc8a4 Handle negative values for list-like sections in jnp.split 2021-05-27 18:25:18 +02:00
jax authors
4ad332e83f Merge pull request #6841 from colemanliyah:fix_fingerprint
PiperOrigin-RevId: 376023432
2021-05-26 14:02:05 -07:00
Liyah Coleman
b68f2c99fd fixed fingerprint debugging message to be compatible with current min jaxlib version 2021-05-26 19:03:28 +00:00
jax authors
7150a10047 Merge pull request #6835 from jakevdp:sharp-bits
PiperOrigin-RevId: 375988125
2021-05-26 11:21:44 -07:00
jax authors
9deeb733ea Merge pull request #6836 from jakevdp:numpy-doc
PiperOrigin-RevId: 375987921
2021-05-26 11:17:57 -07:00
jax authors
0da0caa57a Merge pull request #6840 from hawkinsp:ci
PiperOrigin-RevId: 375987148
2021-05-26 11:14:28 -07:00
Peter Hawkins
f07ccf0074 Use short tracebacks in CI builds.
Often useful information is hard to see in the GitHub UI with the default traceback verbosity of pytest.
2021-05-26 13:43:38 -04:00
Jake VanderPlas
d844609c6d DOC: add section to Sharp Bits discussing implicit list conversions 2021-05-26 09:03:42 -07:00
jax authors
ba422f2a36 Merge pull request #6825 from colemanliyah:master
PiperOrigin-RevId: 375955559
2021-05-26 09:00:20 -07:00
Gregory Thornton
03a1ee9269 Update Jax linesearch to behave more like Scipy 2021-05-26 12:49:56 +01:00
George Necula
62603fde67 Copybara import of the project:
--
746a232632652233f649b15d94f3ed2fd0ccc1fb by George Necula <gcnecula@gmail.com>:

[jax2tf] Updates known limitations.

This PR fixes several issues:
  * It updates the documentation of the known limitations
  * Increases the numerical tolerance for conv_general_dilated on GPU, to
  address test flakiness.
  * Adds a workaround for a TF bug that results in a crash when
  trying to extract the optimized HLO.

--
4302101aed30a2c7625a2dd5acbe1ca17f9540e4 by George Necula <gcnecula@gmail.com>:

Added limitation for dot_general on GPU

--
207f66a970b7f596e1b265c7aa91fa56e27e7d51 by George Necula <gcnecula@gmail.com>:

Added limitation for dot_general on GPU

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6837 from gnecula:tf_adjust_lim 207f66a970b7f596e1b265c7aa91fa56e27e7d51
PiperOrigin-RevId: 375910042
2021-05-26 04:03:07 -07:00
George Necula
ada35ae045 [jax2tf] Update limitations to account for tf.math.igamma improvements
PiperOrigin-RevId: 375888440
2021-05-26 01:14:20 -07:00
Jake VanderPlas
80a310f8d8 DOC: add note about array views in numpy docs 2021-05-25 16:54:28 -07:00
Katherine Wu
00e030247e
Enable custom gradients SavedModel option in test 2021-05-25 16:07:09 -07:00
jax authors
3b973ac04a Merge pull request #6822 from hawkinsp:take
PiperOrigin-RevId: 375794394
2021-05-25 14:22:07 -07:00
Nicholas Junge
0308527f55 Add auxiliary data support in custom_linear_solve 2021-05-25 18:00:46 +02:00
Liyah Coleman
369ca134fc add fingerprint to debugging log 2021-05-24 23:31:25 +00:00
jax authors
d81a13ccea Merge pull request #6826 from akbir:update_bazel_mac_arm
PiperOrigin-RevId: 375548691
2021-05-24 13:46:21 -07:00
akbir khan
dc81610bb6 updated to official bazel.4.1.0 2021-05-24 21:40:08 +01:00
jax authors
0573169aaf Merge pull request #6821 from hawkinsp:versions
PiperOrigin-RevId: 375499903
2021-05-24 10:09:14 -07:00
Peter Hawkins
dacf31f202 Check for NumPy and SciPy versions during jaxlib builds. 2021-05-24 12:39:37 -04:00
Peter Hawkins
c2b0f72d66 Fix handling of empty dimensions in jnp.take(). 2021-05-24 11:59:41 -04:00
jax authors
070295494a Merge pull request #6820 from hawkinsp:xla
PiperOrigin-RevId: 375476354
2021-05-24 08:03:28 -07:00
George Necula
a64d685e63 [jax2tf] Cleanup limitations for rev in light of improvements in TensorFlow.
PiperOrigin-RevId: 375475438
2021-05-24 07:56:59 -07:00
Peter Hawkins
e87173e88b Update XLA. 2021-05-24 10:43:40 -04:00
jax authors
6743d771dc Merge pull request #6819 from lgeiger:replace-pow
PiperOrigin-RevId: 375461396
2021-05-24 06:05:00 -07:00
Lukas Geiger
3a2e80ef51 Replace pow() with srqt() or square() where possible 2021-05-24 10:43:35 +01:00
jax authors
7ea7cea687 Merge pull request #6816 from gnecula:bfloat16_random
PiperOrigin-RevId: 375417596
2021-05-23 22:58:35 -07:00
jax authors
89d208b62b Merge pull request #6807 from lgeiger:reuse-jvp-ans
PiperOrigin-RevId: 375366940
2021-05-23 11:13:57 -07:00
George Necula
70f0110b32 Fix dtypes.issubdtype when called with "bfloat16" (as string)
Fixes: #6813
2021-05-23 19:32:45 +03:00
George Necula
74638a4553 [jax2tf] Improve conversion of sign and abs, to account for TF limitations
PiperOrigin-RevId: 375274010
2021-05-22 11:03:01 -07:00