This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.
Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on.
PiperOrigin-RevId: 700537898
* Set abstract_mesh context manager during pjit_p.bind at the top level too since scan builds jaxpr during it's lowering in `_scan_impl` (do the same for AOT path)
* Set the abstract mesh only once if it's not set. Don't override an already set context. This means that only top level jit sets the context manager.
* Add dynamic_slice and dynamic_update_slice sharding rules since scan calls into them.
* scan only allows `xs` where the 0th dim is full replicated i.e. None.
PiperOrigin-RevId: 699014167
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.
Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.
PiperOrigin-RevId: 698493184
When we have a cache miss in `_cpp_pjit` we want to compile the function and
store the executable. Previously we had a roundabout way of getting hold of that
executable. We'd trace the function to a jaxpr but we wouldn't lower and compile
it ourselves. Instead, we'd call `pjit_p.bind`. The layers of the tracing onion
would be peeled off and eventually we'd hit the `pjit_p` impl rule,
`_pjit_call_impl`. This rule has its own cache. With luck we'd also miss *that*
cache, and then `_pjit_call_impl` would lower and compile the jaxpr and store
the executable in `most_recent_pjit_call_executable`. We'd eventually pop the
stack back up to the `_cpp_pjit` cache miss and then we'd get hold of the
compiled object by looking up `most_recent_pjit_call_executable`.
There's room for bugs here if we hit one cache but not the other. For example,
if we miss the `_cpp_pjit` cache but we hit the `_pjit_call_impl` cache then we
won't compile the executable. Normally that would just mean that the `_cpp_pjit`
cache won't be populated. But if we've previously hit a function with the same
jaxpr but slightly different compilation parameters (e.g. device IDs) then we'll
get a bogus hit in `most_recent_call_exectuable` and we'll add an incorrect
cache entry. The divergent cache behavior you need to trigger this started
happening with the "stackless" change because the tracing context became a
bigger part of the cache key and `_cpp_pjit` and `_pjit_call_impl` will in
general have different tracing contexts.
With this change, we remove the whole `most_recent_pjit_call_executable` system.
Instead `_cpp_pjit` lowers, compiles and runs the jaxpr itself and obtains the
executable directly rather than calling into `pjit_p.bind`. We do call into
`pjit_p.bind` if we're not in an eval context, but in that case we don't expect
to be able to populate the `_cpp_pjit` cache anyway.
Performance wise, we should be at parity, although this has not yet been tested.
Authoring wise, the new kernel is significantly smaller and simpler to write.
A major known limitation of this approach, which we have a plan to fix, is the invariant that the `seq_len % grid_size == 0` - we plan to relax this limitation in following CLs.
PiperOrigin-RevId: 689868468
This allows us to get more cache hits globally. For example:
Before:
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr) # cpp cache miss
After:
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr) # cpp cache hit
Reverts b615266175effe4aefeb903620a19f3719a604da
PiperOrigin-RevId: 675746175
This fixes a tracing cache miss issue when you eval shape with a weak_type input and get a strong type output back and pass that back in leading to a cache miss.
Fixes: https://github.com/google/jax/issues/23302
PiperOrigin-RevId: 668949430
This massively simplifies the amount of checks we need and improves dispatch time too. It also fixes a donation bug being hit in serving code related to layouts and non-standardization of default layout in JAX.
PiperOrigin-RevId: 668527139
This allows us to get more cache hits globally. For example:
Before:
```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr) # cpp cache miss
```
After:
```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr) # cpp cache hit
```
Also, we can remove the hack (which I didn't like) in multihost_utils.py.
PiperOrigin-RevId: 665574475
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.
Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
```
with mesh:
out = pjit(lambda: 1)()
```
The sharding of `out` was a `GSPMDSharding` which is not ideal. This change fixes that and returns a `NamedSharding` instead.
This is also required for `Shardy` integration.
PiperOrigin-RevId: 658842350