59 Commits

Author SHA1 Message Date
George Necula
817b3e5757 [better_errors] Continue adding debug info to Jaxprs (step 7)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
2025-02-09 18:14:33 +02:00
George Necula
dcf72b01f4 [better_errors] Improvements in propagation of debugging info
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).

Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.

In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.

Added more type declarations.

Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
2025-01-20 15:09:51 +01:00
George Necula
f9dfe7f646 [better_errors] More cleanup 2025-01-15 10:22:29 +00:00
Peter Hawkins
c61b2f6b81 Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
2025-01-10 06:58:46 -08:00
Christos Perivolaropoulos
aaabb9752f Partial discharge for scan_p ops.
PiperOrigin-RevId: 707558502
2024-12-18 08:23:06 -08:00
Jacob Burnim
af5013568a Fix error when swapping a ref with a trivial indexing transform.
Without this fix, the added test case fails with:
```
...
jax/_src/state/discharge.py:416: in _swap_discharge_rule
    z, x_new = _swap_discharge(x, val, idx, tree)
jax/_src/state/discharge.py:421: in _swap_discharge
    return transform_swap_array(x, transforms, val)
jax/_src/state/discharge.py:396: in transform_swap_array
    result_val = lax_slicing.dynamic_update_slice(
jax/_src/lax/slicing.py:215: in dynamic_update_slice
    start_indices = _dynamic_slice_indices(operand, start_indices)
...
AttributeError: 'NoneType' object has no attribute 'ndim'
```
from encountering a None when computing the `result_val`.
2024-12-06 09:12:14 -08:00
George Necula
0831e2e340 [shape_poly] Adding shape polymorphism support for the state primitives. 2024-11-21 06:17:01 -08:00
Dougal Maclaurin
478b750c29 Reverts f281c6f46475270a57a02416469226315377592c
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
Dougal Maclaurin
f281c6f464 Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Dougal Maclaurin
ec39b592f7 Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
Dougal Maclaurin
f355dcf34b Remove UnshapedArray values from JAX (it remains as an abstract class).
Part of a plan to move away from our "abstract value" lattice to more traditional types.

PiperOrigin-RevId: 691626481
2024-10-30 18:53:51 -07:00
Sharad Vikram
ce8ecbd16d Add an extension mechanism to run_state that allows:
* Uninitialized values
* Custom ref aval construction

This will allow us to replace `run_scoped` with `run_state`, and allow us to change the memory space of initialized values.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 688965089
2024-10-23 08:00:56 -07:00
Adam Paszke
c9f946ef57 Only thread a discharged ref value through a cond when it changes in some branch
Otherwise, we can simply pass it in as an argument, but we can avoid updating it
since it will always remain constant. Both programs have equivalent semantics,
but this one can be optimized better since it makes it more apparent that the
cond does not actually modify a ref.

PiperOrigin-RevId: 681482148
2024-10-02 09:29:07 -07:00
Christos Perivolaropoulos
84fc011e27 Introducing partial discharge rules and implementations for cond_p
As things stand you can partially discharge a jaxpr with
`discharge_state(should_discharge=[...])` but each equation is discharges *all*
its arguments. This means that primitives like `scan_p` and `cond_p` discharge
all references they refer to (no pun intended) regardless of whether the user
asked for it. We provide a special discharge rule that is preferred to the
normal one when present that allows the op to discharge only some of the
references.

This feature is especially useful for pallas kernels because contrary to all
other contexts where jaxprs are expected to eventually be fully discharged,
pallas kernels lower references all the way to the runtime as pointers or
MLIR memrefs.

Here we implement the partial discharge rule for `cond_p` and will implement it
for others in due course.

PiperOrigin-RevId: 681021324
2024-10-01 08:03:58 -07:00
George Necula
d3454f374e Add some hypothesis testing utilities and developer documentation.
Add a helper function for setting up hypothesis testing,
with support for selecting an interactive hypothesis profile
that speeds up interactive development.
2024-07-15 17:05:32 +02:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Dougal
e63b35d550 Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file.
Co-authored-by: Matt Johnson <mattjj@google.com>
2024-05-02 14:36:16 -04:00
Dougal
29368e6a8e Add a zeros rule for mutable arrays and test it using a custom vjp.
add jit compatibility (have pjit jvp instantiate all ref tangents)

Co-authored-by: Matt Johnson <mattjj@google.com>
2024-04-12 15:22:07 -04:00
Matthew Johnson
46a516275f [mutable-arrays] enable refs without cps, and not just at top level
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-04-03 16:23:19 -07:00
jax authors
ce0d0c17c3 Merge pull request #20218 from mattjj:mutable-arrays-closure
PiperOrigin-RevId: 615463712
2024-03-13 10:23:23 -07:00
Matthew Johnson
649cd50681 [mutable-arrays] support closed-over mutable arrays in jit 2024-03-13 09:59:03 -07:00
Jake VanderPlas
6771a59181 [key reuse] add jax.random.clone 2024-03-08 09:06:00 -08:00
Matthew Johnson
3a403f2a0e [mutable-arrays] move MutableArray, add eager, improve tests, fix bug
1. move MutableArray to core.py, and some handlers to their respective files
2. fix a bug in aliasing setup (it was just broken before, now better test coverage)
3. add eager support by enabling get_p, swap_p, and addupdate_p impls
4. improve tests slightly
2024-03-01 15:03:23 -08:00
Matthew Johnson
ab0f7061ad [mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others

The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
   handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
   refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.

As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-29 21:50:19 -08:00
Jake VanderPlas
d08e9a03d8 [key reuse] add eager checks 2024-02-29 15:30:19 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Sharad Vikram
836563fadf [Pallas] Refactor indexing primitives to use NDIndexer abstraction
Some notes about this change:
* This change upgrades the `RefView` abstraction to store multiple indexers.
  This allows doing things like `ref.at[0].at[0]` to recursively create a view
  of a `Ref`. `RefView`s therefore encapsluate multiple `NDIndexer`s.
* This generalizes most of the indexing primitive APIs (i.e. get_p, swap_p, addupdate_p)
  but does *not* generalize their rules. Most of the rules will raise a
  NotImplementedError if you use multiple `NDIndexer`s. Adding support will be
  done in a future CL.
* With the above in mind, this change only preserves existing public facing APIs
  and adding actual support will involve updating the rules.

PiperOrigin-RevId: 595229523
2024-01-02 15:53:40 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Sergei Lebedev
f9087ab0c6 MAINT Drop underscore from the name of externally-referenced state objects 2023-10-13 21:30:13 +01:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Matthew Johnson
5715db4832 [run_state] add pjit run_state discharge rule and basic test 2023-10-05 21:14:00 -07:00
Matthew Johnson
29af93b4cb [run_state] add scan nested test, tweak rule name to mention 'state' 2023-10-04 11:48:42 -07:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Sergei Lebedev
df7f6a06c0 MAINT Use a generator expression in tuple([... for ... in ...])
In a few cases I also replaced tuple([*xs, *ys]) with (*xs, ys), because
tuple literals support unpacking as well.
2023-09-21 22:25:38 +01:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Sharad Vikram
2eb9a6ab1b Disable test_set_vmap on GPU due to flakiness
PiperOrigin-RevId: 558336356
2023-08-18 22:57:54 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Sharad Vikram
c446b42522 Add discharge rules for scan/while 2023-07-06 22:30:35 +00:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Sharad Vikram
907782d2e5 Deduplicate references closed over across branches of a lax.cond.
This fixes a correctness issue that could crop up when doing `run_state(cond)`.

PiperOrigin-RevId: 540795172
2023-06-15 23:58:14 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
jax authors
3c1f3abba2 Merge pull request #15149 from sharadmv:runstate
PiperOrigin-RevId: 521809360
2023-04-04 10:56:25 -07:00
Sharad Vikram
5101184ad4 Add initial implementation of a run_state primitive 2023-04-03 21:32:32 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Sharad Vikram
4960e656af Refactor Ref abstract type to contain other AbstractValues 2023-02-23 17:02:40 -08:00
Sharad Vikram
a6c4c87f3e Add JaxprInputEffect and refactor StateEffects to use it 2023-02-21 16:30:06 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
jax authors
398aaaacc7 Add support for Ellipsis as an index for stateful operations.
PiperOrigin-RevId: 497466879
2022-12-23 22:46:50 -08:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00