9119 Commits

Author SHA1 Message Date
jax authors
f3224caf46 [Pallas/Fuser] Add support for pl.Element in fuser BlockSpec
PiperOrigin-RevId: 749492881
2025-04-20 01:13:19 -07:00
jax authors
037dab7a66 Merge pull request #27733 from gspschmid:gschmid/jax_fwd_and_bwd
PiperOrigin-RevId: 749195753
2025-04-18 17:37:11 -07:00
Robert Dyro
0e1b34196b Refactor array serialization into separate JAX and tensorstore logic
Array serialization in array_serialization.py contains a mixture of JAX
specific serialization logic and tensorstore driver. This change separates JAX
and tensorstore methods (a) making serialization more modular and (b)
potentially allowing for alternative array serialization backends in the
future.

Additional clean-up changes include:
- making ocdbt kvstore driver default in tensorstore
- robustified array serialization tests especially on multi-host
- explicit tensorstore array chunking to ensure chunk file size does not blow up

PiperOrigin-RevId: 749175295
2025-04-18 16:07:39 -07:00
jax authors
0ae613ee48 Makes Effort_02 the default value for memory_fitting_level.
PiperOrigin-RevId: 749159983
2025-04-18 15:09:02 -07:00
Yash Katariya
80d1fbac42 Handle sharding param in convert_element_type's batching rule properly by adding the explicit mesh axis on dim 0
PiperOrigin-RevId: 749125322
2025-04-18 13:11:19 -07:00
Dan Foreman-Mackey
c6482ed636 Ensure outputs are tracers when inlining jit. 2025-04-18 14:39:56 -04:00
Yash Katariya
854b2c85db Drop into Auto mode for .at[...].set(...) but instead of taking an out_sharding argument in set, use the input array's sharding. Since this is an update, after .set, the input array's sharding should be preserved.
Fixes https://github.com/jax-ml/jax/issues/28111

PiperOrigin-RevId: 749089846
2025-04-18 11:17:38 -07:00
Peter Hawkins
9515606892 [JAX] Remove jax.lib.xla_client.mlir_api_version and its uses.
(We leave the name exported by JAX to avoid breaking users, but fixed to its last known value.)

PiperOrigin-RevId: 749070199
2025-04-18 10:17:38 -07:00
Peter Hawkins
96865709b1 Allow the CPU collective implementation to be overridden to None.
PiperOrigin-RevId: 749055960
2025-04-18 09:29:11 -07:00
Dan Foreman-Mackey
492cd3d931 Reverts c2ba1790417ca206a4d88b25aef4d5ae510dd717
PiperOrigin-RevId: 749049676
2025-04-18 09:03:12 -07:00
Sharad Vikram
eab1dfccbc [Pallas] Generalize BlockSpec to support different indexing mode for each dim in the block shape
Currently block_shape is tuple[int | None, …]. We propose generalizing block_shape to take in more types in the tuple to more generally support:
* Squeeze dimension (currently None, could be pl.Squeezed())
* Unblocked: currently the entire index_map needs to be Unblocked or not. This will allow individual indices to be Blocked/Unblocked, e.g. pl.BlockSpec((pl.Unblocked(...), 512), …)
* Ragged sizes: the index_map will return a pl.ds with a dynamic size (bounded by some something). For example: pl.BlockSpec((pl.DynamicSizedSlice(512), 1024), lambda i, j: (pl.ds(...), j).

This will make BlockSpecs a lot more flexible and will enable things like doing arbitrary slicing in things like pipeline emitter.

PiperOrigin-RevId: 748881960
2025-04-17 18:46:38 -07:00
Yash Katariya
a2ebdf6d71 Rename with_user_mesh to with_explicit_mesh
PiperOrigin-RevId: 748880870
2025-04-17 18:41:35 -07:00
Yash Katariya
7de522c5a3 Enter into auto mode for .at[...].get(...) a bit earlier so that all ops inside _gather are in auto mode.
Fix select's batching rule where `explicit_mesh_axis` that we capture in `axis_data` was not propagated properly to the `broadcast` happening in `bdim_at_front`.

PiperOrigin-RevId: 748867490
2025-04-17 17:42:13 -07:00
Peter Hawkins
474dcd409d Remove code to support jaxlib < v0.6.
New minimum jaxlib_extension_version is 330.

PiperOrigin-RevId: 748853497
2025-04-17 16:44:41 -07:00
jax authors
c2ba179041 Merge pull request #28103 from dfm:pe-src-info
PiperOrigin-RevId: 748818185
2025-04-17 14:46:55 -07:00
Sergei Lebedev
23c973e4fa [pallas:mosaic] Replaced device_type= with kernel_type in TPUCompilerParams
The `device_type` can be inferred from the `tpu.core_type` on the kernel.
`kernel_type`, on the other hand, can also be used to define specialized
lowering rules for scalar/vector subcores.

PiperOrigin-RevId: 748794989
2025-04-17 13:40:18 -07:00
Parker Schuh
7634230cdc Remove unused jax_spmd_mode flag.
PiperOrigin-RevId: 748792684
2025-04-17 13:32:52 -07:00
Dan Foreman-Mackey
1d652ab7f4 Don't recompute source_info for each tracer during staging. 2025-04-17 15:31:38 -04:00
Yash Katariya
06ad3528e9 Use _make_lengths_same for explicit mode too.
We add `None`'s when ndim > len(sharding.spec) and only remove `None`s when `ndim < len(sharding.spec)`. If sharded axes exist, then we error out when removing specs.

PiperOrigin-RevId: 748735303
2025-04-17 10:48:46 -07:00
Sergei Lebedev
4ceb4b0526 Do not use -> ...
It is a non-standard pytype feature which is not supported by any other type checker.

PiperOrigin-RevId: 748636378
2025-04-17 04:37:22 -07:00
Sergei Lebedev
c576d328bd Added lax.axis_size and switched all existing usage of psum(1, ...) to it
PiperOrigin-RevId: 748604842
2025-04-17 02:22:25 -07:00
Yash Katariya
82215f660e Remove jax_varying_axes_in_types config and rewrite from shard_map_p
PiperOrigin-RevId: 748545142
2025-04-16 22:27:50 -07:00
Yash Katariya
0a9d0bec5b Remove _manual_axes from NamedSharding since we can now track the manual axes on the mesh.
PiperOrigin-RevId: 748534841
2025-04-16 21:49:53 -07:00
jax authors
003713cc4f Merge pull request #28069 from dfm:fix-argums-partial
PiperOrigin-RevId: 748453227
2025-04-16 16:00:23 -07:00
Yash Katariya
a31e53a6c8 Return False in is_env_present if importing kubernetes leads to a ModuleNotFoundError
PiperOrigin-RevId: 748440123
2025-04-16 15:15:27 -07:00
Yash Katariya
5f6b99a143 Fix a bug in reduce_window sharding rule where padding is a tuple but we were checking for a scalar instead. Fixes https://github.com/jax-ml/jax/issues/28070
PiperOrigin-RevId: 748418451
2025-04-16 14:10:13 -07:00
Dan Foreman-Mackey
9afc047bf0 Fix bug in argnums_partial_except when static_argnums is unsorted. 2025-04-16 16:18:10 -04:00
jax authors
74f1d887eb Merge pull request #28018 from Cjkkkk:disable_packed_layout_at_ampere
PiperOrigin-RevId: 748349568
2025-04-16 10:54:25 -07:00
Jevin Jiang
770dae72cb [Pallas][Mosaic][TPU] Add disable_bounds_checks compiler params
When we run the program with "--xla_jf_bounds_check=true", we can selectively disable bounds checks for pallas kernels now.

PiperOrigin-RevId: 748193719
2025-04-16 01:01:27 -07:00
Chris Jones
2beff6a1df [pallas] Fix case of Fusible{ElementDtype,TyRules}.
The first letter was inadvertently made lower-case in the previous re-naming CL.

PiperOrigin-RevId: 748086763
2025-04-15 17:43:44 -07:00
Roy Frostig
90af597786 remove inaccurate inline comment in PRNGKeyArray constructor
PiperOrigin-RevId: 748085747
2025-04-15 17:39:40 -07:00
Roy Frostig
47bc2f55dc convert NumPy RNG key data to uncommitted default-device-backed jax.Array data
Generally, we want to maintain that key data backing a `PRNGKeyArray` is a `jax.Array`. This change converts NumPy arrays on construction.

Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 748077900
2025-04-15 17:11:25 -07:00
jax authors
25e0fe59d5 Merge pull request #27984 from carlosgmartin:logsumexp_doc
PiperOrigin-RevId: 748059520
2025-04-15 16:10:57 -07:00
jax authors
002be7a1ab Merge pull request #28047 from jakevdp:logsoftmax-dep
PiperOrigin-RevId: 748059518
2025-04-15 16:10:27 -07:00
Yash Katariya
655bfcac39 Enable standard_insert_pvary for optimization_barrier which was disabled before.
PiperOrigin-RevId: 748027360
2025-04-15 14:41:08 -07:00
Jake VanderPlas
b271a67bbc Clean up softmax initial deprecation 2025-04-15 14:36:56 -07:00
Jake VanderPlas
ba8877789d Roll back https://github.com/jax-ml/jax/pull/28022 due to test breakages.
Reverts b336daf747940301de5956dce4ebe790298e6b5b

PiperOrigin-RevId: 747988862
2025-04-15 13:00:04 -07:00
Yash Katariya
6e00b5e02d [NFC] Rename standard_insert_pbroadcast to standard_insert_pvary
PiperOrigin-RevId: 747943230
2025-04-15 11:02:45 -07:00
Jake VanderPlas
c56cf4f68d jax.random.bernoulli: use higher-resolution sampler 2025-04-15 08:18:47 -07:00
Georg Stefan Schmid
ae6a18d70d Add jax.fwd_and_bwd 2025-04-15 08:21:18 +00:00
Chris Jones
1926b99bfd [pallas] Fix spelling of 'fusible'.
PiperOrigin-RevId: 747663692
2025-04-14 19:35:59 -07:00
Mark Sandler
0ed0fb7c54 Adds a debugging message to assert, otherwise the error is pretty cryptic.
PiperOrigin-RevId: 747657234
2025-04-14 19:11:02 -07:00
Sharad Vikram
4fa3cd91d3 [Pallas/Fuser] Add basic closed over consts support to pull_block_spec
PiperOrigin-RevId: 747657069
2025-04-14 19:09:04 -07:00
Peter Hawkins
57e33bcbcd Deprecate the contents of jax.util.
PiperOrigin-RevId: 747629222
2025-04-14 17:20:30 -07:00
Ivy Zheng
ab600c3e82 Remove obsolete python key path registry.
PiperOrigin-RevId: 747613761
2025-04-14 16:33:05 -07:00
jax authors
19be20fc6f Merge pull request #27919 from kaixih:enable_doc_scaled_dot_fix
PiperOrigin-RevId: 747578845
2025-04-14 14:55:23 -07:00
Peter Hawkins
8930a67e63 Fix stablehlo version comparison in test utilities.
PiperOrigin-RevId: 747547427
2025-04-14 13:34:32 -07:00
cjkkkk
760d0e0e97 disable packed layout test on old arch prior to Hopper 2025-04-14 20:33:30 +00:00
jax authors
d014912671 Merge pull request #28007 from jakevdp:int-power
PiperOrigin-RevId: 747498460
2025-04-14 11:26:05 -07:00
jax authors
6fcb036b96 Merge pull request #27966 from jakevdp:jit-signature
PiperOrigin-RevId: 747492659
2025-04-14 11:11:02 -07:00