Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.
For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.
See https://github.com/jax-ml/jax/issues/26480 for more details.
PiperOrigin-RevId: 726770483
In https://github.com/jax-ml/jax/pull/24370, `ffi_call` was updated to return a callable, and the original calling convention was deprecated. This change is part of the deprecation cycle for this calling convention.
PiperOrigin-RevId: 708424223
This PR updates the FFI lowering rule to support a DeviceLoweringLayout
object as input when specifying the input and output layouts. For now,
this just converts the DLL object to its appropriate list of
minor-to-major integers because that's what the custom call op expects.
While this function is currently only used for lowering FFI calls, it could be used most places where `ir.*Attr` objects are directly constructed.
PiperOrigin-RevId: 661761712
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.
In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.
PiperOrigin-RevId: 660831000
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.
~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.
The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
This function provides sensible defaults for custom call lowering rules
with the goal of reducing the amount of boilerplate required for
implementing custom calls.
Co-authored-by: Sergei Lebedev <slebedev@google.com>
This re-enables the tests removed in https://github.com/google/jax/pull/21563
and adds support for exposing the XLA FFI headers in the
`jax.extend.ffi.include_dir` directory during a bazel build. While it's
unlikely that these will be useful for most bazel users, it is good to provide
a consistent interface with the wheel build and to be able to test this feature.
PiperOrigin-RevId: 640194961
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.
PiperOrigin-RevId: 578349699