Currently `keystr` just calls `str` on the key entries, leading to quite
verbose output. For example:
>>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
... for path, _ in jax.tree_util.tree_leaves_with_path(params):
... print(jax.tree_util.keystr(path))
['foo']['bar']['bat'][0]
['foo']['bar']['bat'][1]
['foo']['bar']['baz']
This change allows for a new "simple" format where the string representation
of key entries are further simplified. Additionally we allow a custom
separator since it is very common to use `/` (for example to separate module
and parameter names):
... for path, _ in jax.tree_util.tree_leaves_with_path(params):
... print(jax.tree_util.keystr(path, simple=True, separator='/'))
foo/bar/bat/0
foo/bar/bat/1
foo/bar/baz
```
PiperOrigin-RevId: 717971583
The paged_attention_kernel tests for TPU v6 was disabled in the past but I discovered that all the failing tests have `are_kv_quantized=True`. So we can still test the non-quantized part on TPU v6.
PiperOrigin-RevId: 717969073
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
* `mesh_cast`: AxisTypes between src and dst mesh **must** differ. There should be **no "visible" data movement**. The shape of the aval doesn't change.
* `reshard`: Mesh should be the **same** between src and dst (same axis_names, axis_sizes and axis_types). **Data movement is allowed**. The shape of the aval doesn't change.
We might make `reshard` == `device_put`, hence the API is in experimental. This decision can be taken at a later point in time. The reason not to just give `device_put` this power is because `device_put` does a lot of stuff right now (and is going to get even more powers in the near future like cross-host transfers) and it's semantics would be very confusing if we keep piling sharding-in-types stuff on it.
PiperOrigin-RevId: 717588253
Created debug_info_test.py and moved there some of the
tests involving debug_info. In the future we will put here
more tests for debugging info, and their helper functions.
Following https://github.com/jax-ml/jax/pull/25916 there were a few TODOs
left in the code to remove api_util.debug_info and replace the
one remaining use with api_util.tracing_debug_info.
PiperOrigin-RevId: 717583667
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).
Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.
In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.
Added more type declarations.
Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
Also make the `axes` parameter optional of hidden_axes and visible_axes functions. If axes is optional, you drop into full hidden/visible mode.
PiperOrigin-RevId: 716771872
* mesh_cast only works when the axis types between src and dst mesh changes. Hence the name!
* No explicit data movement is allowed. Specs containing axes that are visible cannot be different between src and dst shardings.
* src and dst mesh axis_names and axis_sizes should be the same.
TODO: Make `shardings` parameter to `mesh_cast` optional.
PiperOrigin-RevId: 716727084
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager
Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.
PiperOrigin-RevId: 716446406
and multiple devices.
Whenever this happens we can essentially introduce an effects barrier
instead of doing the normal device -> host -> device transfer.
Fixes https://github.com/jax-ml/jax/issues/25671.
PiperOrigin-RevId: 716309978
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
- Change the `async_load` lowering to manage the single thread context.
- Use a predicate for the top-level arrive_expect. If we want to hide this further, we can have a warp-group level op that lowers to a single-threaded context.
PiperOrigin-RevId: 716219730
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
The motivation behind this change is twofold:
1. it simplifies test writing (no need to produce arbitrary, manual, non-splat
constants to produce arguments with a strided layout);
2. it'll allow running layout inference on different `FuncOp`s in isolation,
before inlining.
While the primary motivation is to simplify test writing for upcoming changes,
`2.` is useful if we ever intend to call functions whose body's layout we have
inferred from other functions. It's not clear to me that we have a use case for
that, but the theoretical benefit is worth pointing out.
Crucially, layout inference does not set default layouts for `FuncOp`s, since
the caller may choose a different layout for its arguments. As a result, there
is also no layout inference rule for `func.FuncOp`.
PiperOrigin-RevId: 716158516
As it stands, there is only ever one element in this list (see b/384741132) and only the 0th element is ever used so we can simplify.
This is a potentially breaking change for external users, but (as stated in the [documentation](https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available)) no guarantees are made on this type, which is intended for debugging purposes and not intended to be a reliable public API.
PiperOrigin-RevId: 715837855
Mosaic is not more aggressive in its inference of large 2nd minor layouts,
which causes slight problems for Pallas pipelines. This will be addressed
shortly.
PiperOrigin-RevId: 715714752