12 Commits

Author SHA1 Message Date
Peter Hawkins
66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00
Dan Foreman-Mackey
9298018afa Enable shardy batch partitionable FFI test.
PiperOrigin-RevId: 726171678
2025-02-12 13:17:40 -08:00
Dan Foreman-Mackey
c502332ed5 Add "sequential_unrolled" vmap method for callbacks.
Like the `sequential` method, this loops over calls to the callback, but in this case, the loop is unrolled.

PiperOrigin-RevId: 725601366
2025-02-11 06:09:16 -08:00
Dan Foreman-Mackey
ba8c3a925b Fix missing batch partitioning hook in GPU plugin.
I had somehow missed properly registering the "batch partitionable" registration hook on the GPU plugin, causing a segfault when the missing pointer was accessed. This fixes that and updates the tests to make sure that the registration code is executed even without multiple devices.

PiperOrigin-RevId: 725312473
2025-02-10 12:42:17 -08:00
Dan Foreman-Mackey
c521bc6205 [xla:python] Add a mechanism for "batch partitioning" of FFI calls.
This is the first in a series of changes to add a simple API for supporting a set of common sharding and partitioning patterns for FFI calls. The high level motivation is that custom calls (including FFI calls) are opaque to the SPMD partitioner, and the only ways to customize the partitioning behavior is to (a) explicitly register an `xla::CustomCallPartitoner` with XLA, or (b) use the `jax.experimental.custom_partitioning` APIs. Option (a) isn't generally practical for most use cases where the FFI handler lives in an external binary. Option (b) is flexible, and supports all common use cases, but it requires embedding Python callbacks in to the HLO, which can lead to issues including cache misses. Furthermore, `custom_partitioning` is overpowered for many use cases, where only (what I will call) "batch partitioning" is supported.

In this case, "batch partitioning" refers to the behavior of many FFI calls where they can be trivially partitioned on some number of (leading) dimensions, with the same call being executed independently on each shard of data. If the data are sharded on non-batch dimensions, partitioning will still re-shard the data to be replicated on the non-batch dimensions. This kind of partitioning logic applies to all the LAPACK/cuSOLVER/etc.-backed linear algebra functions in jaxlib, as well as some external users of `custom_partitioning`.

The approach I'm taking here is to add a new registration function to the XLA client, which let's a user label their FFI call as batch partitionable. Then, when lowering the custom call, the user passes the number of batch dimensions as a frontend attribute, which is then interpreted by the SPMD partitioner.

In parallel with this change, shardy has added support for sharding propagation across custom calls using a string representation that is similar in spirit to this approach, but somewhat more general. However, the shardy implementation still requires a Python callback for the partitioning step, so it doesn't (yet!) solve all of the relevant problems with the `custom_partitioning` approach. Ultimately, it should be possible to have the partitioner parse the shardy sharding rule representation, but I wanted to start with the minimal implementation.

PiperOrigin-RevId: 724367877
2025-02-07 09:14:06 -08:00
Gunhyun Park
20555f63da Lower np.ndarray to DenseElementsAttr instead of ArrayAttr.
PiperOrigin-RevId: 721833949
2025-01-31 11:06:06 -08:00
Jake VanderPlas
1ee015674f [internal] add deprecation test utilities 2025-01-10 11:54:09 -08:00
jax authors
56f0f9534d Merge pull request #25633 from dfm:move-ffi
PiperOrigin-RevId: 712863350
2025-01-07 04:40:21 -08:00
Dan Foreman-Mackey
cb4d97aa1f Move jex.ffi to jax.ffi. 2024-12-29 13:06:19 +00:00
Dan Foreman-Mackey
690fa1d90c Remove failing ffi test
The FFI headers aren't properly exposed during a bazel build, so these
tests are failing. I'll re-enable them when I get a chance to get that
working properly.
2024-05-31 15:36:33 -04:00
Dan Foreman-Mackey
1e206880d3 Move jax.ffi submodule to jax.extend.ffi 2024-05-31 12:34:59 -04:00
Dan Foreman-Mackey
88790711e8 Package XLA FFI headers with jaxlib wheel
The new "typed" API that XLA provides for foreign function calls is
header-only and packaging it as part of jaxlib could simplify the open
source workflow for building custom calls.

It's not completely obvious that we need to include this, because jaxlib
isn't strictly required as a _build_ dependency for FFI calls, although
it typically will be required as a _run time_ dependency. Also, it
probably wouldn't be too painful for external projects to use the
headers directly from the openxla/xla repo.

All that being said, I wanted to figure out how to do this, and it has
been requested a few times.
2024-05-22 12:28:38 -04:00