jax authors
700d1e1e63
Merge pull request #10177 from hawkinsp:jaxlib
...
PiperOrigin-RevId: 439975370
jaxlib-v0.3.5
2022-04-06 17:27:38 -07:00
Peter Hawkins
96ba290faf
Jax 0.3.5 and jaxlib 0.3.5 release.
2022-04-06 23:56:41 +00:00
jax authors
7ee6adb1a5
Merge pull request #10173 from jakevdp:bcoo-add-batchdim
...
PiperOrigin-RevId: 439955276
2022-04-06 15:50:10 -07:00
Yash Katariya
6a7a34603d
Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint).
...
This move is because sharded_jit is being deprecated.
PiperOrigin-RevId: 439948391
2022-04-06 15:19:19 -07:00
Jake VanderPlas
93a24f3b83
[sparse] add bcoo_add_batchdim
2022-04-06 14:44:29 -07:00
Peter Hawkins
bc658e7456
[MHLO] Add direct MHLO lowerings for most linear algebra kernels.
...
PiperOrigin-RevId: 439927594
2022-04-06 13:59:09 -07:00
Yash Katariya
4ed06602d3
Add deprecation warning for sharded_jit.
...
PiperOrigin-RevId: 439926957
2022-04-06 13:54:06 -07:00
Peter Hawkins
3bfa6af2c8
[MHLO] Add MHLO lowering for PRNG kernels.
...
PiperOrigin-RevId: 439919104
2022-04-06 13:23:01 -07:00
Peter Hawkins
b9bb61322c
[MHLO] Prefer backend-specific HLO lowerings instead of non-backend-specific MHLO lowerings.
...
This allows (in subsequent changes) to switch the generic case for translating a primitive to MHLO, even if we can't yet use an MHLO lowering for a backend-specific case yet.
Add a handful of direct MLIR lowerings for primitives that lacked them.
PiperOrigin-RevId: 439912093
2022-04-06 12:53:56 -07:00
Peter Hawkins
4012267a01
Revert: implement jnp.trace
in terms of jnp.diagonal
...
This change appears to blow up compilation times for some models on TPU.
PiperOrigin-RevId: 439880940
2022-04-06 10:46:01 -07:00
jax authors
be64d8ba7c
Merge pull request #10164 from hawkinsp:macarm
...
PiperOrigin-RevId: 439857716
2022-04-06 09:21:53 -07:00
jax authors
21d81c0858
Merge pull request #10160 from lgeiger:jnp-trace
...
PiperOrigin-RevId: 439857693
2022-04-06 09:16:53 -07:00
Peter Hawkins
d073a0f264
Pin a newer @platforms in the Bazel workspace to fix Mac ARM builds.
2022-04-06 14:43:56 +00:00
Lukas Geiger
3e877f39a0
Implement jnp.trace
in terms of jnp.diagonal
2022-04-06 01:07:06 +01:00
jax authors
1df942f276
Merge pull request #10148 from sharadmv:version-info
...
PiperOrigin-RevId: 439686412
2022-04-05 15:06:47 -07:00
Sharad Vikram
d72a7b4054
Add version int tuple __version_info__
to JAX
2022-04-05 13:26:05 -07:00
jax authors
fef367019b
Merge pull request #10140 from jakevdp:jnp-diagonal
...
PiperOrigin-RevId: 439596331
2022-04-05 09:14:14 -07:00
jax authors
28575794cc
Merge pull request #10138 from lucasb-eyer:patch-5
...
PiperOrigin-RevId: 439596234
2022-04-05 09:09:19 -07:00
Peter Hawkins
152c210af2
[MHLO] Implement return type inference for GetTupleElementOp and TupleOp.
...
PiperOrigin-RevId: 439589720
2022-04-05 08:38:38 -07:00
Jake VanderPlas
b7344ed512
jnp.diagonal: implement in terms of gather rather than sum
2022-04-04 17:02:11 -07:00
jax authors
97d834feed
Merge pull request #10139 from jakevdp:fix-gpu-test
...
PiperOrigin-RevId: 439439327
2022-04-04 16:29:02 -07:00
Jake VanderPlas
5a96c0cb18
Skip test outside x64
2022-04-04 16:00:18 -07:00
Jake VanderPlas
1246b6fc73
Separate jax.test_util implementations into public and private sources.
...
Eventually the private functionality will no longer be exported via the jax.test_util submodule.
PiperOrigin-RevId: 439415485
2022-04-04 14:43:39 -07:00
Peter Hawkins
71a5eb263b
[GPU] Force an input buffer copy for double precision complex-to-real IRFFTs.
...
Fixes https://github.com/google/jax/issues/9946
PiperOrigin-RevId: 439414091
2022-04-04 14:38:52 -07:00
Lucas Beyer
f7b749c99c
Explicit doc note about device_put* async
2022-04-04 23:38:51 +02:00
Yash Katariya
6825f654b1
* Disallow any other type other than GDA and ShapedArray for auto sharding.
...
* Raise errors in the following 4 cases when GDAs sharding does not match the input sharding. **In all the 4 cases below, the check only runs once! There is no double checking going on. I have added tests for these cases. Please check them out.**
* Auto sharding
* f_pjitted(gda) -- `_pjit_call_impl` catches this mismatch. Only doing this check when `compiled._auto_spmd_lowering` is True.
* compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch
* NO auto sharding
* f_pjitted(gda) -- This is already covered and tested and happens in `infer_params`
* compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch
PiperOrigin-RevId: 439413895
2022-04-04 14:33:51 -07:00
Jake VanderPlas
4949e78859
Re-land changes from https://github.com/google/jax/pull/10069
...
PiperOrigin-RevId: 439381161
2022-04-04 12:18:43 -07:00
Colin Gaffney
41b6e00141
Enable use of GlobalDeviceArray
(GDA) in T5X Checkpointer. Add a separate unit test, gda_checkpoints_test
, to cover this use case.
...
GDA is locked behind a `use_gda` bool in Checkpointer. The feature is currently not enabled anywhere.
Our follow-up plan is to add code which would enable GDA use throughout T5X, and to fix any remaining issues with Checkpointer.
PiperOrigin-RevId: 439358913
2022-04-04 10:56:07 -07:00
Peter Hawkins
1b8be90801
Remove the jax_enable_mlir flag. MLIR is now the only supported code path.
...
This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes.
PiperOrigin-RevId: 439324450
2022-04-04 08:40:09 -07:00
jax authors
e1bbbf55cd
Merge pull request #10130 from mattjj:no-string-annotations
...
PiperOrigin-RevId: 439174012
2022-04-03 12:11:03 -07:00
Matthew Johnson
c72d8f6b09
remove string annotations from core.py
2022-04-03 11:19:07 -07:00
jax authors
359b614b5f
Merge pull request #10122 from sharadmv:jax2tf-name-stack
...
PiperOrigin-RevId: 439036794
2022-04-02 09:02:21 -07:00
jax authors
e64a57d2c3
Merge pull request #10121 from hawkinsp:hcbcache
...
PiperOrigin-RevId: 439036780
2022-04-02 08:57:24 -07:00
jax authors
cdf4177f92
Merge pull request #10126 from jakevdp:tree-multimap
...
PiperOrigin-RevId: 438956536
2022-04-01 18:52:50 -07:00
Jake VanderPlas
c61a18b346
DOC: switch from tree_multimap to tree_map in docs
2022-04-01 14:52:16 -07:00
Jake VanderPlas
df1ceaeeb1
Deprecate jax.tree_util.tree_multimap
2022-04-01 14:51:54 -07:00
jax authors
9693898e85
Merge pull request #10123 from fabianp:patch-4
...
PiperOrigin-RevId: 438906356
2022-04-01 14:09:08 -07:00
Yash Katariya
0eaeff6fc0
Give auto sharder the mesh information specifically the mesh_shape and the devices
...
ids of devices in the mesh.
PiperOrigin-RevId: 438906211
2022-04-01 14:04:23 -07:00
Fabian Pedregosa
4fd466be1f
remove $ from command line commands
...
A few commands in this file were prefixed with $ which results in an invalid command when copied with the sphinx "copy" button.
2022-04-01 15:50:29 -04:00
jax authors
1c3edc811d
Merge pull request #10110 from pschuh:weakref-bug
...
PiperOrigin-RevId: 438887762
2022-04-01 12:45:35 -07:00
Sharad Vikram
aac8ec8649
Fixes jax2tf
's test_name_scope
to use graph introspection instead of
...
side-effect
2022-04-01 12:39:56 -07:00
Yash Katariya
7b7458b474
Give auto sharder the mesh information specifically the mesh_shape and the devices
...
ids of devices in the mesh.
PiperOrigin-RevId: 438882041
2022-04-01 12:19:25 -07:00
Parker Schuh
df1c478ec5
Fix race condition for weakref destructor by catching rare exceptions.
2022-04-01 12:04:36 -07:00
jax authors
8c3385c542
Expose AutoSharding's mesh_shape and mesh_ids options to JAX.
...
PiperOrigin-RevId: 438874347
2022-04-01 11:47:56 -07:00
Peter Hawkins
208e83ceb7
Avoid retracing when a host_callback.call is called multiple times with the same function.
...
If we build a lambda in the host_callback.call() method, the identity of that lambda is different each time and will never lead to a primitive compilation cache hit. Instead, use a custom wrapper object with hash/equality.
This issue was found in passing while debugging #9970 .
2022-04-01 14:41:14 -04:00
jax authors
a4a551a458
Merge pull request #10119 from jakevdp:pil-fix
...
PiperOrigin-RevId: 438853940
2022-04-01 10:23:00 -07:00
Jake VanderPlas
1f300e729b
CI: pin pillow<9.1 to prevent deprecation warnings
2022-04-01 09:23:27 -07:00
jax authors
e766b96063
Merge pull request #10058 from yotarok:istft
...
PiperOrigin-RevId: 438832534
2022-04-01 08:43:27 -07:00
jax authors
4decbcb00e
Merge pull request #10103 from LenaMartens:changelist/438319917
...
PiperOrigin-RevId: 438821559
2022-04-01 07:40:45 -07:00
Yash Katariya
aa5d6b4a58
Fix the breakage by including --experimental_cc_shared_library as done by TF.
...
PiperOrigin-RevId: 438746867
2022-03-31 23:07:42 -07:00