11814 Commits

Author SHA1 Message Date
Yash Katariya
b52bcc1639 Reverts 3c07c10a9a55f9a32390dd10cf3f420bdf3f1ed8
PiperOrigin-RevId: 590700623
2023-12-13 13:45:14 -08:00
Yash Katariya
3c07c10a9a Remove the `jax_require_devices_during_lowering flag since it was temporary. Added the semi-breaking change to Changelog.md.
PiperOrigin-RevId: 590684939
2023-12-13 12:48:48 -08:00
jax authors
4459991d55 Merge pull request #18961 from mattjj:issue18955
PiperOrigin-RevId: 590655131
2023-12-13 11:04:24 -08:00
jax authors
196c97fa0c Merge pull request #18949 from froystig:seed-offset
PiperOrigin-RevId: 590637382
2023-12-13 10:18:40 -08:00
jax authors
5104c6b098 Merge pull request #18951 from mattjj:shmap-varargs-error
PiperOrigin-RevId: 590636629
2023-12-13 10:10:18 -08:00
Matthew Johnson
4ba6bd5108 [shard-map] register cumsum et al with generic rules
fixes #18955
2023-12-13 09:54:01 -08:00
Matthew Johnson
2bff2f4094 [shard-map] fix varargs error message bug
see #18823

Co-authored-by: Chase Roberts <chaser@nvidia.com>
2023-12-13 09:40:39 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
jax authors
4991aac30f Import new version of Triton
PiperOrigin-RevId: 590496513
2023-12-13 01:24:24 -08:00
Roy Frostig
671790730e introduce a config flag to control a random seed offset 2023-12-12 18:31:07 -08:00
jax authors
32c99f627e Remove the old cache-key generation code.
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.

Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.

Testing: test workload.
PiperOrigin-RevId: 590378036
2023-12-12 16:34:32 -08:00
jax authors
05df8750ce Merge pull request #18946 from jakevdp:dynamic-shape-error
PiperOrigin-RevId: 590343483
2023-12-12 14:48:54 -08:00
Jieying Luo
7305b64fa6 Add cuda to _nonexperimental_plugins. It passed existing GPU tests, and have set up continuous testing job.
PiperOrigin-RevId: 590328635
2023-12-12 14:08:48 -08:00
Jake VanderPlas
a1ee8c1743 Improve shape validation when jax_dynamic_shapes=True 2023-12-12 13:58:46 -08:00
Yash Katariya
f210b0f95a Add a flag jax_require_devices_during_lowering to control if physical devices are passed during lowering to stablehlo. This is temporary to unblock nvidia.
PiperOrigin-RevId: 590318918
2023-12-12 13:34:09 -08:00
Jake VanderPlas
fe2ad89209 array api: add jnp.linalg.cross & jnp.linalg.outer 2023-12-12 11:22:31 -08:00
jax authors
2fa90e1d43 Merge pull request #18943 from jakevdp:xla-bridge
PiperOrigin-RevId: 590269567
2023-12-12 11:21:24 -08:00
jax authors
616f4d29bb Merge pull request #18888 from superbobry:pp-improvement
PiperOrigin-RevId: 590269555
2023-12-12 11:12:42 -08:00
Jake VanderPlas
796449e331 Add missing __future__ import 2023-12-12 10:31:52 -08:00
jax authors
94d58b7270 mesh_utils.create_hybrid_device_mesh: make sorting granules by key user configurable.
When sorting by granule key is disabled, the granules are used to create the mesh in the order in which they appear in the sequence of devices.

PiperOrigin-RevId: 590228169
2023-12-12 09:16:41 -08:00
Sergei Lebedev
840abfb7ab The pretty printer now de-duplicates identical jaxprs
This compresses the output e.g. when a jitted function is called repeatedly
in a Python loop.
2023-12-12 17:14:43 +00:00
George Necula
b077483bfa [export] Add support for serialization and deserialization of Exported
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.

Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.

Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).

In the process of implementing this we have done some small cleanup of the Exported structure:

  * renamed serialization_version to mlir_module_serialization_version
  * renamed disabled_checks to disabled_safety_checks

This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.

There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.

PiperOrigin-RevId: 590078785
2023-12-11 23:23:02 -08:00
jax authors
b225c86f10 Merge pull request #18262 from jakevdp:key-reuse-jaxpr
PiperOrigin-RevId: 589913404
2023-12-11 12:46:55 -08:00
Kevin Gleason
184e3a8800 Integrate StableHLO at openxla/stablehlo@ab709fe4
PiperOrigin-RevId: 589908773
2023-12-11 12:30:50 -08:00
Peter Hawkins
384e29e30d [XLA:Python] Add support for explicitly creating the gloo tcp context.
Pass the context to the CPU client explicitly.

PiperOrigin-RevId: 589898821
2023-12-11 12:05:51 -08:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00
George Necula
b5f852f7a1 [jax2tf] Remove backwards compatibility shims (jax2tf.shape_poly)
A while ago we moved the native exporting code out of jax2tf, to
jax.experimental.export, but we left behind shape_poly.py shim for
backwards compatibility. Now we remove this shim.

Replace previous

  from jax.experimental.jax2tf import shape_poly

with

  from jax.experimental.export import shape_poly

PiperOrigin-RevId: 589850402
2023-12-11 09:34:44 -08:00
jax authors
9d0a9918b5 Add unroll=True to all the calls of fori_loop
PiperOrigin-RevId: 589842015
2023-12-11 09:04:45 -08:00
Sergei Lebedev
352e10ed68 Effects is now an immutable set
This allows safely using `no_effects` as a default value.

PiperOrigin-RevId: 589836905
2023-12-11 08:45:52 -08:00
Matthew Johnson
25eb913d10 don't call lax.xeinsum from jnp.einsum when str contains '{'
can still call lax.xeinsum directly
2023-12-09 11:11:31 -08:00
Matthew Johnson
9a1a09c28b remove _use_xeinsum from jnp.einsum api
can still call jnp.einsum with a '{' in the spec string to trigger xeinsum, or
just call lax.xeinsum directly
2023-12-09 10:53:22 -08:00
George Necula
e50ef1b383 [jax2tf] Remove backwards compatibility shims (jax2tf.jax_export)
A while ago we moved the native exporting code out of jax2tf, to
jax.experimental.export, but we left behind jax_export.py shim for
backwards compatibility. Now we remove this shim.

Replace previous

  from jax.experimental.jax2tf import jax_export

with

  from jax.experimental.export import export

PiperOrigin-RevId: 589409387
2023-12-09 07:53:34 -08:00
jax authors
709564ab78 Move jit to the callsite.
PiperOrigin-RevId: 589328135
2023-12-08 22:19:56 -08:00
Yash Katariya
10f6a35f83 Add a registry for primitives that require device_assignment during lowering
PiperOrigin-RevId: 589272990
2023-12-08 16:31:41 -08:00
Yash Katariya
5fb8ceca73 Make lowering oblivious to real physical devices. Instead cache lowering on HloSharding only (which is based on logical device numbers)
Make an exception for callbacks and custom_partitioning because they need access to device_assignment during lowering.

PiperOrigin-RevId: 589244695
2023-12-08 14:36:09 -08:00
jax authors
809a37c567 Merge pull request #18881 from superbobry:pyupgrade
PiperOrigin-RevId: 589191161
2023-12-08 11:20:50 -08:00
jax authors
e686ed7e93 Merge pull request #18870 from jakevdp:array-api-tests
PiperOrigin-RevId: 589184048
2023-12-08 10:58:23 -08:00
jax authors
3976c00f58 Merge pull request #18885 from mattjj:defaultdict-tree-flatten-order
PiperOrigin-RevId: 589179088
2023-12-08 10:42:27 -08:00
jax authors
d3f4bbfdd0 Fix cache_used metric implementation and test.
The cache_used metric is incremented once per task and is
used to determine how many tasks are using the Jax
compilation cache. The current implementation and unit
test are not thread safe. This results in the test
failing when unit tests are executed concurrently.

The fix is to make the implementation thread safe and
to update the test to examine the delta in the metric.

Testing: Cloud TPU VM testing; test workload.
PiperOrigin-RevId: 589174850
2023-12-08 10:30:22 -08:00
Matthew Johnson
d2fcf27f93 must flatten defaultdict in key-sorted order, like regular dicts 2023-12-08 10:10:09 -08:00
jax authors
d874435dc1 Merge pull request #18841 from DEKHTIARJonathan:patch-1
PiperOrigin-RevId: 589150737
2023-12-08 09:15:01 -08:00
Jake VanderPlas
4b1077da09 array-api: update test suite & fix nonzero 2023-12-08 08:55:57 -08:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
Sharad Vikram
7af1c149f5 [Pallas/Mosaic] Lower lax.fori_loops to *rolled* loops.
Note that this is a breaking change! Current uses of `lax.fori_loop` inside of kernels
should instead pass `unroll=True` (loops were being unrolled by default and
we are switching that with this change).

PiperOrigin-RevId: 589017485
2023-12-07 22:55:12 -08:00
Parker Schuh
ffb115bf2e Add annotation for donation. This configures XLA's AddBufferDonor directly via donated_args instead of first configuring input_output_aliases. This is best effort anyways, but it works on TPU.
PiperOrigin-RevId: 588974683
2023-12-07 18:46:30 -08:00
jax authors
1189d61bc0 [Pallas] Fix batching rule for kernels with scratch inputs
Scratch inputs do not need a batching dimension.

PiperOrigin-RevId: 588921137
2023-12-07 15:10:12 -08:00
jax authors
e423347dda Declare magic port number for jax.distributed.initialize in cloud TPU environments.
PiperOrigin-RevId: 588920806
2023-12-07 15:02:04 -08:00
jax authors
c2f8e18016 Merge pull request #18862 from superbobry:pp-improvement
PiperOrigin-RevId: 588906413
2023-12-07 14:15:33 -08:00
jax authors
c4239fc05e Merge pull request #18797 from jakevdp:factorial
PiperOrigin-RevId: 588809589
2023-12-07 08:59:57 -08:00
Sergei Lebedev
ea158d3109 Print pjit name= before other params
The jaxpr sometimes gets pretty big, making it hard to see the name.
2023-12-07 16:54:07 +00:00