26 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
cb4d97aa1f Move jex.ffi to jax.ffi. 2024-12-29 13:06:19 +00:00
Jake VanderPlas
40367a9eaf Cleanup: remove uses of no-op raise_to_shaped 2024-12-12 09:49:06 -08:00
Dan Foreman-Mackey
478ea0dcd6 Allow 64-bit output types from ffi_call regardless of enable_x64 flag. 2024-11-11 15:01:53 -08:00
Ke Wu
7404e0d29d Add typing overloads for jax.extend.ffi.ffi_call() to aid type checkers
PiperOrigin-RevId: 694639758
2024-11-08 14:23:01 -08:00
Dan Foreman-Mackey
21f3353544 Add support for layouts and other advanced features in ffi_call. 2024-10-25 12:31:07 -04:00
Dan Foreman-Mackey
61701af4a2 Rename vmap methods for callbacks. 2024-10-21 15:03:04 -04:00
Dan Foreman-Mackey
0b651f0f45 Make ffi_call return a callable 2024-10-21 12:16:57 -04:00
Dan Foreman-Mackey
1d27d420ac Deprecate the vectorized argument to pure_callback and ffi_call. 2024-10-02 11:33:51 -04:00
Dan Foreman-Mackey
f60c5ccdee Add support for passing array attributes via ffi_call 2024-10-01 19:22:04 -04:00
Dan Foreman-Mackey
d80a89d86b Add support for FFI calls with side effects via ffi_call 2024-09-27 19:46:35 -04:00
Dan Foreman-Mackey
86f48a85b4 Add support for the DeviceLocalLayout API when lowering FFI calls.
This PR updates the FFI lowering rule to support a DeviceLoweringLayout
object as input when specifying the input and output layouts. For now,
this just converts the DLL object to its appropriate list of
minor-to-major integers because that's what the custom call op expects.
2024-09-05 14:30:06 -04:00
Georg Stefan Schmid
24bb8ae443 [ffi] Add support for token inputs and outputs 2024-09-03 18:28:34 +00:00
Dan Foreman-Mackey
79c222eee6 Fix bug in ffi_lowering where custom layouts were ignored.
PiperOrigin-RevId: 664795687
2024-08-19 07:20:06 -07:00
Dan Foreman-Mackey
dad2f576ac Add support for shape polymorphism in ffi_lowering and move lu_pivots_to_permutation lowering out of jaxlib.
The lowering logic for all jaxlib custom calls are currently split between JAX and jaxlib for reasons that are harder to justify now that the compiled calls are split between jaxlib and the relevant plugins. As part of my project to update these calls and simplify the lowering logic, it makes sense to consolidate the lowering rules in JAX instead of jaxlib since the logic is now the same for both GPU and CPU. This update tackles a simple kernel as a test case for what this would look like.

Since the full lowering rule is now implemented in JAX, we can take advantage of the MLIR helpers that are included there, including `jex.ffi.ffi_lowering`, which I needed to update to support shape polymorphism.

Of note: I think it is safe (in a compatibility sense) to delete the lowering code from jaxlib, but it does mean that it won't be possible to lower this operation when `jax.__version__ < jaxlib.__version__`. I think this is okay given our compatibility guarantees, but I'd love a sanity check on that!

Another note, this doesn't actually change the lowered HLO for this op, so we don't need to worry about export compatibility.

PiperOrigin-RevId: 664680250
2024-08-19 01:05:31 -07:00
Dan Foreman-Mackey
96045043a4 Move ir_attribute builder from extend.ffi to interpreters.mlir.
While this function is currently only used for lowering FFI calls, it could be used most places where `ir.*Attr` objects are directly constructed.

PiperOrigin-RevId: 661761712
2024-08-11 01:47:49 -07:00
Dan Foreman-Mackey
b308c64936 Export jaxlib.xla_client.register_custom_call_target as jax.extend.ffi.register_ffi_target.
This means that users of the FFI interface won't need to directly
interact with `jaxlib.xla_client` at all.

I've expanded the doctring a little and changed one default: the default
`api_version` is `1` instead of `0` to be consistent with the new name.
2024-07-19 08:12:25 -04:00
Dan Foreman-Mackey
e9b087d3a8 Add ffi_call function with a similar signature to pure_callback.
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.

~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.

The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
2024-07-01 09:40:31 -04:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Dan Foreman-Mackey
1fa66590d1 Edit pycapsule docstring to provide a little bit more context
The docstring for the recently added `pycapsule` function in
`jax.extend.ffi` didn't conform to our usual docstring format, so I
updated it and added a little bit more context.
2024-06-07 13:07:03 -04:00
Dan Foreman-Mackey
ac560c0d90 Add helper function for building custom call lowering rules
This function provides sensible defaults for custom call lowering rules
with the goal of reducing the amount of boilerplate required for
implementing custom calls.

Co-authored-by: Sergei Lebedev <slebedev@google.com>
2024-06-06 11:34:08 -04:00
Andrey Portnoy
fa46bbf0e0 Add pycapsule utility to jax.extend.ffi
This allows one to build PyCapsules ready for XLA custom call
registration out of function pointers retrieved from shared libraries
using ctypes. In particular, this obviates the need to create Python
bindings to custom call targets.
2024-06-03 15:33:49 -04:00
Dan Foreman-Mackey
1e206880d3 Move jax.ffi submodule to jax.extend.ffi 2024-05-31 12:34:59 -04:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
Roy Frostig
16d082b002 [jex] replace extend.random.PRNGImpl with extend.random.define_prng_impl
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.

PiperOrigin-RevId: 578349699
2023-10-31 17:21:54 -07:00
Roy Frostig
2bf9322ccc move wrap_key_data to jax.random
This is a fine function for the public API, rather than `jax.extend`.
2023-09-18 14:38:22 -07:00
Roy Frostig
a69f134cde add jax.extend.random.wrap_key_data 2023-08-26 11:39:25 -07:00