diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index 46b6543b8..5195f5cca 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -27,6 +27,7 @@ from jax._src import deprecations from jax._src import dispatch from jax._src import effects from jax._src import util +from jax._src import xla_bridge from jax._src.callback import callback_batching_rule from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -87,6 +88,18 @@ def register_ffi_type_id( return xla_client.register_custom_type_id(name, obj, platform=platform) +def register_ffi_target_as_batch_partitionable(name: str) -> None: + """Registers an FFI target as batch partitionable. + + Args: + name: the name of the target. + """ + xla_client.register_custom_call_as_batch_partitionable(name) + xla_bridge.register_plugin_callbacks( + functools.partial(xla_client.register_custom_call_as_batch_partitionable, + name)) + + def pycapsule(funcptr): """Wrap a ctypes function pointer in a PyCapsule. diff --git a/jax/ffi.py b/jax/ffi.py index 1f9be6e8c..6606c58a0 100644 --- a/jax/ffi.py +++ b/jax/ffi.py @@ -22,4 +22,5 @@ from jax._src.ffi import ( pycapsule as pycapsule, register_ffi_target as register_ffi_target, register_ffi_type_id as register_ffi_type_id, + register_ffi_target_as_batch_partitionable as register_ffi_target_as_batch_partitionable, ) diff --git a/tests/BUILD b/tests/BUILD index 07ba4b52f..17fc08ad4 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -161,6 +161,9 @@ jax_multiplatform_test( jax_multiplatform_test( name = "ffi_test", srcs = ["ffi_test.py"], + enable_configs = [ + "gpu_p100x2", + ], # TODO(dfm): Remove after removal of jex.ffi imports. deps = ["//jax:extend"], ) diff --git a/tests/ffi_test.py b/tests/ffi_test.py index 713d4cdcb..bbe44ef56 100644 --- a/tests/ffi_test.py +++ b/tests/ffi_test.py @@ -24,19 +24,21 @@ import jax from jax import lax import jax.extend as jex import jax.numpy as jnp -import jax.sharding as shd +from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import core +from jax._src import dispatch from jax._src import test_util as jtu from jax._src.interpreters import mlir from jax._src.layout import DeviceLocalLayout -from jax._src.lib import lapack +from jax._src.lib import lapack, xla_extension_version from jax._src.lib.mlir.dialects import hlo from jax._src.lax import linalg as lax_linalg_internal from jax.experimental.shard_map import shard_map jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(8) class FfiTest(jtu.JaxTestCase): @@ -282,11 +284,10 @@ class FfiTest(jtu.JaxTestCase): @jtu.run_on_devices("gpu", "cpu") def test_shard_map(self): - mesh = jtu.create_mesh((1,), ("i",)) + mesh = jtu.create_mesh((len(jax.devices()),), ("i",)) x = self.rng().randn(8, 4, 5).astype(np.float32) - @partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'), - out_specs=shd.PartitionSpec('i')) + @partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i")) def f(x): return ffi_call_geqrf(x) @@ -328,5 +329,91 @@ def ffi_call_geqrf(x, _use_extend=False, **kwargs): cuda=partial(call, "cuda")) +class BatchPartitioningTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if xla_extension_version < 312: + self.skipTest("Requires XLA extension version >= 312") + if jax.device_count() < 2: + self.skipTest("Requires multiple devices") + if jtu.test_device_matches(["cpu"]): + lapack._lapack.initialize() + for target_name in ["lapack_sgeqrf_ffi", "cusolver_geqrf_ffi", + "hipsolver_geqrf_ffi"]: + jax.ffi.register_ffi_target_as_batch_partitionable(target_name) + + @jtu.run_on_devices("gpu", "cpu") + def test_shard_map(self): + mesh = jtu.create_mesh((len(jax.devices()),), ("i",)) + x = self.rng().randn(8, 4, 5).astype(np.float32) + + @partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i"), + check_rep=False) + def f(x): + return batch_partitionable_ffi_call(x) + + f(x) # eager mode doesn't crash + jax.jit(f)(x) # neither does JIT + self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text()) + + @jtu.run_on_devices("gpu", "cpu") + def test_batch_partitioning(self): + if config.use_shardy_partitioner.value: + self.skipTest( + "Shardy does not yet support batch partitioning of FFI calls.") + + def f(x): + return batch_partitionable_ffi_call(x) + + mesh = jtu.create_mesh((len(jax.devices()),), ("i",)) + x = self.rng().randn(8, 4, 5).astype(np.float32) + x = jax.device_put(x, jax.NamedSharding(mesh, P("i"))) + + f(x) # eager mode doesn't crash + jax.jit(f)(x) # neither does JIT + self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text()) + + +def batch_partitionable_ffi_call(x): + return batch_partitionable_p.bind(x) + + +batch_partitionable_p = core.Primitive("batch_partitionable") +batch_partitionable_p.multiple_results = True +dispatch.simple_impl(batch_partitionable_p) + + +@batch_partitionable_p.def_abstract_eval +def _batch_partitionable_abstract_eval(x): + return x, core.ShapedArray(x.shape[:-1], x.dtype) + + +def _batch_partitionable_lowering(target_name, ctx, x): + x_aval, = ctx.avals_in + num_batch_dims = len(x_aval.shape) - 2 + frontend_attrs = mlir.ir_attribute({"num_batch_dims": str(num_batch_dims)}) + return jax.ffi.ffi_lowering( + target_name, + extra_attributes={"mhlo.frontend_attributes": frontend_attrs} + )(ctx, x) + + +mlir.register_lowering( + batch_partitionable_p, + partial(_batch_partitionable_lowering, "lapack_sgeqrf_ffi"), + platform="cpu", +) +mlir.register_lowering( + batch_partitionable_p, + partial(_batch_partitionable_lowering, "cusolver_geqrf_ffi"), + platform="cuda", +) +mlir.register_lowering( + batch_partitionable_p, + partial(_batch_partitionable_lowering, "hipsolver_geqrf_ffi"), + platform="rocm", +) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())