mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[xla:python] Add a mechanism for "batch partitioning" of FFI calls.
This is the first in a series of changes to add a simple API for supporting a set of common sharding and partitioning patterns for FFI calls. The high level motivation is that custom calls (including FFI calls) are opaque to the SPMD partitioner, and the only ways to customize the partitioning behavior is to (a) explicitly register an `xla::CustomCallPartitoner` with XLA, or (b) use the `jax.experimental.custom_partitioning` APIs. Option (a) isn't generally practical for most use cases where the FFI handler lives in an external binary. Option (b) is flexible, and supports all common use cases, but it requires embedding Python callbacks in to the HLO, which can lead to issues including cache misses. Furthermore, `custom_partitioning` is overpowered for many use cases, where only (what I will call) "batch partitioning" is supported. In this case, "batch partitioning" refers to the behavior of many FFI calls where they can be trivially partitioned on some number of (leading) dimensions, with the same call being executed independently on each shard of data. If the data are sharded on non-batch dimensions, partitioning will still re-shard the data to be replicated on the non-batch dimensions. This kind of partitioning logic applies to all the LAPACK/cuSOLVER/etc.-backed linear algebra functions in jaxlib, as well as some external users of `custom_partitioning`. The approach I'm taking here is to add a new registration function to the XLA client, which let's a user label their FFI call as batch partitionable. Then, when lowering the custom call, the user passes the number of batch dimensions as a frontend attribute, which is then interpreted by the SPMD partitioner. In parallel with this change, shardy has added support for sharding propagation across custom calls using a string representation that is similar in spirit to this approach, but somewhat more general. However, the shardy implementation still requires a Python callback for the partitioning step, so it doesn't (yet!) solve all of the relevant problems with the `custom_partitioning` approach. Ultimately, it should be possible to have the partitioner parse the shardy sharding rule representation, but I wanted to start with the minimal implementation. PiperOrigin-RevId: 724367877
This commit is contained in:
parent
5bc17f7ec3
commit
c521bc6205
@ -27,6 +27,7 @@ from jax._src import deprecations
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.callback import callback_batching_rule
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -87,6 +88,18 @@ def register_ffi_type_id(
|
||||
return xla_client.register_custom_type_id(name, obj, platform=platform)
|
||||
|
||||
|
||||
def register_ffi_target_as_batch_partitionable(name: str) -> None:
|
||||
"""Registers an FFI target as batch partitionable.
|
||||
|
||||
Args:
|
||||
name: the name of the target.
|
||||
"""
|
||||
xla_client.register_custom_call_as_batch_partitionable(name)
|
||||
xla_bridge.register_plugin_callbacks(
|
||||
functools.partial(xla_client.register_custom_call_as_batch_partitionable,
|
||||
name))
|
||||
|
||||
|
||||
def pycapsule(funcptr):
|
||||
"""Wrap a ctypes function pointer in a PyCapsule.
|
||||
|
||||
|
@ -22,4 +22,5 @@ from jax._src.ffi import (
|
||||
pycapsule as pycapsule,
|
||||
register_ffi_target as register_ffi_target,
|
||||
register_ffi_type_id as register_ffi_type_id,
|
||||
register_ffi_target_as_batch_partitionable as register_ffi_target_as_batch_partitionable,
|
||||
)
|
||||
|
@ -161,6 +161,9 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "ffi_test",
|
||||
srcs = ["ffi_test.py"],
|
||||
enable_configs = [
|
||||
"gpu_p100x2",
|
||||
],
|
||||
# TODO(dfm): Remove after removal of jex.ffi imports.
|
||||
deps = ["//jax:extend"],
|
||||
)
|
||||
|
@ -24,19 +24,21 @@ import jax
|
||||
from jax import lax
|
||||
import jax.extend as jex
|
||||
import jax.numpy as jnp
|
||||
import jax.sharding as shd
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.layout import DeviceLocalLayout
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib import lapack, xla_extension_version
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lax import linalg as lax_linalg_internal
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
|
||||
class FfiTest(jtu.JaxTestCase):
|
||||
@ -282,11 +284,10 @@ class FfiTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def test_shard_map(self):
|
||||
mesh = jtu.create_mesh((1,), ("i",))
|
||||
mesh = jtu.create_mesh((len(jax.devices()),), ("i",))
|
||||
x = self.rng().randn(8, 4, 5).astype(np.float32)
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'),
|
||||
out_specs=shd.PartitionSpec('i'))
|
||||
@partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i"))
|
||||
def f(x):
|
||||
return ffi_call_geqrf(x)
|
||||
|
||||
@ -328,5 +329,91 @@ def ffi_call_geqrf(x, _use_extend=False, **kwargs):
|
||||
cuda=partial(call, "cuda"))
|
||||
|
||||
|
||||
class BatchPartitioningTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if xla_extension_version < 312:
|
||||
self.skipTest("Requires XLA extension version >= 312")
|
||||
if jax.device_count() < 2:
|
||||
self.skipTest("Requires multiple devices")
|
||||
if jtu.test_device_matches(["cpu"]):
|
||||
lapack._lapack.initialize()
|
||||
for target_name in ["lapack_sgeqrf_ffi", "cusolver_geqrf_ffi",
|
||||
"hipsolver_geqrf_ffi"]:
|
||||
jax.ffi.register_ffi_target_as_batch_partitionable(target_name)
|
||||
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def test_shard_map(self):
|
||||
mesh = jtu.create_mesh((len(jax.devices()),), ("i",))
|
||||
x = self.rng().randn(8, 4, 5).astype(np.float32)
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i"),
|
||||
check_rep=False)
|
||||
def f(x):
|
||||
return batch_partitionable_ffi_call(x)
|
||||
|
||||
f(x) # eager mode doesn't crash
|
||||
jax.jit(f)(x) # neither does JIT
|
||||
self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text())
|
||||
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def test_batch_partitioning(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest(
|
||||
"Shardy does not yet support batch partitioning of FFI calls.")
|
||||
|
||||
def f(x):
|
||||
return batch_partitionable_ffi_call(x)
|
||||
|
||||
mesh = jtu.create_mesh((len(jax.devices()),), ("i",))
|
||||
x = self.rng().randn(8, 4, 5).astype(np.float32)
|
||||
x = jax.device_put(x, jax.NamedSharding(mesh, P("i")))
|
||||
|
||||
f(x) # eager mode doesn't crash
|
||||
jax.jit(f)(x) # neither does JIT
|
||||
self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text())
|
||||
|
||||
|
||||
def batch_partitionable_ffi_call(x):
|
||||
return batch_partitionable_p.bind(x)
|
||||
|
||||
|
||||
batch_partitionable_p = core.Primitive("batch_partitionable")
|
||||
batch_partitionable_p.multiple_results = True
|
||||
dispatch.simple_impl(batch_partitionable_p)
|
||||
|
||||
|
||||
@batch_partitionable_p.def_abstract_eval
|
||||
def _batch_partitionable_abstract_eval(x):
|
||||
return x, core.ShapedArray(x.shape[:-1], x.dtype)
|
||||
|
||||
|
||||
def _batch_partitionable_lowering(target_name, ctx, x):
|
||||
x_aval, = ctx.avals_in
|
||||
num_batch_dims = len(x_aval.shape) - 2
|
||||
frontend_attrs = mlir.ir_attribute({"num_batch_dims": str(num_batch_dims)})
|
||||
return jax.ffi.ffi_lowering(
|
||||
target_name,
|
||||
extra_attributes={"mhlo.frontend_attributes": frontend_attrs}
|
||||
)(ctx, x)
|
||||
|
||||
|
||||
mlir.register_lowering(
|
||||
batch_partitionable_p,
|
||||
partial(_batch_partitionable_lowering, "lapack_sgeqrf_ffi"),
|
||||
platform="cpu",
|
||||
)
|
||||
mlir.register_lowering(
|
||||
batch_partitionable_p,
|
||||
partial(_batch_partitionable_lowering, "cusolver_geqrf_ffi"),
|
||||
platform="cuda",
|
||||
)
|
||||
mlir.register_lowering(
|
||||
batch_partitionable_p,
|
||||
partial(_batch_partitionable_lowering, "hipsolver_geqrf_ffi"),
|
||||
platform="rocm",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user