9384 Commits

Author SHA1 Message Date
Tom Hennigan
7f43316e27 Add an option to simplify keystr output and use a custom separator.
Currently `keystr` just calls `str` on the key entries, leading to quite
verbose output. For example:

    >>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path))
    ['foo']['bar']['bat'][0]
    ['foo']['bar']['bat'][1]
    ['foo']['bar']['baz']

This change allows for a new "simple" format where the string representation
of key entries are further simplified. Additionally we allow a custom
separator since it is very common to use `/` (for example to separate module
and parameter names):

    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path, simple=True, separator='/'))
    foo/bar/bat/0
    foo/bar/bat/1
    foo/bar/baz
```

PiperOrigin-RevId: 717971583
2025-01-21 10:18:42 -08:00
jax authors
96a3ed36c7 Add part (non-quantized K/V pages) of paged_attention_kernel tests back for TPU v6.
The paged_attention_kernel tests for TPU v6 was disabled in the past but I discovered that all the failing tests have `are_kv_quantized=True`. So we can still test the non-quantized part on TPU v6.

PiperOrigin-RevId: 717969073
2025-01-21 10:12:52 -08:00
jax authors
70a5175d0a Merge pull request #25647 from Rifur13:bwd_pass
PiperOrigin-RevId: 717954065
2025-01-21 09:37:41 -08:00
Yash Katariya
d30476a0ae Make make_mesh take visible_axes, hidden_axes and collective_axes as parameters instead of axis_types to make it a more cleaner API.
The mesh axis names provided to those parameters should be disjoint i.e. no overlap.

PiperOrigin-RevId: 717941577
2025-01-21 09:02:39 -08:00
Adam Paszke
3c8cf3c92e [Pallas] Improve testing for casts from narrow types + test int4
For sub-32bit types there are so few distinct values that we can just
exhaustively test them all.

PiperOrigin-RevId: 717879040
2025-01-21 05:56:10 -08:00
George Necula
3f73f7b0eb [better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.

When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.

We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-21 13:38:10 +01:00
Yash Katariya
bba5ada525 Make sure reshard and mesh_cast behave properly under eager mode
PiperOrigin-RevId: 717714860
2025-01-20 20:33:09 -08:00
Yash Katariya
a943ebf449 [sharding_in_types] Move the calculation of new_mesh inside the decorator so that functions can be decorated with hidden_axes and visible_axes at the top level
PiperOrigin-RevId: 717701048
2025-01-20 19:41:22 -08:00
Gleb Pobudzey
b8b9f2bc33 Fix the backwards pass and support more block sizes. 2025-01-21 02:48:37 +00:00
Yash Katariya
d50d1e2c40 Don't allow users to query tracer.sharding even under sharding in types mode.
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.

PiperOrigin-RevId: 717638986
2025-01-20 15:12:47 -08:00
jax authors
e41f4caa3e Merge pull request #25988 from gnecula:debug_info_tests
PiperOrigin-RevId: 717592689
2025-01-20 11:57:35 -08:00
Yash Katariya
799eb98cac Add reshard API in experimental. Currently for sharding_in_types we have 2 APIs: mesh_cast and reshard. Both work in sharding_in_types mode and affect the sharding of the aval. Following are the semantics of both:
* `mesh_cast`: AxisTypes between src and dst mesh **must** differ. There should be **no "visible" data movement**. The shape of the aval doesn't change.

* `reshard`: Mesh should be the **same** between src and dst (same axis_names, axis_sizes and axis_types). **Data movement is allowed**. The shape of the aval doesn't change.

We might make `reshard` == `device_put`, hence the API is in experimental. This decision can be taken at a later point in time. The reason not to just give `device_put` this power is because `device_put` does a lot of stuff right now (and is going to get even more powers in the near future like cross-host transfers) and it's semantics would be very confusing if we keep piling sharding-in-types stuff on it.

PiperOrigin-RevId: 717588253
2025-01-20 11:39:25 -08:00
George Necula
e5d89e738a [better_errors] Refactor debug info tests
Created debug_info_test.py and moved there some of the
tests involving debug_info. In the future we will put here
more tests for debugging info, and their helper functions.
2025-01-20 20:21:01 +01:00
George Necula
4fd0bb05b1 [better_errors] Finally remove api_util.debug_info.
Following https://github.com/jax-ml/jax/pull/25916 there were a few TODOs
left in the code to remove api_util.debug_info and replace the
one remaining use with api_util.tracing_debug_info.

PiperOrigin-RevId: 717583667
2025-01-20 11:19:53 -08:00
Adam Paszke
543dd94762 [Mosaic TPU] Add a faster implementation for packing b16 to s8 in TPUv6
PiperOrigin-RevId: 717583425
2025-01-20 11:18:22 -08:00
George Necula
dcf72b01f4 [better_errors] Improvements in propagation of debugging info
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).

Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.

In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.

Added more type declarations.

Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
2025-01-20 15:09:51 +01:00
Yash Katariya
c7f8d17f5a Expose hidden_axes via jax namespace as public API. Also mention it as a workaround for primitives we don't support yet.
PiperOrigin-RevId: 716839003
2025-01-17 16:48:58 -08:00
Yash Katariya
12b59f8e53 Rename hidden_mode -> hidden_axes and hidden_mode_ctx -> use_hidden_axes. Same for visible mode and visible_mode_ctx.
Also make the `axes` parameter optional of hidden_axes and visible_axes functions. If axes is optional, you drop into full hidden/visible mode.

PiperOrigin-RevId: 716771872
2025-01-17 13:01:07 -08:00
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
Yash Katariya
695c02b1c4 [sharding_in_types] Rename sharding_cast to mesh_cast and add a few restrictions:
* mesh_cast only works when the axis types between src and dst mesh changes. Hence the name!

* No explicit data movement is allowed. Specs containing axes that are visible cannot be different between src and dst shardings.

* src and dst mesh axis_names and axis_sizes should be the same.

TODO: Make `shardings` parameter to `mesh_cast` optional.
PiperOrigin-RevId: 716727084
2025-01-17 10:53:43 -08:00
jax authors
4d20052f7a Merge pull request #25642 from Rifur13:numerical_stability
PiperOrigin-RevId: 716714783
2025-01-17 10:19:36 -08:00
Jake VanderPlas
f83175fc94 [key reuse] fix signature for device_put 2025-01-17 09:47:50 -08:00
Yash Katariya
ce85b89884 [sharding_in_types] Error out for reshape for splits like this: (4, 6, 8) -> (4, 4, 2, 6)
PiperOrigin-RevId: 716653203
2025-01-17 06:58:29 -08:00
Yash Katariya
af667199db [sharding_in_types] Rename .at[...].get(out_spec) to .at[...].get(out_sharding).
PiperOrigin-RevId: 716466870
2025-01-16 18:56:52 -08:00
Yash Katariya
97cd748376 Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Yash Katariya
49224d6cdb Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager

Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.

PiperOrigin-RevId: 716446406
2025-01-16 17:55:54 -08:00
Gleb Pobudzey
2cdd9b7dd9 Fixing bwd attention test tolerance level 2025-01-17 01:41:51 +00:00
Parker Schuh
f2f552c108 Allow resharding between tokens on a single device
and multiple devices.

Whenever this happens we can essentially introduce an effects barrier
instead of doing the normal device -> host -> device transfer.

Fixes https://github.com/jax-ml/jax/issues/25671.

PiperOrigin-RevId: 716309978
2025-01-16 11:24:22 -08:00
Yash Katariya
b23c42372b [sharding_in_types] If an indexing operation hits into gather_p, error out saying to use .at[...].get(out_spec=...) instead.
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
2025-01-16 10:51:15 -08:00
Adam Paszke
8954e71d73 [Mosaic TPU] Improve support for int16->int32 casts in TPUv4
PiperOrigin-RevId: 716250236
2025-01-16 08:44:10 -08:00
Benjamin Chetioui
d3bf243342 [Mosaic GPU] Add layout inference for splat arith.ConstantOps and vector.SplatOps.
PiperOrigin-RevId: 716224880
2025-01-16 07:18:35 -08:00
Dimitar (Mitko) Asenov
24884071b9 [MosaicGPU] Remove the single_thread context from top-level dialect code.
- Change the `async_load` lowering to manage the single thread context.
- Use a predicate for the top-level arrive_expect. If we want to hide this further, we can have a warp-group level op that lowers to a single-threaded context.

PiperOrigin-RevId: 716219730
2025-01-16 06:59:32 -08:00
Benjamin Chetioui
3366c92782 [Mosaic GPU][NFC] Simplify and clean up layout inference tests to use FuncOps.
PiperOrigin-RevId: 716216260
2025-01-16 06:48:57 -08:00
Yash Katariya
c6b5ac5c7b [sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.

  `operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`

* Merging into 1 dimension only and all the merging dimensions should be unsharded.

  `operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`

* Split into singleton dimensions i.e. adding extra dims of size 1

  `operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`

* Merge singleton dimensions i.e. removing extra dims of size 1

  `operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`

* Identity reshape

  `operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`

These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.

PiperOrigin-RevId: 716216240
2025-01-16 06:47:26 -08:00
Dimitar (Mitko) Asenov
ce03cf976e [MosaicGPU] Move gpu_address_space_to_nvptx inside utils.py and use it.
PiperOrigin-RevId: 716214822
2025-01-16 06:41:51 -08:00
Adam Paszke
ef4dbd9cb9 [Mosaic TPU] Add support for packing to 16-bit integers on TPUv4
And refactor some test conditions to better match what we really support.
The tests were failing on older TPUs.

PiperOrigin-RevId: 716214098
2025-01-16 06:39:23 -08:00
Dimitar (Mitko) Asenov
22417ae28e [MosaicGPU] Extract code into a new method BarrierRef.from_dialect_barrier_memref and implement support for 1D barrier memrefs.
PiperOrigin-RevId: 716180182
2025-01-16 04:30:43 -08:00
Benjamin Chetioui
bc7204f003 [Mosaic GPU] Allow querying layouts from a FuncOp's block arguments if set.
The motivation behind this change is twofold:

1. it simplifies test writing (no need to produce arbitrary, manual, non-splat
   constants to produce arguments with a strided layout);
2. it'll allow running layout inference on different `FuncOp`s in isolation,
   before inlining.

While the primary motivation is to simplify test writing for upcoming changes,
`2.` is useful if we ever intend to call functions whose body's layout we have
inferred from other functions. It's not clear to me that we have a use case for
that, but the theoretical benefit is worth pointing out.

Crucially, layout inference does not set default layouts for `FuncOp`s, since
the caller may choose a different layout for its arguments. As a result, there
is also no layout inference rule for `func.FuncOp`.

PiperOrigin-RevId: 716158516
2025-01-16 03:05:41 -08:00
Zachary Garrett
f7d097f7cc Make utils for reporting function name work with functools.partial by using the inner .func attribute if the object doesn't have a __name__ attribute. functools.partial objects do not have __name__ attributes by default.
PiperOrigin-RevId: 715881812
2025-01-15 11:40:59 -08:00
jax authors
41993fdb24 Merge pull request #25755 from ROCm:ci_rnn_final-upstream
PiperOrigin-RevId: 715856939
2025-01-15 10:40:54 -08:00
Zac Mustin
2d72e8de84 Jax: Stop returning a list of cost-analyses.
As it stands, there is only ever one element in this list (see b/384741132) and only the 0th element is ever used so we can simplify.

This is a potentially breaking change for external users, but (as stated in the [documentation](https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available)) no guarantees are made on this type, which is intended for debugging purposes and not intended to be a reliable public API.

PiperOrigin-RevId: 715837855
2025-01-15 09:53:59 -08:00
jax authors
70c1ee5d9c Merge pull request #25876 from gnecula:debug_info_3
PiperOrigin-RevId: 715831527
2025-01-15 09:35:03 -08:00
jax authors
2e5e4799fd Merge pull request #25880 from jakevdp:fix-gather
PiperOrigin-RevId: 715804120
2025-01-15 08:10:44 -08:00
Adam Paszke
aa19f9c4c4 [Pallas TPU] Temporarily strengthen restrictions on Pallas tests
Mosaic is not more aggressive in its inference of large 2nd minor layouts,
which causes slight problems for Pallas pipelines. This will be addressed
shortly.

PiperOrigin-RevId: 715714752
2025-01-15 02:32:14 -08:00
George Necula
f9dfe7f646 [better_errors] More cleanup 2025-01-15 10:22:29 +00:00
jax authors
c18492be65 [pallas][mosaic kernel export] Add initial support for exporting a dynamic shapes (placeholder bound) kernel out of mosaic, via pallas as both MLIR and jaxpr.
PiperOrigin-RevId: 715629439
2025-01-14 20:34:11 -08:00
Ruturaj4
fe68eb8b25 [ROCm] Implement RNN support 2025-01-14 19:04:49 -06:00
Justin Fu
cc9f6e7528 [Pallas] Fix GQA triton kernel test.
PiperOrigin-RevId: 715576240
2025-01-14 16:40:55 -08:00
Peter Hawkins
d1810b42cb Temporarily disable GQA attention tests on GPU, which were broken by a Triton integrate.
PiperOrigin-RevId: 715516188
2025-01-14 13:48:37 -08:00
Justin Fu
ff5cb811e6 [Mosaic GPU] Enable x64 tests for mosaic gpu.
PiperOrigin-RevId: 715496496
2025-01-14 13:02:48 -08:00