1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

41 Commits

Author SHA1 Message Date
George Necula
a0812cd57e [better_errors] Make it explicit that debug_info is not None.
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.

For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.

See https://github.com/jax-ml/jax/issues/26480 for more details.

PiperOrigin-RevId: 726770483
2025-02-13 22:07:04 -08:00
Dan Foreman-Mackey
cb4d97aa1f Move jex.ffi to jax.ffi. 2024-12-29 13:06:19 +00:00
Dan Foreman-Mackey
4216f8fad0 Accelerate deprecation of legacy JAX FFI calling convention.
In https://github.com/jax-ml/jax/pull/24370, `ffi_call` was updated to return a callable, and the original calling convention was deprecated. This change is part of the deprecation cycle for this calling convention.

PiperOrigin-RevId: 708424223
2024-12-20 14:13:29 -08:00
jax authors
ae46b7564e Merge pull request from froystig:random-dtypes
PiperOrigin-RevId: 698268678
2024-11-19 23:04:06 -08:00
Roy Frostig
4bb81075bc represent random.key_impl of builtin RNGs by canonical string name
We do not have great reason to return specs here, and sticking to
strings instead can help with simple serialization.
2024-11-19 20:58:10 -08:00
Dan Foreman-Mackey
41a0493e56 Add shard map replication rule for ffi_call. 2024-11-14 15:44:31 -08:00
Dan Foreman-Mackey
478ea0dcd6 Allow 64-bit output types from ffi_call regardless of enable_x64 flag. 2024-11-11 15:01:53 -08:00
Dan Foreman-Mackey
1785479cbd Fix segfault caused by uninitialized LAPACK in FFI test. 2024-10-29 10:41:59 -04:00
Sergei Lebedev
321fa00741 Skip testVectorizedDeprecation on Python 3.13 to unblock the CI
PiperOrigin-RevId: 690598772
2024-10-28 07:15:24 -07:00
Dan Foreman-Mackey
21f3353544 Add support for layouts and other advanced features in ffi_call. 2024-10-25 12:31:07 -04:00
jax authors
9a2dd19a92 Merge pull request from andportnoy:aportnoy/unknown-platform-lowering-warning
PiperOrigin-RevId: 688630259
2024-10-22 11:40:39 -07:00
Andrey Portnoy
2aaa108f06 Raise an error when registering a lowering for an unknown platform 2024-10-22 13:29:48 -04:00
Dan Foreman-Mackey
61701af4a2 Rename vmap methods for callbacks. 2024-10-21 15:03:04 -04:00
Dan Foreman-Mackey
0b651f0f45 Make ffi_call return a callable 2024-10-21 12:16:57 -04:00
Dan Foreman-Mackey
1d27d420ac Deprecate the vectorized argument to pure_callback and ffi_call. 2024-10-02 11:33:51 -04:00
Dan Foreman-Mackey
f60c5ccdee Add support for passing array attributes via ffi_call 2024-10-01 19:22:04 -04:00
Dan Foreman-Mackey
1a1e16abcc Remove forward compatibility checks from lowering of LU decomposition.
The forward compatibility window for these checks has passed so it is now safe to remove them.

PiperOrigin-RevId: 680565099
2024-09-30 07:23:56 -07:00
Dan Foreman-Mackey
d80a89d86b Add support for FFI calls with side effects via ffi_call 2024-09-27 19:46:35 -04:00
Dan Foreman-Mackey
86f48a85b4 Add support for the DeviceLocalLayout API when lowering FFI calls.
This PR updates the FFI lowering rule to support a DeviceLoweringLayout
object as input when specifying the input and output layouts. For now,
this just converts the DLL object to its appropriate list of
minor-to-major integers because that's what the custom call op expects.
2024-09-05 14:30:06 -04:00
Georg Stefan Schmid
24bb8ae443 [ffi] Add support for token inputs and outputs 2024-09-03 18:28:34 +00:00
Dan Foreman-Mackey
79c222eee6 Fix bug in ffi_lowering where custom layouts were ignored.
PiperOrigin-RevId: 664795687
2024-08-19 07:20:06 -07:00
Dan Foreman-Mackey
ae5b4284d5 Make ffi_call tests backwards compatible with the released jaxlib.
PiperOrigin-RevId: 662017095
2024-08-12 03:08:49 -07:00
Dan Foreman-Mackey
4f8f66f10b Add more complete tests for attribute serialization when lowering an FFI call.
PiperOrigin-RevId: 661849681
2024-08-11 12:34:02 -07:00
Dan Foreman-Mackey
96045043a4 Move ir_attribute builder from extend.ffi to interpreters.mlir.
While this function is currently only used for lowering FFI calls, it could be used most places where `ir.*Attr` objects are directly constructed.

PiperOrigin-RevId: 661761712
2024-08-11 01:47:49 -07:00
Dan Foreman-Mackey
11d9c2de2c Update GPU implementation of lu_pivots_to_permutation to infer the permutation size directly from the input dimensions, instead of using an input parameter.
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.

In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.

PiperOrigin-RevId: 660831000
2024-08-08 07:35:47 -07:00
Jake VanderPlas
3fa86a9b32 remove jax.extend.backend.default_backend in favor of jax.backend
I added this two days ago before realizing there is already a canonical API
for this in the top-level namespace, so it should be safe to remove.
2024-08-02 07:07:29 -07:00
Jake VanderPlas
19d185ac8d jax.extend.backend: add semi-private backend utils from xla_bridge
For context, see https://jax.readthedocs.io/en/latest/jep/15856-jex.html.

PiperOrigin-RevId: 658179318
2024-07-31 16:19:30 -07:00
Yash Katariya
30037547d7 Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57
PiperOrigin-RevId: 657400413
2024-07-29 18:44:31 -07:00
Dan Foreman-Mackey
e9b087d3a8 Add ffi_call function with a similar signature to pure_callback.
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.

~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.

The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
2024-07-01 09:40:31 -04:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
Dan Foreman-Mackey
ac560c0d90 Add helper function for building custom call lowering rules
This function provides sensible defaults for custom call lowering rules
with the goal of reducing the amount of boilerplate required for
implementing custom calls.

Co-authored-by: Sergei Lebedev <slebedev@google.com>
2024-06-06 11:34:08 -04:00
Dan Foreman-Mackey
0bf6700e3f Expose XLA FFI headers to bazel build and re-enable tests
This re-enables the tests removed in https://github.com/google/jax/pull/21563
and adds support for exposing the XLA FFI headers in the
`jax.extend.ffi.include_dir` directory during a bazel build. While it's
unlikely that these will be useful for most bazel users, it is good to provide
a consistent interface with the wheel build and to be able to test this feature.

PiperOrigin-RevId: 640194961
2024-06-04 10:14:43 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Yue Sheng
1cef1d9503 jax.clear_backends() is not doing what it is intended to do, users should try to avoid using it.
We decide to move it into `jax.extend`. This CL is the first step which adds a new module `jax.extend.backend`.

PiperOrigin-RevId: 615934218
2024-03-14 16:11:31 -07:00
Roy Frostig
16d082b002 [jex] replace extend.random.PRNGImpl with extend.random.define_prng_impl
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.

PiperOrigin-RevId: 578349699
2023-10-31 17:21:54 -07:00
Roy Frostig
7b831ba84d test custom PRNG impl construction round trip 2023-10-06 11:18:03 -07:00
Jake VanderPlas
48087cbe8d JEX: add jex.abstract_arrays.array_types 2023-09-19 11:37:05 -07:00
Roy Frostig
2bf9322ccc move wrap_key_data to jax.random
This is a fine function for the public API, rather than `jax.extend`.
2023-09-18 14:38:22 -07:00
Jake VanderPlas
ca39457ea9 JEX: move jax.linear_util to jax.extend.linear_util 2023-08-30 18:32:12 -07:00
Roy Frostig
a69f134cde add jax.extend.random.wrap_key_data 2023-08-26 11:39:25 -07:00
Roy Frostig
a71c0e6ecc create jax.extend.random as a copy of jax.prng
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
PiperOrigin-RevId: 559874051
2023-08-24 14:41:56 -07:00