The previous approach was to report, for several elements
of the cache key, the closest mismatch. Some parts of
the cache key were ignored, which led to "explanation unavailable".
The same happened when we had two keys close to the current
one, each differring in a different part of the key.
No explanation was produced because for each part of the key,
there was a matching key already in the cache, even though
the key taken as a whole did not match.
Now, we scan *all* parts of they key and compute the differences.
We keep track of the "size" of the differences, and we explain
the differences to those keys that are closest (possibly more
than one key if equidistant).
For example, for shape differences we'll report the
closest matching shape. If a type differs in both the dtype
and some parts of the shape, or sharding, it is considered
farther away.
We add new tests and explanations for different
static argnums and argnames.
There are still cases when we do not produce an explanation, but
now the "explanation unavailable" includes a description
of which component of the key is different, and what the
difference is. This may still be hard to understand by the
user but at least they can file a clearer bug.
Refactored the tests, and added a few new ones.
When we print explanations for tracing cache misses,
we use traceback_util to ignore JAX-internal functions.
Here we change the detection mechanism to use
source_info_util, which has a more exhaustive
list of JAX internals.
This removes a lot of uninteresting explanations
from a large benchmark.
jax-fixit
PiperOrigin-RevId: 746703003
Previously, jax.jit returned a function with extra attributes, e.g., `trace`, and `lower`, such that we can use:
```
jax.jit(f).trace(...)
```
The new attributes create problems when `jax.jit` is used along `functools.wraps`.
Essentially, `functools.wraps(jax.jit(f))(wrapper)` is supposed to result in a
function that when invoked will invoke `wrapper` and then presumably `jax.jit(f)`.
This works as expected if you just call the result, but if you try to use it with
`lower` and `trace`, the `wrapper` is bypassed. This is because `wraps` copies the
attributes `trace` and `lower` from `jax.jit(f)` onto the resulting function,
so when `trace` is invoked the `wrapper` is bypassed entirely.
See #27829 and #27825.
The solution proposed here is to make the `trace` and `lower` be class attributes,
so that they are not copied by `functools.wraps`.
Thus, if you try to use `lower` or `trace` on the result of
`functools.wraps(jax.jit(f))()` you will get an error.
That is better than silently ignoring the wrapper.
The workaround is to apply `jax.jit` last among your wrappers.
Fixes: #27829
About half of the tracing-cache-miss explanations in a large benchmark
end up being from JAX-internal functions, such as `jax.numpy` functions.
These cache misses are not what the JAX user wants to see, so we filter
them out, using the same mechanism used for filtering tracebacks.
But this is a contradiction since layouts apply to device local shape and without knowing the sharding, you can't decide the layout. But there are cases where you don't care what the sharding is, you just want to force a row-major layout (for example). **This API should only be used for those cases**.
PiperOrigin-RevId: 744888557
kwargs are passed sorted by the actual kwarg keyword. This order
must be accounted for when we construct the `debug_info.arg_names`.
Extended the tests to be more precise about not mixing up kwargs,
e.g., use different shapes and look for the shape in the HLO.
Imported from GitHub PR https://github.com/jax-ml/jax/pull/27576
This is an experimental extension to attrs. Attrs should be considered both experimental and deprecated.
This PR also includes some fixes for getattr/setattr.
Copybara import of the project:
--
3b1ea1a5f90b28744522670d0498ce5a6b194274 by Matthew Johnson <mattjj@google.com>:
[attrs] experimental appendattr
Merging this change closes#27576
COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/27576 from mattjj:appendattr b93795201b39b8f75890c9228368c994ae1e38e8
PiperOrigin-RevId: 741662724
Since this functionality was added for a dynamic shapes experiment, only enable it when dynamic_shapes config is True.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 740942785
Before this change, we handled attrs for initial-style primitives like jit/scan
like this:
1. the traceable would form a jaxpr and see what attrs were touched (by
jax_getattr or jax_setattr),
2. for each such attr, the traceable would do jax_getattr to get the current
value, tree-flatten, pass the flat valuesinto the (pure) bind, get the new
values out, tree-unflatten, then jax_setattr the result.
That approach would error if the function called `jax_setattr` to set a
previously non-existant attr. That is, this would work:
```python
from jax.experimental.attrs import jax_setattr
class Thing: ...
thing = Thing()
jax_setattr(thing, 'x', 1.0)
```
but it wouldn't work under a `jax.jit`.
This commit makes the same code work under a jit. We just
1. in partial_eval.py's `to_jaxpr`, ensure attrs added during jaxpr formation
are deleted, using a special sentinel value `dne_sentinel` to indicate the
attribute initially did not exist before tracing;
2. in pjit.py's `_get_states`, when reading initial attr values before the
pjit_p bind, if the attribute does not exist we don't try to read it and
instead just use `dne_sentinel` as the value, which is a convenient empty
pytree;
3. in pjit.py's `_attr_token` for jit caching, when forming the cache key based
on the current attr states, we map attrs that don't exist to `dne_sentinel`
(rather than just erroring when the attr doesn't exist, as before).
In short, we use a special value to indicate "does not exist".
If `jax_getattr` supported the 'default' argument, the code would be a little
cleaner since we could avoid the `if hasattr` stuff. And that's probably a
useful feature to have anyway. We can add that in a follow-up.
This PR only makes setattr-to-nonexistant-attr work with jit. We'll add scan
etc in follow-ups.
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore
2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.
PiperOrigin-RevId: 736360041
This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested.
This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally.
PiperOrigin-RevId: 735602187
Add a mechanism for using the same Var names for Vars that
are aliased. In this PR, we use this for `pjit`, such that the
following `print(jax.make_jaxpr(lambda a: jax.jit(lambda a: a + 1)(a))(0.))`
prints:
```
{ lambda ; a:f32[]. let
b:f32[] = pjit[
name=<lambda>
jaxpr={ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
] a
in (b,) }
```
instead of the previous:
```
{ lambda ; a:f32[]. let
b:f32[] = pjit[
name=<lambda>
jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0 in (d,) }
] a
in (b,) }
```
The same mechanism could be used for other higher-order primitives,
e.g., cond, and others.
Also add some typing declarations and rename APIs to use "shared jaxpr"
in lieu of "top-level jaxpr" for those Jaxprs that are used multiple
times and are printed first. I presume that the term "top-level jaxpr"
was picked because these are printed first at top-level. But this is
confusing, because they are really subjaxprs. In fact, there was already
a function `core.pp_toplevel_jaxpr` for printing the top-level Jaxpr,
and there was also `core.pp_top_level_jaxpr` (which now is named
`core.pp_shared_jaxpr`.
* `_partitions` is now canonicalized and only contains `tuples`, `singular strings`, `None` or `UNCONSTRAINED`. No more empty tuples (`P((), 'x')`) and singleton tuples.
* Cache the creating of sharding on ShapedArray since it's expensive to do it a lot of times
* Change the `__hash__` and `__eq__` of `NamedSharding` to depend on `self.spec` instead of `self._parsed_pspec`.
PiperOrigin-RevId: 731745062
Those APIs don't support that right now anyways and they raise an ugly KeyError. Instead we raise a better error here.
I have added a TODO to get the mesh from args so that computation follows data works but we can decide to do that in the future if a lot of users request that and don't want to use `use_mesh`.
PiperOrigin-RevId: 730687231
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.
I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.
Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
* `bitcast_convert_element_type`
* `cumsum`
* `cumlogsumexp`
* `cumprod`
* `cummax`
* `cummin`
* `reduce_window`
* `reduce_window_sum`
* `reduce_window_max`
* `reduce_window_min`
* `select_and_gather_add`
For `reduce_window_...` primitives only trivial windowing is supported along non-replicated dimensions. We can relax the other NotImplemented case in the future.
PiperOrigin-RevId: 729910108