23 Commits

Author SHA1 Message Date
Roy Frostig
b79fe34b94 re-export RNG primitives from jax.extend.core.primitives 2024-07-17 12:24:20 -07:00
Dan Foreman-Mackey
e9b087d3a8 Add ffi_call function with a similar signature to pure_callback.
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.

~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.

The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
2024-07-01 09:40:31 -04:00
Dan Foreman-Mackey
ac560c0d90 Add helper function for building custom call lowering rules
This function provides sensible defaults for custom call lowering rules
with the goal of reducing the amount of boilerplate required for
implementing custom calls.

Co-authored-by: Sergei Lebedev <slebedev@google.com>
2024-06-06 11:34:08 -04:00
Andrey Portnoy
fa46bbf0e0 Add pycapsule utility to jax.extend.ffi
This allows one to build PyCapsules ready for XLA custom call
registration out of function pointers retrieved from shared libraries
using ctypes. In particular, this obviates the need to create Python
bindings to custom call targets.
2024-06-03 15:33:49 -04:00
Dan Foreman-Mackey
1e206880d3 Move jax.ffi submodule to jax.extend.ffi 2024-05-31 12:34:59 -04:00
Chase Roberts
74c2e25314 Add more imports to jax extend 2024-04-17 15:13:17 -07:00
Roy Frostig
09415607bb fix up extend:core build rule
We want `pytype_strict_library` here.

PiperOrigin-RevId: 624337356
2024-04-12 17:31:10 -07:00
Roy Frostig
65034b3da4 add and populate jax.extend.core.primitives 2024-04-10 09:27:42 -07:00
Yue Sheng
1cef1d9503 jax.clear_backends() is not doing what it is intended to do, users should try to avoid using it.
We decide to move it into `jax.extend`. This CL is the first step which adds a new module `jax.extend.backend`.

PiperOrigin-RevId: 615934218
2024-03-14 16:11:31 -07:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
Peter Hawkins
78543f7bb8 Add jax.extend.mlir.
Some users of JAX want to use the MLIR dialects defined in jaxlib. In particular, these need to be used by custom lowering rules. Add a semi-public (jax.extend) API to access these, rather than having them use jax._src.lib.mlir.

PiperOrigin-RevId: 588448489
2023-12-06 09:16:43 -08:00
Peter Hawkins
7fa0f464fd [bazel] Add a BUILD file for jax/extend, and add more granular targets for individual pieces of extend.
In general we'd like to use more granular BUILD targets rather than larger monolithic targets. If nothing else, they interact better with pytype.

This change is in preparation for adding the JAX MLIR bindings to jax.extend, since they are something that JAX users sometimes need especially for defining custom ops.

PiperOrigin-RevId: 587893573
2023-12-04 17:48:50 -08:00
Roy Frostig
16d082b002 [jex] replace extend.random.PRNGImpl with extend.random.define_prng_impl
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.

PiperOrigin-RevId: 578349699
2023-10-31 17:21:54 -07:00
Jake VanderPlas
6da4750c3b [random] remove internal uses of deprecated prng.seed_with_impl() 2023-10-17 13:18:08 -07:00
Jake VanderPlas
4ba7590d85 export jax.extend.source_info_util.current
PiperOrigin-RevId: 573290435
2023-10-13 12:31:11 -07:00
Jake VanderPlas
4e463c0aa2 JEX: add jax.extend.source_info_util 2023-10-13 09:36:00 -07:00
Roy Frostig
5158e251b6 identify PRNG schemes on key arrays, and recognize them in key constructors
Specifically:

* Introduce `jax.random.key_impl`, which accepts a key array and
  returns a hashable identifier of its PRNG implementation.

* Accept this identifier optionally as the `impl` argument to
  `jax.random.key` and `wrap_key_data`.

This now works:

```python
k1 = jax.random.key(72, impl='threefry2x32')
impl = jax.random.key_impl(k1)
k2 = jax.random.key(72, impl=impl)
assert arrays_equal(k1, k2)
assert k1.dtype == k2.dtype
```

This change also set up an internal PRNG registry and register
built-in implementations, to simplify various places where we
essentially reconstruct such a registry from scratch (such as in
tests).

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-10-06 10:15:08 -07:00
Jake VanderPlas
48087cbe8d JEX: add jex.abstract_arrays.array_types 2023-09-19 11:37:05 -07:00
Roy Frostig
2bf9322ccc move wrap_key_data to jax.random
This is a fine function for the public API, rather than `jax.extend`.
2023-09-18 14:38:22 -07:00
Jake VanderPlas
ca39457ea9 JEX: move jax.linear_util to jax.extend.linear_util 2023-08-30 18:32:12 -07:00
Roy Frostig
a69f134cde add jax.extend.random.wrap_key_data 2023-08-26 11:39:25 -07:00
Roy Frostig
a71c0e6ecc create jax.extend.random as a copy of jax.prng
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
PiperOrigin-RevId: 559874051
2023-08-24 14:41:56 -07:00
Roy Frostig
ca008f37e3 initiate jax.extend via docs and top-level module set-up 2023-05-15 15:47:06 -07:00