[xla:python] Add a mechanism for "batch partitioning" of FFI calls.

This is the first in a series of changes to add a simple API for supporting a set of common sharding and partitioning patterns for FFI calls. The high level motivation is that custom calls (including FFI calls) are opaque to the SPMD partitioner, and the only ways to customize the partitioning behavior is to (a) explicitly register an `xla::CustomCallPartitoner` with XLA, or (b) use the `jax.experimental.custom_partitioning` APIs. Option (a) isn't generally practical for most use cases where the FFI handler lives in an external binary. Option (b) is flexible, and supports all common use cases, but it requires embedding Python callbacks in to the HLO, which can lead to issues including cache misses. Furthermore, `custom_partitioning` is overpowered for many use cases, where only (what I will call) "batch partitioning" is supported.

In this case, "batch partitioning" refers to the behavior of many FFI calls where they can be trivially partitioned on some number of (leading) dimensions, with the same call being executed independently on each shard of data. If the data are sharded on non-batch dimensions, partitioning will still re-shard the data to be replicated on the non-batch dimensions. This kind of partitioning logic applies to all the LAPACK/cuSOLVER/etc.-backed linear algebra functions in jaxlib, as well as some external users of `custom_partitioning`.

The approach I'm taking here is to add a new registration function to the XLA client, which let's a user label their FFI call as batch partitionable. Then, when lowering the custom call, the user passes the number of batch dimensions as a frontend attribute, which is then interpreted by the SPMD partitioner.

In parallel with this change, shardy has added support for sharding propagation across custom calls using a string representation that is similar in spirit to this approach, but somewhat more general. However, the shardy implementation still requires a Python callback for the partitioning step, so it doesn't (yet!) solve all of the relevant problems with the `custom_partitioning` approach. Ultimately, it should be possible to have the partitioner parse the shardy sharding rule representation, but I wanted to start with the minimal implementation.

PiperOrigin-RevId: 724367877
This commit is contained in:
Dan Foreman-Mackey 2025-02-07 09:13:34 -08:00 committed by jax authors
parent 5bc17f7ec3
commit c521bc6205
4 changed files with 109 additions and 5 deletions

View File

@ -27,6 +27,7 @@ from jax._src import deprecations
from jax._src import dispatch
from jax._src import effects
from jax._src import util
from jax._src import xla_bridge
from jax._src.callback import callback_batching_rule
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -87,6 +88,18 @@ def register_ffi_type_id(
return xla_client.register_custom_type_id(name, obj, platform=platform)
def register_ffi_target_as_batch_partitionable(name: str) -> None:
"""Registers an FFI target as batch partitionable.
Args:
name: the name of the target.
"""
xla_client.register_custom_call_as_batch_partitionable(name)
xla_bridge.register_plugin_callbacks(
functools.partial(xla_client.register_custom_call_as_batch_partitionable,
name))
def pycapsule(funcptr):
"""Wrap a ctypes function pointer in a PyCapsule.

View File

@ -22,4 +22,5 @@ from jax._src.ffi import (
pycapsule as pycapsule,
register_ffi_target as register_ffi_target,
register_ffi_type_id as register_ffi_type_id,
register_ffi_target_as_batch_partitionable as register_ffi_target_as_batch_partitionable,
)

View File

@ -161,6 +161,9 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "ffi_test",
srcs = ["ffi_test.py"],
enable_configs = [
"gpu_p100x2",
],
# TODO(dfm): Remove after removal of jex.ffi imports.
deps = ["//jax:extend"],
)

View File

@ -24,19 +24,21 @@ import jax
from jax import lax
import jax.extend as jex
import jax.numpy as jnp
import jax.sharding as shd
from jax.sharding import PartitionSpec as P
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout
from jax._src.lib import lapack
from jax._src.lib import lapack, xla_extension_version
from jax._src.lib.mlir.dialects import hlo
from jax._src.lax import linalg as lax_linalg_internal
from jax.experimental.shard_map import shard_map
jax.config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
class FfiTest(jtu.JaxTestCase):
@ -282,11 +284,10 @@ class FfiTest(jtu.JaxTestCase):
@jtu.run_on_devices("gpu", "cpu")
def test_shard_map(self):
mesh = jtu.create_mesh((1,), ("i",))
mesh = jtu.create_mesh((len(jax.devices()),), ("i",))
x = self.rng().randn(8, 4, 5).astype(np.float32)
@partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'),
out_specs=shd.PartitionSpec('i'))
@partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i"))
def f(x):
return ffi_call_geqrf(x)
@ -328,5 +329,91 @@ def ffi_call_geqrf(x, _use_extend=False, **kwargs):
cuda=partial(call, "cuda"))
class BatchPartitioningTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_extension_version < 312:
self.skipTest("Requires XLA extension version >= 312")
if jax.device_count() < 2:
self.skipTest("Requires multiple devices")
if jtu.test_device_matches(["cpu"]):
lapack._lapack.initialize()
for target_name in ["lapack_sgeqrf_ffi", "cusolver_geqrf_ffi",
"hipsolver_geqrf_ffi"]:
jax.ffi.register_ffi_target_as_batch_partitionable(target_name)
@jtu.run_on_devices("gpu", "cpu")
def test_shard_map(self):
mesh = jtu.create_mesh((len(jax.devices()),), ("i",))
x = self.rng().randn(8, 4, 5).astype(np.float32)
@partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i"),
check_rep=False)
def f(x):
return batch_partitionable_ffi_call(x)
f(x) # eager mode doesn't crash
jax.jit(f)(x) # neither does JIT
self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text())
@jtu.run_on_devices("gpu", "cpu")
def test_batch_partitioning(self):
if config.use_shardy_partitioner.value:
self.skipTest(
"Shardy does not yet support batch partitioning of FFI calls.")
def f(x):
return batch_partitionable_ffi_call(x)
mesh = jtu.create_mesh((len(jax.devices()),), ("i",))
x = self.rng().randn(8, 4, 5).astype(np.float32)
x = jax.device_put(x, jax.NamedSharding(mesh, P("i")))
f(x) # eager mode doesn't crash
jax.jit(f)(x) # neither does JIT
self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text())
def batch_partitionable_ffi_call(x):
return batch_partitionable_p.bind(x)
batch_partitionable_p = core.Primitive("batch_partitionable")
batch_partitionable_p.multiple_results = True
dispatch.simple_impl(batch_partitionable_p)
@batch_partitionable_p.def_abstract_eval
def _batch_partitionable_abstract_eval(x):
return x, core.ShapedArray(x.shape[:-1], x.dtype)
def _batch_partitionable_lowering(target_name, ctx, x):
x_aval, = ctx.avals_in
num_batch_dims = len(x_aval.shape) - 2
frontend_attrs = mlir.ir_attribute({"num_batch_dims": str(num_batch_dims)})
return jax.ffi.ffi_lowering(
target_name,
extra_attributes={"mhlo.frontend_attributes": frontend_attrs}
)(ctx, x)
mlir.register_lowering(
batch_partitionable_p,
partial(_batch_partitionable_lowering, "lapack_sgeqrf_ffi"),
platform="cpu",
)
mlir.register_lowering(
batch_partitionable_p,
partial(_batch_partitionable_lowering, "cusolver_geqrf_ffi"),
platform="cuda",
)
mlir.register_lowering(
batch_partitionable_p,
partial(_batch_partitionable_lowering, "hipsolver_geqrf_ffi"),
platform="rocm",
)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())