18626 Commits

Author SHA1 Message Date
Yash Katariya
b52bcc1639 Reverts 3c07c10a9a55f9a32390dd10cf3f420bdf3f1ed8
PiperOrigin-RevId: 590700623
jax-v0.4.22 jaxlib-v0.4.22 jax-v0.4.22-rc2
2023-12-13 13:45:14 -08:00
Yash Katariya
e8888065c0 Fix the pjit flakey test. The test was weirdly written in the first place. The current suspicion is that x.copy() inside jit made it flakey. Also delete some duplicate tests from the time of migrating to jax.Array
PiperOrigin-RevId: 590692653
2023-12-13 13:15:30 -08:00
Yash Katariya
2e633522a0 Start jax and jaxlib 0.4.22 release
PiperOrigin-RevId: 590686003
jax-v0.4.22-rc
2023-12-13 12:57:23 -08:00
Yash Katariya
3c07c10a9a Remove the `jax_require_devices_during_lowering flag since it was temporary. Added the semi-breaking change to Changelog.md.
PiperOrigin-RevId: 590684939
2023-12-13 12:48:48 -08:00
jax authors
0281596ec9 Merge pull request #18869 from 8bitmp3:small-fix-jax-pytrees
PiperOrigin-RevId: 590680938
2023-12-13 12:40:06 -08:00
jax authors
d269c581f0 Merge pull request #18582 from 8bitmp3:jax-docs-debugging-101
PiperOrigin-RevId: 590680437
2023-12-13 12:31:58 -08:00
jax authors
4459991d55 Merge pull request #18961 from mattjj:issue18955
PiperOrigin-RevId: 590655131
2023-12-13 11:04:24 -08:00
jax authors
196c97fa0c Merge pull request #18949 from froystig:seed-offset
PiperOrigin-RevId: 590637382
2023-12-13 10:18:40 -08:00
jax authors
5104c6b098 Merge pull request #18951 from mattjj:shmap-varargs-error
PiperOrigin-RevId: 590636629
2023-12-13 10:10:18 -08:00
Matthew Johnson
4ba6bd5108 [shard-map] register cumsum et al with generic rules
fixes #18955
2023-12-13 09:54:01 -08:00
Matthew Johnson
2bff2f4094 [shard-map] fix varargs error message bug
see #18823

Co-authored-by: Chase Roberts <chaser@nvidia.com>
2023-12-13 09:40:39 -08:00
jax authors
77c08b4ada Update XLA dependency to use revision
a487d8ba5d.

PiperOrigin-RevId: 590595972
2023-12-13 07:49:45 -08:00
jax authors
fff7300be4 Update XLA dependency to use revision
a5335a8045.

PiperOrigin-RevId: 590567887
2023-12-13 05:50:23 -08:00
jax authors
c3dba36814 Merge pull request #18933 from superbobry:pyupgrade
PiperOrigin-RevId: 590554955
2023-12-13 04:58:50 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
jax authors
4991aac30f Import new version of Triton
PiperOrigin-RevId: 590496513
2023-12-13 01:24:24 -08:00
jax authors
4183f29cd1 Update XLA dependency to use revision
84d6deeb74.

PiperOrigin-RevId: 590456513
2023-12-12 22:34:55 -08:00
Roy Frostig
671790730e introduce a config flag to control a random seed offset 2023-12-12 18:31:07 -08:00
8bitmp3
7bb4b0fab4 Update working-with-pytrees.md 2023-12-13 00:40:42 +00:00
jax authors
32c99f627e Remove the old cache-key generation code.
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.

Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.

Testing: test workload.
PiperOrigin-RevId: 590378036
2023-12-12 16:34:32 -08:00
8bitmp3
14407b9a06 Upgrade JAX debugging doc 2023-12-13 00:19:41 +00:00
jax authors
05df8750ce Merge pull request #18946 from jakevdp:dynamic-shape-error
PiperOrigin-RevId: 590343483
2023-12-12 14:48:54 -08:00
Jieying Luo
4fe9e59644 Add jax[cuda12] install variation for using cuda plugin.
PiperOrigin-RevId: 590342149
2023-12-12 14:40:54 -08:00
Jieying Luo
7305b64fa6 Add cuda to _nonexperimental_plugins. It passed existing GPU tests, and have set up continuous testing job.
PiperOrigin-RevId: 590328635
2023-12-12 14:08:48 -08:00
Roy Frostig
8dcc079dd8 increase sample sizes on noisy random tests
... revealed by the partitionable Threefry upgrade; see
https://github.com/google/jax/discussions/18480

PiperOrigin-RevId: 590327421
2023-12-12 14:00:03 -08:00
Jake VanderPlas
a1ee8c1743 Improve shape validation when jax_dynamic_shapes=True 2023-12-12 13:58:46 -08:00
Yash Katariya
f210b0f95a Add a flag jax_require_devices_during_lowering to control if physical devices are passed during lowering to stablehlo. This is temporary to unblock nvidia.
PiperOrigin-RevId: 590318918
2023-12-12 13:34:09 -08:00
jax authors
1851447b3c Merge pull request #18928 from jakevdp:jnp-linalg
PiperOrigin-RevId: 590302414
2023-12-12 12:49:30 -08:00
Jake VanderPlas
fe2ad89209 array api: add jnp.linalg.cross & jnp.linalg.outer 2023-12-12 11:22:31 -08:00
jax authors
2fa90e1d43 Merge pull request #18943 from jakevdp:xla-bridge
PiperOrigin-RevId: 590269567
2023-12-12 11:21:24 -08:00
jax authors
616f4d29bb Merge pull request #18888 from superbobry:pp-improvement
PiperOrigin-RevId: 590269555
2023-12-12 11:12:42 -08:00
Jake VanderPlas
796449e331 Add missing __future__ import 2023-12-12 10:31:52 -08:00
Jieying Luo
8128b54e15 Increase test timeout for Cloud TPU nightly CI.
Recent runs got timeout when it is close to finish. https://github.com/google/jax/actions/runs/7182549669/job/19559323926

PiperOrigin-RevId: 590253735
2023-12-12 10:30:09 -08:00
jax authors
94d58b7270 mesh_utils.create_hybrid_device_mesh: make sorting granules by key user configurable.
When sorting by granule key is disabled, the granules are used to create the mesh in the order in which they appear in the sequence of devices.

PiperOrigin-RevId: 590228169
2023-12-12 09:16:41 -08:00
Sergei Lebedev
840abfb7ab The pretty printer now de-duplicates identical jaxprs
This compresses the output e.g. when a jitted function is called repeatedly
in a Python loop.
2023-12-12 17:14:43 +00:00
George Necula
b077483bfa [export] Add support for serialization and deserialization of Exported
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.

Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.

Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).

In the process of implementing this we have done some small cleanup of the Exported structure:

  * renamed serialization_version to mlir_module_serialization_version
  * renamed disabled_checks to disabled_safety_checks

This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.

There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.

PiperOrigin-RevId: 590078785
2023-12-11 23:23:02 -08:00
David Majnemer
4347950d9d Internal only changes.
PiperOrigin-RevId: 590071197
2023-12-11 22:46:15 -08:00
jax authors
d08d2dcd2c Update XLA dependency to use revision
0ae69946cf.

PiperOrigin-RevId: 590063667
2023-12-11 22:11:05 -08:00
jax authors
b225c86f10 Merge pull request #18262 from jakevdp:key-reuse-jaxpr
PiperOrigin-RevId: 589913404
2023-12-11 12:46:55 -08:00
jax authors
3b2902d00e Merge pull request #18925 from gnecula:drop_enable_xla_tests
PiperOrigin-RevId: 589910325
2023-12-11 12:38:52 -08:00
Kevin Gleason
184e3a8800 Integrate StableHLO at openxla/stablehlo@ab709fe4
PiperOrigin-RevId: 589908773
2023-12-11 12:30:50 -08:00
Peter Hawkins
384e29e30d [XLA:Python] Add support for explicitly creating the gloo tcp context.
Pass the context to the CPU client explicitly.

PiperOrigin-RevId: 589898821
2023-12-11 12:05:51 -08:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00
Jevin Jiang
3651d4c4f5 [XLA:Mosaic] Support tpu.bitcast for i16, i8.
PiperOrigin-RevId: 589881484
2023-12-11 11:14:16 -08:00
jax authors
fc1b4b0496 Merge pull request #18924 from jakevdp:ci-ratchet
PiperOrigin-RevId: 589879388
2023-12-11 11:06:08 -08:00
George Necula
f4a76bd7a9 Remove enable_xla=False test harnesses from export test
This test uses native serialization and thus enable_xla=False
is irrelevant. This drops 2000+ tests.
2023-12-11 20:02:50 +01:00
Jake VanderPlas
b1e9afcaa6 CI: use ratchet to pin actions commits 2023-12-11 10:56:16 -08:00
jax authors
a721ae2a69 Merge pull request #18923 from jakevdp:ci-setup-python
PiperOrigin-RevId: 589869661
2023-12-11 10:36:58 -08:00
George Necula
b5f852f7a1 [jax2tf] Remove backwards compatibility shims (jax2tf.shape_poly)
A while ago we moved the native exporting code out of jax2tf, to
jax.experimental.export, but we left behind shape_poly.py shim for
backwards compatibility. Now we remove this shim.

Replace previous

  from jax.experimental.jax2tf import shape_poly

with

  from jax.experimental.export import shape_poly

PiperOrigin-RevId: 589850402
2023-12-11 09:34:44 -08:00
Jake VanderPlas
bdd8318e36 CI: update actions/setup-python to v5 2023-12-11 09:23:30 -08:00