11 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
b308c64936 Export jaxlib.xla_client.register_custom_call_target as jax.extend.ffi.register_ffi_target.
This means that users of the FFI interface won't need to directly
interact with `jaxlib.xla_client` at all.

I've expanded the doctring a little and changed one default: the default
`api_version` is `1` instead of `0` to be consistent with the new name.
2024-07-19 08:12:25 -04:00
Dan Foreman-Mackey
e9b087d3a8 Add ffi_call function with a similar signature to pure_callback.
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.

~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.

The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
2024-07-01 09:40:31 -04:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Dan Foreman-Mackey
1fa66590d1 Edit pycapsule docstring to provide a little bit more context
The docstring for the recently added `pycapsule` function in
`jax.extend.ffi` didn't conform to our usual docstring format, so I
updated it and added a little bit more context.
2024-06-07 13:07:03 -04:00
Dan Foreman-Mackey
ac560c0d90 Add helper function for building custom call lowering rules
This function provides sensible defaults for custom call lowering rules
with the goal of reducing the amount of boilerplate required for
implementing custom calls.

Co-authored-by: Sergei Lebedev <slebedev@google.com>
2024-06-06 11:34:08 -04:00
Andrey Portnoy
fa46bbf0e0 Add pycapsule utility to jax.extend.ffi
This allows one to build PyCapsules ready for XLA custom call
registration out of function pointers retrieved from shared libraries
using ctypes. In particular, this obviates the need to create Python
bindings to custom call targets.
2024-06-03 15:33:49 -04:00
Dan Foreman-Mackey
1e206880d3 Move jax.ffi submodule to jax.extend.ffi 2024-05-31 12:34:59 -04:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
Roy Frostig
16d082b002 [jex] replace extend.random.PRNGImpl with extend.random.define_prng_impl
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.

PiperOrigin-RevId: 578349699
2023-10-31 17:21:54 -07:00
Roy Frostig
2bf9322ccc move wrap_key_data to jax.random
This is a fine function for the public API, rather than `jax.extend`.
2023-09-18 14:38:22 -07:00
Roy Frostig
a69f134cde add jax.extend.random.wrap_key_data 2023-08-26 11:39:25 -07:00