9119 Commits

Author SHA1 Message Date
Jevin Jiang
c1bdd1a234 [Mosaic TPU] Allow specify priority in enqueueDMA.
For now we only support priority 0 (on-demand thread) and priority 1 (background thread) on local DMA.

PiperOrigin-RevId: 743780185
2025-04-03 19:39:45 -07:00
Yash Katariya
5b3e419515 Add auto_axes, explicit_axes and manual_axes properties to Mesh and AbstractMesh
PiperOrigin-RevId: 743767895
2025-04-03 18:35:28 -07:00
Roy Frostig
f8bbe98a86 require out_shardings as a keyword-only argument on public functions
PiperOrigin-RevId: 743753215
2025-04-03 17:26:05 -07:00
Christos Perivolaropoulos
26fc1cde4c [pallas:mgpu] Initial version of inline_mgpu op
PiperOrigin-RevId: 743751560
2025-04-03 17:19:05 -07:00
jax authors
a04b5ecfcd Merge pull request #27717 from froystig:out-shard-perm
PiperOrigin-RevId: 743744847
2025-04-03 16:53:32 -07:00
Christos Perivolaropoulos
7583814e35 [mgpu:pallas] Changes to allow the use of WGMMA_TRANSPOSED_LAYOUT.
It is up to _handle_transposes() to check that the swizzle dimension is not
transposed rather than `UnswizzleRef.untransform_transpose()`. This allows us to
disable the check in certain situations where mgpu can handle it like wgmma and
swap_p when storing a WGMMA_TRANSPOSED_LAYOUT.

If this check is completely skipped it can cause the kernel to crash at runtime.

Furthermore this CL adds a test to check this behavior.

PiperOrigin-RevId: 743738166
2025-04-03 16:30:18 -07:00
Roy Frostig
bbdea54ccb add an out_sharding option to jax.random.permutation
Drop into `Auto` mode in the implementation.
2025-04-03 16:21:45 -07:00
Christos Perivolaropoulos
3901014f9a [pallas:mgpu] General ref transform handling at lowering time.
Replace `_handle_reshape()` and `_handle_index()` with a general
`_handle_transform()` that applies all transforms except tiling and (optionally)
transposes. The implementation is based on
`_untransform_{transpose,reshape,index}()` transform methods on transforms that
find the conjugate of the transpose/reshape/index wrt the transform.

PiperOrigin-RevId: 743731515
2025-04-03 16:07:36 -07:00
kaixih
41868ef06d format 2025-04-03 21:46:10 +00:00
jax authors
1bd0c58f51 Merge pull request #27691 from gnecula:export_override_lowering
PiperOrigin-RevId: 743675002
2025-04-03 13:21:59 -07:00
jax authors
c2eb9c1d9e Eliminate DeprecationWarning in python3.12+ in jax pallas for ~.
The code was using ~ with a boolean, which leads to a new DeprecationWarning. That should only be used with ints.

PiperOrigin-RevId: 743668386
2025-04-03 13:02:33 -07:00
jax authors
d7fc04b682 Merge pull request #27681 from jakevdp:jax-array
PiperOrigin-RevId: 743658654
2025-04-03 12:32:00 -07:00
Justin Fu
780c8827f2 [Mosaic GPU] Fix index_invariant slot in warp-specialized pipeline.
PiperOrigin-RevId: 743633331
2025-04-03 11:20:05 -07:00
jax authors
24fef3dd77 Merge pull request #27304 from wenscarl:nvfp4_grad_ste
PiperOrigin-RevId: 743601775
2025-04-03 09:58:08 -07:00
Sergei Lebedev
f2f9152d57 Moved the jax.Array baseclass to C++
This allows `ArrayImpl` to directly subclass `jax.Array` without relying on
the expensive virtual subclasses machinery from `abc`.

PiperOrigin-RevId: 743573028
2025-04-03 08:28:02 -07:00
George Necula
1941714d26 [export] Add support for override_lowering_rules to jax.export.
This parameter is already part of the internal API for the
AOT lowering function, here we just expose it to `jax.export`.
2025-04-03 16:13:16 +01:00
Sergei Lebedev
552eea8ebd [pallas:mosaic_gpu] emit_pipeline* now passes the loop indices into the body
This replaces the old behavior where `emit_pipeline*` would replace the current
parallel grid with the sequential grid, changing the output of `pl.program_id`.
With this change, `pl.program_id` always works wrt the parallel grid.

PiperOrigin-RevId: 743498194
2025-04-03 03:58:34 -07:00
Sergei Lebedev
ea196dac12 [pallas:mosaic_gpu] Slightly reworded the docstrings for a few recently added primitives
PiperOrigin-RevId: 743492343
2025-04-03 03:35:36 -07:00
jax authors
862342d657 Merge pull request #27688 from froystig:out-shard-randint
PiperOrigin-RevId: 743452501
2025-04-03 01:09:55 -07:00
jax authors
45e6808bb5 Merge pull request #27084 from danielsuo:switch-fwd
PiperOrigin-RevId: 743452172
2025-04-03 01:07:50 -07:00
Roy Frostig
ab816ed8c4 add an out_sharding option to jax.random.randint
Drop into `Auto` mode in the implementation.
2025-04-02 21:05:19 -07:00
Roy Frostig
2f617631fb use common maybe_auto_axes helper in random.uniform 2025-04-02 17:47:25 -07:00
jax authors
aa06e1650f Merge pull request #27687 from froystig:out-shard-bits
PiperOrigin-RevId: 743343131
2025-04-02 17:44:28 -07:00
jax authors
ffbd5ef67a Merge pull request #27677 from dfm:dir-lin-custom-transpose
PiperOrigin-RevId: 743342672
2025-04-02 17:42:29 -07:00
Roy Frostig
2540fcde11 add an out_sharding option to jax.random.bits
Drop into `Auto` mode in the implementation.
2025-04-02 17:19:57 -07:00
kaixih
5ddec65086 Remove asserts 2025-04-03 00:00:25 +00:00
cjkkkk
5e0ccb40d6 add option to expose attention residual to user 2025-04-02 22:55:58 +00:00
Sergei Lebedev
9fa5de7b05 [pallas] Removed pl.device_id. Use lax.axis_index instead.
PiperOrigin-RevId: 743307670
2025-04-02 15:45:52 -07:00
jax authors
c8273d7795 Merge pull request #24197 from yhtang:add-k8s-ci
PiperOrigin-RevId: 743302226
2025-04-02 15:33:18 -07:00
Sergei Lebedev
9c58a112b3 jnp.array no longer accepts None
PiperOrigin-RevId: 743291099
2025-04-02 14:58:51 -07:00
Jake VanderPlas
96780f19b0 jax.numpy: support __jax_array__ in several more functions 2025-04-02 14:28:54 -07:00
jax authors
e75b66463c Merge pull request #27680 from jakevdp:array-api-version
PiperOrigin-RevId: 743277981
2025-04-02 14:22:03 -07:00
Jake VanderPlas
a2d62e2d3a [array_api] update array_api_version to 2024.12 2025-04-02 13:46:07 -07:00
Jake VanderPlas
7f4e8c56fe jnp.concat and friends: support __jax_array__ 2025-04-02 13:17:59 -07:00
Dan Foreman-Mackey
a442fecca8 Fix custom_transpose when composed with custom_jvp and use_direct_linearize=True. 2025-04-02 15:32:15 -04:00
Gunhyun Park
92f7aeab48 Add simple vmap support for lax.ragged_all_to_all.
PiperOrigin-RevId: 743230485
2025-04-02 12:10:34 -07:00
Yash Katariya
3d70fc8197 Add pbroadcast insertion for psum_p in the traceable. This effectively replaces psum_p with psum2_p if varying_axes_in_types is on. psum_p can be replaced with psum2_p in follow up CLs
Also populate the aval of `ShardMapTracer` with `vma`

PiperOrigin-RevId: 743188081
2025-04-02 10:21:21 -07:00
jax authors
056c976ecb Merge pull request #27660 from froystig:xla-meta-ctx
PiperOrigin-RevId: 743178649
2025-04-02 09:59:33 -07:00
Jake VanderPlas
3aeabaedea jnp.isinf & friends: support __jax_array__ 2025-04-02 08:38:54 -07:00
Yash Katariya
2e16367991 Remove the extra stack frame that was introduce in uniform due to dropping the entire function in auto axes.
PiperOrigin-RevId: 743148311
2025-04-02 08:30:27 -07:00
Dan Foreman-Mackey
c18139ba7b Remove legacy GPU kernels for QR decomposition.
Following the compatibility timeline described here: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility

On Apr 2, it will have been 6 months since the release of 0.4.34 which is the relevant release for this kernels.

PiperOrigin-RevId: 743142261
2025-04-02 08:10:11 -07:00
Dan Foreman-Mackey
6242ffb1ca Remove unused Attrs from lu_pivots_to_permutation FFI kernel.
It has been more than 6 months since the release of 0.4.32 which was the first release to stop including `permutation_size` as an attribute when lowering, so it is now safe (via our compatibility policy) to remove this argument.

PiperOrigin-RevId: 743132169
2025-04-02 07:40:57 -07:00
Yash Katariya
10b2cda90e Relax the aval check in select_hlo_lowering_opaque to only check for shardings if they are not empty. The same thing happens in select_p's sharding rule
PiperOrigin-RevId: 743105350
2025-04-02 06:10:33 -07:00
Sergei Lebedev
45d577d3dc Prepare for disallowing jnp.array(None)
PiperOrigin-RevId: 743074472
2025-04-02 04:17:36 -07:00
jax authors
555698103d Merge pull request #27628 from mattjj:random-gamma-grad-no-more-primitive
PiperOrigin-RevId: 743059662
2025-04-02 03:24:59 -07:00
jax authors
2e8ea62b7d Merge pull request #27463 from gnecula:debug_info_fix_kwargs
PiperOrigin-RevId: 743054832
2025-04-02 03:07:00 -07:00
jax authors
398e8b0d79 Merge pull request #27644 from jakevdp:cumulative-jax-array
PiperOrigin-RevId: 743046047
2025-04-02 02:37:13 -07:00
George Necula
076d021057 [better_errors] Fix the handling of kwargs for debug_info.
kwargs are passed sorted by the actual kwarg keyword. This order
must be accounted for when we construct the `debug_info.arg_names`.

Extended the tests to be more precise about not mixing up kwargs,
e.g., use different shapes and look for the shape in the HLO.
2025-04-02 10:32:38 +01:00
Roy Frostig
1875c76bd2 let XLA metadata be unset in nested dynamic scopes
Treat `None` metadata values as a special instruction not to set (or
to unset, if nested) the corresponding entry.

In particular, this makes it possible to unset metadata within the
sub-computations of higher-order operations (e.g. branches in
conditionals, loop bodies, etc.). This can be used, for example, to
annotate a conditional but not all the operations in its
branches. That is, the HLO for the following function `f` on a scalar
float argument:

```
def cos(x):
  with set_xla_metadata(a=None):
    return jnp.cos(x)

@jax.jit
def f(x):
  with set_xla_metadata(a="b"):
    return jax.lax.cond(x < 0., jnp.sin, cos, x)
```

produces an attribute `a` on the conditional and on the sine, but not
on the cosine.
2025-04-01 20:25:19 -07:00
Ayaka
f139192201 Add OOB checks to jax.numpy array indexing
PiperOrigin-RevId: 742927160
2025-04-01 19:17:57 -07:00