diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py
index 46b6543b8..5195f5cca 100644
--- a/jax/_src/ffi.py
+++ b/jax/_src/ffi.py
@@ -27,6 +27,7 @@ from jax._src import deprecations
 from jax._src import dispatch
 from jax._src import effects
 from jax._src import util
+from jax._src import xla_bridge
 from jax._src.callback import callback_batching_rule
 from jax._src.interpreters import ad
 from jax._src.interpreters import batching
@@ -87,6 +88,18 @@ def register_ffi_type_id(
   return xla_client.register_custom_type_id(name, obj, platform=platform)
 
 
+def register_ffi_target_as_batch_partitionable(name: str) -> None:
+  """Registers an FFI target as batch partitionable.
+
+  Args:
+    name: the name of the target.
+  """
+  xla_client.register_custom_call_as_batch_partitionable(name)
+  xla_bridge.register_plugin_callbacks(
+      functools.partial(xla_client.register_custom_call_as_batch_partitionable,
+                        name))
+
+
 def pycapsule(funcptr):
   """Wrap a ctypes function pointer in a PyCapsule.
 
diff --git a/jax/ffi.py b/jax/ffi.py
index 1f9be6e8c..6606c58a0 100644
--- a/jax/ffi.py
+++ b/jax/ffi.py
@@ -22,4 +22,5 @@ from jax._src.ffi import (
     pycapsule as pycapsule,
     register_ffi_target as register_ffi_target,
     register_ffi_type_id as register_ffi_type_id,
+    register_ffi_target_as_batch_partitionable as register_ffi_target_as_batch_partitionable,
 )
diff --git a/tests/BUILD b/tests/BUILD
index 07ba4b52f..17fc08ad4 100644
--- a/tests/BUILD
+++ b/tests/BUILD
@@ -161,6 +161,9 @@ jax_multiplatform_test(
 jax_multiplatform_test(
     name = "ffi_test",
     srcs = ["ffi_test.py"],
+    enable_configs = [
+        "gpu_p100x2",
+    ],
     # TODO(dfm): Remove after removal of jex.ffi imports.
     deps = ["//jax:extend"],
 )
diff --git a/tests/ffi_test.py b/tests/ffi_test.py
index 713d4cdcb..bbe44ef56 100644
--- a/tests/ffi_test.py
+++ b/tests/ffi_test.py
@@ -24,19 +24,21 @@ import jax
 from jax import lax
 import jax.extend as jex
 import jax.numpy as jnp
-import jax.sharding as shd
+from jax.sharding import PartitionSpec as P
 
 from jax._src import config
 from jax._src import core
+from jax._src import dispatch
 from jax._src import test_util as jtu
 from jax._src.interpreters import mlir
 from jax._src.layout import DeviceLocalLayout
-from jax._src.lib import lapack
+from jax._src.lib import lapack, xla_extension_version
 from jax._src.lib.mlir.dialects import hlo
 from jax._src.lax import linalg as lax_linalg_internal
 from jax.experimental.shard_map import shard_map
 
 jax.config.parse_flags_with_absl()
+jtu.request_cpu_devices(8)
 
 
 class FfiTest(jtu.JaxTestCase):
@@ -282,11 +284,10 @@ class FfiTest(jtu.JaxTestCase):
 
   @jtu.run_on_devices("gpu", "cpu")
   def test_shard_map(self):
-    mesh = jtu.create_mesh((1,), ("i",))
+    mesh = jtu.create_mesh((len(jax.devices()),), ("i",))
     x = self.rng().randn(8, 4, 5).astype(np.float32)
 
-    @partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'),
-             out_specs=shd.PartitionSpec('i'))
+    @partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i"))
     def f(x):
       return ffi_call_geqrf(x)
 
@@ -328,5 +329,91 @@ def ffi_call_geqrf(x, _use_extend=False, **kwargs):
       cuda=partial(call, "cuda"))
 
 
+class BatchPartitioningTest(jtu.JaxTestCase):
+  def setUp(self):
+    super().setUp()
+    if xla_extension_version < 312:
+      self.skipTest("Requires XLA extension version >= 312")
+    if jax.device_count() < 2:
+      self.skipTest("Requires multiple devices")
+    if jtu.test_device_matches(["cpu"]):
+      lapack._lapack.initialize()
+    for target_name in ["lapack_sgeqrf_ffi", "cusolver_geqrf_ffi",
+                        "hipsolver_geqrf_ffi"]:
+      jax.ffi.register_ffi_target_as_batch_partitionable(target_name)
+
+  @jtu.run_on_devices("gpu", "cpu")
+  def test_shard_map(self):
+    mesh = jtu.create_mesh((len(jax.devices()),), ("i",))
+    x = self.rng().randn(8, 4, 5).astype(np.float32)
+
+    @partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i"),
+             check_rep=False)
+    def f(x):
+      return batch_partitionable_ffi_call(x)
+
+    f(x)  # eager mode doesn't crash
+    jax.jit(f)(x)  # neither does JIT
+    self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text())
+
+  @jtu.run_on_devices("gpu", "cpu")
+  def test_batch_partitioning(self):
+    if config.use_shardy_partitioner.value:
+      self.skipTest(
+          "Shardy does not yet support batch partitioning of FFI calls.")
+
+    def f(x):
+      return batch_partitionable_ffi_call(x)
+
+    mesh = jtu.create_mesh((len(jax.devices()),), ("i",))
+    x = self.rng().randn(8, 4, 5).astype(np.float32)
+    x = jax.device_put(x, jax.NamedSharding(mesh, P("i")))
+
+    f(x)  # eager mode doesn't crash
+    jax.jit(f)(x)  # neither does JIT
+    self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text())
+
+
+def batch_partitionable_ffi_call(x):
+  return batch_partitionable_p.bind(x)
+
+
+batch_partitionable_p = core.Primitive("batch_partitionable")
+batch_partitionable_p.multiple_results = True
+dispatch.simple_impl(batch_partitionable_p)
+
+
+@batch_partitionable_p.def_abstract_eval
+def _batch_partitionable_abstract_eval(x):
+  return x, core.ShapedArray(x.shape[:-1], x.dtype)
+
+
+def _batch_partitionable_lowering(target_name, ctx, x):
+  x_aval, = ctx.avals_in
+  num_batch_dims = len(x_aval.shape) - 2
+  frontend_attrs = mlir.ir_attribute({"num_batch_dims": str(num_batch_dims)})
+  return jax.ffi.ffi_lowering(
+      target_name,
+      extra_attributes={"mhlo.frontend_attributes": frontend_attrs}
+  )(ctx, x)
+
+
+mlir.register_lowering(
+    batch_partitionable_p,
+    partial(_batch_partitionable_lowering, "lapack_sgeqrf_ffi"),
+    platform="cpu",
+)
+mlir.register_lowering(
+    batch_partitionable_p,
+    partial(_batch_partitionable_lowering, "cusolver_geqrf_ffi"),
+    platform="cuda",
+)
+mlir.register_lowering(
+    batch_partitionable_p,
+    partial(_batch_partitionable_lowering, "hipsolver_geqrf_ffi"),
+    platform="rocm",
+)
+
+
 if __name__ == "__main__":
   absltest.main(testLoader=jtu.JaxTestLoader())