This is a follow up to https://github.com/jax-ml/jax/pull/26160 and https://github.com/openxla/xla/pull/21980. See those PRs for more discussion of the motivation for this change.
In this PR, we disable CPU asynchronous execution when running within the body of a host callback, because this can cause deadlocks.
PiperOrigin-RevId: 720918318
Everything passes other than an io callback test due to the lowered `sdy.manual_computation` returning a token. Will be fixed in a follow-up.
PiperOrigin-RevId: 713780181
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.
In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.
PiperOrigin-RevId: 713272197
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.
Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
The motivation for this change is two-fold
* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
of always copying it to the host. Note that the version here still always
copies to the host.
If this change breaks you, you can recover the old behavior by changing
jax.pure_callback(
f,
result_shape_dtypes,
*args,
**kwargs,
)
to
jax.pure_callback(
lambda *args: f(*jax.tree.map(np.asarray, args)),
result_shape_dtypes,
*args,
**kwargs,
)
so that the callback function is called with NumPy arrays as before.
I will update the "External callbacks" tutorial in a follow up.
PiperOrigin-RevId: 622457378
The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.
See issue #20385 for more details.
We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:
```
def f_host():
return 42.
io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```
However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.
In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.
In some sense this is replacing the change in #20433 to add a canonicalization
step instead of an error.
Previously, the user-provided Python callback function was first
flattened and then the result passed as a primitive parameter to
the callback primitives. This means that two separate io_callback
invocations with the same Python callable will generate different
Jaxprs. To prevent this we defer the flattening to lowering time.
The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal.
Fixed the checks that the callback returns values of expected shape and dtype,
and added tests.
Reverts 19e6156ccec0df7a900471df7840bc421da2898b
PiperOrigin-RevId: 619156176
The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal.
Fixed the checks that the callback returns values of expected shape and dtype,
and added tests.
PiperOrigin-RevId: 618814787
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
PiperOrigin-RevId: 572587137
We need to keep callback objects alive as long as any running executables are alive. It is possible to discard the Python data structures for an executable before the runtime has finished running that executable, which can lead to a use after free. Instead, make the runtime keep host callbacks alive.
PiperOrigin-RevId: 571141106
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
This CL changes `shard_map` to propagate the tokens from/to the outer lowering context and allows `io_callable` with/without ordering to be used inside `shard_map`. As in shardable ordered effects that were recently added, ordered `io_callable` inside `shard_map` has partial ordering, where the ordering is enforced only within a device.
Also fixes a bug where `mlir.eval_dynamic_shape` fails when `ctx` already has `tokens_out` set. Since `LoweringRuleContext` allows `tokens_out` to be set only once, the `ctx.tokens_out` is cleared before it is passed to `lower_fun`, which calls `ctx.set_tokens_out` internally.
PiperOrigin-RevId: 568314239
Ordered effects currently are not allowed in multi-device computations.
This is too restrictive sometimes, e.g., `io_callback(ordered=True)` uses
maximal sharding on one device and the callback would be issued only
once even in multi-device computations.
Here we add support for ordered shardable effects, which behave like
ordered effects except they are allowed in SPMD computations.
Currently, only `callback.IOOrderedEffect` is declared shardable.
In general, if the sharding of the side-effecting operation is not
maximal, then such effects would appear in a partial order, with
effects appearing ordered by program point and unordered among
the different devices at a given program point.
We also generalize the mechanism for tracking runtime tokens and
token buffers to work with multiple devices.
PiperOrigin-RevId: 566242557
Removes callback testing function and uses io_callback
and pure_callback instead. This allows us to remove
some tests from the PureCallbackTest class.
Renames IoPythonCallbackTest -> IoCallbackTest and PurePythonCallbackTest -> PureCallbackTest.
PiperOrigin-RevId: 562285255
This CL allows callers of `pure_callback` and `io_callback` to be able to specify the device to be used to run host callbacks. This is to make it possible to move users of `jax.experimental.host_callback(..., device_index=i)` (deprecated) to `pure_callback` or `io_callback`.
Instead of taking a device index referring to a device in the device assignment, the new API takes sharding to match the look and feel with other JAX APIs. The current implementation supports `SingleDeviceSharding` only since the way sharding is annotated in StableHLO makes it tricky to use anything other than `MAXIMAL` or `MANUAL`. But if we later decide to expand support for other types of sharding, we will be able to do it without changing the API or breaking existing users.
This CL also fixes an issue where `pure_callback` and `io_callback` had different sharding semantics inside `SPMDAxisContext`. Specifically, `io_callback` used to emit `MAXIMAL` even for `SPMDAxisContext`, whereas `pure_callback` used `MANUAL` sharding. The latter seems to make more sense since `SPMDAxisContext` with all manual axes should come with per-device semantics. This CL made both styles of callbacks use consistent sharding by factoring out sharding calculation as a common function.
PiperOrigin-RevId: 560203044
Currently JAX callbacks on TPU raise errors when the called function takes empty arguments or returns empty results. It seems that the send_to_host function works
even with empty arrays, but recv_from_host crashes (crash log below).
Here we work around this issue, by ensuring that only the non-empty results of the Python callback are sent to the device computation and the empty results are replaced with empty constants in the device computation.
This is part of the work to replace uses of host_callback with io_callback.
PiperOrigin-RevId: 559061336
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009