121 Commits

Author SHA1 Message Date
Yash Katariya
00d8297071 [sharding_in_types] Set the sharding_in_types config to True. This is a purely internal change and shouldn't affect any public APIs.
Some caveats of enabling sharding-in-types by default are that we'll see tracing cache misses which will lead to lowering cache miss and compilation cache misses in the **following cases**: (but persistent compilation cache is not affected so we'll see a cache hit there)

1. Call `jitted_f(arr_ns)` with an array on `NamedSharding` and again `jitted_f(arr_ps)` with an array of same shape and dtype but now with `PositionalSharding`
    * This leads to a tracing cache miss because on the second call, the aval has no sharding since it's PositionalSharding. This applies to calling with any sharding other than NamedSharding

2. `jitted_f = jit(f, in_shardings=ns)`. Call `jitted_f(sharded_arr)` and then on the second call you pass a numpy array `jitted_f(numpy_arr)`
   * This also leads to a cache miss because the avals currently don't look at in_shardings because the semantics of in_shardings is complicated and I don't think we should change the aval based on in_shardings.

**The solution in both cases is make sure to pass the array sharded on the same mesh during both calls to jit.**

PiperOrigin-RevId: 728361493
2025-02-18 14:35:14 -08:00
Ayaka
b6361b3e76 Minor format cleanup
Remove 2 redundant whitespaces mentioned in https://github.com/jax-ml/jax/pull/25056#pullrequestreview-2615387492.

PiperOrigin-RevId: 727264168
2025-02-15 04:56:27 -08:00
chaserileyroberts
60f0184637 Added stream annotation support via @compute_on('gpu_stream:#') 2025-02-13 07:15:18 +00:00
Bart Chrzaszcz
6ed4c29c8a #sdy enable test_mem_kind_donation_pinned_host for Shardy.
PiperOrigin-RevId: 725142884
2025-02-10 03:15:25 -08:00
Justin Fu
b6acb9cb7a Fix remat bug on primitives with multiple outputs.
Addresses https://github.com/jax-ml/jax/issues/25841

PiperOrigin-RevId: 715434084
2025-01-14 10:26:58 -08:00
Adam Paszke
ad00ec1dc9 [Mosaic TPU] Guard tests for new features by the libtpu version
PiperOrigin-RevId: 707875450
2024-12-19 05:04:09 -08:00
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Bart Chrzaszcz
6f69774c00 #sdy enable test_compute_offload_mesh_with_linear_layout for Shardy.
PiperOrigin-RevId: 704301465
2024-12-09 08:46:48 -08:00
Peter Hawkins
79318a08cf Remove dead code after minimum jaxlib version bump to v0.4.36.
New minimum xla_extension_version is 299, and the new mlir_api_version is 57.

PiperOrigin-RevId: 704280856
2024-12-09 07:35:05 -08:00
jax authors
f160df0444 More thorough propagation of host linear layout. Currently linear layout on host
can only originate from entry computation. Propagation only goes strickly down/up.
More needs to be done later if such layout can original from host compute itself.
Removed the temporary pattern match solution.

PiperOrigin-RevId: 702966364
2024-12-04 21:06:34 -08:00
jax authors
4d60db1741 Add test_compute_on_host_shared_sharding in memories_test
PiperOrigin-RevId: 698250352
2024-11-19 21:33:27 -08:00
Yash Katariya
87ce0cbb00 Make GPU work with copy=True and device_put since same device pinned_host -> pinned_host copy is possible.
PiperOrigin-RevId: 694713334
2024-11-08 18:29:47 -08:00
Yash Katariya
0bb30f0777 Propagate CopySemantics from python to C++ transfer APIs so that device_put works correctly in presence of copy/donate options that user specified.
This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT.

This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily.

Fixes https://github.com/jax-ml/jax/issues/24521

PiperOrigin-RevId: 694274616
2024-11-07 15:51:54 -08:00
Yash Katariya
6c8e56f43f Finish 0.4.35 release by removing dead code
PiperOrigin-RevId: 689396609
2024-10-24 08:45:43 -07:00
jax authors
dd5426301a Allow simple host call that uses host tensor as parameter/result in
linear layout. This cl only handles very simple host call patterns.
A more thorough implementation of propagation of T(1)S(5) will be done
later.

This cl doesn't handle host call that passes/returns tensors that
live on device with linear layout either, which will also be impelmented
separately.

PiperOrigin-RevId: 687113203
2024-10-17 18:22:46 -07:00
Bart Chrzaszcz
fb32841b1b #sdy add JAX Shardy support for memories.
PiperOrigin-RevId: 684867097
2024-10-11 09:44:24 -07:00
Yash Katariya
ce2b49787f Skip test_ragged_copy_on_host if xla_extension_version < 290
PiperOrigin-RevId: 683326972
2024-10-07 14:30:34 -07:00
jax authors
e90487e906 Host Offloading: Process "MoveToHost" instructions in the order they are executed.
- This ensures we process "MoveToHost" instructions that reside at the beginning of a host memory instruction offload chain.
- This avoids processing MoveToHost instructions out of order, creating invalid instructions within a host memory instruction offload chain.

PiperOrigin-RevId: 682448060
2024-10-04 14:17:36 -07:00
jax authors
291619c291 Allow custom call computations to contain subcomputations
PiperOrigin-RevId: 682429391
2024-10-04 13:22:14 -07:00
Peter Hawkins
d3f63a66b8 Remove code to support jaxlib <= 0.4.33. 2024-10-04 11:39:05 -04:00
jax authors
be76fb6abf Add host compute offload test: test_offload_take_host.
PiperOrigin-RevId: 682088063
2024-10-03 17:06:46 -07:00
Yash Katariya
203cda6f98 Move test_aot_device_implicit_transfer to pjit_test.py
This test is not specific to compute offload and is more relevant to pjit.

PiperOrigin-RevId: 680599882
2024-09-30 09:10:17 -07:00
Jane Liu
57bef447c6 Enable weight offloading tests that are supported on GPUs now 2024-09-26 23:26:27 -07:00
jax authors
5a1549cccf Merge pull request #23853 from zhenying-liu:remat-scan
PiperOrigin-RevId: 679365040
2024-09-26 18:12:30 -07:00
Jane Liu
adaf54a4bb enable the activation offloading test 2024-09-23 23:07:03 -07:00
Yash Katariya
8b5b71750b Fix jaxpr equation context propagation in jaxpr equations when inline=True.
PiperOrigin-RevId: 675754808
2024-09-17 16:40:36 -07:00
Yash Katariya
a144eb234b Add compute_on_context_manager to thread local jit state. This is to avoid getting false cache hits
PiperOrigin-RevId: 671507042
2024-09-05 14:16:13 -07:00
jax authors
97db78ba24 Adds test_compute_offload_with_donation in memories_test
PiperOrigin-RevId: 671410527
2024-09-05 09:57:59 -07:00
Yash Katariya
e1b497078e Rename jtu.create_global_mesh to jtu.create_mesh and use jax.make_mesh inside jtu.create_mesh to get maximum test coverage of the new API.
PiperOrigin-RevId: 670744047
2024-09-03 16:23:07 -07:00
Yash Katariya
969dd89040 Reverts changelist 668370165
PiperOrigin-RevId: 669670355
2024-08-31 08:44:23 -07:00
Yash Katariya
164b884f33 Fix failing tests in CI
PiperOrigin-RevId: 669357019
2024-08-30 09:49:58 -07:00
Yash Katariya
4533aeaf26 Remove jax_enable_memories conditionals from JAX and remove it from tests too.
PiperOrigin-RevId: 662322241
2024-08-12 19:15:43 -07:00
Yash Katariya
c08656c61d [Rollback] We still want to allow multiple meshes in the user program
Reverts dd958adc39550d2758ecdb13809c6d85df7658a2

PiperOrigin-RevId: 661537233
2024-08-09 23:17:46 -07:00
Yash Katariya
abc9ba00e9 Rename count_jit_and_pmap_compiles to count_jit_and_pmap_lowerings
PiperOrigin-RevId: 661496993
2024-08-09 20:03:43 -07:00
Jieying Luo
ccc27a7a5f Remove PJRT version check in memories_test.py that is no longer needed.
0.43 is the version at 2024 Feb. Cloud TPU CI uses 20240228 so it should contain the PJRT C API needed for the test d3b6066f91/.github/workflows/cloud-tpu-ci-nightly.yml (L35).

PiperOrigin-RevId: 660869710
2024-08-08 09:35:41 -07:00
Yash Katariya
dd958adc39 Add mesh_shape to the lowering context. This is to allow custom partitioning to not depend on the mesh context manager to return NamedShardings even if the arguments have NamedShardings on them.
Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions.

Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before`

In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses

PiperOrigin-RevId: 660177423
2024-08-06 18:35:44 -07:00
jax authors
9074e8544f Add test for zero-sized host memory parameter
PiperOrigin-RevId: 660097039
2024-08-06 14:31:41 -07:00
Kanglan Tang
ae541203bc Skip flaky test_weight_offload_with_dp_on_output test on GPU backend.
PiperOrigin-RevId: 660057950
2024-08-06 12:35:53 -07:00
Yash Katariya
489fbc0ed5 Add a test for streaming in closed over constants from host to device
PiperOrigin-RevId: 659711557
2024-08-05 16:00:45 -07:00
Yash Katariya
e6851e6b22 Fix the AOT check for sharding consistency which skipped checking the devices of the sharding.
So before for TPU compiled computation, a user could have passed in a committed array on CPU and JAX wouldn't have errored which is wrong.

This change fixes that. Also `is_equivalent_to` should check for devices, HloSharding and memory_kind (so removing the redundant `memory_kind` check too).

PiperOrigin-RevId: 658794885
2024-08-02 08:15:32 -07:00
Yash Katariya
ec6514cc08 Add donation test for a pure compute offloaded computation
PiperOrigin-RevId: 658187714
2024-07-31 16:49:27 -07:00
Kanglan Tang
a7e071ec42 Skip flaky memories tests on GPU backend.
PiperOrigin-RevId: 658177202
2024-07-31 16:12:52 -07:00
jax authors
cc212457d2 Merge pull request #22481 from zhenying-liu:offloading
PiperOrigin-RevId: 657413977
2024-07-29 19:43:35 -07:00
jax authors
b34c96d48b Fix host compute annotations in the presence of copies.
If an explicit copy operation was present in a function offload to the host, the adding of compute annotations during lowering failed since the operation lowered to a block argument and the "owner" of the block argument is the block - not an operation. Fixed by filtering out this case when adding compute annotations.

PiperOrigin-RevId: 653151844
2024-07-17 02:29:16 -07:00
Jane Liu
c774d7b29e Enable the passed tests for memories and layout 2024-07-17 05:37:39 +00:00
Peter Hawkins
f488c4cc31 Disable some tests that fail on Cloud TPU. 2024-07-15 16:00:58 -04:00
Eugene Zhulenev
de6339569d [jax] Add a test that runs reduction on host
Check that nested computations generated by offloaded computation are correctly outlined into the host module.

PiperOrigin-RevId: 647771541
2024-06-28 12:53:10 -07:00
Yash Katariya
ba88601b9c remove the cloud TPU disable of memories_test.py because everything should work now
PiperOrigin-RevId: 647711611
2024-06-28 09:43:25 -07:00
Yash Katariya
ba5b3c7941 Make lowering aware of compute_type so that we choose the correct lowering code.
For example, if you have 2 `lax.linalg.qr` calls (one on `TPU` and another on `device_host`), we should lower to the `device_host` qr decomposition to CPU.

PiperOrigin-RevId: 647705015
2024-06-28 09:21:34 -07:00
Yash Katariya
4e31272c34 Write a test for offloading computations to host inside a shard_map
PiperOrigin-RevId: 647119210
2024-06-26 16:39:58 -07:00