25219 Commits

Author SHA1 Message Date
Jake VanderPlas
23c1d62910 internal: move more NumPy APIs to ensure_arraylike 2025-01-23 08:48:13 -08:00
Yash Katariya
6b95ad0a53 Don't calculate tracing debug_info twice
PiperOrigin-RevId: 718860522
2025-01-23 07:37:14 -08:00
Yash Katariya
33aa088a5c Remove axis_name from pmap_unmapped_aval_handlers
PiperOrigin-RevId: 718859837
2025-01-23 07:34:38 -08:00
Nitin Srinivasan
e8d40ff1a7 Fix typo and improve readability of workflow documentation
PiperOrigin-RevId: 718838936
2025-01-23 06:24:55 -08:00
Dimitar (Mitko) Asenov
6f609926a6 [Mosaic GPU] Remove an unnecessary restriction in the vector.store lowering
This was made obsolete by:
f89accc56a

PiperOrigin-RevId: 718808561
2025-01-23 04:24:14 -08:00
Dimitar (Mitko) Asenov
f57d603c45 [Mosaic GPU] Simplify enums in the MLIR Mosaic GPU dialect.
This enables us to use them more simply in the current and upcoming Python code. The Python bindings for enum and enum attributes leave much to be desired.

PiperOrigin-RevId: 718795667
2025-01-23 03:38:26 -08:00
Dimitar (Mitko) Asenov
6b747b4109 [Mosaic GPU] Add a result to the WGMMA op definition in the MLIR dialect
PiperOrigin-RevId: 718788390
2025-01-23 03:10:07 -08:00
Dimitar (Mitko) Asenov
3a411d883a [Mosaic GPU] Implement basic WGMMAFragLayout inference and propagation
PiperOrigin-RevId: 718781860
2025-01-23 02:48:04 -08:00
Parker Schuh
f3e27b6c28 Support axis_index using a nested shard_map instead of iota with full to shard.
PiperOrigin-RevId: 718661661
2025-01-22 19:14:37 -08:00
Nitin Srinivasan
9aad6a6827 Add job that runs Bazel single accelerator and multi-accelerator CUDA tests (non-RBE)
PiperOrigin-RevId: 718637923
2025-01-22 17:51:45 -08:00
Yash Katariya
704b2e5fba [sharding_in_types] Make vmap work with shard_map + pallas
PiperOrigin-RevId: 718578207
2025-01-22 16:48:32 -08:00
Peter Hawkins
cd51e9dd14 Speed up name stack printing.
If we repeatedly form tuples by concatenation during printing, we make what should be a linear time operation quadratic.

Also simplify the API contract of extend() to only add a single element, and remove the unused method wrap_name.

PiperOrigin-RevId: 718570432
2025-01-22 16:24:16 -08:00
jax authors
0fccabcd49 Merge pull request #26050 from jakevdp:effver-adoption
PiperOrigin-RevId: 718562550
2025-01-22 16:01:35 -08:00
Sharad Vikram
64e9b07ee3 Update debugging Pallas g3doc to remove text about scalar printing restriction
PiperOrigin-RevId: 718560406
2025-01-22 15:55:05 -08:00
Yash Katariya
23d360bded Remove axis_name from unmapped_aval
PiperOrigin-RevId: 718558713
2025-01-22 15:49:04 -08:00
Jake VanderPlas
423be16ecc DOC: mention adoption of EffVer in JEP 2025-01-22 15:44:11 -08:00
jax authors
f6243ff8e1 Merge pull request #25889 from Stella-S-Yan:cache_reset
PiperOrigin-RevId: 718537398
2025-01-22 14:52:05 -08:00
jax authors
1b6080d943 Merge pull request #26044 from hawkinsp:slots
PiperOrigin-RevId: 718537271
2025-01-22 14:50:07 -08:00
Peter Hawkins
f4adcc650f Set __slots__ on core.Trace subclasses.
This is easy to do and makes field accesses on Trace classes slightly faster.
2025-01-22 16:17:54 -05:00
jax authors
fc9356085e Merge pull request #26024 from jakevdp:setops-args
PiperOrigin-RevId: 718467165
2025-01-22 11:37:51 -08:00
Justin Fu
10bb38bb79 [Mosaic GPU] Add manual consumed barrier handling to WS pipeline.
PiperOrigin-RevId: 718451678
2025-01-22 10:59:58 -08:00
Nitin Srinivasan
14029c6762 Replace "gpu" with "cuda" to be specific about the type of gpu tests that are running
Also, make `run_docker_container.sh` executable

PiperOrigin-RevId: 718420712
2025-01-22 09:40:11 -08:00
jax authors
6c76cc4e36 Integrate LLVM at llvm/llvm-project@d33e33fde7
Updates LLVM usage to match
[d33e33fde770](https://github.com/llvm/llvm-project/commit/d33e33fde770)

PiperOrigin-RevId: 718414171
2025-01-22 09:22:07 -08:00
jax authors
5c6b6c124f Update XLA dependency to use revision
f08027150f.

PiperOrigin-RevId: 718409072
2025-01-22 09:06:38 -08:00
jax authors
d6028153c4 Merge pull request #26028 from gnecula:debug_info_more_tests
PiperOrigin-RevId: 718398022
2025-01-22 08:33:25 -08:00
Will Froom
dc16721b52 [XLA:CPU] Use central difference to calculate numerical gradient
PiperOrigin-RevId: 718383754
2025-01-22 07:49:43 -08:00
George Necula
849ccc978b [better_errors] Expand the tests for debug_info
Debugging info is needed for error messages, and for
lowering. For the former, we need debug info inside
tracers. For the latter, inside Jaxprs. We add a
new set of tests that intentionally leak tracers while
tracing and then we check that the tracers have the
expected debug info. We also form Jaxprs and we
check that they have the expected debug info.
We uncovered a few missing debug infos, those are
marked with TODO.
2025-01-22 16:49:16 +01:00
jax authors
e304e9ea16 Merge pull request #25992 from gnecula:debug_info_arg_names
PiperOrigin-RevId: 718216003
2025-01-21 22:17:08 -08:00
Stella S Yan
f87c94db75 Fix cache init when JAX Array is created early (#25768) 2025-01-22 03:44:29 +00:00
Jevin Jiang
908df65a26 [Mosaic TPU] Emulate converting x16 vector to mask if mask packing is supported.
PiperOrigin-RevId: 718133639
2025-01-21 17:23:33 -08:00
jax authors
54bb7f5ddb Remove meaningless template keywords.
This will fix -Wmissing-template-arg-list-after-template-kw warnings.
This warning is error-by-default in Clang.

PiperOrigin-RevId: 718133601
2025-01-21 17:22:04 -08:00
Yash Katariya
3aa55992fe Remove device_context from trace_context because we don't need it there. We can get compilation cache misses (and tracing/lowering cache hit) naturally without putting concrete devices into trace_context.
PiperOrigin-RevId: 718113413
2025-01-21 16:21:36 -08:00
Jake VanderPlas
a69f9dcc19 jax.numpy setops: use ensure_arraylike & avoid asarray 2025-01-21 16:05:49 -08:00
Yash Katariya
051861bbf1 Error out if contracting dimensions are sharded and ask the user to provide the output sharding
PiperOrigin-RevId: 718087084
2025-01-21 15:06:54 -08:00
Dan Foreman-Mackey
aa8c0010e2 Add support for constants in the decomposition of lax.composites.
This change adds support for including "consts" (i.e. closed-over arrays) in the body of the decomposition definition for a `lax.composite` op. The caveat here is that, since the signature of the decomposition must match the composite itself, the values of any consts must be known when lowering so that they can be inlined into the decomposition's HLO. Therefore, there is no support for closing over tracers. Since `lax.composite` doesn't support most transformations anyways, this typically isn't going to be a major limitation except with `jax.jit` (as demonstrated in the tests).

PiperOrigin-RevId: 718048021
2025-01-21 13:28:42 -08:00
Peter Hawkins
afb750cf76 Remove unnecessary use of xla_client.OpMetadata class.
We create this object and immediately turn it into a different object. We can cut out a step here!

PiperOrigin-RevId: 718023353
2025-01-21 12:24:51 -08:00
Tzu-Wei Sung
79bd72e2e8 [Mosaic] Remove hardcoded TARGET_SHAPE and align Python/C++ APIs.
PiperOrigin-RevId: 717973752
2025-01-21 10:24:10 -08:00
Tom Hennigan
7f43316e27 Add an option to simplify keystr output and use a custom separator.
Currently `keystr` just calls `str` on the key entries, leading to quite
verbose output. For example:

    >>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path))
    ['foo']['bar']['bat'][0]
    ['foo']['bar']['bat'][1]
    ['foo']['bar']['baz']

This change allows for a new "simple" format where the string representation
of key entries are further simplified. Additionally we allow a custom
separator since it is very common to use `/` (for example to separate module
and parameter names):

    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path, simple=True, separator='/'))
    foo/bar/bat/0
    foo/bar/bat/1
    foo/bar/baz
```

PiperOrigin-RevId: 717971583
2025-01-21 10:18:42 -08:00
jax authors
96a3ed36c7 Add part (non-quantized K/V pages) of paged_attention_kernel tests back for TPU v6.
The paged_attention_kernel tests for TPU v6 was disabled in the past but I discovered that all the failing tests have `are_kv_quantized=True`. So we can still test the non-quantized part on TPU v6.

PiperOrigin-RevId: 717969073
2025-01-21 10:12:52 -08:00
jax authors
70a5175d0a Merge pull request #25647 from Rifur13:bwd_pass
PiperOrigin-RevId: 717954065
2025-01-21 09:37:41 -08:00
Sergei Lebedev
4c363766f8 [mosaic_gpu] Removed debug prints in jax._src.lib.mosaic_gpu
PiperOrigin-RevId: 717952850
2025-01-21 09:36:12 -08:00
jax authors
3b5b98163b Merge pull request #25982 from roth-jakob:warning_callbacks
PiperOrigin-RevId: 717951477
2025-01-21 09:29:52 -08:00
jax authors
aaa7e922ed Update XLA dependency to use revision
5388e86b8b.

PiperOrigin-RevId: 717944684
2025-01-21 09:10:48 -08:00
Yash Katariya
d30476a0ae Make make_mesh take visible_axes, hidden_axes and collective_axes as parameters instead of axis_types to make it a more cleaner API.
The mesh axis names provided to those parameters should be disjoint i.e. no overlap.

PiperOrigin-RevId: 717941577
2025-01-21 09:02:39 -08:00
Christos Perivolaropoulos
c4643c6156 [mosaic_gpu] a function for bitwidth as well as bytewidht.
This is to enable s4 and friends.

PiperOrigin-RevId: 717896640
2025-01-21 06:52:04 -08:00
Adam Paszke
3c8cf3c92e [Pallas] Improve testing for casts from narrow types + test int4
For sub-32bit types there are so few distinct values that we can just
exhaustively test them all.

PiperOrigin-RevId: 717879040
2025-01-21 05:56:10 -08:00
George Necula
3f73f7b0eb [better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.

When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.

We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-21 13:38:10 +01:00
Dimitar (Mitko) Asenov
f89accc56a [Mosaic GPU] Add support for converting all fragmented layouts to ir and back.
This will be used in the layout inference and lowering of the dialect WGMMA op

PiperOrigin-RevId: 717836648
2025-01-21 03:27:03 -08:00
Dimitar (Mitko) Asenov
0a89760c24 [Mosaic GPU] Do not use mgpu in wgmma.py
This enables the dialect lowering to depend on `wgmma.py` without creating a circular dependency. I need this in a follow up CL that implements the lowering of the WGMMA dialect op.

PiperOrigin-RevId: 717791901
2025-01-21 01:06:42 -08:00
Yash Katariya
bba5ada525 Make sure reshard and mesh_cast behave properly under eager mode
PiperOrigin-RevId: 717714860
2025-01-20 20:33:09 -08:00