25204 Commits

Author SHA1 Message Date
Yash Katariya
23d360bded Remove axis_name from unmapped_aval
PiperOrigin-RevId: 718558713
2025-01-22 15:49:04 -08:00
jax authors
f6243ff8e1 Merge pull request #25889 from Stella-S-Yan:cache_reset
PiperOrigin-RevId: 718537398
2025-01-22 14:52:05 -08:00
jax authors
1b6080d943 Merge pull request #26044 from hawkinsp:slots
PiperOrigin-RevId: 718537271
2025-01-22 14:50:07 -08:00
Peter Hawkins
f4adcc650f Set __slots__ on core.Trace subclasses.
This is easy to do and makes field accesses on Trace classes slightly faster.
2025-01-22 16:17:54 -05:00
jax authors
fc9356085e Merge pull request #26024 from jakevdp:setops-args
PiperOrigin-RevId: 718467165
2025-01-22 11:37:51 -08:00
Justin Fu
10bb38bb79 [Mosaic GPU] Add manual consumed barrier handling to WS pipeline.
PiperOrigin-RevId: 718451678
2025-01-22 10:59:58 -08:00
Nitin Srinivasan
14029c6762 Replace "gpu" with "cuda" to be specific about the type of gpu tests that are running
Also, make `run_docker_container.sh` executable

PiperOrigin-RevId: 718420712
2025-01-22 09:40:11 -08:00
jax authors
6c76cc4e36 Integrate LLVM at llvm/llvm-project@d33e33fde7
Updates LLVM usage to match
[d33e33fde770](https://github.com/llvm/llvm-project/commit/d33e33fde770)

PiperOrigin-RevId: 718414171
2025-01-22 09:22:07 -08:00
jax authors
5c6b6c124f Update XLA dependency to use revision
f08027150f.

PiperOrigin-RevId: 718409072
2025-01-22 09:06:38 -08:00
jax authors
d6028153c4 Merge pull request #26028 from gnecula:debug_info_more_tests
PiperOrigin-RevId: 718398022
2025-01-22 08:33:25 -08:00
Will Froom
dc16721b52 [XLA:CPU] Use central difference to calculate numerical gradient
PiperOrigin-RevId: 718383754
2025-01-22 07:49:43 -08:00
George Necula
849ccc978b [better_errors] Expand the tests for debug_info
Debugging info is needed for error messages, and for
lowering. For the former, we need debug info inside
tracers. For the latter, inside Jaxprs. We add a
new set of tests that intentionally leak tracers while
tracing and then we check that the tracers have the
expected debug info. We also form Jaxprs and we
check that they have the expected debug info.
We uncovered a few missing debug infos, those are
marked with TODO.
2025-01-22 16:49:16 +01:00
jax authors
e304e9ea16 Merge pull request #25992 from gnecula:debug_info_arg_names
PiperOrigin-RevId: 718216003
2025-01-21 22:17:08 -08:00
Stella S Yan
f87c94db75 Fix cache init when JAX Array is created early (#25768) 2025-01-22 03:44:29 +00:00
Jevin Jiang
908df65a26 [Mosaic TPU] Emulate converting x16 vector to mask if mask packing is supported.
PiperOrigin-RevId: 718133639
2025-01-21 17:23:33 -08:00
jax authors
54bb7f5ddb Remove meaningless template keywords.
This will fix -Wmissing-template-arg-list-after-template-kw warnings.
This warning is error-by-default in Clang.

PiperOrigin-RevId: 718133601
2025-01-21 17:22:04 -08:00
Yash Katariya
3aa55992fe Remove device_context from trace_context because we don't need it there. We can get compilation cache misses (and tracing/lowering cache hit) naturally without putting concrete devices into trace_context.
PiperOrigin-RevId: 718113413
2025-01-21 16:21:36 -08:00
Jake VanderPlas
a69f9dcc19 jax.numpy setops: use ensure_arraylike & avoid asarray 2025-01-21 16:05:49 -08:00
Yash Katariya
051861bbf1 Error out if contracting dimensions are sharded and ask the user to provide the output sharding
PiperOrigin-RevId: 718087084
2025-01-21 15:06:54 -08:00
Dan Foreman-Mackey
aa8c0010e2 Add support for constants in the decomposition of lax.composites.
This change adds support for including "consts" (i.e. closed-over arrays) in the body of the decomposition definition for a `lax.composite` op. The caveat here is that, since the signature of the decomposition must match the composite itself, the values of any consts must be known when lowering so that they can be inlined into the decomposition's HLO. Therefore, there is no support for closing over tracers. Since `lax.composite` doesn't support most transformations anyways, this typically isn't going to be a major limitation except with `jax.jit` (as demonstrated in the tests).

PiperOrigin-RevId: 718048021
2025-01-21 13:28:42 -08:00
Peter Hawkins
afb750cf76 Remove unnecessary use of xla_client.OpMetadata class.
We create this object and immediately turn it into a different object. We can cut out a step here!

PiperOrigin-RevId: 718023353
2025-01-21 12:24:51 -08:00
Tzu-Wei Sung
79bd72e2e8 [Mosaic] Remove hardcoded TARGET_SHAPE and align Python/C++ APIs.
PiperOrigin-RevId: 717973752
2025-01-21 10:24:10 -08:00
Tom Hennigan
7f43316e27 Add an option to simplify keystr output and use a custom separator.
Currently `keystr` just calls `str` on the key entries, leading to quite
verbose output. For example:

    >>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path))
    ['foo']['bar']['bat'][0]
    ['foo']['bar']['bat'][1]
    ['foo']['bar']['baz']

This change allows for a new "simple" format where the string representation
of key entries are further simplified. Additionally we allow a custom
separator since it is very common to use `/` (for example to separate module
and parameter names):

    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path, simple=True, separator='/'))
    foo/bar/bat/0
    foo/bar/bat/1
    foo/bar/baz
```

PiperOrigin-RevId: 717971583
2025-01-21 10:18:42 -08:00
jax authors
96a3ed36c7 Add part (non-quantized K/V pages) of paged_attention_kernel tests back for TPU v6.
The paged_attention_kernel tests for TPU v6 was disabled in the past but I discovered that all the failing tests have `are_kv_quantized=True`. So we can still test the non-quantized part on TPU v6.

PiperOrigin-RevId: 717969073
2025-01-21 10:12:52 -08:00
jax authors
70a5175d0a Merge pull request #25647 from Rifur13:bwd_pass
PiperOrigin-RevId: 717954065
2025-01-21 09:37:41 -08:00
Sergei Lebedev
4c363766f8 [mosaic_gpu] Removed debug prints in jax._src.lib.mosaic_gpu
PiperOrigin-RevId: 717952850
2025-01-21 09:36:12 -08:00
jax authors
3b5b98163b Merge pull request #25982 from roth-jakob:warning_callbacks
PiperOrigin-RevId: 717951477
2025-01-21 09:29:52 -08:00
jax authors
aaa7e922ed Update XLA dependency to use revision
5388e86b8b.

PiperOrigin-RevId: 717944684
2025-01-21 09:10:48 -08:00
Yash Katariya
d30476a0ae Make make_mesh take visible_axes, hidden_axes and collective_axes as parameters instead of axis_types to make it a more cleaner API.
The mesh axis names provided to those parameters should be disjoint i.e. no overlap.

PiperOrigin-RevId: 717941577
2025-01-21 09:02:39 -08:00
Christos Perivolaropoulos
c4643c6156 [mosaic_gpu] a function for bitwidth as well as bytewidht.
This is to enable s4 and friends.

PiperOrigin-RevId: 717896640
2025-01-21 06:52:04 -08:00
Adam Paszke
3c8cf3c92e [Pallas] Improve testing for casts from narrow types + test int4
For sub-32bit types there are so few distinct values that we can just
exhaustively test them all.

PiperOrigin-RevId: 717879040
2025-01-21 05:56:10 -08:00
George Necula
3f73f7b0eb [better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.

When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.

We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-21 13:38:10 +01:00
Dimitar (Mitko) Asenov
f89accc56a [Mosaic GPU] Add support for converting all fragmented layouts to ir and back.
This will be used in the layout inference and lowering of the dialect WGMMA op

PiperOrigin-RevId: 717836648
2025-01-21 03:27:03 -08:00
Dimitar (Mitko) Asenov
0a89760c24 [Mosaic GPU] Do not use mgpu in wgmma.py
This enables the dialect lowering to depend on `wgmma.py` without creating a circular dependency. I need this in a follow up CL that implements the lowering of the WGMMA dialect op.

PiperOrigin-RevId: 717791901
2025-01-21 01:06:42 -08:00
Yash Katariya
bba5ada525 Make sure reshard and mesh_cast behave properly under eager mode
PiperOrigin-RevId: 717714860
2025-01-20 20:33:09 -08:00
Yash Katariya
a943ebf449 [sharding_in_types] Move the calculation of new_mesh inside the decorator so that functions can be decorated with hidden_axes and visible_axes at the top level
PiperOrigin-RevId: 717701048
2025-01-20 19:41:22 -08:00
Gleb Pobudzey
b8b9f2bc33 Fix the backwards pass and support more block sizes. 2025-01-21 02:48:37 +00:00
Yash Katariya
d50d1e2c40 Don't allow users to query tracer.sharding even under sharding in types mode.
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.

PiperOrigin-RevId: 717638986
2025-01-20 15:12:47 -08:00
jax authors
7f19b345fb Merge pull request #25984 from Rifur13:use_exp2
PiperOrigin-RevId: 717627777
2025-01-20 14:20:39 -08:00
jax authors
e41f4caa3e Merge pull request #25988 from gnecula:debug_info_tests
PiperOrigin-RevId: 717592689
2025-01-20 11:57:35 -08:00
Yash Katariya
799eb98cac Add reshard API in experimental. Currently for sharding_in_types we have 2 APIs: mesh_cast and reshard. Both work in sharding_in_types mode and affect the sharding of the aval. Following are the semantics of both:
* `mesh_cast`: AxisTypes between src and dst mesh **must** differ. There should be **no "visible" data movement**. The shape of the aval doesn't change.

* `reshard`: Mesh should be the **same** between src and dst (same axis_names, axis_sizes and axis_types). **Data movement is allowed**. The shape of the aval doesn't change.

We might make `reshard` == `device_put`, hence the API is in experimental. This decision can be taken at a later point in time. The reason not to just give `device_put` this power is because `device_put` does a lot of stuff right now (and is going to get even more powers in the near future like cross-host transfers) and it's semantics would be very confusing if we keep piling sharding-in-types stuff on it.

PiperOrigin-RevId: 717588253
2025-01-20 11:39:25 -08:00
George Necula
e5d89e738a [better_errors] Refactor debug info tests
Created debug_info_test.py and moved there some of the
tests involving debug_info. In the future we will put here
more tests for debugging info, and their helper functions.
2025-01-20 20:21:01 +01:00
George Necula
4fd0bb05b1 [better_errors] Finally remove api_util.debug_info.
Following https://github.com/jax-ml/jax/pull/25916 there were a few TODOs
left in the code to remove api_util.debug_info and replace the
one remaining use with api_util.tracing_debug_info.

PiperOrigin-RevId: 717583667
2025-01-20 11:19:53 -08:00
Adam Paszke
543dd94762 [Mosaic TPU] Add a faster implementation for packing b16 to s8 in TPUv6
PiperOrigin-RevId: 717583425
2025-01-20 11:18:22 -08:00
jax authors
a43edb4644 Update XLA dependency to use revision
e3ee51f579.

PiperOrigin-RevId: 717538958
2025-01-20 08:15:10 -08:00
jax authors
ce48f647e7 Merge pull request #25916 from gnecula:debug_info_4
PiperOrigin-RevId: 717519303
2025-01-20 06:49:19 -08:00
George Necula
dcf72b01f4 [better_errors] Improvements in propagation of debugging info
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).

Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.

In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.

Added more type declarations.

Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
2025-01-20 15:09:51 +01:00
Gleb Pobudzey
e7db4d5055
Merge branch 'jax-ml:main' into use_exp2 2025-01-19 23:10:03 -05:00
jax authors
d415c80b86 Update XLA dependency to use revision
665f79fbda.

PiperOrigin-RevId: 717256185
2025-01-19 08:59:38 -08:00
Jakob Roth
4f8699c8a1 Update docs of callbacks
Callback functions should not call into JAX. This information was
missing in the docs of the callbacks. This commit adds this information
to the docs.

See: #25861, #24255
2025-01-19 11:33:20 +01:00