Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on.
PiperOrigin-RevId: 700537898
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.
Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.
PiperOrigin-RevId: 698493184
This change also adds AxisTypes to Mesh which are `User`, `Auto` and `Collective`.
In the following changes, I'll remove the `config.sharding_in_types` flag and we'll enter into various modes via AxisTypes mentioned on the mesh.
PiperOrigin-RevId: 696559375
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.
There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.
PiperOrigin-RevId: 693114411
- core_map is like a shard_map but it takes in no inputs and outputs
- we can use it in Pallas to generalize mapping a function over the cores of a chip (e.g. TensorCores in a TPU or SMs in a GPU)
- we specify how the function will be mapped over the device with a `mesh` object. This is also a convenient mechanism for picking the backend for pallas to target
PiperOrigin-RevId: 686036101
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
This is true for mock devices or user specific devices and jax.devices() too.
Fix the tests so that the mock devices are hashable.
PiperOrigin-RevId: 561103167
This is done by returning the same object when constructing mesh if devices.shape, axis_names and flat device list matches.
PiperOrigin-RevId: 560828993
XLA-compatible `Sharding` implementations keep a `DeviceList` object as
`_internal_device_list`. This is used for finding the default memory kind more
quickly in C++, and enables caching of the default memory kind between multiple
`NamedSharding` objects that shares the same `Mesh`. Also it uses an
addressable device within `DeviceList`, which will be required for supporting
multiple device types with different default memory kinds.
PiperOrigin-RevId: 556969789
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
This work is an effort to reduce cyclic dependencies in JAX internals.
Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.
PiperOrigin-RevId: 515667671