Add and test support for partitioning of batch dimensions in lax.linalg.

On CPU and GPU, almost all of the primitives in lax.linalg are backed by custom calls that support simple semantics when batch dimensions are sharded. Before this change, all linalg operations on CPU and GPU will insert an `all-gather` before being executed when called on sharded inputs, even when that shouldn't be necessary. This change adds support for this type of partitioning, to cover a wide range of use cases.

There are a few remaining GPU ops that don't support partitioning either because they are backed by HLO ops that don't partition properly (Cholesky factorization and triangular solves), or because they're still using descriptors with problem dimensions in kernel. I'm going to fix these in follow up changes.

PiperOrigin-RevId: 731732301
This commit is contained in:
Dan Foreman-Mackey 2025-02-27 08:15:45 -08:00 committed by jax authors
parent de4d047852
commit f93c2a1aa5
6 changed files with 292 additions and 20 deletions

View File

@ -18,6 +18,7 @@ from collections.abc import Callable
import enum
from functools import partial
import math
import string
from typing import Any, Literal, overload
import numpy as np
@ -31,6 +32,8 @@ from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src.core import ShapedArray, is_constant_dim, is_constant_shape
from jax._src.custom_partitioning_sharding_rule import (
sdy_sharding_rule_to_mlir, str_to_sdy_sharding_rule)
from jax._src import ffi
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -1643,7 +1646,8 @@ def _lu_pivots_to_permutation_gpu_lowering(ctx, pivots, *,
permutation_size,
target_name_prefix):
del permutation_size # unused
rule = ffi.ffi_lowering(f"{target_name_prefix}_lu_pivots_to_permutation")
rule = _linalg_ffi_lowering(f"{target_name_prefix}_lu_pivots_to_permutation",
num_non_batch_dims=1, column_major=False)
return rule(ctx, pivots)
@ -2395,12 +2399,15 @@ def _triangular_solve_cpu_lower(
if ctx.is_forward_compat() or jaxlib_version <= (0, 5, 1):
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)),
alpha_aval = ShapedArray((), a_aval.dtype),
batch_partitionable = False
else:
alpha = ()
alpha_aval = ()
batch_partitionable = True
rule = _linalg_ffi_lowering(target_name,
[a_aval, b_aval, *alpha_aval],
operand_output_aliases={1: 0})
operand_output_aliases={1: 0},
batch_partitionable=batch_partitionable)
return rule(ctx, a, b, *alpha,
side=_matrix_side_attr(left_side),
uplo=_matrix_uplo_attr(lower),
@ -2740,14 +2747,38 @@ def _column_major_matrix_layout(dim: int) -> tuple[int, ...]:
# The layout for a batch of matrices with Fortran order.
return (dim - 2, dim - 1) + tuple(range(dim - 3, -1, -1))
def _sdy_rule_for_aval(letters, num_batch_dims, aval):
return " ".join(
("...", *(next(letters) for _ in range(len(aval.shape) - num_batch_dims)))
)
def _build_sdy_sharding_rule(num_batch_dims, avals_in, avals_out):
letters = iter(string.ascii_letters)
lhs = ", ".join(
_sdy_rule_for_aval(letters, num_batch_dims, a) for a in avals_in)
rhs = ", ".join(
_sdy_rule_for_aval(letters, num_batch_dims, a) for a in avals_out)
sdy_sharding_rule = str_to_sdy_sharding_rule(f"{lhs} -> {rhs}")
return sdy_sharding_rule_to_mlir(
sdy_sharding_rule,
[mlir.aval_to_ir_type(a) for a in avals_in],
[mlir.aval_to_ir_type(a) for a in avals_out])
def _linalg_ffi_lowering(target_name, avals_in=None, avals_out=None,
operand_output_aliases=None, column_major=True):
operand_output_aliases=None, column_major=True,
num_non_batch_dims=2, batch_partitionable=True):
# A lightweight wrapper around ffi.ffi_lowering that can automatically set
# the layouts appropriately for column-major matrices, which most handlers
# used here will expect.
def rule(ctx, *args, **kwargs):
avals_in_ = ctx.avals_in if avals_in is None else avals_in
avals_out_ = ctx.avals_out if avals_out is None else avals_out
# TODO(danfm): Add support for shape polymorphism and batch partitioning.
has_dynamic_shape = any(
not is_constant_shape(aval.shape) for aval in (*avals_in_, *avals_out_))
batch_partitionable_ = batch_partitionable and not has_dynamic_shape
max_num_dims = max(len(v.shape) for v in avals_in_)
ctx = ctx.replace(avals_in=avals_in_, avals_out=avals_out_)
operand_layouts = [
@ -2758,8 +2789,18 @@ def _linalg_ffi_lowering(target_name, avals_in=None, avals_out=None,
_column_major_matrix_layout(len(aval.shape))
if column_major and len(aval.shape) == max_num_dims else None
for aval in avals_out_]
num_batch_dims = max_num_dims - num_non_batch_dims
frontend_attrs = mlir.ir_attribute({"num_batch_dims": str(num_batch_dims)})
if batch_partitionable_:
extra_attributes = {"mhlo.frontend_attributes": frontend_attrs}
if config.use_shardy_partitioner.value:
extra_attributes["sdy.sharding_rule"] = _build_sdy_sharding_rule(
num_batch_dims, avals_in_, avals_out_)
else:
extra_attributes = None
rule = ffi.ffi_lowering(target_name, operand_layouts=operand_layouts,
result_layouts=result_layouts,
operand_output_aliases=operand_output_aliases)
operand_output_aliases=operand_output_aliases,
extra_attributes=extra_attributes)
return rule(ctx, *args, **kwargs)
return rule

View File

@ -38,20 +38,16 @@ for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
if _cuda_linalg:
for _name, _value in _cuda_linalg.registrations().items():
# TODO(danfm): remove after JAX 0.5.1 release
api_version = (1
if _name.endswith("lu_pivots_to_permutation")
or _name.endswith("_ffi") else 0)
xla_client.register_custom_call_target(
_name, _value, platform="CUDA", api_version=api_version
_name, _value, platform="CUDA", api_version=1
)
xla_client.register_custom_call_as_batch_partitionable(
"cu_lu_pivots_to_permutation")
if _hip_linalg:
for _name, _value in _hip_linalg.registrations().items():
# TODO(danfm): remove after JAX 0.5.1 release
api_version = (1
if _name.endswith("lu_pivots_to_permutation")
or _name.endswith("_ffi") else 0)
xla_client.register_custom_call_target(
_name, _value, platform="ROCM", api_version=api_version
_name, _value, platform="ROCM", api_version=1
)
xla_client.register_custom_call_as_batch_partitionable(
"hip_lu_pivots_to_permutation")

View File

@ -44,7 +44,10 @@ for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
if _cusolver:
for _name, _value in _cusolver.registrations().items():
# TODO(danfm): Clean up after all legacy custom calls are ported.
api_version = 1 if _name.endswith("_ffi") else 0
api_version = 0
if _name.endswith("_ffi"):
api_version = 1
xla_client.register_custom_call_as_batch_partitionable(_name)
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=api_version)
@ -60,6 +63,7 @@ for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
if _cuhybrid:
for _name, _value in _cuhybrid.registrations().items():
xla_client.register_custom_call_as_batch_partitionable(_name)
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=1)
@ -91,7 +95,10 @@ for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
if _hipsolver:
for _name, _value in _hipsolver.registrations().items():
# TODO(danfm): Clean up after all legacy custom calls are ported.
api_version = 1 if _name.endswith("_ffi") else 0
api_version = 0
if _name.endswith("_ffi"):
api_version = 1
xla_client.register_custom_call_as_batch_partitionable(_name)
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
api_version=api_version)
@ -107,6 +114,7 @@ for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
if _hiphybrid:
for _name, _value in _hiphybrid.registrations().items():
xla_client.register_custom_call_as_batch_partitionable(_name)
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
api_version=1)

View File

@ -21,11 +21,12 @@ from .cpu._lapack import eig
from .cpu._lapack import schur
for _name, _value in _lapack.registrations().items():
api_version = 0
if _name.endswith("_ffi"):
api_version = 1
xla_client.register_custom_call_as_batch_partitionable(_name)
xla_client.register_custom_call_target(
_name,
_value,
platform="cpu",
api_version=(1 if _name.endswith("_ffi") else 0),
_name, _value, platform="cpu", api_version=api_version
)

View File

@ -698,6 +698,22 @@ jax_multiplatform_test(
},
)
jax_multiplatform_test(
name = "linalg_sharding_test",
srcs = ["linalg_sharding_test.py"],
enable_backends = [
"cpu",
],
enable_configs = [
"gpu_p100x2",
"gpu_p100x2_shardy",
"gpu_p100x2_pjrt_c_api",
],
tags = [
"multiaccelerator",
],
)
jax_multiplatform_test(
name = "magma_linalg_test",
srcs = ["magma_linalg_test.py"],

View File

@ -0,0 +1,210 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from absl.testing import absltest
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
from jax._src import config
from jax._src import test_util as jtu
from jax.sharding import PartitionSpec as P
config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex
CPU_ONLY_FUN_AND_SHAPES = [
# These functions are supported on GPU, but partitioning support will
# require updates to GSPMD, since they are lowered directly to HLO ops
# instead of custom calls on GPU.
(lax.linalg.cholesky, ((6, 6),)),
(lax.linalg.triangular_solve, ((6, 6), (4, 6))),
# The GPU kernel for this function still uses an opaque descriptor to
# encode the input shapes so it is not partitionable.
# TODO(danfm): Update the kernel and enable this test on GPU.
(lax.linalg.tridiagonal_solve, ((6,), (6,), (6,), (6, 4))),
# These functions are only supported on CPU.
(lax.linalg.hessenberg, ((6, 6),)),
(lax.linalg.schur, ((6, 6),)),
]
CPU_AND_GPU_FUN_AND_SHAPES = [
(lax.linalg.eig, ((6, 6),)),
(lax.linalg.eigh, ((6, 6),)),
(lax.linalg.lu, ((10, 6),)),
(lax.linalg.qr, ((6, 6),)),
(lax.linalg.svd, ((10, 6),)),
(lax.linalg.tridiagonal, ((6, 6),)),
]
ALL_FUN_AND_SHAPES = CPU_ONLY_FUN_AND_SHAPES + CPU_AND_GPU_FUN_AND_SHAPES
class LinalgShardingTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jax.device_count() < 2:
self.skipTest("Requires multiple devices")
def get_fun_and_shapes(self, fun_and_shapes, grad=False):
if (jtu.test_device_matches(["gpu"])
and fun_and_shapes not in CPU_AND_GPU_FUN_AND_SHAPES):
self.skipTest(f"{fun_and_shapes[0].__name__} not supported on GPU")
if not grad:
return fun_and_shapes
fun, shapes = fun_and_shapes
if fun in (lax.linalg.schur, lax.linalg.hessenberg, lax.linalg.tridiagonal):
self.skipTest(f"{fun.__name__} does not support differentation")
if jtu.test_device_matches(["gpu"]) and fun in (
lax.linalg.eig, lax.linalg.lu, lax.linalg.qr
):
self.skipTest(
f"JVP of {fun.__name__} uses triangular solve on GPU, which doesn't "
"support batch partitioning yet")
if fun == lax.linalg.eig:
fun = functools.partial(
fun,
compute_left_eigenvectors=False,
compute_right_eigenvectors=False,
)
if fun == lax.linalg.svd:
fun = functools.partial(fun, full_matrices=False)
return fun, shapes
def get_args(self, shapes, dtype, batch_size=None):
rng = jtu.rand_default(self.rng())
def arg_maker(shape):
if batch_size is not None:
x = rng((batch_size, *shape), dtype)
else:
x = rng(shape, dtype)
if len(shape) == 2 and shape[0] == shape[1]:
x = np.matmul(x, np.swapaxes(np.conj(x), -1, -2))
return x
return tuple(arg_maker(shape) for shape in shapes)
@jtu.sample_product(
fun_and_shapes=ALL_FUN_AND_SHAPES,
dtype=float_types + complex_types,
)
@jtu.run_on_devices("gpu", "cpu")
def test_batch_axis_sharding(self, fun_and_shapes, dtype):
fun, shapes = self.get_fun_and_shapes(fun_and_shapes)
args = self.get_args(shapes, dtype, batch_size=8)
mesh = jtu.create_mesh((2,), ("i",))
sharding = jax.NamedSharding(mesh, P("i"))
args_sharded = jax.device_put(args, sharding)
fun_jit = jax.jit(fun)
expected = fun(*args)
actual = fun_jit(*args_sharded)
self.assertAllClose(actual, expected)
self.assertNotIn("all-", fun_jit.lower(*args_sharded).compile().as_text())
vmap_fun = jax.vmap(fun)
vmap_fun_jit = jax.jit(vmap_fun)
actual = vmap_fun_jit(*args_sharded)
self.assertAllClose(actual, expected)
self.assertNotIn(
"all-", vmap_fun_jit.lower(*args_sharded).compile().as_text())
@jtu.sample_product(
fun_and_shapes=ALL_FUN_AND_SHAPES,
dtype=float_types + complex_types,
)
@jtu.run_on_devices("gpu", "cpu")
def test_non_batch_axis_sharding(self, fun_and_shapes, dtype):
fun, shapes = self.get_fun_and_shapes(fun_and_shapes)
args = self.get_args(shapes, dtype)
mesh = jtu.create_mesh((2,), ("i",))
sharding = jax.NamedSharding(mesh, P("i"))
args_sharded = jax.device_put(args, sharding)
fun_jit = jax.jit(fun)
expected = fun(*args)
actual = fun_jit(*args_sharded)
self.assertAllClose(actual, expected)
self.assertIn(
"all-gather", fun_jit.lower(*args_sharded).compile().as_text())
@jtu.sample_product(
fun_and_shapes=ALL_FUN_AND_SHAPES,
dtype=float_types + complex_types,
)
@jtu.run_on_devices("gpu", "cpu")
def test_batch_axis_sharding_jvp(self, fun_and_shapes, dtype):
fun, shapes = self.get_fun_and_shapes(fun_and_shapes, grad=True)
primals = self.get_args(shapes, dtype, batch_size=8)
tangents = tuple(map(jnp.ones_like, primals))
def jvp_fun(primals, tangents):
return jax.jvp(fun, primals, tangents)
mesh = jtu.create_mesh((2,), ("i",))
sharding = jax.NamedSharding(mesh, P("i"))
primals_sharded = jax.device_put(primals, sharding)
tangents_sharded = jax.device_put(tangents, sharding)
jvp_fun_jit = jax.jit(jvp_fun)
_, expected = jvp_fun(primals, tangents)
for args in [
(primals_sharded, tangents_sharded),
(primals, tangents_sharded),
(primals_sharded, tangents),
]:
_, actual = jvp_fun_jit(*args)
self.assertAllClose(actual, expected)
hlo = jvp_fun_jit.lower(primals_sharded, tangents_sharded).compile()
self.assertNotIn("all-", hlo.as_text())
@jtu.sample_product(
fun_and_shapes=ALL_FUN_AND_SHAPES,
dtype=float_types + complex_types,
)
@jtu.run_on_devices("gpu", "cpu")
def test_batch_axis_sharding_vjp(self, fun_and_shapes, dtype):
fun, shapes = self.get_fun_and_shapes(fun_and_shapes, grad=True)
primals = self.get_args(shapes, dtype, batch_size=8)
out, vjp_fun = jax.vjp(fun, *primals)
tangents = jax.tree.map(jnp.ones_like, out)
mesh = jtu.create_mesh((2,), ("i",))
sharding = jax.NamedSharding(mesh, P("i"))
tangents_sharded = jax.device_put(tangents, sharding)
vjp_fun_jit = jax.jit(vjp_fun)
expected = vjp_fun(tangents)
actual = vjp_fun_jit(tangents_sharded)
self.assertAllClose(actual, expected)
hlo = vjp_fun_jit.lower(tangents_sharded).compile()
self.assertNotIn("all-", hlo.as_text())
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())