Some caveats of enabling sharding-in-types by default are that we'll see tracing cache misses which will lead to lowering cache miss and compilation cache misses in the **following cases**: (but persistent compilation cache is not affected so we'll see a cache hit there)
1. Call `jitted_f(arr_ns)` with an array on `NamedSharding` and again `jitted_f(arr_ps)` with an array of same shape and dtype but now with `PositionalSharding`
* This leads to a tracing cache miss because on the second call, the aval has no sharding since it's PositionalSharding. This applies to calling with any sharding other than NamedSharding
2. `jitted_f = jit(f, in_shardings=ns)`. Call `jitted_f(sharded_arr)` and then on the second call you pass a numpy array `jitted_f(numpy_arr)`
* This also leads to a cache miss because the avals currently don't look at in_shardings because the semantics of in_shardings is complicated and I don't think we should change the aval based on in_shardings.
**The solution in both cases is make sure to pass the array sharded on the same mesh during both calls to jit.**
PiperOrigin-RevId: 728361493
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.
Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
can only originate from entry computation. Propagation only goes strickly down/up.
More needs to be done later if such layout can original from host compute itself.
Removed the temporary pattern match solution.
PiperOrigin-RevId: 702966364
This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT.
This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily.
Fixes https://github.com/jax-ml/jax/issues/24521
PiperOrigin-RevId: 694274616
linear layout. This cl only handles very simple host call patterns.
A more thorough implementation of propagation of T(1)S(5) will be done
later.
This cl doesn't handle host call that passes/returns tensors that
live on device with linear layout either, which will also be impelmented
separately.
PiperOrigin-RevId: 687113203
- This ensures we process "MoveToHost" instructions that reside at the beginning of a host memory instruction offload chain.
- This avoids processing MoveToHost instructions out of order, creating invalid instructions within a host memory instruction offload chain.
PiperOrigin-RevId: 682448060
Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions.
Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before`
In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses
PiperOrigin-RevId: 660177423
So before for TPU compiled computation, a user could have passed in a committed array on CPU and JAX wouldn't have errored which is wrong.
This change fixes that. Also `is_equivalent_to` should check for devices, HloSharding and memory_kind (so removing the redundant `memory_kind` check too).
PiperOrigin-RevId: 658794885
If an explicit copy operation was present in a function offload to the host, the adding of compute annotations during lowering failed since the operation lowered to a block argument and the "owner" of the block argument is the block - not an operation. Fixed by filtering out this case when adding compute annotations.
PiperOrigin-RevId: 653151844
For example, if you have 2 `lax.linalg.qr` calls (one on `TPU` and another on `device_host`), we should lower to the `device_host` qr decomposition to CPU.
PiperOrigin-RevId: 647705015