12232 Commits

Author SHA1 Message Date
jax authors
3eff9d11d2 Internal change
PiperOrigin-RevId: 460434859
2022-07-12 05:21:20 -07:00
Marc van Zee
9b54660183 [jax2tf] Use tf.nest.flatten instead of tf.keras.backend.flatten.
PiperOrigin-RevId: 460395701
2022-07-12 00:54:06 -07:00
Jake VanderPlas
82e81a0c0a Define PyTreeDef on pytree module rather than its parent
PiperOrigin-RevId: 460342835
2022-07-11 17:57:36 -07:00
John QiangZhang
153b6aeb78 Fix limitations-of-call_tf github link typo.
call_tf misspell as "call-tf" on multiple places.

PiperOrigin-RevId: 460335218
2022-07-11 17:12:59 -07:00
Yash Katariya
0bc8f8abeb * Check if the device assignment is the same across input and output shardings.
* Allow mixed inputs only if the sharding matches with what is specified in in_axis_resources.

PiperOrigin-RevId: 460326054
2022-07-11 16:27:11 -07:00
jax authors
11896b68a2 Merge pull request #11429 from sharadmv:for-loop
PiperOrigin-RevId: 460318883
2022-07-11 15:52:42 -07:00
jax authors
d109c604c7 Merge pull request #11451 from jakevdp:doc-execute-promotion
PiperOrigin-RevId: 460305621
2022-07-11 14:51:08 -07:00
Benjamin Kramer
9e16efa98a Integrate LLVM at llvm/llvm-project@71c9757474
Updates LLVM usage to match
[71c9757474c3](https://github.com/llvm/llvm-project/commit/71c9757474c3)

PiperOrigin-RevId: 460299215
2022-07-11 14:21:09 -07:00
Jake VanderPlas
d8d836bfb0 Fix execution issues in type_promotion doc 2022-07-11 14:16:11 -07:00
Yash Katariya
8f09606a40 Make jaxpr_has_pmap work for other primitives too
PiperOrigin-RevId: 460286042
2022-07-11 13:24:16 -07:00
Sharad Vikram
9d610e2de6 Add loop invariant residual fixpoint test 2022-07-11 13:10:03 -07:00
jax authors
183b9e4503 Merge pull request #11397 from jakevdp:diagonal-err
PiperOrigin-RevId: 460236416
2022-07-11 09:52:19 -07:00
jax authors
cc42c8091d Merge pull request #11406 from jakevdp:bcoo-add-batchdim
PiperOrigin-RevId: 460226570
2022-07-11 09:06:42 -07:00
jax authors
df74907257 Merge pull request #11440 from google:random-docstring-fix
PiperOrigin-RevId: 460220463
2022-07-11 08:36:00 -07:00
jax authors
e19d026b3a Merge pull request #11442 from hawkinsp:shards
PiperOrigin-RevId: 460211938
2022-07-11 07:52:06 -07:00
Peter Hawkins
64e0b5d801 Increase bazel sharding of GPU tests.
Reduces the maximum time for some test shards to avoid flaky timeouts.
2022-07-11 14:19:43 +00:00
James Bradbury
64eb46a172
Fix RST formatting in random.py docstring 2022-07-11 02:51:35 -07:00
Sharad Vikram
b666f665ec Rollback of HCB GPU custom call due to internal failures
PiperOrigin-RevId: 460079787
2022-07-10 13:05:27 -07:00
Yash Katariya
5910cdc861 Print the repr of device
PiperOrigin-RevId: 459986696
2022-07-09 17:08:44 -07:00
jax authors
ed51c65576 Merge pull request #11405 from mattjj:djax-vmap
PiperOrigin-RevId: 459958155
2022-07-09 10:38:39 -07:00
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
Yash Katariya
09ba51f323 Move _get_array_mapping from gda.py to pxla.py
PiperOrigin-RevId: 459891853
2022-07-08 21:38:06 -07:00
jax authors
df993ea32f Merge pull request #11410 from sharadmv:for-loop
PiperOrigin-RevId: 459879694
2022-07-08 19:37:57 -07:00
Sharad Vikram
bff71b2c4f Add loop-invariant residual optimization for for 2022-07-08 18:54:51 -07:00
jax authors
66ab792fc0 Merge pull request #11383 from YouJiacheng:Enable-HCB-customCall-implementation-on-GPU
PiperOrigin-RevId: 459872063
2022-07-08 18:23:16 -07:00
Anish Tondwalkar
847c04fd8c [mhlo] CosOp -> CosineOp
Aligns the op class name with the mnemonic

PiperOrigin-RevId: 459852934
2022-07-08 16:04:14 -07:00
Anish Tondwalkar
5f7018a62e [mhlo] SinOp -> SineOp
Aligns the op class name with the mnemonic

PiperOrigin-RevId: 459830783
2022-07-08 14:08:12 -07:00
jax authors
1bb1fe0658 Remove workaround for rank-0 zarr chunk layout bug in TensorStore
This has now been fixed in TensorStore.

PiperOrigin-RevId: 459824051
2022-07-08 13:35:29 -07:00
jax authors
dac310c221 Merge pull request #11421 from jakevdp:scalar-meta-nocopy
PiperOrigin-RevId: 459823335
2022-07-08 13:30:20 -07:00
Yash Katariya
bb2c5f111a Resolve TODOs and add some more checks for the jax.Array path.
PiperOrigin-RevId: 459808511
2022-07-08 12:19:19 -07:00
Anish Tondwalkar
a2f2d1fa42 [mhlo] ConvOp -> ConvolutionOp
Aligns the op class name with the mnemonic

PiperOrigin-RevId: 459808502
2022-07-08 12:13:51 -07:00
YouJiacheng
7c707832aa Enable CustomCall implementation on GPU 2022-07-09 02:29:08 +08:00
Jake VanderPlas
e19df1a9bf Use asarray rather than array in ScalarMeta
Why? This will make it so that jnp.int32(x) and friends no longer insert
a gratuitous copy_p operation in the jaxpr
2022-07-08 11:16:40 -07:00
jax authors
5285a15de1 Merge pull request #11419 from hawkinsp:jaxlibcleanup
PiperOrigin-RevId: 459780016
2022-07-08 10:04:39 -07:00
Yash Katariya
229ddecc45 * Remove AUTO from MeshPspecSharding and treat it like _UNSPECIFIED singleton value.
* Support partial mentions of AUTO which is supported by GDA currently and used in pax. Added tests for all of this.
  * As a consequence of this, I lifted the restriction on not providing `in_axis_resources` to pjit under `config.jax_array`.

* Made all auto sharding tests parameterized to test both gda and array.

PiperOrigin-RevId: 459776152
2022-07-08 09:45:23 -07:00
Peter Hawkins
5a7bedca37 Increase shard_count for sparse_test_gpu to 20.
1918d39765 updated the wrong test!

This test is close to the timeout in the GPU CI and flakes sometimes.

PiperOrigin-RevId: 459762867
2022-07-08 08:30:26 -07:00
Peter Hawkins
41b015ab0c Remove stale code from jax/_src/lib/__init__.py
Remove inaccurate/stale __all__.
Remove unused alias _xla_extension_version.
2022-07-08 11:09:58 -04:00
jax authors
928f22cb6b Merge pull request #11418 from hawkinsp:bzl
PiperOrigin-RevId: 459754677
2022-07-08 07:38:13 -07:00
Peter Hawkins
a48f4e116e Change Bazel test rules to generate per-backend test suites. 2022-07-08 14:19:05 +00:00
jax authors
55dcbec5b5 Merge pull request #11407 from hawkinsp:minver
PiperOrigin-RevId: 459740984
2022-07-08 06:04:47 -07:00
Tamara Norman
bc9c4b77d0 Adjust docs to account for what the actual current RNG behavior is
PiperOrigin-RevId: 459712928
2022-07-08 02:55:36 -07:00
jax authors
7ffedb5815 Merge pull request #11400 from jakevdp:deprecate-treeutil
PiperOrigin-RevId: 459681801
2022-07-07 23:05:35 -07:00
jax authors
34fea3d496 Merge pull request #11408 from hawkinsp:sparseshard
PiperOrigin-RevId: 459647331
2022-07-07 18:23:20 -07:00
Peter Hawkins
1918d39765 Increase number of shards for GPU sparse_test to 20. 2022-07-07 21:14:25 -04:00
jax authors
a3e8ae4b1a Merge pull request #11388 from jakevdp:fix-bool-weak
PiperOrigin-RevId: 459641851
2022-07-07 17:43:17 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Jake VanderPlas
adcf30ef6b [sparse] remove deprecated bcoo_add_batch_dim utility 2022-07-07 16:57:36 -07:00
jax authors
44bd311ae7 Merge pull request #11403 from jakevdp:sparse-unary
PiperOrigin-RevId: 459634024
2022-07-07 16:57:23 -07:00
Jake VanderPlas
17de5e4840 jnp.diagonal: raise explicit error if ndim < 2 2022-07-07 16:36:40 -07:00
Jake VanderPlas
56d61d3f2d BUG: ensure that boolean scalars are never marked weak 2022-07-07 15:41:23 -07:00