11033 Commits

Author SHA1 Message Date
jax authors
700d1e1e63 Merge pull request #10177 from hawkinsp:jaxlib
PiperOrigin-RevId: 439975370
jaxlib-v0.3.5
2022-04-06 17:27:38 -07:00
Peter Hawkins
96ba290faf Jax 0.3.5 and jaxlib 0.3.5 release. 2022-04-06 23:56:41 +00:00
jax authors
7ee6adb1a5 Merge pull request #10173 from jakevdp:bcoo-add-batchdim
PiperOrigin-RevId: 439955276
2022-04-06 15:50:10 -07:00
Yash Katariya
6a7a34603d Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint).
This move is because sharded_jit is being deprecated.

PiperOrigin-RevId: 439948391
2022-04-06 15:19:19 -07:00
Jake VanderPlas
93a24f3b83 [sparse] add bcoo_add_batchdim 2022-04-06 14:44:29 -07:00
Peter Hawkins
bc658e7456 [MHLO] Add direct MHLO lowerings for most linear algebra kernels.
PiperOrigin-RevId: 439927594
2022-04-06 13:59:09 -07:00
Yash Katariya
4ed06602d3 Add deprecation warning for sharded_jit.
PiperOrigin-RevId: 439926957
2022-04-06 13:54:06 -07:00
Peter Hawkins
3bfa6af2c8 [MHLO] Add MHLO lowering for PRNG kernels.
PiperOrigin-RevId: 439919104
2022-04-06 13:23:01 -07:00
Peter Hawkins
b9bb61322c [MHLO] Prefer backend-specific HLO lowerings instead of non-backend-specific MHLO lowerings.
This allows (in subsequent changes) to switch the generic case for translating a primitive to MHLO, even if we can't yet use an MHLO lowering for a backend-specific case yet.

Add a handful of direct MLIR lowerings for primitives that lacked them.

PiperOrigin-RevId: 439912093
2022-04-06 12:53:56 -07:00
Peter Hawkins
4012267a01 Revert: implement jnp.trace in terms of jnp.diagonal
This change appears to blow up compilation times for some models on TPU.

PiperOrigin-RevId: 439880940
2022-04-06 10:46:01 -07:00
jax authors
be64d8ba7c Merge pull request #10164 from hawkinsp:macarm
PiperOrigin-RevId: 439857716
2022-04-06 09:21:53 -07:00
jax authors
21d81c0858 Merge pull request #10160 from lgeiger:jnp-trace
PiperOrigin-RevId: 439857693
2022-04-06 09:16:53 -07:00
Peter Hawkins
d073a0f264 Pin a newer @platforms in the Bazel workspace to fix Mac ARM builds. 2022-04-06 14:43:56 +00:00
Lukas Geiger
3e877f39a0 Implement jnp.trace in terms of jnp.diagonal 2022-04-06 01:07:06 +01:00
jax authors
1df942f276 Merge pull request #10148 from sharadmv:version-info
PiperOrigin-RevId: 439686412
2022-04-05 15:06:47 -07:00
Sharad Vikram
d72a7b4054 Add version int tuple __version_info__ to JAX 2022-04-05 13:26:05 -07:00
jax authors
fef367019b Merge pull request #10140 from jakevdp:jnp-diagonal
PiperOrigin-RevId: 439596331
2022-04-05 09:14:14 -07:00
jax authors
28575794cc Merge pull request #10138 from lucasb-eyer:patch-5
PiperOrigin-RevId: 439596234
2022-04-05 09:09:19 -07:00
Peter Hawkins
152c210af2 [MHLO] Implement return type inference for GetTupleElementOp and TupleOp.
PiperOrigin-RevId: 439589720
2022-04-05 08:38:38 -07:00
Jake VanderPlas
b7344ed512 jnp.diagonal: implement in terms of gather rather than sum 2022-04-04 17:02:11 -07:00
jax authors
97d834feed Merge pull request #10139 from jakevdp:fix-gpu-test
PiperOrigin-RevId: 439439327
2022-04-04 16:29:02 -07:00
Jake VanderPlas
5a96c0cb18 Skip test outside x64 2022-04-04 16:00:18 -07:00
Jake VanderPlas
1246b6fc73 Separate jax.test_util implementations into public and private sources.
Eventually the private functionality will no longer be exported via the jax.test_util submodule.

PiperOrigin-RevId: 439415485
2022-04-04 14:43:39 -07:00
Peter Hawkins
71a5eb263b [GPU] Force an input buffer copy for double precision complex-to-real IRFFTs.
Fixes https://github.com/google/jax/issues/9946

PiperOrigin-RevId: 439414091
2022-04-04 14:38:52 -07:00
Lucas Beyer
f7b749c99c
Explicit doc note about device_put* async 2022-04-04 23:38:51 +02:00
Yash Katariya
6825f654b1 * Disallow any other type other than GDA and ShapedArray for auto sharding.
* Raise errors in the following 4 cases when GDAs sharding does not match the input sharding. **In all the 4 cases below, the check only runs once! There is no double checking going on. I have added tests for these cases. Please check them out.**
  * Auto sharding
    * f_pjitted(gda) -- `_pjit_call_impl` catches this mismatch. Only doing this check when `compiled._auto_spmd_lowering` is True.
    * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch
  * NO auto sharding
    * f_pjitted(gda) -- This is already covered and tested and happens in `infer_params`
    * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch

PiperOrigin-RevId: 439413895
2022-04-04 14:33:51 -07:00
Jake VanderPlas
4949e78859 Re-land changes from https://github.com/google/jax/pull/10069
PiperOrigin-RevId: 439381161
2022-04-04 12:18:43 -07:00
Colin Gaffney
41b6e00141 Enable use of GlobalDeviceArray (GDA) in T5X Checkpointer. Add a separate unit test, gda_checkpoints_test, to cover this use case.
GDA is locked behind a `use_gda` bool in Checkpointer. The feature is currently not enabled anywhere.

Our follow-up plan is to add code which would enable GDA use throughout T5X, and to fix any remaining issues with Checkpointer.

PiperOrigin-RevId: 439358913
2022-04-04 10:56:07 -07:00
Peter Hawkins
1b8be90801 Remove the jax_enable_mlir flag. MLIR is now the only supported code path.
This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes.

PiperOrigin-RevId: 439324450
2022-04-04 08:40:09 -07:00
jax authors
e1bbbf55cd Merge pull request #10130 from mattjj:no-string-annotations
PiperOrigin-RevId: 439174012
2022-04-03 12:11:03 -07:00
Matthew Johnson
c72d8f6b09 remove string annotations from core.py 2022-04-03 11:19:07 -07:00
jax authors
359b614b5f Merge pull request #10122 from sharadmv:jax2tf-name-stack
PiperOrigin-RevId: 439036794
2022-04-02 09:02:21 -07:00
jax authors
e64a57d2c3 Merge pull request #10121 from hawkinsp:hcbcache
PiperOrigin-RevId: 439036780
2022-04-02 08:57:24 -07:00
jax authors
cdf4177f92 Merge pull request #10126 from jakevdp:tree-multimap
PiperOrigin-RevId: 438956536
2022-04-01 18:52:50 -07:00
Jake VanderPlas
c61a18b346 DOC: switch from tree_multimap to tree_map in docs 2022-04-01 14:52:16 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
jax authors
9693898e85 Merge pull request #10123 from fabianp:patch-4
PiperOrigin-RevId: 438906356
2022-04-01 14:09:08 -07:00
Yash Katariya
0eaeff6fc0 Give auto sharder the mesh information specifically the mesh_shape and the devices
ids of devices in the mesh.

PiperOrigin-RevId: 438906211
2022-04-01 14:04:23 -07:00
Fabian Pedregosa
4fd466be1f
remove $ from command line commands
A few commands in this file were prefixed with $ which results in an invalid command when copied with the sphinx "copy" button.
2022-04-01 15:50:29 -04:00
jax authors
1c3edc811d Merge pull request #10110 from pschuh:weakref-bug
PiperOrigin-RevId: 438887762
2022-04-01 12:45:35 -07:00
Sharad Vikram
aac8ec8649 Fixes jax2tf's test_name_scope to use graph introspection instead of
side-effect
2022-04-01 12:39:56 -07:00
Yash Katariya
7b7458b474 Give auto sharder the mesh information specifically the mesh_shape and the devices
ids of devices in the mesh.

PiperOrigin-RevId: 438882041
2022-04-01 12:19:25 -07:00
Parker Schuh
df1c478ec5 Fix race condition for weakref destructor by catching rare exceptions. 2022-04-01 12:04:36 -07:00
jax authors
8c3385c542 Expose AutoSharding's mesh_shape and mesh_ids options to JAX.
PiperOrigin-RevId: 438874347
2022-04-01 11:47:56 -07:00
Peter Hawkins
208e83ceb7 Avoid retracing when a host_callback.call is called multiple times with the same function.
If we build a lambda in the host_callback.call() method, the identity of that lambda is different each time and will never lead to a primitive compilation cache hit. Instead, use a custom wrapper object with hash/equality.

This issue was found in passing while debugging #9970.
2022-04-01 14:41:14 -04:00
jax authors
a4a551a458 Merge pull request #10119 from jakevdp:pil-fix
PiperOrigin-RevId: 438853940
2022-04-01 10:23:00 -07:00
Jake VanderPlas
1f300e729b CI: pin pillow<9.1 to prevent deprecation warnings 2022-04-01 09:23:27 -07:00
jax authors
e766b96063 Merge pull request #10058 from yotarok:istft
PiperOrigin-RevId: 438832534
2022-04-01 08:43:27 -07:00
jax authors
4decbcb00e Merge pull request #10103 from LenaMartens:changelist/438319917
PiperOrigin-RevId: 438821559
2022-04-01 07:40:45 -07:00
Yash Katariya
aa5d6b4a58 Fix the breakage by including --experimental_cc_shared_library as done by TF.
PiperOrigin-RevId: 438746867
2022-03-31 23:07:42 -07:00