75 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
2ae018ed8e Unconditionally skip async deadlock test for pure_callback.
PiperOrigin-RevId: 721012451
2025-01-29 09:49:01 -08:00
Dan Foreman-Mackey
9d39ab305a Disable async dispatch within the body of a host callback.
This is a follow up to https://github.com/jax-ml/jax/pull/26160 and https://github.com/openxla/xla/pull/21980. See those PRs for more discussion of the motivation for this change.

In this PR, we disable CPU asynchronous execution when running within the body of a host callback, because this can cause deadlocks.

PiperOrigin-RevId: 720918318
2025-01-29 04:24:33 -08:00
Peter Hawkins
c61b2f6b81 Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
2025-01-10 06:58:46 -08:00
Bart Chrzaszcz
dc53c563bb #sdy enable pure callbacks and debug prints in JAX.
Everything passes other than an io callback test due to the lowered `sdy.manual_computation` returning a token. Will be fixed in a follow-up.

PiperOrigin-RevId: 713780181
2025-01-09 13:37:51 -08:00
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Dan Foreman-Mackey
61701af4a2 Rename vmap methods for callbacks. 2024-10-21 15:03:04 -04:00
Dan Foreman-Mackey
1d27d420ac Deprecate the vectorized argument to pure_callback and ffi_call. 2024-10-02 11:33:51 -04:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Jake VanderPlas
a861c55a28 test cleanup: use ExitStack to reduce test boilerplate 2024-06-06 14:18:27 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Matthew Johnson
192e2b6ce9 relax a side-effects test that was erroneously checking against a canonical order
PiperOrigin-RevId: 625131535
2024-04-15 17:13:35 -07:00
Matthew Johnson
4f90d365b2 [callbacks] allow unordered effects in batched while_loop if predicate is not batched 2024-04-11 21:43:59 -07:00
Matthew Johnson
8037e7b08f [callbacks] io_callback batching rule accidentally called pure_callback 2024-04-11 20:45:46 -07:00
Sergei Lebedev
9616900cc9 jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold

* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
  of always copying it to the host. Note that the version here still always
  copies to the host.

If this change breaks you, you can recover the old behavior by changing

    jax.pure_callback(
        f,
        result_shape_dtypes,
        *args,
        **kwargs,
    )

to

    jax.pure_callback(
        lambda *args: f(*jax.tree.map(np.asarray, args)),
        result_shape_dtypes,
        *args,
        **kwargs,
    )

so that the callback function is called with NumPy arrays as before.

I will update the "External callbacks" tutorial in a follow up.

PiperOrigin-RevId: 622457378
2024-04-06 09:30:08 -07:00
George Necula
a510f03ef8 [callback] Add a flag to implement host_callback in terms of io_callback.
The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue #20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
2024-04-05 08:51:30 +01:00
George Necula
35b1cb799a [callback] Allow external callbacks to return 64-bit values in 32-bit mode
Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:

```
def f_host():
  return 42.

io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```

However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.

In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.

In some sense this is replacing the change in  #20433 to add a canonicalization
step instead of an error.
2024-04-03 11:15:11 +01:00
George Necula
bff24c6d6f [callback] Improve caching effectiveness in presence of callbacks.
Previously, the user-provided Python callback function was first
flattened and then the result passed as a primitive parameter to
the callback primitives. This means that two separate io_callback
invocations with the same Python callable will generate different
Jaxprs. To prevent this we defer the flattening to lowering time.
2024-04-02 15:33:24 +02:00
Sergei Lebedev
ec73c4031a Do not deadlock the GPU if a pure_callback dispatches a GPU kernel
PiperOrigin-RevId: 619656442
2024-03-27 14:26:03 -07:00
Parker Schuh
0b09762efd Guard host transfers inside pure_callbacks from deadlocking the TPU.
Also fix python/callback.cc to not swallow errors in numpy conversions.

PiperOrigin-RevId: 619375128
2024-03-26 18:36:39 -07:00
George Necula
75db481299 [callback] Fix io_callback for callbacks that return Python literals.
The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal.

Fixed the checks that the callback returns values of expected shape and dtype,
and added tests.

Reverts 19e6156ccec0df7a900471df7840bc421da2898b

PiperOrigin-RevId: 619156176
2024-03-26 05:32:41 -07:00
Qiao Zhang
19e6156cce Reverts 2a4e1caac465bb4cb448e7370b23a822cce699e2
PiperOrigin-RevId: 618908505
2024-03-25 11:36:13 -07:00
George Necula
2a4e1caac4 [callback] Fix io_callback for callbacks that return Python literals.
The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal.

Fixed the checks that the callback returns values of expected shape and dtype,
and added tests.

PiperOrigin-RevId: 618814787
2024-03-25 05:53:52 -07:00
Jake VanderPlas
84e49bd6ce Remove internal references to deprecated jax.experimental.maps 2024-03-19 09:24:52 -07:00
Parker Schuh
9a00721a54 Propagate effects errors to the results (only if effects are enabled).
This will now happen when results of effectful computations are
converted to numpy arrays.

PiperOrigin-RevId: 615883363
2024-03-14 13:32:32 -07:00
Sergei Lebedev
a8e2ee9b65 Log the exception if the callback passed to jax.*_callback raises
PiperOrigin-RevId: 615407343
2024-03-13 07:23:29 -07:00
Sharad Vikram
d7bf9563e6 Fix bug where axis size selected incorrectly in pure_callback vmap rule
Fixes #19978
2024-02-26 15:09:20 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Jieying Luo
c9db50cfd0 Enable python_callback_test for stream executor.
python_callback_test is supported for GPU stream executor. TPU stream executor was deprecated.

PiperOrigin-RevId: 578960299
2023-11-02 13:26:59 -07:00
Sergei Lebedev
f9087ab0c6 MAINT Drop underscore from the name of externally-referenced state objects 2023-10-13 21:30:13 +01:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Peter Hawkins
15126504a7 [JAX] Keep CPU host callbacks alive via IFRT, rather than by attaching them to the Python object.
We need to keep callback objects alive as long as any running executables are alive. It is possible to discard the Python data structures for an executable before the runtime has finished running that executable, which can lead to a use after free. Instead, make the runtime keep host callbacks alive.

PiperOrigin-RevId: 571141106
2023-10-05 15:07:03 -07:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Junwhan Ahn
c72d72c5a7 Add support for (ordered) io_callable within shard_map
This CL changes `shard_map` to propagate the tokens from/to the outer lowering context and allows `io_callable` with/without ordering to be used inside `shard_map`. As in shardable ordered effects that were recently added, ordered `io_callable` inside `shard_map` has partial ordering, where the ordering is enforced only within a device.

Also fixes a bug where `mlir.eval_dynamic_shape` fails when `ctx` already has `tokens_out` set. Since `LoweringRuleContext` allows `tokens_out` to be set only once, the `ctx.tokens_out` is cleared before it is passed to `lower_fun`, which calls `ctx.set_tokens_out` internally.

PiperOrigin-RevId: 568314239
2023-09-25 13:56:43 -07:00
George Necula
32ee27b5cb [callbacks] Add support for shardable ordered effects.
Ordered effects currently are not allowed in multi-device computations.
This is too restrictive sometimes, e.g., `io_callback(ordered=True)` uses
maximal sharding on one device and the callback would be issued only
once even in multi-device computations.

Here we add support for ordered shardable effects, which behave like
ordered effects except they are allowed in SPMD computations.
Currently, only `callback.IOOrderedEffect` is declared shardable.

In general, if the sharding of the side-effecting operation is not
maximal, then such effects would appear in a partial order, with
effects appearing ordered by program point and unordered among
the different devices at a given program point.

We also generalize the mechanism for tracking runtime tokens and
token buffers to work with multiple devices.

PiperOrigin-RevId: 566242557
2023-09-18 02:50:25 -07:00
Jake Hall
f59a4163fa Test changes for out-of-tree backend. 2023-09-14 12:18:37 +01:00
George Necula
f27816af30 [callback] Enable 64-bit types and add tests.
This takes advantage of a recent fix in XLA:TPU to enable
64-bit host transfers.

PiperOrigin-RevId: 562890507
2023-09-05 14:23:28 -07:00
George Necula
01c068eabd [callback] Some test cleanup.
Removes callback testing function and uses io_callback
and pure_callback instead. This allows us to remove
some tests from the PureCallbackTest class.

Renames IoPythonCallbackTest -> IoCallbackTest and PurePythonCallbackTest -> PureCallbackTest.

PiperOrigin-RevId: 562285255
2023-09-02 21:51:07 -07:00
Junwhan Ahn
c35bc81605 Add an optional sharding argument to pure_callback and io_callback
This CL allows callers of `pure_callback` and `io_callback` to be able to specify the device to be used to run host callbacks. This is to make it possible to move users of `jax.experimental.host_callback(..., device_index=i)` (deprecated) to `pure_callback` or `io_callback`.

Instead of taking a device index referring to a device in the device assignment, the new API takes sharding to match the look and feel with other JAX APIs. The current implementation supports `SingleDeviceSharding` only since the way sharding is annotated in StableHLO makes it tricky to use anything other than `MAXIMAL` or `MANUAL`. But if we later decide to expand support for other types of sharding, we will be able to do it without changing the API or breaking existing users.

This CL also fixes an issue where `pure_callback` and `io_callback` had different sharding semantics inside `SPMDAxisContext`. Specifically, `io_callback` used to emit `MAXIMAL` even for `SPMDAxisContext`, whereas `pure_callback` used `MANUAL` sharding. The latter seems to make more sense since `SPMDAxisContext` with all manual axes should come with per-device semantics. This CL made both styles of callbacks use consistent sharding by factoring out sharding calculation as a common function.

PiperOrigin-RevId: 560203044
2023-08-25 14:57:35 -07:00
George Necula
26f091e446 [callback] Disable stream_executor tests.
PiperOrigin-RevId: 559252832
2023-08-22 16:15:00 -07:00
George Necula
8891503f87 [callback] Add workaround for TPU host_callback not supporting empty arrays.
Currently JAX callbacks on TPU raise errors when the called function takes empty arguments or returns empty results. It seems that the send_to_host function works
even with empty arrays, but recv_from_host crashes (crash log below).

Here we work around this issue, by ensuring that only the non-empty results of the Python callback are sent to the device computation and the empty results are replaced with empty constants in the device computation.

This is part of the work to replace uses of host_callback with io_callback.

PiperOrigin-RevId: 559061336
2023-08-22 03:47:18 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Sharad Vikram
61f22676b0 Add maximal sharding for pure_callback not inside of a shard_map 2023-05-11 13:28:37 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Skye Wanderman-Milne
ef5e4a4035 Remove 'pjrt_c_api_unimplemented' pytest mark.
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
2023-03-24 23:14:54 +00:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Matthew Johnson
1f67351f56 [shard_map] make debug_print work with shard_map, eager and jit 2023-03-08 20:38:03 -08:00