jax authors
3eff9d11d2
Internal change
...
PiperOrigin-RevId: 460434859
2022-07-12 05:21:20 -07:00
Marc van Zee
9b54660183
[jax2tf] Use tf.nest.flatten instead of tf.keras.backend.flatten.
...
PiperOrigin-RevId: 460395701
2022-07-12 00:54:06 -07:00
Jake VanderPlas
82e81a0c0a
Define PyTreeDef on pytree module rather than its parent
...
PiperOrigin-RevId: 460342835
2022-07-11 17:57:36 -07:00
John QiangZhang
153b6aeb78
Fix limitations-of-call_tf github link typo.
...
call_tf misspell as "call-tf" on multiple places.
PiperOrigin-RevId: 460335218
2022-07-11 17:12:59 -07:00
Yash Katariya
0bc8f8abeb
* Check if the device assignment is the same across input and output shardings.
...
* Allow mixed inputs only if the sharding matches with what is specified in in_axis_resources.
PiperOrigin-RevId: 460326054
2022-07-11 16:27:11 -07:00
jax authors
11896b68a2
Merge pull request #11429 from sharadmv:for-loop
...
PiperOrigin-RevId: 460318883
2022-07-11 15:52:42 -07:00
jax authors
d109c604c7
Merge pull request #11451 from jakevdp:doc-execute-promotion
...
PiperOrigin-RevId: 460305621
2022-07-11 14:51:08 -07:00
Benjamin Kramer
9e16efa98a
Integrate LLVM at llvm/llvm-project@71c9757474
...
Updates LLVM usage to match
[71c9757474c3](https://github.com/llvm/llvm-project/commit/71c9757474c3 )
PiperOrigin-RevId: 460299215
2022-07-11 14:21:09 -07:00
Jake VanderPlas
d8d836bfb0
Fix execution issues in type_promotion doc
2022-07-11 14:16:11 -07:00
Yash Katariya
8f09606a40
Make jaxpr_has_pmap
work for other primitives too
...
PiperOrigin-RevId: 460286042
2022-07-11 13:24:16 -07:00
Sharad Vikram
9d610e2de6
Add loop invariant residual fixpoint test
2022-07-11 13:10:03 -07:00
jax authors
183b9e4503
Merge pull request #11397 from jakevdp:diagonal-err
...
PiperOrigin-RevId: 460236416
2022-07-11 09:52:19 -07:00
jax authors
cc42c8091d
Merge pull request #11406 from jakevdp:bcoo-add-batchdim
...
PiperOrigin-RevId: 460226570
2022-07-11 09:06:42 -07:00
jax authors
df74907257
Merge pull request #11440 from google:random-docstring-fix
...
PiperOrigin-RevId: 460220463
2022-07-11 08:36:00 -07:00
jax authors
e19d026b3a
Merge pull request #11442 from hawkinsp:shards
...
PiperOrigin-RevId: 460211938
2022-07-11 07:52:06 -07:00
Peter Hawkins
64e0b5d801
Increase bazel sharding of GPU tests.
...
Reduces the maximum time for some test shards to avoid flaky timeouts.
2022-07-11 14:19:43 +00:00
James Bradbury
64eb46a172
Fix RST formatting in random.py docstring
2022-07-11 02:51:35 -07:00
Sharad Vikram
b666f665ec
Rollback of HCB GPU custom call due to internal failures
...
PiperOrigin-RevId: 460079787
2022-07-10 13:05:27 -07:00
Yash Katariya
5910cdc861
Print the repr of device
...
PiperOrigin-RevId: 459986696
2022-07-09 17:08:44 -07:00
jax authors
ed51c65576
Merge pull request #11405 from mattjj:djax-vmap
...
PiperOrigin-RevId: 459958155
2022-07-09 10:38:39 -07:00
Matthew Johnson
5b82ba787c
[dynamic-shapes] start basic vmap compatibility
2022-07-09 10:03:40 -07:00
Yash Katariya
09ba51f323
Move _get_array_mapping from gda.py to pxla.py
...
PiperOrigin-RevId: 459891853
2022-07-08 21:38:06 -07:00
jax authors
df993ea32f
Merge pull request #11410 from sharadmv:for-loop
...
PiperOrigin-RevId: 459879694
2022-07-08 19:37:57 -07:00
Sharad Vikram
bff71b2c4f
Add loop-invariant residual optimization for for
2022-07-08 18:54:51 -07:00
jax authors
66ab792fc0
Merge pull request #11383 from YouJiacheng:Enable-HCB-customCall-implementation-on-GPU
...
PiperOrigin-RevId: 459872063
2022-07-08 18:23:16 -07:00
Anish Tondwalkar
847c04fd8c
[mhlo] CosOp -> CosineOp
...
Aligns the op class name with the mnemonic
PiperOrigin-RevId: 459852934
2022-07-08 16:04:14 -07:00
Anish Tondwalkar
5f7018a62e
[mhlo] SinOp -> SineOp
...
Aligns the op class name with the mnemonic
PiperOrigin-RevId: 459830783
2022-07-08 14:08:12 -07:00
jax authors
1bb1fe0658
Remove workaround for rank-0 zarr chunk layout bug in TensorStore
...
This has now been fixed in TensorStore.
PiperOrigin-RevId: 459824051
2022-07-08 13:35:29 -07:00
jax authors
dac310c221
Merge pull request #11421 from jakevdp:scalar-meta-nocopy
...
PiperOrigin-RevId: 459823335
2022-07-08 13:30:20 -07:00
Yash Katariya
bb2c5f111a
Resolve TODOs and add some more checks for the jax.Array path.
...
PiperOrigin-RevId: 459808511
2022-07-08 12:19:19 -07:00
Anish Tondwalkar
a2f2d1fa42
[mhlo] ConvOp -> ConvolutionOp
...
Aligns the op class name with the mnemonic
PiperOrigin-RevId: 459808502
2022-07-08 12:13:51 -07:00
YouJiacheng
7c707832aa
Enable CustomCall implementation on GPU
2022-07-09 02:29:08 +08:00
Jake VanderPlas
e19df1a9bf
Use asarray rather than array in ScalarMeta
...
Why? This will make it so that jnp.int32(x) and friends no longer insert
a gratuitous copy_p operation in the jaxpr
2022-07-08 11:16:40 -07:00
jax authors
5285a15de1
Merge pull request #11419 from hawkinsp:jaxlibcleanup
...
PiperOrigin-RevId: 459780016
2022-07-08 10:04:39 -07:00
Yash Katariya
229ddecc45
* Remove AUTO from MeshPspecSharding and treat it like _UNSPECIFIED singleton value.
...
* Support partial mentions of AUTO which is supported by GDA currently and used in pax. Added tests for all of this.
* As a consequence of this, I lifted the restriction on not providing `in_axis_resources` to pjit under `config.jax_array`.
* Made all auto sharding tests parameterized to test both gda and array.
PiperOrigin-RevId: 459776152
2022-07-08 09:45:23 -07:00
Peter Hawkins
5a7bedca37
Increase shard_count for sparse_test_gpu to 20.
...
1918d39765
updated the wrong test!
This test is close to the timeout in the GPU CI and flakes sometimes.
PiperOrigin-RevId: 459762867
2022-07-08 08:30:26 -07:00
Peter Hawkins
41b015ab0c
Remove stale code from jax/_src/lib/__init__.py
...
Remove inaccurate/stale __all__.
Remove unused alias _xla_extension_version.
2022-07-08 11:09:58 -04:00
jax authors
928f22cb6b
Merge pull request #11418 from hawkinsp:bzl
...
PiperOrigin-RevId: 459754677
2022-07-08 07:38:13 -07:00
Peter Hawkins
a48f4e116e
Change Bazel test rules to generate per-backend test suites.
2022-07-08 14:19:05 +00:00
jax authors
55dcbec5b5
Merge pull request #11407 from hawkinsp:minver
...
PiperOrigin-RevId: 459740984
2022-07-08 06:04:47 -07:00
Tamara Norman
bc9c4b77d0
Adjust docs to account for what the actual current RNG behavior is
...
PiperOrigin-RevId: 459712928
2022-07-08 02:55:36 -07:00
jax authors
7ffedb5815
Merge pull request #11400 from jakevdp:deprecate-treeutil
...
PiperOrigin-RevId: 459681801
2022-07-07 23:05:35 -07:00
jax authors
34fea3d496
Merge pull request #11408 from hawkinsp:sparseshard
...
PiperOrigin-RevId: 459647331
2022-07-07 18:23:20 -07:00
Peter Hawkins
1918d39765
Increase number of shards for GPU sparse_test to 20.
2022-07-07 21:14:25 -04:00
jax authors
a3e8ae4b1a
Merge pull request #11388 from jakevdp:fix-bool-weak
...
PiperOrigin-RevId: 459641851
2022-07-07 17:43:17 -07:00
Peter Hawkins
0b4b0ba072
Update minimum jaxlib version to 0.3.14.
2022-07-08 00:36:02 +00:00
Jake VanderPlas
adcf30ef6b
[sparse] remove deprecated bcoo_add_batch_dim utility
2022-07-07 16:57:36 -07:00
jax authors
44bd311ae7
Merge pull request #11403 from jakevdp:sparse-unary
...
PiperOrigin-RevId: 459634024
2022-07-07 16:57:23 -07:00
Jake VanderPlas
17de5e4840
jnp.diagonal: raise explicit error if ndim < 2
2022-07-07 16:36:40 -07:00
Jake VanderPlas
56d61d3f2d
BUG: ensure that boolean scalars are never marked weak
2022-07-07 15:41:23 -07:00