This means that users of the FFI interface won't need to directly
interact with `jaxlib.xla_client` at all.
I've expanded the doctring a little and changed one default: the default
`api_version` is `1` instead of `0` to be consistent with the new name.
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.
~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.
The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
The docstring for the recently added `pycapsule` function in
`jax.extend.ffi` didn't conform to our usual docstring format, so I
updated it and added a little bit more context.
This function provides sensible defaults for custom call lowering rules
with the goal of reducing the amount of boilerplate required for
implementing custom calls.
Co-authored-by: Sergei Lebedev <slebedev@google.com>
This allows one to build PyCapsules ready for XLA custom call
registration out of function pointers retrieved from shared libraries
using ctypes. In particular, this obviates the need to create Python
bindings to custom call targets.
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.
PiperOrigin-RevId: 578349699