jax authors
f3224caf46
[Pallas/Fuser] Add support for pl.Element in fuser BlockSpec
...
PiperOrigin-RevId: 749492881
2025-04-20 01:13:19 -07:00
jax authors
037dab7a66
Merge pull request #27733 from gspschmid:gschmid/jax_fwd_and_bwd
...
PiperOrigin-RevId: 749195753
2025-04-18 17:37:11 -07:00
Robert Dyro
0e1b34196b
Refactor array serialization into separate JAX and tensorstore logic
...
Array serialization in array_serialization.py contains a mixture of JAX
specific serialization logic and tensorstore driver. This change separates JAX
and tensorstore methods (a) making serialization more modular and (b)
potentially allowing for alternative array serialization backends in the
future.
Additional clean-up changes include:
- making ocdbt kvstore driver default in tensorstore
- robustified array serialization tests especially on multi-host
- explicit tensorstore array chunking to ensure chunk file size does not blow up
PiperOrigin-RevId: 749175295
2025-04-18 16:07:39 -07:00
jax authors
0ae613ee48
Makes Effort_02 the default value for memory_fitting_level.
...
PiperOrigin-RevId: 749159983
2025-04-18 15:09:02 -07:00
Yash Katariya
80d1fbac42
Handle sharding
param in convert_element_type's batching rule properly by adding the explicit mesh axis on dim 0
...
PiperOrigin-RevId: 749125322
2025-04-18 13:11:19 -07:00
Dan Foreman-Mackey
c6482ed636
Ensure outputs are tracers when inlining jit.
2025-04-18 14:39:56 -04:00
Yash Katariya
854b2c85db
Drop into Auto
mode for .at[...].set(...)
but instead of taking an out_sharding
argument in set
, use the input array's sharding
. Since this is an update, after .set
, the input array's sharding should be preserved.
...
Fixes https://github.com/jax-ml/jax/issues/28111
PiperOrigin-RevId: 749089846
2025-04-18 11:17:38 -07:00
Peter Hawkins
9515606892
[JAX] Remove jax.lib.xla_client.mlir_api_version and its uses.
...
(We leave the name exported by JAX to avoid breaking users, but fixed to its last known value.)
PiperOrigin-RevId: 749070199
2025-04-18 10:17:38 -07:00
Peter Hawkins
96865709b1
Allow the CPU collective implementation to be overridden to None.
...
PiperOrigin-RevId: 749055960
2025-04-18 09:29:11 -07:00
Dan Foreman-Mackey
492cd3d931
Reverts c2ba1790417ca206a4d88b25aef4d5ae510dd717
...
PiperOrigin-RevId: 749049676
2025-04-18 09:03:12 -07:00
Sharad Vikram
eab1dfccbc
[Pallas] Generalize BlockSpec to support different indexing mode for each dim in the block shape
...
Currently block_shape is tuple[int | None, …]. We propose generalizing block_shape to take in more types in the tuple to more generally support:
* Squeeze dimension (currently None, could be pl.Squeezed())
* Unblocked: currently the entire index_map needs to be Unblocked or not. This will allow individual indices to be Blocked/Unblocked, e.g. pl.BlockSpec((pl.Unblocked(...), 512), …)
* Ragged sizes: the index_map will return a pl.ds with a dynamic size (bounded by some something). For example: pl.BlockSpec((pl.DynamicSizedSlice(512), 1024), lambda i, j: (pl.ds(...), j).
This will make BlockSpecs a lot more flexible and will enable things like doing arbitrary slicing in things like pipeline emitter.
PiperOrigin-RevId: 748881960
2025-04-17 18:46:38 -07:00
Yash Katariya
a2ebdf6d71
Rename with_user_mesh
to with_explicit_mesh
...
PiperOrigin-RevId: 748880870
2025-04-17 18:41:35 -07:00
Yash Katariya
7de522c5a3
Enter into auto mode for .at[...].get(...)
a bit earlier so that all ops inside _gather
are in auto mode.
...
Fix select's batching rule where `explicit_mesh_axis` that we capture in `axis_data` was not propagated properly to the `broadcast` happening in `bdim_at_front`.
PiperOrigin-RevId: 748867490
2025-04-17 17:42:13 -07:00
Peter Hawkins
474dcd409d
Remove code to support jaxlib < v0.6.
...
New minimum jaxlib_extension_version is 330.
PiperOrigin-RevId: 748853497
2025-04-17 16:44:41 -07:00
jax authors
c2ba179041
Merge pull request #28103 from dfm:pe-src-info
...
PiperOrigin-RevId: 748818185
2025-04-17 14:46:55 -07:00
Sergei Lebedev
23c973e4fa
[pallas:mosaic] Replaced device_type=
with kernel_type
in TPUCompilerParams
...
The `device_type` can be inferred from the `tpu.core_type` on the kernel.
`kernel_type`, on the other hand, can also be used to define specialized
lowering rules for scalar/vector subcores.
PiperOrigin-RevId: 748794989
2025-04-17 13:40:18 -07:00
Parker Schuh
7634230cdc
Remove unused jax_spmd_mode flag.
...
PiperOrigin-RevId: 748792684
2025-04-17 13:32:52 -07:00
Dan Foreman-Mackey
1d652ab7f4
Don't recompute source_info for each tracer during staging.
2025-04-17 15:31:38 -04:00
Yash Katariya
06ad3528e9
Use _make_lengths_same for explicit mode too.
...
We add `None`'s when ndim > len(sharding.spec) and only remove `None`s when `ndim < len(sharding.spec)`. If sharded axes exist, then we error out when removing specs.
PiperOrigin-RevId: 748735303
2025-04-17 10:48:46 -07:00
Sergei Lebedev
4ceb4b0526
Do not use -> ...
...
It is a non-standard pytype feature which is not supported by any other type checker.
PiperOrigin-RevId: 748636378
2025-04-17 04:37:22 -07:00
Sergei Lebedev
c576d328bd
Added lax.axis_size
and switched all existing usage of psum(1, ...)
to it
...
PiperOrigin-RevId: 748604842
2025-04-17 02:22:25 -07:00
Yash Katariya
82215f660e
Remove jax_varying_axes_in_types config and rewrite
from shard_map_p
...
PiperOrigin-RevId: 748545142
2025-04-16 22:27:50 -07:00
Yash Katariya
0a9d0bec5b
Remove _manual_axes from NamedSharding since we can now track the manual axes on the mesh.
...
PiperOrigin-RevId: 748534841
2025-04-16 21:49:53 -07:00
jax authors
003713cc4f
Merge pull request #28069 from dfm:fix-argums-partial
...
PiperOrigin-RevId: 748453227
2025-04-16 16:00:23 -07:00
Yash Katariya
a31e53a6c8
Return False in is_env_present
if importing kubernetes leads to a ModuleNotFoundError
...
PiperOrigin-RevId: 748440123
2025-04-16 15:15:27 -07:00
Yash Katariya
5f6b99a143
Fix a bug in reduce_window sharding rule where padding is a tuple but we were checking for a scalar instead. Fixes https://github.com/jax-ml/jax/issues/28070
...
PiperOrigin-RevId: 748418451
2025-04-16 14:10:13 -07:00
Dan Foreman-Mackey
9afc047bf0
Fix bug in argnums_partial_except when static_argnums is unsorted.
2025-04-16 16:18:10 -04:00
jax authors
74f1d887eb
Merge pull request #28018 from Cjkkkk:disable_packed_layout_at_ampere
...
PiperOrigin-RevId: 748349568
2025-04-16 10:54:25 -07:00
Jevin Jiang
770dae72cb
[Pallas][Mosaic][TPU] Add disable_bounds_checks
compiler params
...
When we run the program with "--xla_jf_bounds_check=true", we can selectively disable bounds checks for pallas kernels now.
PiperOrigin-RevId: 748193719
2025-04-16 01:01:27 -07:00
Chris Jones
2beff6a1df
[pallas] Fix case of Fusible{ElementDtype,TyRules}
.
...
The first letter was inadvertently made lower-case in the previous re-naming CL.
PiperOrigin-RevId: 748086763
2025-04-15 17:43:44 -07:00
Roy Frostig
90af597786
remove inaccurate inline comment in PRNGKeyArray
constructor
...
PiperOrigin-RevId: 748085747
2025-04-15 17:39:40 -07:00
Roy Frostig
47bc2f55dc
convert NumPy RNG key data to uncommitted default-device-backed jax.Array
data
...
Generally, we want to maintain that key data backing a `PRNGKeyArray` is a `jax.Array`. This change converts NumPy arrays on construction.
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 748077900
2025-04-15 17:11:25 -07:00
jax authors
25e0fe59d5
Merge pull request #27984 from carlosgmartin:logsumexp_doc
...
PiperOrigin-RevId: 748059520
2025-04-15 16:10:57 -07:00
jax authors
002be7a1ab
Merge pull request #28047 from jakevdp:logsoftmax-dep
...
PiperOrigin-RevId: 748059518
2025-04-15 16:10:27 -07:00
Yash Katariya
655bfcac39
Enable standard_insert_pvary for optimization_barrier which was disabled before.
...
PiperOrigin-RevId: 748027360
2025-04-15 14:41:08 -07:00
Jake VanderPlas
b271a67bbc
Clean up softmax initial deprecation
2025-04-15 14:36:56 -07:00
Jake VanderPlas
ba8877789d
Roll back https://github.com/jax-ml/jax/pull/28022 due to test breakages.
...
Reverts b336daf747940301de5956dce4ebe790298e6b5b
PiperOrigin-RevId: 747988862
2025-04-15 13:00:04 -07:00
Yash Katariya
6e00b5e02d
[NFC] Rename standard_insert_pbroadcast
to standard_insert_pvary
...
PiperOrigin-RevId: 747943230
2025-04-15 11:02:45 -07:00
Jake VanderPlas
c56cf4f68d
jax.random.bernoulli: use higher-resolution sampler
2025-04-15 08:18:47 -07:00
Georg Stefan Schmid
ae6a18d70d
Add jax.fwd_and_bwd
2025-04-15 08:21:18 +00:00
Chris Jones
1926b99bfd
[pallas] Fix spelling of 'fusible'.
...
PiperOrigin-RevId: 747663692
2025-04-14 19:35:59 -07:00
Mark Sandler
0ed0fb7c54
Adds a debugging message to assert, otherwise the error is pretty cryptic.
...
PiperOrigin-RevId: 747657234
2025-04-14 19:11:02 -07:00
Sharad Vikram
4fa3cd91d3
[Pallas/Fuser] Add basic closed over consts support to pull_block_spec
...
PiperOrigin-RevId: 747657069
2025-04-14 19:09:04 -07:00
Peter Hawkins
57e33bcbcd
Deprecate the contents of jax.util.
...
PiperOrigin-RevId: 747629222
2025-04-14 17:20:30 -07:00
Ivy Zheng
ab600c3e82
Remove obsolete python key path registry.
...
PiperOrigin-RevId: 747613761
2025-04-14 16:33:05 -07:00
jax authors
19be20fc6f
Merge pull request #27919 from kaixih:enable_doc_scaled_dot_fix
...
PiperOrigin-RevId: 747578845
2025-04-14 14:55:23 -07:00
Peter Hawkins
8930a67e63
Fix stablehlo version comparison in test utilities.
...
PiperOrigin-RevId: 747547427
2025-04-14 13:34:32 -07:00
cjkkkk
760d0e0e97
disable packed layout test on old arch prior to Hopper
2025-04-14 20:33:30 +00:00
jax authors
d014912671
Merge pull request #28007 from jakevdp:int-power
...
PiperOrigin-RevId: 747498460
2025-04-14 11:26:05 -07:00
jax authors
6fcb036b96
Merge pull request #27966 from jakevdp:jit-signature
...
PiperOrigin-RevId: 747492659
2025-04-14 11:11:02 -07:00