47 Commits

Author SHA1 Message Date
Loren Maggiore
208194f9a5 context manager methods for AbstractMesh to appease type checker.
PiperOrigin-RevId: 702890537
2024-12-04 15:58:03 -08:00
Yash Katariya
a735bf83e5 Simply abstract_mesh and device_context context managers and handle everything via their corresponding configs in config.py
PiperOrigin-RevId: 702852769
2024-12-04 14:04:25 -08:00
Yash Katariya
653f65452d Fix the broken behavior of not resetting the abstract_mesh and device_context properly during __exit__.
PiperOrigin-RevId: 702762477
2024-12-04 09:59:23 -08:00
Yash Katariya
0d2dfea4b1 Add a private set_mesh API to enter into sharding_in_types mode. This is how users will enable sharding in types mode (with correct axis types set too but that doesn't work yet).
Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on.

PiperOrigin-RevId: 700537898
2024-11-26 20:01:04 -08:00
Yash Katariya
6763fcfb4e Fix a weird interaction with set_local and empty tuples passed to it.
PiperOrigin-RevId: 700392735
2024-11-26 10:50:05 -08:00
Yash Katariya
627debc78b Create a null_mesh_context internal context manager to handle null contexts properly.
PiperOrigin-RevId: 700167406
2024-11-25 18:32:05 -08:00
Yash Katariya
40fc6598f9 [sharding_in_types] Make flash_attention forward pass in TPU pallas work nicely with sharding in types. Backward pass is still busted which I will fix in follow up CLs.
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.

Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.

PiperOrigin-RevId: 698493184
2024-11-20 13:07:30 -08:00
Yash Katariya
8525ef2b23 [sharding_in_types] Don't emit a wsc under full manual mode to avoid increasing HLO size by a lot
PiperOrigin-RevId: 697048126
2024-11-15 17:42:16 -08:00
Yash Katariya
9a0e9e55d8 [sharding_in_types] Handle collective axes in lowering rules more generally. If any axis is collective, set all dims of aval to unspecified dims in wrap_with_sharding_op.
Also lower shardings with `Collective` axes correctly to HloSharding.

PiperOrigin-RevId: 696703030
2024-11-14 17:32:01 -08:00
Yash Katariya
05716b58b0 [sharding_in_types] Support shard_map with sharding in types. Right now only full manual mode is supported.
This change also adds AxisTypes to Mesh which are `User`, `Auto` and `Collective`.

In the following changes, I'll remove the `config.sharding_in_types` flag and we'll enter into various modes via AxisTypes mentioned on the mesh.

PiperOrigin-RevId: 696559375
2024-11-14 09:58:03 -08:00
Peter Hawkins
0e8acff5c6 Reverts a913fbf2fddc5b8c1b6c85b159d0eeb1bf65d461
PiperOrigin-RevId: 693360032
2024-11-05 08:32:25 -08:00
jax authors
a913fbf2fd rollback due to data race
Reverts ab47d4687f647de3aa145a9a782fb7b4aaf92af4

PiperOrigin-RevId: 693191298
2024-11-04 21:05:33 -08:00
Peter Hawkins
ab47d4687f [JAX] [XLA:Python] Move JAX configuration objects into C++.
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.

There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.

PiperOrigin-RevId: 693114411
2024-11-04 15:39:06 -08:00
Yash Katariya
ca2d1584f8 Remove mesh_utils.create_device_mesh from docs
PiperOrigin-RevId: 687695419
2024-10-19 15:48:42 -07:00
Bart Chrzaszcz
801fe87da6 Do not allow None axis names in meshes.
PiperOrigin-RevId: 686557025
2024-10-16 10:32:25 -07:00
Sharad Vikram
cd78c653e7 [Pallas] Use core_map instead of shard_map for Shmallas
- core_map is like a shard_map but it takes in no inputs and outputs
- we can use it in Pallas to generalize mapping a function over the cores of a chip (e.g. TensorCores in a TPU or SMs in a GPU)
- we specify how the function will be mapped over the device with a `mesh` object. This is also a convenient mechanism for picking the backend for pallas to target

PiperOrigin-RevId: 686036101
2024-10-15 03:26:58 -07:00
Yash Katariya
824ccd7183 [Shardy] Inline meshes when using shardy and get rid of global meshes from the MLIR body.
Also do a couple of cleanups.

PiperOrigin-RevId: 685746298
2024-10-14 10:08:04 -07:00
Yash Katariya
c6f7316d43 Add a private _extremely_unsafe_enter_tracing_context to enter abstractMesh into tracing context. This is a temporary workaround for internal use cases.
PiperOrigin-RevId: 682960902
2024-10-06 14:50:24 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Loren Maggiore
f75c5c6b2d [jax] config option to disable using a mesh as a context manager.
PiperOrigin-RevId: 676475039
2024-09-19 10:42:41 -07:00
jax authors
16eb13e9db Fix empty mesh size and abstract_mesh
* Fix `size` to return 0 rather than 1 for the empty mesh.
* Fix `abstract_mesh` to return an empty abstract mesh.

PiperOrigin-RevId: 665408468
2024-08-20 10:00:37 -07:00
Yash Katariya
9a8f0a67f5 Add a devices property to AbstractMesh but raise an error in it. This is to make pytype happy
PiperOrigin-RevId: 662712450
2024-08-13 17:37:58 -07:00
Yash Katariya
daa69da321 Introduce jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...]) and allow with_sharding_constraint and shard_map to accept an abstract mesh as input (with_sharding_constraint is via NamedSharding(abstract_mesh, pspec)).
**Semantics**

Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).

Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.

**Why do this?**

There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.

So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

@jax.jit
def f(x):
  y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # DEVICE MISMATCH ERROR!
```

The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.

**Okay, so how do you fix this?**

As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)

The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.

**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)

@jax.jit
def f(x):
  y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # tracing and lowering cache hit
```

**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**

**What about `shard_map`?**

shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)

@jax.jit
def f(x):
  y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # tracing and lowering cache hit
```

This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!

PiperOrigin-RevId: 662670932
2024-08-13 15:18:08 -07:00
Yash Katariya
7de3c06147 Delete mesh.Loop now that xmap has been deleted
PiperOrigin-RevId: 656084608
2024-07-25 14:08:32 -07:00
Sharad Vikram
ae8da83357 Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
  mesh = pltpu.create_tensorcore_mesh('core')
  y = jnp.zeros_like(x)
  @state_discharge.run_state
  def inner(refs):
    x_ref, y_ref = refs
    def kernel():
      def alloc(sem):
        pltpu.async_copy(x_ref, y_ref, sem).wait()
      pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
    shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
                        check_rep=False)()
  _, y = inner((x, y))
  return y
```

Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch

This change allows you to express pallas_call *compositionally* using existing APIs.

1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA

The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.

PiperOrigin-RevId: 655320587
2024-07-23 15:16:50 -07:00
Bart Chrzaszcz
864178d3a3 #sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.

Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations

The following test:

```py
def test_sdy_lowering(self):
  mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
  np_inp = np.arange(16).reshape(8, 2)
  s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
  arr = jax.device_put(np_inp, s)

  @partial(jax.jit, out_shardings=s)
  def f(x):
    return x * 2

  print(f.lower(arr).as_text())
```

outputs:

```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <"x"=4, "y"=2>
  func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
    %c = stablehlo.constant dense<2> : tensor<i64>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
    %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
    return %1 : tensor<8x2xi64>
  }
}
```

Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.

PiperOrigin-RevId: 655127611
2024-07-23 05:32:06 -07:00
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
Yash Katariya
6c34a56b87 Add util.cache to jax.clear_caches and move pjit, sharding, array, etc uses of functools.lru_cache to util.cache so that those caches will be cleared if jax.clear_caches is called.
PiperOrigin-RevId: 642359226
2024-06-11 12:46:47 -07:00
Yash Katariya
72f00ebaec Add __str__ to Mesh so that in jaxprs the mesh doesn't print all the device ids.
PiperOrigin-RevId: 599568637
2024-01-18 11:23:25 -08:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
Matthew Johnson
64cb53f624 improve an error message during Mesh creation 2023-12-06 16:43:36 -08:00
Tom Hennigan
1b504bb68e Allow threads to race setting attributes on Mesh.
PiperOrigin-RevId: 584602313
2023-11-22 05:47:56 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Yash Katariya
8ee58117e2 Don't print all the devices in the mesh during ResourceEnv's repr. Just print the mesh shape.
PiperOrigin-RevId: 577305337
2023-10-27 14:25:34 -07:00
Jake VanderPlas
f1fc2adfbd Fix mypy error 2023-08-29 13:25:12 -07:00
Yash Katariya
6072d5993e Any devices passed to jax.sharding.Mesh are required to be hashable.
This is true for mock devices or user specific devices and jax.devices() too.

Fix the tests so that the mock devices are hashable.

PiperOrigin-RevId: 561103167
2023-08-29 12:20:54 -07:00
Yash Katariya
a37e2159b3 Don't drop out of C++ fast path if mesh pointers are not equal.
This is done by returning the same object when constructing mesh if devices.shape, axis_names and flat device list matches.

PiperOrigin-RevId: 560828993
2023-08-28 15:04:05 -07:00
Yash Katariya
242c2c1b52 Use _internal_device_list in __hash__ and __eq__ of Shardings and Mesh to speed them up.
PiperOrigin-RevId: 557665385
2023-08-16 18:41:58 -07:00
Hyeontaek Lim
423c8d8d4f [JAX] Use DeviceList in JAX Sharding implementations
XLA-compatible `Sharding` implementations keep a `DeviceList` object as
`_internal_device_list`. This is used for finding the default memory kind more
quickly in C++, and enables caching of the default memory kind between multiple
`NamedSharding` objects that shares the same `Mesh`. Also it uses an
addressable device within `DeviceList`, which will be required for supporting
multiple device types with different default memory kinds.

PiperOrigin-RevId: 556969789
2023-08-14 18:11:23 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Yash Katariya
fdbad53b15 Make _device_assignment a Tuple[Device] so that we don't convert a list to a tuple and vice-versa everywhere
PiperOrigin-RevId: 524002310
2023-04-13 08:03:27 -07:00
Yash Katariya
a3ce08cf1d Override addressable_devices for NamedSharding since the mesh can be the same throughout the program.
PiperOrigin-RevId: 522677209
2023-04-07 13:54:37 -07:00
Peter Hawkins
0f368e4428 Cache __repr__ and device_ids properties on Mesh.
PiperOrigin-RevId: 522653188
2023-04-07 12:12:14 -07:00
Yash Katariya
8838039287 Override is_fully_addressable() for NamedSharding.
The intent of this change is to speed up is_fully_addressable() when computing it repeatedly over the same mesh.

PiperOrigin-RevId: 522500766
2023-04-06 19:46:29 -07:00
Peter Hawkins
623282715d Split Mesh and ResourceEnv into a new module jax._src.mesh.
This work is an effort to reduce cyclic dependencies in JAX internals.

Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.

PiperOrigin-RevId: 515667671
2023-03-10 10:08:21 -08:00