mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00

This could be useful for supporting the most common use cases for FFI custom calls. It has several benefits over using the `Primitive` based approach, but the biggest one (in my opinion) is that it doesn't require interacting with `mlir` at all. It does have the limitation that transforms would need to be registered using interfaces like `custom_vjp`, but many users of custom calls already do that. ~~The easiest to-do item (I think) is to implement batching using a `vectorized` parameter like `pure_callback`, but we could also think about more sophisticated vmapping interfaces in the future.~~ Done. The more difficult to-do is to think about how to support sharding, and we might actually want to expose an interface similar to the one from `custom_partitioning`. I have less experience with this part so I'll have to think some more about it, and feedback would be appreciated!
12 lines
166 B
ReStructuredText
12 lines
166 B
ReStructuredText
``jax.extend.ffi`` module
|
|
=========================
|
|
|
|
.. automodule:: jax.extend.ffi
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
ffi_call
|
|
ffi_lowering
|
|
pycapsule
|