2022-09-27 10:06:10 -07:00
|
|
|
# Copyright 2022 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2022-12-14 15:07:04 -08:00
|
|
|
# Note: import <name> as <name> is required for names to be exported.
|
|
|
|
# See PEP 484 & https://github.com/google/jax/issues/7570
|
|
|
|
|
2023-03-13 08:49:39 -07:00
|
|
|
from jax._src.sharding import Sharding as Sharding
|
|
|
|
from jax._src.sharding_impls import (
|
2024-06-05 09:06:36 -07:00
|
|
|
XLACompatibleSharding as _deprecated_XLACompatibleSharding,
|
2022-11-02 19:12:32 -07:00
|
|
|
NamedSharding as NamedSharding,
|
2022-09-27 10:06:10 -07:00
|
|
|
SingleDeviceSharding as SingleDeviceSharding,
|
|
|
|
PmapSharding as PmapSharding,
|
2023-02-17 17:10:27 -08:00
|
|
|
GSPMDSharding as GSPMDSharding,
|
2022-11-02 19:12:32 -07:00
|
|
|
PositionalSharding as PositionalSharding,
|
2022-09-27 10:06:10 -07:00
|
|
|
)
|
2023-04-06 11:42:45 -07:00
|
|
|
from jax._src.partition_spec import (
|
|
|
|
PartitionSpec as PartitionSpec,
|
|
|
|
)
|
2023-02-09 05:47:59 -08:00
|
|
|
from jax._src.interpreters.pxla import Mesh as Mesh
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
from jax._src.mesh import AbstractMesh
|
2024-06-05 09:06:36 -07:00
|
|
|
|
|
|
|
_deprecations = {
|
|
|
|
# Added Jun 4, 2024.
|
|
|
|
"XLACompatibleSharding": (
|
|
|
|
(
|
|
|
|
"jax.sharding.XLACompatibleSharding is deprecated. Use"
|
|
|
|
" jax.sharding.Sharding instead."
|
|
|
|
),
|
|
|
|
_deprecated_XLACompatibleSharding,
|
|
|
|
)
|
|
|
|
}
|
|
|
|
|
|
|
|
import typing
|
|
|
|
if typing.TYPE_CHECKING:
|
|
|
|
XLACompatibleSharding = _deprecated_XLACompatibleSharding
|
|
|
|
else:
|
|
|
|
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
|
|
|
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
|
|
|
del _deprecation_getattr
|
|
|
|
del typing
|