9108 Commits

Author SHA1 Message Date
Sharad Vikram
eab1dfccbc [Pallas] Generalize BlockSpec to support different indexing mode for each dim in the block shape
Currently block_shape is tuple[int | None, …]. We propose generalizing block_shape to take in more types in the tuple to more generally support:
* Squeeze dimension (currently None, could be pl.Squeezed())
* Unblocked: currently the entire index_map needs to be Unblocked or not. This will allow individual indices to be Blocked/Unblocked, e.g. pl.BlockSpec((pl.Unblocked(...), 512), …)
* Ragged sizes: the index_map will return a pl.ds with a dynamic size (bounded by some something). For example: pl.BlockSpec((pl.DynamicSizedSlice(512), 1024), lambda i, j: (pl.ds(...), j).

This will make BlockSpecs a lot more flexible and will enable things like doing arbitrary slicing in things like pipeline emitter.

PiperOrigin-RevId: 748881960
2025-04-17 18:46:38 -07:00
Yash Katariya
a2ebdf6d71 Rename with_user_mesh to with_explicit_mesh
PiperOrigin-RevId: 748880870
2025-04-17 18:41:35 -07:00
Yash Katariya
7de522c5a3 Enter into auto mode for .at[...].get(...) a bit earlier so that all ops inside _gather are in auto mode.
Fix select's batching rule where `explicit_mesh_axis` that we capture in `axis_data` was not propagated properly to the `broadcast` happening in `bdim_at_front`.

PiperOrigin-RevId: 748867490
2025-04-17 17:42:13 -07:00
Peter Hawkins
474dcd409d Remove code to support jaxlib < v0.6.
New minimum jaxlib_extension_version is 330.

PiperOrigin-RevId: 748853497
2025-04-17 16:44:41 -07:00
jax authors
c2ba179041 Merge pull request #28103 from dfm:pe-src-info
PiperOrigin-RevId: 748818185
2025-04-17 14:46:55 -07:00
Sergei Lebedev
23c973e4fa [pallas:mosaic] Replaced device_type= with kernel_type in TPUCompilerParams
The `device_type` can be inferred from the `tpu.core_type` on the kernel.
`kernel_type`, on the other hand, can also be used to define specialized
lowering rules for scalar/vector subcores.

PiperOrigin-RevId: 748794989
2025-04-17 13:40:18 -07:00
Parker Schuh
7634230cdc Remove unused jax_spmd_mode flag.
PiperOrigin-RevId: 748792684
2025-04-17 13:32:52 -07:00
Dan Foreman-Mackey
1d652ab7f4 Don't recompute source_info for each tracer during staging. 2025-04-17 15:31:38 -04:00
Yash Katariya
06ad3528e9 Use _make_lengths_same for explicit mode too.
We add `None`'s when ndim > len(sharding.spec) and only remove `None`s when `ndim < len(sharding.spec)`. If sharded axes exist, then we error out when removing specs.

PiperOrigin-RevId: 748735303
2025-04-17 10:48:46 -07:00
Sergei Lebedev
4ceb4b0526 Do not use -> ...
It is a non-standard pytype feature which is not supported by any other type checker.

PiperOrigin-RevId: 748636378
2025-04-17 04:37:22 -07:00
Sergei Lebedev
c576d328bd Added lax.axis_size and switched all existing usage of psum(1, ...) to it
PiperOrigin-RevId: 748604842
2025-04-17 02:22:25 -07:00
Yash Katariya
82215f660e Remove jax_varying_axes_in_types config and rewrite from shard_map_p
PiperOrigin-RevId: 748545142
2025-04-16 22:27:50 -07:00
Yash Katariya
0a9d0bec5b Remove _manual_axes from NamedSharding since we can now track the manual axes on the mesh.
PiperOrigin-RevId: 748534841
2025-04-16 21:49:53 -07:00
jax authors
003713cc4f Merge pull request #28069 from dfm:fix-argums-partial
PiperOrigin-RevId: 748453227
2025-04-16 16:00:23 -07:00
Yash Katariya
a31e53a6c8 Return False in is_env_present if importing kubernetes leads to a ModuleNotFoundError
PiperOrigin-RevId: 748440123
2025-04-16 15:15:27 -07:00
Yash Katariya
5f6b99a143 Fix a bug in reduce_window sharding rule where padding is a tuple but we were checking for a scalar instead. Fixes https://github.com/jax-ml/jax/issues/28070
PiperOrigin-RevId: 748418451
2025-04-16 14:10:13 -07:00
Dan Foreman-Mackey
9afc047bf0 Fix bug in argnums_partial_except when static_argnums is unsorted. 2025-04-16 16:18:10 -04:00
jax authors
74f1d887eb Merge pull request #28018 from Cjkkkk:disable_packed_layout_at_ampere
PiperOrigin-RevId: 748349568
2025-04-16 10:54:25 -07:00
Jevin Jiang
770dae72cb [Pallas][Mosaic][TPU] Add disable_bounds_checks compiler params
When we run the program with "--xla_jf_bounds_check=true", we can selectively disable bounds checks for pallas kernels now.

PiperOrigin-RevId: 748193719
2025-04-16 01:01:27 -07:00
Chris Jones
2beff6a1df [pallas] Fix case of Fusible{ElementDtype,TyRules}.
The first letter was inadvertently made lower-case in the previous re-naming CL.

PiperOrigin-RevId: 748086763
2025-04-15 17:43:44 -07:00
Roy Frostig
90af597786 remove inaccurate inline comment in PRNGKeyArray constructor
PiperOrigin-RevId: 748085747
2025-04-15 17:39:40 -07:00
Roy Frostig
47bc2f55dc convert NumPy RNG key data to uncommitted default-device-backed jax.Array data
Generally, we want to maintain that key data backing a `PRNGKeyArray` is a `jax.Array`. This change converts NumPy arrays on construction.

Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 748077900
2025-04-15 17:11:25 -07:00
jax authors
25e0fe59d5 Merge pull request #27984 from carlosgmartin:logsumexp_doc
PiperOrigin-RevId: 748059520
2025-04-15 16:10:57 -07:00
jax authors
002be7a1ab Merge pull request #28047 from jakevdp:logsoftmax-dep
PiperOrigin-RevId: 748059518
2025-04-15 16:10:27 -07:00
Yash Katariya
655bfcac39 Enable standard_insert_pvary for optimization_barrier which was disabled before.
PiperOrigin-RevId: 748027360
2025-04-15 14:41:08 -07:00
Jake VanderPlas
b271a67bbc Clean up softmax initial deprecation 2025-04-15 14:36:56 -07:00
Jake VanderPlas
ba8877789d Roll back https://github.com/jax-ml/jax/pull/28022 due to test breakages.
Reverts b336daf747940301de5956dce4ebe790298e6b5b

PiperOrigin-RevId: 747988862
2025-04-15 13:00:04 -07:00
Yash Katariya
6e00b5e02d [NFC] Rename standard_insert_pbroadcast to standard_insert_pvary
PiperOrigin-RevId: 747943230
2025-04-15 11:02:45 -07:00
Jake VanderPlas
c56cf4f68d jax.random.bernoulli: use higher-resolution sampler 2025-04-15 08:18:47 -07:00
Chris Jones
1926b99bfd [pallas] Fix spelling of 'fusible'.
PiperOrigin-RevId: 747663692
2025-04-14 19:35:59 -07:00
Mark Sandler
0ed0fb7c54 Adds a debugging message to assert, otherwise the error is pretty cryptic.
PiperOrigin-RevId: 747657234
2025-04-14 19:11:02 -07:00
Sharad Vikram
4fa3cd91d3 [Pallas/Fuser] Add basic closed over consts support to pull_block_spec
PiperOrigin-RevId: 747657069
2025-04-14 19:09:04 -07:00
Peter Hawkins
57e33bcbcd Deprecate the contents of jax.util.
PiperOrigin-RevId: 747629222
2025-04-14 17:20:30 -07:00
Ivy Zheng
ab600c3e82 Remove obsolete python key path registry.
PiperOrigin-RevId: 747613761
2025-04-14 16:33:05 -07:00
jax authors
19be20fc6f Merge pull request #27919 from kaixih:enable_doc_scaled_dot_fix
PiperOrigin-RevId: 747578845
2025-04-14 14:55:23 -07:00
Peter Hawkins
8930a67e63 Fix stablehlo version comparison in test utilities.
PiperOrigin-RevId: 747547427
2025-04-14 13:34:32 -07:00
cjkkkk
760d0e0e97 disable packed layout test on old arch prior to Hopper 2025-04-14 20:33:30 +00:00
jax authors
d014912671 Merge pull request #28007 from jakevdp:int-power
PiperOrigin-RevId: 747498460
2025-04-14 11:26:05 -07:00
jax authors
6fcb036b96 Merge pull request #27966 from jakevdp:jit-signature
PiperOrigin-RevId: 747492659
2025-04-14 11:11:02 -07:00
Jake VanderPlas
42542feac6 jnp.power: better docs for invalid input 2025-04-14 10:42:29 -07:00
jax authors
30669dc219 Merge pull request #27993 from gnecula:explain_timing
PiperOrigin-RevId: 747480248
2025-04-14 10:41:05 -07:00
Jake VanderPlas
ceca6ec1fc jax.jit: deprecate non-standard call signature. 2025-04-14 10:13:05 -07:00
Dan Foreman-Mackey
1b1bd071bc Finalize deprecation of vectorized argument in callbacks.
The `vectorized` argument to `pure_callback` and `ffi_call` was deprecated in JAX v0.4.34 (released Oct 4 2024), then added to the CHANGELOG in v0.4.35 (doh! released Oct 22). The JAX compatibility policy requires 3 months of compatible releases before a deprecation is finalized, so it is time to remove this parameter from the public API. The `vmap_method` parameter can be used instead, and the docs for [`pure_callback`](https://docs.jax.dev/en/latest/_autosummary/jax.pure_callback.html) provide more details.

This change has one other (non-obvious!) affect on the user facing APIs. (Note that this change in behavior has also been protected by a deprecation warning since the `vectorized` parameter was deprecated.) The default behavior of `pure_callback` and `ffi_call` under `vmap` is to now raise an exception, rather than silently producing a loop. To opt in to the previous default behavior, use `vmap_method="sequential"`.

PiperOrigin-RevId: 747413383
2025-04-14 07:43:59 -07:00
jax authors
b6c6c1c258 Merge pull request #27971 from ywrt:patch-1
PiperOrigin-RevId: 747399343
2025-04-14 07:00:10 -07:00
George Necula
b8df474965 [explain_cache_miss] Add to explanations the duration of the missed function call
This enables the user to focus on the most important
call sites.

jax-fixit
2025-04-14 16:08:24 +03:00
jax authors
6ca623f79b Merge pull request #27980 from gnecula:tracing_cache
PiperOrigin-RevId: 747274185
2025-04-13 23:53:16 -07:00
carlosgmartin
2336cd1695 Minor improvements to doc for jax.nn.logsumexp. 2025-04-13 15:17:11 -04:00
George Necula
f070cdecb3 [explain-cache-miss] Improve tracing-cache-miss explanations
The previous approach was to report, for several elements
of the cache key, the closest mismatch. Some parts of
the cache key were ignored, which led to "explanation unavailable".
The same happened when we had two keys close to the current
one, each differring in a different part of the key.
No explanation was produced because for each part of the key,
there was a matching key already in the cache, even though
the key taken as a whole did not match.

Now, we scan *all* parts of they key and compute the differences.
We keep track of the "size" of the differences, and we explain
the differences to those keys that are closest (possibly more
than one key if equidistant).
For example, for shape differences we'll report the
closest matching shape. If a type differs in both the dtype
and some parts of the shape, or sharding, it is considered
farther away.

We add new tests and explanations for  different
static argnums and argnames.

There are still cases when we do not produce an explanation, but
now the "explanation unavailable" includes a description
of which component of the key is different, and what the
difference is. This may still be hard to understand by the
user but at least they can file a clearer bug.

Refactored the tests, and added a few new ones.
2025-04-13 20:44:46 +03:00
Roy Frostig
566d0775a8 unify stages.Lowering and stages.XlaLowering
We no longer have many different implicit types conforming to `Lowering`, only `pxla.MeshComputation` and `pxla.PmapComputation`. Both are `XlaLowering` subtypes. So define just one common base class, call it `Lowering`, and inherit from just that in both concrete internal computation/lowering subtypes.

PiperOrigin-RevId: 746735857
2025-04-12 00:31:14 -07:00
Roy Frostig
99ca14601d revert making Executable an ABC
PiperOrigin-RevId: 746726071
2025-04-11 23:49:25 -07:00