131 Commits

Author SHA1 Message Date
jax authors
ff5a2e8c91 Enable test_scan_offload in memories_test.
PiperOrigin-RevId: 742840628
2025-04-01 14:26:26 -07:00
Yash Katariya
e8038501d0 Fix a bug where jit was forwarding inputs to outputs even when donation was True for that inputs. This caused the output to be marked as deleted since the input was being forwarded to the output.
Since this functionality was added for a dynamic shapes experiment, only enable it when dynamic_shapes config is True.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 740942785
2025-03-26 16:31:11 -07:00
Tom Natan
8bbd738df1 [JAX Shardy] #sdy Unskip another test that is now passing
PiperOrigin-RevId: 738814411
2025-03-20 08:37:29 -07:00
Tom Natan
c098b363fb [JAX Shardy] Unskip stream annotation test when shardy is enabled, since the underlying issue is now resolved.
PiperOrigin-RevId: 738802372
2025-03-20 08:01:52 -07:00
jax authors
bb274f1311 Merge pull request #27274 from yliu120:new_fix_annotation
PiperOrigin-RevId: 738665000
2025-03-19 22:03:16 -07:00
Yunlong Liu
258ed1b0a5 Fixes the stream annotation compute on box. 2025-03-20 04:14:19 +00:00
Yash Katariya
76d9890bb7 Run the stream annotation tests on 2 devices so that it can be tested in TAP
PiperOrigin-RevId: 738113725
2025-03-18 13:01:48 -07:00
Shraiysh
cb2eb15739 PR #22800: Change the default value of print_operand_shape_ to false and print_large_constants_ to true.
Imported from GitHub PR https://github.com/openxla/xla/pull/22800

Operand shape in long hlo text adds redundant information, which shouldn't be required. Changing the default value to off.

The large constants were also printed earlier by default print options, and it is required for parsability and reproducibility. Turning this on by default. This is still controlled by debug option and the default value of that flag disables the large constants, and that behavior is not changed. Just the default print options change here.

Copybara import of the project:

--
e30dea20489b3fb4d03d373fec0391d69486f4aa by Shraiysh Vaishay <svaishay@nvidia.com>:

Change the default value of print_operand_shape_ to false and print_large_constants_ to true.

Operand shape in long hlo text adds redundant information, which
shouldn't be required. Changing the default value to off.

The large constants were also printed earlier by default print options,
and it is required for parsability and reproducibility. Turning this on by default.
This is still controlled by debug option and the default value of that
flag disables the large constants, and that behavior is not changed. Just the
default print options change here.

--
7008af0dd0ce342ecbe9475f1d0e277319f1705a by Shraiysh Vaishay <svaishay@nvidia.com>:

Handle tests

--
b22d5f95cfb7e15f930a2198279a76c38593cc53 by Shraiysh Vaishay <svaishay@nvidia.com>:

Fix more tests

--
d51579cae7359c6426a87ad4a7ff1b4b0c80f74a by Shraiysh Vaishay <svaishay@nvidia.com>:

Fix more tests

Merging this change closes #22800

PiperOrigin-RevId: 735690598
2025-03-11 03:17:04 -07:00
Yash Katariya
07c4c03a05 Remove the skip for test_output_streaming_inside_scan
PiperOrigin-RevId: 733070842
2025-03-03 14:54:03 -08:00
Peter Hawkins
66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00
Yash Katariya
00d8297071 [sharding_in_types] Set the sharding_in_types config to True. This is a purely internal change and shouldn't affect any public APIs.
Some caveats of enabling sharding-in-types by default are that we'll see tracing cache misses which will lead to lowering cache miss and compilation cache misses in the **following cases**: (but persistent compilation cache is not affected so we'll see a cache hit there)

1. Call `jitted_f(arr_ns)` with an array on `NamedSharding` and again `jitted_f(arr_ps)` with an array of same shape and dtype but now with `PositionalSharding`
    * This leads to a tracing cache miss because on the second call, the aval has no sharding since it's PositionalSharding. This applies to calling with any sharding other than NamedSharding

2. `jitted_f = jit(f, in_shardings=ns)`. Call `jitted_f(sharded_arr)` and then on the second call you pass a numpy array `jitted_f(numpy_arr)`
   * This also leads to a cache miss because the avals currently don't look at in_shardings because the semantics of in_shardings is complicated and I don't think we should change the aval based on in_shardings.

**The solution in both cases is make sure to pass the array sharded on the same mesh during both calls to jit.**

PiperOrigin-RevId: 728361493
2025-02-18 14:35:14 -08:00
Ayaka
b6361b3e76 Minor format cleanup
Remove 2 redundant whitespaces mentioned in https://github.com/jax-ml/jax/pull/25056#pullrequestreview-2615387492.

PiperOrigin-RevId: 727264168
2025-02-15 04:56:27 -08:00
chaserileyroberts
60f0184637 Added stream annotation support via @compute_on('gpu_stream:#') 2025-02-13 07:15:18 +00:00
Bart Chrzaszcz
6ed4c29c8a #sdy enable test_mem_kind_donation_pinned_host for Shardy.
PiperOrigin-RevId: 725142884
2025-02-10 03:15:25 -08:00
Justin Fu
b6acb9cb7a Fix remat bug on primitives with multiple outputs.
Addresses https://github.com/jax-ml/jax/issues/25841

PiperOrigin-RevId: 715434084
2025-01-14 10:26:58 -08:00
Adam Paszke
ad00ec1dc9 [Mosaic TPU] Guard tests for new features by the libtpu version
PiperOrigin-RevId: 707875450
2024-12-19 05:04:09 -08:00
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Bart Chrzaszcz
6f69774c00 #sdy enable test_compute_offload_mesh_with_linear_layout for Shardy.
PiperOrigin-RevId: 704301465
2024-12-09 08:46:48 -08:00
Peter Hawkins
79318a08cf Remove dead code after minimum jaxlib version bump to v0.4.36.
New minimum xla_extension_version is 299, and the new mlir_api_version is 57.

PiperOrigin-RevId: 704280856
2024-12-09 07:35:05 -08:00
jax authors
f160df0444 More thorough propagation of host linear layout. Currently linear layout on host
can only originate from entry computation. Propagation only goes strickly down/up.
More needs to be done later if such layout can original from host compute itself.
Removed the temporary pattern match solution.

PiperOrigin-RevId: 702966364
2024-12-04 21:06:34 -08:00
jax authors
4d60db1741 Add test_compute_on_host_shared_sharding in memories_test
PiperOrigin-RevId: 698250352
2024-11-19 21:33:27 -08:00
Yash Katariya
87ce0cbb00 Make GPU work with copy=True and device_put since same device pinned_host -> pinned_host copy is possible.
PiperOrigin-RevId: 694713334
2024-11-08 18:29:47 -08:00
Yash Katariya
0bb30f0777 Propagate CopySemantics from python to C++ transfer APIs so that device_put works correctly in presence of copy/donate options that user specified.
This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT.

This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily.

Fixes https://github.com/jax-ml/jax/issues/24521

PiperOrigin-RevId: 694274616
2024-11-07 15:51:54 -08:00
Yash Katariya
6c8e56f43f Finish 0.4.35 release by removing dead code
PiperOrigin-RevId: 689396609
2024-10-24 08:45:43 -07:00
jax authors
dd5426301a Allow simple host call that uses host tensor as parameter/result in
linear layout. This cl only handles very simple host call patterns.
A more thorough implementation of propagation of T(1)S(5) will be done
later.

This cl doesn't handle host call that passes/returns tensors that
live on device with linear layout either, which will also be impelmented
separately.

PiperOrigin-RevId: 687113203
2024-10-17 18:22:46 -07:00
Bart Chrzaszcz
fb32841b1b #sdy add JAX Shardy support for memories.
PiperOrigin-RevId: 684867097
2024-10-11 09:44:24 -07:00
Yash Katariya
ce2b49787f Skip test_ragged_copy_on_host if xla_extension_version < 290
PiperOrigin-RevId: 683326972
2024-10-07 14:30:34 -07:00
jax authors
e90487e906 Host Offloading: Process "MoveToHost" instructions in the order they are executed.
- This ensures we process "MoveToHost" instructions that reside at the beginning of a host memory instruction offload chain.
- This avoids processing MoveToHost instructions out of order, creating invalid instructions within a host memory instruction offload chain.

PiperOrigin-RevId: 682448060
2024-10-04 14:17:36 -07:00
jax authors
291619c291 Allow custom call computations to contain subcomputations
PiperOrigin-RevId: 682429391
2024-10-04 13:22:14 -07:00
Peter Hawkins
d3f63a66b8 Remove code to support jaxlib <= 0.4.33. 2024-10-04 11:39:05 -04:00
jax authors
be76fb6abf Add host compute offload test: test_offload_take_host.
PiperOrigin-RevId: 682088063
2024-10-03 17:06:46 -07:00
Yash Katariya
203cda6f98 Move test_aot_device_implicit_transfer to pjit_test.py
This test is not specific to compute offload and is more relevant to pjit.

PiperOrigin-RevId: 680599882
2024-09-30 09:10:17 -07:00
Jane Liu
57bef447c6 Enable weight offloading tests that are supported on GPUs now 2024-09-26 23:26:27 -07:00
jax authors
5a1549cccf Merge pull request #23853 from zhenying-liu:remat-scan
PiperOrigin-RevId: 679365040
2024-09-26 18:12:30 -07:00
Jane Liu
adaf54a4bb enable the activation offloading test 2024-09-23 23:07:03 -07:00
Yash Katariya
8b5b71750b Fix jaxpr equation context propagation in jaxpr equations when inline=True.
PiperOrigin-RevId: 675754808
2024-09-17 16:40:36 -07:00
Yash Katariya
a144eb234b Add compute_on_context_manager to thread local jit state. This is to avoid getting false cache hits
PiperOrigin-RevId: 671507042
2024-09-05 14:16:13 -07:00
jax authors
97db78ba24 Adds test_compute_offload_with_donation in memories_test
PiperOrigin-RevId: 671410527
2024-09-05 09:57:59 -07:00
Yash Katariya
e1b497078e Rename jtu.create_global_mesh to jtu.create_mesh and use jax.make_mesh inside jtu.create_mesh to get maximum test coverage of the new API.
PiperOrigin-RevId: 670744047
2024-09-03 16:23:07 -07:00
Yash Katariya
969dd89040 Reverts changelist 668370165
PiperOrigin-RevId: 669670355
2024-08-31 08:44:23 -07:00
Yash Katariya
164b884f33 Fix failing tests in CI
PiperOrigin-RevId: 669357019
2024-08-30 09:49:58 -07:00
Yash Katariya
4533aeaf26 Remove jax_enable_memories conditionals from JAX and remove it from tests too.
PiperOrigin-RevId: 662322241
2024-08-12 19:15:43 -07:00
Yash Katariya
c08656c61d [Rollback] We still want to allow multiple meshes in the user program
Reverts dd958adc39550d2758ecdb13809c6d85df7658a2

PiperOrigin-RevId: 661537233
2024-08-09 23:17:46 -07:00
Yash Katariya
abc9ba00e9 Rename count_jit_and_pmap_compiles to count_jit_and_pmap_lowerings
PiperOrigin-RevId: 661496993
2024-08-09 20:03:43 -07:00
Jieying Luo
ccc27a7a5f Remove PJRT version check in memories_test.py that is no longer needed.
0.43 is the version at 2024 Feb. Cloud TPU CI uses 20240228 so it should contain the PJRT C API needed for the test d3b6066f91/.github/workflows/cloud-tpu-ci-nightly.yml (L35).

PiperOrigin-RevId: 660869710
2024-08-08 09:35:41 -07:00
Yash Katariya
dd958adc39 Add mesh_shape to the lowering context. This is to allow custom partitioning to not depend on the mesh context manager to return NamedShardings even if the arguments have NamedShardings on them.
Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions.

Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before`

In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses

PiperOrigin-RevId: 660177423
2024-08-06 18:35:44 -07:00
jax authors
9074e8544f Add test for zero-sized host memory parameter
PiperOrigin-RevId: 660097039
2024-08-06 14:31:41 -07:00
Kanglan Tang
ae541203bc Skip flaky test_weight_offload_with_dp_on_output test on GPU backend.
PiperOrigin-RevId: 660057950
2024-08-06 12:35:53 -07:00
Yash Katariya
489fbc0ed5 Add a test for streaming in closed over constants from host to device
PiperOrigin-RevId: 659711557
2024-08-05 16:00:45 -07:00
Yash Katariya
e6851e6b22 Fix the AOT check for sharding consistency which skipped checking the devices of the sharding.
So before for TPU compiled computation, a user could have passed in a committed array on CPU and JAX wouldn't have errored which is wrong.

This change fixes that. Also `is_equivalent_to` should check for devices, HloSharding and memory_kind (so removing the redundant `memory_kind` check too).

PiperOrigin-RevId: 658794885
2024-08-02 08:15:32 -07:00