This PR updates the FFI lowering rule to support a DeviceLoweringLayout
object as input when specifying the input and output layouts. For now,
this just converts the DLL object to its appropriate list of
minor-to-major integers because that's what the custom call op expects.
The lowering logic for all jaxlib custom calls are currently split between JAX and jaxlib for reasons that are harder to justify now that the compiled calls are split between jaxlib and the relevant plugins. As part of my project to update these calls and simplify the lowering logic, it makes sense to consolidate the lowering rules in JAX instead of jaxlib since the logic is now the same for both GPU and CPU. This update tackles a simple kernel as a test case for what this would look like.
Since the full lowering rule is now implemented in JAX, we can take advantage of the MLIR helpers that are included there, including `jex.ffi.ffi_lowering`, which I needed to update to support shape polymorphism.
Of note: I think it is safe (in a compatibility sense) to delete the lowering code from jaxlib, but it does mean that it won't be possible to lower this operation when `jax.__version__ < jaxlib.__version__`. I think this is okay given our compatibility guarantees, but I'd love a sanity check on that!
Another note, this doesn't actually change the lowered HLO for this op, so we don't need to worry about export compatibility.
PiperOrigin-RevId: 664680250
While this function is currently only used for lowering FFI calls, it could be used most places where `ir.*Attr` objects are directly constructed.
PiperOrigin-RevId: 661761712
This means that users of the FFI interface won't need to directly
interact with `jaxlib.xla_client` at all.
I've expanded the doctring a little and changed one default: the default
`api_version` is `1` instead of `0` to be consistent with the new name.
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.
~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.
The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
The docstring for the recently added `pycapsule` function in
`jax.extend.ffi` didn't conform to our usual docstring format, so I
updated it and added a little bit more context.
This function provides sensible defaults for custom call lowering rules
with the goal of reducing the amount of boilerplate required for
implementing custom calls.
Co-authored-by: Sergei Lebedev <slebedev@google.com>
This allows one to build PyCapsules ready for XLA custom call
registration out of function pointers retrieved from shared libraries
using ctypes. In particular, this obviates the need to create Python
bindings to custom call targets.
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.
PiperOrigin-RevId: 578349699