This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).
Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.
In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.
Added more type declarations.
Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
Without this fix, the added test case fails with:
```
...
jax/_src/state/discharge.py:416: in _swap_discharge_rule
z, x_new = _swap_discharge(x, val, idx, tree)
jax/_src/state/discharge.py:421: in _swap_discharge
return transform_swap_array(x, transforms, val)
jax/_src/state/discharge.py:396: in transform_swap_array
result_val = lax_slicing.dynamic_update_slice(
jax/_src/lax/slicing.py:215: in dynamic_update_slice
start_indices = _dynamic_slice_indices(operand, start_indices)
...
AttributeError: 'NoneType' object has no attribute 'ndim'
```
from encountering a None when computing the `result_val`.
* Uninitialized values
* Custom ref aval construction
This will allow us to replace `run_scoped` with `run_state`, and allow us to change the memory space of initialized values.
Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 688965089
Otherwise, we can simply pass it in as an argument, but we can avoid updating it
since it will always remain constant. Both programs have equivalent semantics,
but this one can be optimized better since it makes it more apparent that the
cond does not actually modify a ref.
PiperOrigin-RevId: 681482148
As things stand you can partially discharge a jaxpr with
`discharge_state(should_discharge=[...])` but each equation is discharges *all*
its arguments. This means that primitives like `scan_p` and `cond_p` discharge
all references they refer to (no pun intended) regardless of whether the user
asked for it. We provide a special discharge rule that is preferred to the
normal one when present that allows the op to discharge only some of the
references.
This feature is especially useful for pallas kernels because contrary to all
other contexts where jaxprs are expected to eventually be fully discharged,
pallas kernels lower references all the way to the runtime as pointers or
MLIR memrefs.
Here we implement the partial discharge rule for `cond_p` and will implement it
for others in due course.
PiperOrigin-RevId: 681021324
Add a helper function for setting up hypothesis testing,
with support for selecting an interactive hypothesis profile
that speeds up interactive development.
1. move MutableArray to core.py, and some handlers to their respective files
2. fix a bug in aliasing setup (it was just broken before, now better test coverage)
3. add eager support by enabling get_p, swap_p, and addupdate_p impls
4. improve tests slightly
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
Some notes about this change:
* This change upgrades the `RefView` abstraction to store multiple indexers.
This allows doing things like `ref.at[0].at[0]` to recursively create a view
of a `Ref`. `RefView`s therefore encapsluate multiple `NDIndexer`s.
* This generalizes most of the indexing primitive APIs (i.e. get_p, swap_p, addupdate_p)
but does *not* generalize their rules. Most of the rules will raise a
NotImplementedError if you use multiple `NDIndexer`s. Adding support will be
done in a future CL.
* With the above in mind, this change only preserves existing public facing APIs
and adding actual support will involve updating the rules.
PiperOrigin-RevId: 595229523
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
PiperOrigin-RevId: 572587137
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().