This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.
I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.
PiperOrigin-RevId: 736859418
For future reference, this can be done via
python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
while IFS=: read file line rest; do
echo "$file:$line";
gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
done < /tmp/unused.txt
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,
def f(x: int) -> str: ...
def g(x: int) -> str: ...
callback = f if ... else g # has type object!
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry
non-None debug info.
We change `lu.wrap_init` to construct the result paths thunk
whenever it is passed a `debug_info`. The goal is to make sure that
all `WrappedFun` have a debug info with result paths support.
We change some calling conventions for internal functions to not
pass along a separate debug_info if we have a `WrappedFun` or
a `Jaxpr`.
We obtain several improvements in presence of debug infos
in debug_info_test.py
* When current_mesh is Manual and aval mesh is Auto
* When current mesh is set and aval mesh is unset
* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.
* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.
This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.
`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.
PiperOrigin-RevId: 722868091
Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
* Track `explicit_mesh_axis` on `AxisData`.
* Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct.
* Make `matchaxis` also handle shardings correctly
* All mapped dimensions should be sharded the same way
* spmd_axis_name and explicit sharded arrays cannot be used together
* `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap.
This should eventually help us get rid of `spmd_axis_name` from `vmap`.
PiperOrigin-RevId: 721007659
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).
Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.
In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.
Added more type declarations.
Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
Without this fix, the added test case fails with:
```
...
jax/_src/state/discharge.py:416: in _swap_discharge_rule
z, x_new = _swap_discharge(x, val, idx, tree)
jax/_src/state/discharge.py:421: in _swap_discharge
return transform_swap_array(x, transforms, val)
jax/_src/state/discharge.py:396: in transform_swap_array
result_val = lax_slicing.dynamic_update_slice(
jax/_src/lax/slicing.py:215: in dynamic_update_slice
start_indices = _dynamic_slice_indices(operand, start_indices)
...
AttributeError: 'NoneType' object has no attribute 'ndim'
```
from encountering a None when computing the `result_val`.
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.
Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.
PiperOrigin-RevId: 698493184
Performance wise, we should be at parity, although this has not yet been tested.
Authoring wise, the new kernel is significantly smaller and simpler to write.
A major known limitation of this approach, which we have a plan to fix, is the invariant that the `seq_len % grid_size == 0` - we plan to relax this limitation in following CLs.
PiperOrigin-RevId: 689868468
* Uninitialized values
* Custom ref aval construction
This will allow us to replace `run_scoped` with `run_state`, and allow us to change the memory space of initialized values.
Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 688965089
The approach here is to add a new notion to jax, for ragged_prop. Ragged prop is useful for computing the dynamism/raggedness of an output, given a set of inputs. In the limit, if we decide that this is a useful property to have in jax as a first class citizen, we could fold the raggedness into the type system. At the moment, however, it is just a small set of rules implemented per op.
PiperOrigin-RevId: 685827096
As things stand you can partially discharge a jaxpr with
`discharge_state(should_discharge=[...])` but each equation is discharges *all*
its arguments. This means that primitives like `scan_p` and `cond_p` discharge
all references they refer to (no pun intended) regardless of whether the user
asked for it. We provide a special discharge rule that is preferred to the
normal one when present that allows the op to discharge only some of the
references.
This feature is especially useful for pallas kernels because contrary to all
other contexts where jaxprs are expected to eventually be fully discharged,
pallas kernels lower references all the way to the runtime as pointers or
MLIR memrefs.
Here we implement the partial discharge rule for `cond_p` and will implement it
for others in due course.
PiperOrigin-RevId: 681021324
This is a second attempt at this change. The first one was rolled back because of reported failures.
Reverts 411928b9668570bbc3795522aba94cece6894881
PiperOrigin-RevId: 680943744
This changes makes it so that the refs users receive inside their kernels have shapes
matching their block specs. However, the refs are not actually plain refs, but transformed
references that begin with the fully transformed abstract ref and then stack the inverse
of the transformation stack on top of it. This means that all primitives that take in refs
can also see the sequence of transforms the user applied in the block spec, which lets us
verify e.g. that the inputs to WGMMA are correctly tiled, even though their user-visible
shape remains 2D. We should be able to use the same trick in the future to propagate tiling
and better infer the layouts for loads and stores.
PiperOrigin-RevId: 680520185