52 Commits

Author SHA1 Message Date
George Necula
2a4e1caac4 [callback] Fix io_callback for callbacks that return Python literals.
The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal.

Fixed the checks that the callback returns values of expected shape and dtype,
and added tests.

PiperOrigin-RevId: 618814787
2024-03-25 05:53:52 -07:00
Jake VanderPlas
84e49bd6ce Remove internal references to deprecated jax.experimental.maps 2024-03-19 09:24:52 -07:00
Parker Schuh
9a00721a54 Propagate effects errors to the results (only if effects are enabled).
This will now happen when results of effectful computations are
converted to numpy arrays.

PiperOrigin-RevId: 615883363
2024-03-14 13:32:32 -07:00
Sergei Lebedev
a8e2ee9b65 Log the exception if the callback passed to jax.*_callback raises
PiperOrigin-RevId: 615407343
2024-03-13 07:23:29 -07:00
Sharad Vikram
d7bf9563e6 Fix bug where axis size selected incorrectly in pure_callback vmap rule
Fixes #19978
2024-02-26 15:09:20 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Jieying Luo
c9db50cfd0 Enable python_callback_test for stream executor.
python_callback_test is supported for GPU stream executor. TPU stream executor was deprecated.

PiperOrigin-RevId: 578960299
2023-11-02 13:26:59 -07:00
Sergei Lebedev
f9087ab0c6 MAINT Drop underscore from the name of externally-referenced state objects 2023-10-13 21:30:13 +01:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Peter Hawkins
15126504a7 [JAX] Keep CPU host callbacks alive via IFRT, rather than by attaching them to the Python object.
We need to keep callback objects alive as long as any running executables are alive. It is possible to discard the Python data structures for an executable before the runtime has finished running that executable, which can lead to a use after free. Instead, make the runtime keep host callbacks alive.

PiperOrigin-RevId: 571141106
2023-10-05 15:07:03 -07:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Junwhan Ahn
c72d72c5a7 Add support for (ordered) io_callable within shard_map
This CL changes `shard_map` to propagate the tokens from/to the outer lowering context and allows `io_callable` with/without ordering to be used inside `shard_map`. As in shardable ordered effects that were recently added, ordered `io_callable` inside `shard_map` has partial ordering, where the ordering is enforced only within a device.

Also fixes a bug where `mlir.eval_dynamic_shape` fails when `ctx` already has `tokens_out` set. Since `LoweringRuleContext` allows `tokens_out` to be set only once, the `ctx.tokens_out` is cleared before it is passed to `lower_fun`, which calls `ctx.set_tokens_out` internally.

PiperOrigin-RevId: 568314239
2023-09-25 13:56:43 -07:00
George Necula
32ee27b5cb [callbacks] Add support for shardable ordered effects.
Ordered effects currently are not allowed in multi-device computations.
This is too restrictive sometimes, e.g., `io_callback(ordered=True)` uses
maximal sharding on one device and the callback would be issued only
once even in multi-device computations.

Here we add support for ordered shardable effects, which behave like
ordered effects except they are allowed in SPMD computations.
Currently, only `callback.IOOrderedEffect` is declared shardable.

In general, if the sharding of the side-effecting operation is not
maximal, then such effects would appear in a partial order, with
effects appearing ordered by program point and unordered among
the different devices at a given program point.

We also generalize the mechanism for tracking runtime tokens and
token buffers to work with multiple devices.

PiperOrigin-RevId: 566242557
2023-09-18 02:50:25 -07:00
Jake Hall
f59a4163fa Test changes for out-of-tree backend. 2023-09-14 12:18:37 +01:00
George Necula
f27816af30 [callback] Enable 64-bit types and add tests.
This takes advantage of a recent fix in XLA:TPU to enable
64-bit host transfers.

PiperOrigin-RevId: 562890507
2023-09-05 14:23:28 -07:00
George Necula
01c068eabd [callback] Some test cleanup.
Removes callback testing function and uses io_callback
and pure_callback instead. This allows us to remove
some tests from the PureCallbackTest class.

Renames IoPythonCallbackTest -> IoCallbackTest and PurePythonCallbackTest -> PureCallbackTest.

PiperOrigin-RevId: 562285255
2023-09-02 21:51:07 -07:00
Junwhan Ahn
c35bc81605 Add an optional sharding argument to pure_callback and io_callback
This CL allows callers of `pure_callback` and `io_callback` to be able to specify the device to be used to run host callbacks. This is to make it possible to move users of `jax.experimental.host_callback(..., device_index=i)` (deprecated) to `pure_callback` or `io_callback`.

Instead of taking a device index referring to a device in the device assignment, the new API takes sharding to match the look and feel with other JAX APIs. The current implementation supports `SingleDeviceSharding` only since the way sharding is annotated in StableHLO makes it tricky to use anything other than `MAXIMAL` or `MANUAL`. But if we later decide to expand support for other types of sharding, we will be able to do it without changing the API or breaking existing users.

This CL also fixes an issue where `pure_callback` and `io_callback` had different sharding semantics inside `SPMDAxisContext`. Specifically, `io_callback` used to emit `MAXIMAL` even for `SPMDAxisContext`, whereas `pure_callback` used `MANUAL` sharding. The latter seems to make more sense since `SPMDAxisContext` with all manual axes should come with per-device semantics. This CL made both styles of callbacks use consistent sharding by factoring out sharding calculation as a common function.

PiperOrigin-RevId: 560203044
2023-08-25 14:57:35 -07:00
George Necula
26f091e446 [callback] Disable stream_executor tests.
PiperOrigin-RevId: 559252832
2023-08-22 16:15:00 -07:00
George Necula
8891503f87 [callback] Add workaround for TPU host_callback not supporting empty arrays.
Currently JAX callbacks on TPU raise errors when the called function takes empty arguments or returns empty results. It seems that the send_to_host function works
even with empty arrays, but recv_from_host crashes (crash log below).

Here we work around this issue, by ensuring that only the non-empty results of the Python callback are sent to the device computation and the empty results are replaced with empty constants in the device computation.

This is part of the work to replace uses of host_callback with io_callback.

PiperOrigin-RevId: 559061336
2023-08-22 03:47:18 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Sharad Vikram
61f22676b0 Add maximal sharding for pure_callback not inside of a shard_map 2023-05-11 13:28:37 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Skye Wanderman-Milne
ef5e4a4035 Remove 'pjrt_c_api_unimplemented' pytest mark.
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
2023-03-24 23:14:54 +00:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Matthew Johnson
1f67351f56 [shard_map] make debug_print work with shard_map, eager and jit 2023-03-08 20:38:03 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Yash Katariya
418c2f9d2a Rename in_axis_resources and out_axis_resources with in_shardings and out_shardings. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.

PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Skye Wanderman-Milne
93cd07efb8 Add PJRT C API to Cloud TPU test matrix
Also shortens the job names so the full name is visible from the
github UI (this was driving me crazy), and marks a new test that can't
be run on the PJRT C API yet.

Example run: https://github.com/google/jax/actions/runs/4019968334
2023-01-27 01:06:21 +00:00
Sharad Vikram
3de5c2b716 Add IO callback 2023-01-17 13:55:05 -08:00
Parker Schuh
30d64f38f1 Add 'hard xmap' support for pure_callback.
PiperOrigin-RevId: 501689068
2023-01-12 15:56:50 -08:00
Skye Wanderman-Milne
f90b5eed52 Add pjrt_c_api_unimplemented pytest marker to skip unsupported tests.
Also adds `test_util.pytest_mark_if_available` helper function.
2023-01-12 22:17:23 +00:00
Jake VanderPlas
f09fd8a4e9 [x64] minor test-only updates for better type safety 2022-11-30 15:18:40 -08:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
Jingxin Ye
59374c1cd8 skip some tests if runtime is stream_executor
DETAILS:
Run on CloudTPU v2-8 and found some tests in debugging_primitives_test
fail due to stream_executor runtime cannot support host callback.
Since host callback only support TFRT, so that skip all those types if
runtime type is stream_executor.

TESTED:
passed unit test on both TPU v2-8 and CPU.
2022-10-18 17:42:33 +00:00
Sharad Vikram
a60ca9f051 Test that array layout is preserved in Python callbacks
PiperOrigin-RevId: 478852392
2022-10-04 12:14:47 -07:00
Jake VanderPlas
1c55f265dd pure_callback: fix batching rule for multiple arguments 2022-09-30 15:35:42 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
00636617c0 jax.test_util: add capture_stdout context manager 2022-09-12 15:21:52 -07:00
Sharad Vikram
e1410bd16b Use lowering as impl rule for pure_callback 2022-09-01 15:29:31 -07:00
Sharad Vikram
311a9cb5d9 Throw error when 64-bit dtypes used incorrectly in jax.pure_callback 2022-08-31 12:31:04 -07:00
Yash Katariya
d77848bcc9 Enable jax_array on CPU for the entire JAX test suite!
PiperOrigin-RevId: 468726200
2022-08-19 10:04:35 -07:00
jax authors
39d54bdbf6 Merge pull request #11928 from sharadmv:pure-callback
PiperOrigin-RevId: 468611094
2022-08-18 20:33:34 -07:00
Sharad Vikram
b0fdf10a63 Apply suggestions from code review
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-18 10:50:50 -07:00
Sharad Vikram
393bca122d Expose pure callback and enable rank polymorphic callbacks 2022-08-17 10:56:42 -07:00
Sharad Vikram
53a44b8a35 Remove jit-of-pmap in callback test
PiperOrigin-RevId: 467738629
2022-08-15 12:50:25 -07:00