464 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
c6482ed636 Ensure outputs are tracers when inlining jit. 2025-04-18 14:39:56 -04:00
Dan Foreman-Mackey
492cd3d931 Reverts c2ba1790417ca206a4d88b25aef4d5ae510dd717
PiperOrigin-RevId: 749049676
2025-04-18 09:03:12 -07:00
Dan Foreman-Mackey
1d652ab7f4 Don't recompute source_info for each tracer during staging. 2025-04-17 15:31:38 -04:00
George Necula
b8df474965 [explain_cache_miss] Add to explanations the duration of the missed function call
This enables the user to focus on the most important
call sites.

jax-fixit
2025-04-14 16:08:24 +03:00
George Necula
f070cdecb3 [explain-cache-miss] Improve tracing-cache-miss explanations
The previous approach was to report, for several elements
of the cache key, the closest mismatch. Some parts of
the cache key were ignored, which led to "explanation unavailable".
The same happened when we had two keys close to the current
one, each differring in a different part of the key.
No explanation was produced because for each part of the key,
there was a matching key already in the cache, even though
the key taken as a whole did not match.

Now, we scan *all* parts of they key and compute the differences.
We keep track of the "size" of the differences, and we explain
the differences to those keys that are closest (possibly more
than one key if equidistant).
For example, for shape differences we'll report the
closest matching shape. If a type differs in both the dtype
and some parts of the shape, or sharding, it is considered
farther away.

We add new tests and explanations for  different
static argnums and argnames.

There are still cases when we do not produce an explanation, but
now the "explanation unavailable" includes a description
of which component of the key is different, and what the
difference is. This may still be hard to understand by the
user but at least they can file a clearer bug.

Refactored the tests, and added a few new ones.
2025-04-13 20:44:46 +03:00
George Necula
dc10200906 [explain-cache-miss] Improve the detection of user file names
When we print explanations for tracing cache misses,
we use traceback_util to ignore JAX-internal functions.
Here we change the detection mechanism to use
source_info_util, which has a more exhaustive
list of JAX internals.

This removes a lot of uninteresting explanations
from a large benchmark.

jax-fixit

PiperOrigin-RevId: 746703003
2025-04-11 21:53:55 -07:00
Yash Katariya
a39b6232be Make sure the order passed to make_jit and _parse_jit_arguments is the same as the order of arguments received in jit API and make it keyword-only
PiperOrigin-RevId: 746527807
2025-04-11 11:18:59 -07:00
George Necula
5adac1cb8a Fix the printing of the function name in tracing-cache-miss explanations
jax-fixit

PiperOrigin-RevId: 746496570
2025-04-11 09:53:57 -07:00
George Necula
7eb397d1e5 Make trace and lower class attributes for jax.jit.
Previously, jax.jit returned a function with extra attributes, e.g., `trace`, and `lower`, such that we can use:

```
jax.jit(f).trace(...)
```

The new attributes create problems when `jax.jit` is used along `functools.wraps`.
Essentially, `functools.wraps(jax.jit(f))(wrapper)` is supposed to result in a
function that when invoked will invoke `wrapper` and then presumably `jax.jit(f)`.
This works as expected if you just call the result, but if you try to use it with
`lower` and `trace`, the `wrapper` is bypassed. This is because `wraps` copies the
attributes `trace` and `lower` from `jax.jit(f)` onto the resulting function,
so when `trace` is invoked the `wrapper` is bypassed entirely.

See #27829 and #27825.

The solution proposed here is to make the `trace` and `lower` be class attributes,
so that they are not copied by `functools.wraps`.
Thus, if you try to use `lower` or `trace` on the result of
`functools.wraps(jax.jit(f))()` you will get an error.
That is better than silently ignoring the wrapper.
The workaround is to apply `jax.jit` last among your wrappers.

Fixes: #27829
2025-04-11 14:51:12 +03:00
George Necula
96d38a6b66 [cache_misses] Skip tracing-cache-miss explanations for JAX internal functions
About half of the tracing-cache-miss explanations in a large benchmark
end up being from JAX-internal functions, such as `jax.numpy` functions.
These cache misses are not what the JAX user wants to see, so we filter
them out, using the same mechanism used for filtering tracebacks.
2025-04-11 12:53:38 +03:00
Peter Hawkins
e02faabfb2 Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
jax authors
9e0368653c Merge pull request #27793 from dfm:lin-out-fwd
PiperOrigin-RevId: 744901130
2025-04-07 17:03:43 -07:00
Yash Katariya
0a72e856cf Add **experimental** with_dll_constraint API. This is for cases when the users wants to let SPMD decide the sharding.
But this is a contradiction since layouts apply to device local shape and without knowing the sharding, you can't decide the layout. But there are cases where you don't care what the sharding is, you just want to force a row-major layout (for example). **This API should only be used for those cases**.

PiperOrigin-RevId: 744888557
2025-04-07 16:21:58 -07:00
Sergei Lebedev
2944e3b2a6 Removed data_dependent_tracing_fallback config option
No internal code needs it any more.

PiperOrigin-RevId: 744870756
2025-04-07 15:27:57 -07:00
Dan Foreman-Mackey
dc00f9bdae Apply output forwarding in lin rule for pjit. 2025-04-07 15:39:33 -04:00
Dan Foreman-Mackey
dbc3bcd3ce Apply forwarding in pjit linearization rule to avoid intermediate copies. 2025-04-07 12:13:58 -04:00
George Necula
076d021057 [better_errors] Fix the handling of kwargs for debug_info.
kwargs are passed sorted by the actual kwarg keyword. This order
must be accounted for when we construct the `debug_info.arg_names`.

Extended the tests to be more precise about not mixing up kwargs,
e.g., use different shapes and look for the shape in the HLO.
2025-04-02 10:32:38 +01:00
Matthew Johnson
6fba4ecc58 PR #27576: [attrs] experimental appendattr
Imported from GitHub PR https://github.com/jax-ml/jax/pull/27576

This is an experimental extension to attrs. Attrs should be considered both experimental and deprecated.

This PR also includes some fixes for getattr/setattr.
Copybara import of the project:

--
3b1ea1a5f90b28744522670d0498ce5a6b194274 by Matthew Johnson <mattjj@google.com>:

[attrs] experimental appendattr

Merging this change closes #27576

COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/27576 from mattjj:appendattr b93795201b39b8f75890c9228368c994ae1e38e8
PiperOrigin-RevId: 741662724
2025-03-28 15:21:12 -07:00
Yash Katariya
e8038501d0 Fix a bug where jit was forwarding inputs to outputs even when donation was True for that inputs. This caused the output to be marked as deleted since the input was being forwarded to the output.
Since this functionality was added for a dynamic shapes experiment, only enable it when dynamic_shapes config is True.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 740942785
2025-03-26 16:31:11 -07:00
Yash Katariya
ec2f0f5913 [sharding_in_types] Enable auto_axes to work without any mesh context manager. We extract the mesh from out_shardings given. This allows APIs like random.uniform to accept NamedSharding in out_sharding argument and continue to work without a mesh context.
PiperOrigin-RevId: 740852542
2025-03-26 11:56:56 -07:00
Matthew Johnson
b4922df220 [attrs] allow setattr on a previously non-existant attr
Before this change, we handled attrs for initial-style primitives like jit/scan
like this:
1. the traceable would form a jaxpr and see what attrs were touched (by
   jax_getattr or jax_setattr),
2. for each such attr, the traceable would do jax_getattr to get the current
   value, tree-flatten, pass the flat valuesinto the (pure) bind, get the new
   values out, tree-unflatten, then jax_setattr the result.

That approach would error if the function called `jax_setattr` to set a
previously non-existant attr. That is, this would work:

```python
from jax.experimental.attrs import jax_setattr
class Thing: ...
thing = Thing()
jax_setattr(thing, 'x', 1.0)
```
but it wouldn't work under a `jax.jit`.

This commit makes the same code work under a jit. We just
1. in partial_eval.py's `to_jaxpr`, ensure attrs added during jaxpr formation
   are deleted, using a special sentinel value `dne_sentinel` to indicate the
   attribute initially did not exist before tracing;
2. in pjit.py's `_get_states`, when reading initial attr values before the
   pjit_p bind, if the attribute does not exist we don't try to read it and
   instead just use `dne_sentinel` as the value, which is a convenient empty
   pytree;
3. in pjit.py's `_attr_token` for jit caching, when forming the cache key based
   on the current attr states, we map attrs that don't exist to `dne_sentinel`
   (rather than just erroring when the attr doesn't exist, as before).

In short, we use a special value to indicate "does not exist".

If `jax_getattr` supported the 'default' argument, the code would be a little
cleaner since we could avoid the `if hasattr` stuff. And that's probably a
useful feature to have anyway. We can add that in a follow-up.

This PR only makes setattr-to-nonexistant-attr work with jit. We'll add scan
etc in follow-ups.
2025-03-25 03:17:11 +00:00
Yash Katariya
4489303dfc Delete ParsedPartitionSpec and preprocess function and do a couple more cleanups
PiperOrigin-RevId: 738503430
2025-03-19 12:44:13 -07:00
Yash Katariya
dde861af5f Remove the jax Array migration guide from the TOC tree but keep the doc around
PiperOrigin-RevId: 738421256
2025-03-19 09:05:45 -07:00
Yash Katariya
133a885e3b use_mesh and use_concrete_mesh should error when used under jit
PiperOrigin-RevId: 738376533
2025-03-19 06:45:18 -07:00
Yash Katariya
88d4bc3d45 Rename AxisTypes enum to AxisType
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00
Yash Katariya
14b9f48535 Allow late binding out_shardings and in_shardings in auto_axes and explicit_axes API
PiperOrigin-RevId: 736535562
2025-03-13 09:37:24 -07:00
Yash Katariya
2d01226b3b Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and set_concrete_mesh -> use_concrete_mesh)
PiperOrigin-RevId: 736382641
2025-03-12 22:30:05 -07:00
Yash Katariya
c6dcbb6759 [sharding_in_types] Rework the axis_types argument in Mesh and AbstractMesh APIs. The changes are:
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore

2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.

PiperOrigin-RevId: 736360041
2025-03-12 20:41:50 -07:00
Yash Katariya
47480b4493 Add a set_mesh API to jax.sharding. set_mesh sets the sharding and never unsets it i.e. this is just __enter__ of a ctx manager without __exit__
PiperOrigin-RevId: 736261724
2025-03-12 14:12:47 -07:00
Yash Katariya
8674495fd7 [sharding_in_types] Make reshard work with np.array.
PiperOrigin-RevId: 736250504
2025-03-12 13:41:42 -07:00
Yash Katariya
76dec38286 Under pjit the with mesh: context will use use_mesh(mesh): jit instead of tracking separately using resource_env.
This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested.
This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally.

PiperOrigin-RevId: 735602187
2025-03-10 20:21:02 -07:00
Dan Foreman-Mackey
36d515ed2c A few more fixes for debug_info tests with direct_linearize. 2025-03-08 07:47:24 -05:00
Matthew Johnson
7c2f842353 shard_map and other fixes to direct-linearize
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-07 21:02:40 +00:00
Yash Katariya
a67ab9fade Just use jit as the string in error messages instead of jit and pjit based on resource_env. This is to start deprecating the need for with mesh and replace it with use_mesh(mesh).
PiperOrigin-RevId: 733959962
2025-03-05 20:09:30 -08:00
Yash Katariya
ba5349f896 Add a note about uneven sharding and with_sharding_constraint. Fixes https://github.com/jax-ml/jax/issues/26946
PiperOrigin-RevId: 733953836
2025-03-05 19:35:03 -08:00
George Necula
a6c47d6f36 Use the same name for aliased Vars when pretty-printing Jaxprs.
Add a mechanism for using the same Var names for Vars that
are aliased. In this PR, we use this for `pjit`, such that the
following `print(jax.make_jaxpr(lambda a: jax.jit(lambda a: a + 1)(a))(0.))`
prints:

```
{ lambda ; a:f32[]. let
    b:f32[] = pjit[
          name=<lambda>
          jaxpr={ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
          ] a
    in (b,) }
```

instead of the previous:

```
{ lambda ; a:f32[]. let
    b:f32[] = pjit[
          name=<lambda>
          jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0 in (d,) }
          ] a
    in (b,) }
```

The same mechanism could be used for other higher-order primitives,
e.g., cond, and others.

Also add some typing declarations and rename APIs to use "shared jaxpr"
in lieu of "top-level jaxpr" for those Jaxprs that are used multiple
times and are printed first. I presume that the term "top-level jaxpr"
was picked because these are printed first at top-level. But this is
confusing, because they are really subjaxprs. In fact, there was already
a function `core.pp_toplevel_jaxpr` for printing the top-level Jaxpr,
and there was also `core.pp_top_level_jaxpr` (which now is named
`core.pp_shared_jaxpr`.
2025-03-03 11:38:51 +01:00
Yash Katariya
dda62f576f Make sure default layout is None for input and output layout in all codepaths
PiperOrigin-RevId: 731865511
2025-02-27 14:26:25 -08:00
Yash Katariya
d69da3b012 More cleanups around ParsedPartitionSpec. In a follow up CL, I can remove it from NamedSharding constructor. Deleting ParsedPartitionSpec is remaining but that's after 0.5.2 release.
PiperOrigin-RevId: 731785005
2025-02-27 10:51:04 -08:00
Yash Katariya
034a827a4d Remove _parsed_pspec from everywhere in JAX except for NamedSharding constructor. I'll do that in the next CL since that has a dependency on C++ so needs guards.
PiperOrigin-RevId: 731772222
2025-02-27 10:17:06 -08:00
Yash Katariya
177e1f6ed9 Canonicalize PartitionSpec so that we can delete ParsedPartitionSpec. We need to do this after sharding-in-types to speed up NamedSharding construction and remove a lot of tech debt and unnecessary complexity.
* `_partitions` is now canonicalized and only contains `tuples`, `singular strings`, `None` or `UNCONSTRAINED`. No more empty tuples (`P((), 'x')`) and singleton tuples.

* Cache the creating of sharding on ShapedArray since it's expensive to do it a lot of times

* Change the `__hash__` and `__eq__` of `NamedSharding` to depend on `self.spec` instead of `self._parsed_pspec`.

PiperOrigin-RevId: 731745062
2025-02-27 08:59:25 -08:00
Yash Katariya
b707f0bdbb [sharding_in_types] Error out when using auto_axes or explicit_axes API when there is no context mesh.
Those APIs don't support that right now anyways and they raise an ugly KeyError. Instead we raise a better error here.

I have added a TODO to get the mesh from args so that computation follows data works but we can decide to do that in the future if a lot of users request that and don't want to use `use_mesh`.

PiperOrigin-RevId: 730687231
2025-02-24 19:19:49 -08:00
Yash Katariya
6f8bab3c92 Add sharding mismatch to explain_tracing_cache_miss
PiperOrigin-RevId: 730645598
2025-02-24 16:49:49 -08:00
George Necula
1be801bac8 [better_errors] Cleanup use of DebugInfo.arg_names and result_paths
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.

I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.

Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
2025-02-23 08:27:56 +02:00
Yash Katariya
d695aa4c63 [sharding_in_types] Add sharding rules for the following primitives:
* `bitcast_convert_element_type`
  * `cumsum`
  * `cumlogsumexp`
  * `cumprod`
  * `cummax`
  * `cummin`
  * `reduce_window`
  * `reduce_window_sum`
  * `reduce_window_max`
  * `reduce_window_min`
  * `select_and_gather_add`

For `reduce_window_...` primitives only trivial windowing is supported along non-replicated dimensions. We can relax the other NotImplemented case in the future.

PiperOrigin-RevId: 729910108
2025-02-22 10:45:58 -08:00
Yash Katariya
7c4fe2a7cc [sharding_in_types] Allow auto_axes and explicit_axes to take numpy arrays, python scalars.
PiperOrigin-RevId: 729729215
2025-02-21 18:49:02 -08:00
Yash Katariya
629426f89c Allow casting to the same axis type
PiperOrigin-RevId: 729667271
2025-02-21 14:53:28 -08:00
Yash Katariya
401fa9019c Mark in_shardings and out_shardings as Any for typing reasons since they can take pytrees. Fixes https://github.com/jax-ml/jax/issues/26609
PiperOrigin-RevId: 728730349
2025-02-19 10:46:09 -08:00
Yash Katariya
66d04f85e6 Error out if going from Manual -> Auto/Explicit AxisTypes in the auto_axes and explicit_axes API that do mesh_cast implicitly.
Also, improve the error raised by canonicalize_sharding to include the api name and current source location.

PiperOrigin-RevId: 728701237
2025-02-19 09:21:53 -08:00
Yash Katariya
a3edfb43ef Now that sharding_in_types config flag is True, remove the config and all the conditionals
PiperOrigin-RevId: 728653433
2025-02-19 06:53:35 -08:00
Yash Katariya
1079dc4477 Let users pass in pspecs to with_sharding_constraint when use_mesh is set. This is in-line with other APIs which allow pspecs like einsum, reshape, etc
PiperOrigin-RevId: 728392216
2025-02-18 15:47:03 -08:00