7 Commits

Author SHA1 Message Date
Peter Hawkins
e02faabfb2 Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
Sergei Lebedev
352e10ed68 Effects is now an immutable set
This allows safely using `no_effects` as a default value.

PiperOrigin-RevId: 589836905
2023-12-11 08:45:52 -08:00
George Necula
32ee27b5cb [callbacks] Add support for shardable ordered effects.
Ordered effects currently are not allowed in multi-device computations.
This is too restrictive sometimes, e.g., `io_callback(ordered=True)` uses
maximal sharding on one device and the callback would be issued only
once even in multi-device computations.

Here we add support for ordered shardable effects, which behave like
ordered effects except they are allowed in SPMD computations.
Currently, only `callback.IOOrderedEffect` is declared shardable.

In general, if the sharding of the side-effecting operation is not
maximal, then such effects would appear in a partial order, with
effects appearing ordered by program point and unordered among
the different devices at a given program point.

We also generalize the mechanism for tracking runtime tokens and
token buffers to work with multiple devices.

PiperOrigin-RevId: 566242557
2023-09-18 02:50:25 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Sharad Vikram
a6c4c87f3e Add JaxprInputEffect and refactor StateEffects to use it 2023-02-21 16:30:06 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00