Yash Katariya
1c68577dcd
Prepare for emergency jax and jaxlib 0.4.23 release
...
PiperOrigin-RevId: 590780824
jax-v0.4.23
jax-v0.4.23-rc
2023-12-13 19:02:24 -08:00
Peter Hawkins
b392622647
Add patch to suppress XLA:GPU logging.
...
PiperOrigin-RevId: 590780227
2023-12-13 18:53:50 -08:00
Yash Katariya
25c16c0b78
Finish jax and jaxlib 0.4.22 release
...
PiperOrigin-RevId: 590775311
2023-12-13 18:26:47 -08:00
Roy Frostig
3380b9feee
split the random generalized normal test and skip its K-S half
...
It is key-sensitive and sometimes slow.
PiperOrigin-RevId: 590756597
2023-12-13 17:01:19 -08:00
jax authors
9198174f63
Merge pull request #18968 from 8bitmp3:update-thinking-in-jax
...
PiperOrigin-RevId: 590726455
2023-12-13 15:10:31 -08:00
8bitmp3
7edc80d635
Update thinking-in-jax working-with-pytrees
2023-12-13 22:57:46 +00:00
jax authors
29ed3cd426
Merge pull request #18581 from 8bitmp3:jax-docs-thinking-in-jax
...
PiperOrigin-RevId: 590701231
2023-12-13 13:53:09 -08:00
Yash Katariya
b52bcc1639
Reverts 3c07c10a9a55f9a32390dd10cf3f420bdf3f1ed8
...
PiperOrigin-RevId: 590700623
jax-v0.4.22
jaxlib-v0.4.22
jax-v0.4.22-rc2
2023-12-13 13:45:14 -08:00
8bitmp3
a4003bf0ae
Upgrade How to think in JAX
2023-12-13 21:39:17 +00:00
Yash Katariya
e8888065c0
Fix the pjit flakey test. The test was weirdly written in the first place. The current suspicion is that x.copy()
inside jit made it flakey. Also delete some duplicate tests from the time of migrating to jax.Array
...
PiperOrigin-RevId: 590692653
2023-12-13 13:15:30 -08:00
Yash Katariya
2e633522a0
Start jax and jaxlib 0.4.22 release
...
PiperOrigin-RevId: 590686003
jax-v0.4.22-rc
2023-12-13 12:57:23 -08:00
Yash Katariya
3c07c10a9a
Remove the `jax_require_devices_during_lowering flag since it was temporary. Added the semi-breaking change to Changelog.md.
...
PiperOrigin-RevId: 590684939
2023-12-13 12:48:48 -08:00
jax authors
0281596ec9
Merge pull request #18869 from 8bitmp3:small-fix-jax-pytrees
...
PiperOrigin-RevId: 590680938
2023-12-13 12:40:06 -08:00
jax authors
d269c581f0
Merge pull request #18582 from 8bitmp3:jax-docs-debugging-101
...
PiperOrigin-RevId: 590680437
2023-12-13 12:31:58 -08:00
jax authors
4459991d55
Merge pull request #18961 from mattjj:issue18955
...
PiperOrigin-RevId: 590655131
2023-12-13 11:04:24 -08:00
jax authors
196c97fa0c
Merge pull request #18949 from froystig:seed-offset
...
PiperOrigin-RevId: 590637382
2023-12-13 10:18:40 -08:00
jax authors
5104c6b098
Merge pull request #18951 from mattjj:shmap-varargs-error
...
PiperOrigin-RevId: 590636629
2023-12-13 10:10:18 -08:00
Matthew Johnson
4ba6bd5108
[shard-map] register cumsum et al with generic rules
...
fixes #18955
2023-12-13 09:54:01 -08:00
Matthew Johnson
2bff2f4094
[shard-map] fix varargs error message bug
...
see #18823
Co-authored-by: Chase Roberts <chaser@nvidia.com>
2023-12-13 09:40:39 -08:00
jax authors
77c08b4ada
Update XLA dependency to use revision
...
a487d8ba5d
.
PiperOrigin-RevId: 590595972
2023-12-13 07:49:45 -08:00
jax authors
fff7300be4
Update XLA dependency to use revision
...
a5335a8045
.
PiperOrigin-RevId: 590567887
2023-12-13 05:50:23 -08:00
jax authors
c3dba36814
Merge pull request #18933 from superbobry:pyupgrade
...
PiperOrigin-RevId: 590554955
2023-12-13 04:58:50 -08:00
Sergei Lebedev
f936613b06
Upgrade remaining sources to Python 3.9
...
This PR is a follow up to #18881 .
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
jax authors
4991aac30f
Import new version of Triton
...
PiperOrigin-RevId: 590496513
2023-12-13 01:24:24 -08:00
jax authors
4183f29cd1
Update XLA dependency to use revision
...
84d6deeb74
.
PiperOrigin-RevId: 590456513
2023-12-12 22:34:55 -08:00
Roy Frostig
671790730e
introduce a config flag to control a random seed offset
2023-12-12 18:31:07 -08:00
8bitmp3
7bb4b0fab4
Update working-with-pytrees.md
2023-12-13 00:40:42 +00:00
jax authors
32c99f627e
Remove the old cache-key generation code.
...
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.
Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.
Testing: test workload.
PiperOrigin-RevId: 590378036
2023-12-12 16:34:32 -08:00
8bitmp3
14407b9a06
Upgrade JAX debugging doc
2023-12-13 00:19:41 +00:00
jax authors
05df8750ce
Merge pull request #18946 from jakevdp:dynamic-shape-error
...
PiperOrigin-RevId: 590343483
2023-12-12 14:48:54 -08:00
Jieying Luo
4fe9e59644
Add jax[cuda12] install variation for using cuda plugin.
...
PiperOrigin-RevId: 590342149
2023-12-12 14:40:54 -08:00
Jieying Luo
7305b64fa6
Add cuda to _nonexperimental_plugins. It passed existing GPU tests, and have set up continuous testing job.
...
PiperOrigin-RevId: 590328635
2023-12-12 14:08:48 -08:00
Roy Frostig
8dcc079dd8
increase sample sizes on noisy random tests
...
... revealed by the partitionable Threefry upgrade; see
https://github.com/google/jax/discussions/18480
PiperOrigin-RevId: 590327421
2023-12-12 14:00:03 -08:00
Jake VanderPlas
a1ee8c1743
Improve shape validation when jax_dynamic_shapes=True
2023-12-12 13:58:46 -08:00
Yash Katariya
f210b0f95a
Add a flag jax_require_devices_during_lowering
to control if physical devices are passed during lowering to stablehlo. This is temporary to unblock nvidia.
...
PiperOrigin-RevId: 590318918
2023-12-12 13:34:09 -08:00
jax authors
1851447b3c
Merge pull request #18928 from jakevdp:jnp-linalg
...
PiperOrigin-RevId: 590302414
2023-12-12 12:49:30 -08:00
Jake VanderPlas
fe2ad89209
array api: add jnp.linalg.cross & jnp.linalg.outer
2023-12-12 11:22:31 -08:00
jax authors
2fa90e1d43
Merge pull request #18943 from jakevdp:xla-bridge
...
PiperOrigin-RevId: 590269567
2023-12-12 11:21:24 -08:00
jax authors
616f4d29bb
Merge pull request #18888 from superbobry:pp-improvement
...
PiperOrigin-RevId: 590269555
2023-12-12 11:12:42 -08:00
Jake VanderPlas
796449e331
Add missing __future__ import
2023-12-12 10:31:52 -08:00
Jieying Luo
8128b54e15
Increase test timeout for Cloud TPU nightly CI.
...
Recent runs got timeout when it is close to finish. https://github.com/google/jax/actions/runs/7182549669/job/19559323926
PiperOrigin-RevId: 590253735
2023-12-12 10:30:09 -08:00
jax authors
94d58b7270
mesh_utils.create_hybrid_device_mesh
: make sorting granules by key user configurable.
...
When sorting by granule key is disabled, the granules are used to create the mesh in the order in which they appear in the sequence of devices.
PiperOrigin-RevId: 590228169
2023-12-12 09:16:41 -08:00
Sergei Lebedev
840abfb7ab
The pretty printer now de-duplicates identical jaxprs
...
This compresses the output e.g. when a jitted function is called repeatedly
in a Python loop.
2023-12-12 17:14:43 +00:00
George Necula
b077483bfa
[export] Add support for serialization and deserialization of Exported
...
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.
Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.
Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).
In the process of implementing this we have done some small cleanup of the Exported structure:
* renamed serialization_version to mlir_module_serialization_version
* renamed disabled_checks to disabled_safety_checks
This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.
There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.
PiperOrigin-RevId: 590078785
2023-12-11 23:23:02 -08:00
David Majnemer
4347950d9d
Internal only changes.
...
PiperOrigin-RevId: 590071197
2023-12-11 22:46:15 -08:00
jax authors
d08d2dcd2c
Update XLA dependency to use revision
...
0ae69946cf
.
PiperOrigin-RevId: 590063667
2023-12-11 22:11:05 -08:00
jax authors
b225c86f10
Merge pull request #18262 from jakevdp:key-reuse-jaxpr
...
PiperOrigin-RevId: 589913404
2023-12-11 12:46:55 -08:00
jax authors
3b2902d00e
Merge pull request #18925 from gnecula:drop_enable_xla_tests
...
PiperOrigin-RevId: 589910325
2023-12-11 12:38:52 -08:00
Kevin Gleason
184e3a8800
Integrate StableHLO at openxla/stablehlo@ab709fe4
...
PiperOrigin-RevId: 589908773
2023-12-11 12:30:50 -08:00
Peter Hawkins
384e29e30d
[XLA:Python] Add support for explicitly creating the gloo tcp context.
...
Pass the context to the CPU client explicitly.
PiperOrigin-RevId: 589898821
2023-12-11 12:05:51 -08:00