Jevin Jiang
c1bdd1a234
[Mosaic TPU] Allow specify priority in enqueueDMA.
...
For now we only support priority 0 (on-demand thread) and priority 1 (background thread) on local DMA.
PiperOrigin-RevId: 743780185
2025-04-03 19:39:45 -07:00
Yash Katariya
5b3e419515
Add auto_axes
, explicit_axes
and manual_axes
properties to Mesh and AbstractMesh
...
PiperOrigin-RevId: 743767895
2025-04-03 18:35:28 -07:00
Roy Frostig
f8bbe98a86
require out_shardings
as a keyword-only argument on public functions
...
PiperOrigin-RevId: 743753215
2025-04-03 17:26:05 -07:00
Christos Perivolaropoulos
26fc1cde4c
[pallas:mgpu] Initial version of inline_mgpu op
...
PiperOrigin-RevId: 743751560
2025-04-03 17:19:05 -07:00
jax authors
a04b5ecfcd
Merge pull request #27717 from froystig:out-shard-perm
...
PiperOrigin-RevId: 743744847
2025-04-03 16:53:32 -07:00
Christos Perivolaropoulos
7583814e35
[mgpu:pallas] Changes to allow the use of WGMMA_TRANSPOSED_LAYOUT.
...
It is up to _handle_transposes() to check that the swizzle dimension is not
transposed rather than `UnswizzleRef.untransform_transpose()`. This allows us to
disable the check in certain situations where mgpu can handle it like wgmma and
swap_p when storing a WGMMA_TRANSPOSED_LAYOUT.
If this check is completely skipped it can cause the kernel to crash at runtime.
Furthermore this CL adds a test to check this behavior.
PiperOrigin-RevId: 743738166
2025-04-03 16:30:18 -07:00
Roy Frostig
bbdea54ccb
add an out_sharding
option to jax.random.permutation
...
Drop into `Auto` mode in the implementation.
2025-04-03 16:21:45 -07:00
Christos Perivolaropoulos
3901014f9a
[pallas:mgpu] General ref transform handling at lowering time.
...
Replace `_handle_reshape()` and `_handle_index()` with a general
`_handle_transform()` that applies all transforms except tiling and (optionally)
transposes. The implementation is based on
`_untransform_{transpose,reshape,index}()` transform methods on transforms that
find the conjugate of the transpose/reshape/index wrt the transform.
PiperOrigin-RevId: 743731515
2025-04-03 16:07:36 -07:00
kaixih
41868ef06d
format
2025-04-03 21:46:10 +00:00
jax authors
1bd0c58f51
Merge pull request #27691 from gnecula:export_override_lowering
...
PiperOrigin-RevId: 743675002
2025-04-03 13:21:59 -07:00
jax authors
c2eb9c1d9e
Eliminate DeprecationWarning in python3.12+ in jax pallas for ~.
...
The code was using ~ with a boolean, which leads to a new DeprecationWarning. That should only be used with ints.
PiperOrigin-RevId: 743668386
2025-04-03 13:02:33 -07:00
jax authors
d7fc04b682
Merge pull request #27681 from jakevdp:jax-array
...
PiperOrigin-RevId: 743658654
2025-04-03 12:32:00 -07:00
Justin Fu
780c8827f2
[Mosaic GPU] Fix index_invariant slot in warp-specialized pipeline.
...
PiperOrigin-RevId: 743633331
2025-04-03 11:20:05 -07:00
jax authors
24fef3dd77
Merge pull request #27304 from wenscarl:nvfp4_grad_ste
...
PiperOrigin-RevId: 743601775
2025-04-03 09:58:08 -07:00
Sergei Lebedev
f2f9152d57
Moved the jax.Array
baseclass to C++
...
This allows `ArrayImpl` to directly subclass `jax.Array` without relying on
the expensive virtual subclasses machinery from `abc`.
PiperOrigin-RevId: 743573028
2025-04-03 08:28:02 -07:00
George Necula
1941714d26
[export] Add support for override_lowering_rules to jax.export.
...
This parameter is already part of the internal API for the
AOT lowering function, here we just expose it to `jax.export`.
2025-04-03 16:13:16 +01:00
Sergei Lebedev
552eea8ebd
[pallas:mosaic_gpu] emit_pipeline*
now passes the loop indices into the body
...
This replaces the old behavior where `emit_pipeline*` would replace the current
parallel grid with the sequential grid, changing the output of `pl.program_id`.
With this change, `pl.program_id` always works wrt the parallel grid.
PiperOrigin-RevId: 743498194
2025-04-03 03:58:34 -07:00
Sergei Lebedev
ea196dac12
[pallas:mosaic_gpu] Slightly reworded the docstrings for a few recently added primitives
...
PiperOrigin-RevId: 743492343
2025-04-03 03:35:36 -07:00
jax authors
862342d657
Merge pull request #27688 from froystig:out-shard-randint
...
PiperOrigin-RevId: 743452501
2025-04-03 01:09:55 -07:00
jax authors
45e6808bb5
Merge pull request #27084 from danielsuo:switch-fwd
...
PiperOrigin-RevId: 743452172
2025-04-03 01:07:50 -07:00
Roy Frostig
ab816ed8c4
add an out_sharding
option to jax.random.randint
...
Drop into `Auto` mode in the implementation.
2025-04-02 21:05:19 -07:00
Roy Frostig
2f617631fb
use common maybe_auto_axes
helper in random.uniform
2025-04-02 17:47:25 -07:00
jax authors
aa06e1650f
Merge pull request #27687 from froystig:out-shard-bits
...
PiperOrigin-RevId: 743343131
2025-04-02 17:44:28 -07:00
jax authors
ffbd5ef67a
Merge pull request #27677 from dfm:dir-lin-custom-transpose
...
PiperOrigin-RevId: 743342672
2025-04-02 17:42:29 -07:00
Roy Frostig
2540fcde11
add an out_sharding
option to jax.random.bits
...
Drop into `Auto` mode in the implementation.
2025-04-02 17:19:57 -07:00
kaixih
5ddec65086
Remove asserts
2025-04-03 00:00:25 +00:00
cjkkkk
5e0ccb40d6
add option to expose attention residual to user
2025-04-02 22:55:58 +00:00
Sergei Lebedev
9fa5de7b05
[pallas] Removed pl.device_id
. Use lax.axis_index
instead.
...
PiperOrigin-RevId: 743307670
2025-04-02 15:45:52 -07:00
jax authors
c8273d7795
Merge pull request #24197 from yhtang:add-k8s-ci
...
PiperOrigin-RevId: 743302226
2025-04-02 15:33:18 -07:00
Sergei Lebedev
9c58a112b3
jnp.array
no longer accepts None
...
PiperOrigin-RevId: 743291099
2025-04-02 14:58:51 -07:00
Jake VanderPlas
96780f19b0
jax.numpy: support __jax_array__ in several more functions
2025-04-02 14:28:54 -07:00
jax authors
e75b66463c
Merge pull request #27680 from jakevdp:array-api-version
...
PiperOrigin-RevId: 743277981
2025-04-02 14:22:03 -07:00
Jake VanderPlas
a2d62e2d3a
[array_api] update array_api_version to 2024.12
2025-04-02 13:46:07 -07:00
Jake VanderPlas
7f4e8c56fe
jnp.concat and friends: support __jax_array__
2025-04-02 13:17:59 -07:00
Dan Foreman-Mackey
a442fecca8
Fix custom_transpose when composed with custom_jvp and use_direct_linearize=True.
2025-04-02 15:32:15 -04:00
Gunhyun Park
92f7aeab48
Add simple vmap support for lax.ragged_all_to_all.
...
PiperOrigin-RevId: 743230485
2025-04-02 12:10:34 -07:00
Yash Katariya
3d70fc8197
Add pbroadcast insertion for psum_p
in the traceable. This effectively replaces psum_p
with psum2_p
if varying_axes_in_types
is on. psum_p can be replaced with psum2_p in follow up CLs
...
Also populate the aval of `ShardMapTracer` with `vma`
PiperOrigin-RevId: 743188081
2025-04-02 10:21:21 -07:00
jax authors
056c976ecb
Merge pull request #27660 from froystig:xla-meta-ctx
...
PiperOrigin-RevId: 743178649
2025-04-02 09:59:33 -07:00
Jake VanderPlas
3aeabaedea
jnp.isinf & friends: support __jax_array__
2025-04-02 08:38:54 -07:00
Yash Katariya
2e16367991
Remove the extra stack frame that was introduce in uniform due to dropping the entire function in auto axes.
...
PiperOrigin-RevId: 743148311
2025-04-02 08:30:27 -07:00
Dan Foreman-Mackey
c18139ba7b
Remove legacy GPU kernels for QR decomposition.
...
Following the compatibility timeline described here: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility
On Apr 2, it will have been 6 months since the release of 0.4.34 which is the relevant release for this kernels.
PiperOrigin-RevId: 743142261
2025-04-02 08:10:11 -07:00
Dan Foreman-Mackey
6242ffb1ca
Remove unused Attrs from lu_pivots_to_permutation
FFI kernel.
...
It has been more than 6 months since the release of 0.4.32 which was the first release to stop including `permutation_size` as an attribute when lowering, so it is now safe (via our compatibility policy) to remove this argument.
PiperOrigin-RevId: 743132169
2025-04-02 07:40:57 -07:00
Yash Katariya
10b2cda90e
Relax the aval check in select_hlo_lowering_opaque
to only check for shardings if they are not empty. The same thing happens in select_p's sharding rule
...
PiperOrigin-RevId: 743105350
2025-04-02 06:10:33 -07:00
Sergei Lebedev
45d577d3dc
Prepare for disallowing jnp.array(None)
...
PiperOrigin-RevId: 743074472
2025-04-02 04:17:36 -07:00
jax authors
555698103d
Merge pull request #27628 from mattjj:random-gamma-grad-no-more-primitive
...
PiperOrigin-RevId: 743059662
2025-04-02 03:24:59 -07:00
jax authors
2e8ea62b7d
Merge pull request #27463 from gnecula:debug_info_fix_kwargs
...
PiperOrigin-RevId: 743054832
2025-04-02 03:07:00 -07:00
jax authors
398e8b0d79
Merge pull request #27644 from jakevdp:cumulative-jax-array
...
PiperOrigin-RevId: 743046047
2025-04-02 02:37:13 -07:00
George Necula
076d021057
[better_errors] Fix the handling of kwargs for debug_info.
...
kwargs are passed sorted by the actual kwarg keyword. This order
must be accounted for when we construct the `debug_info.arg_names`.
Extended the tests to be more precise about not mixing up kwargs,
e.g., use different shapes and look for the shape in the HLO.
2025-04-02 10:32:38 +01:00
Roy Frostig
1875c76bd2
let XLA metadata be unset in nested dynamic scopes
...
Treat `None` metadata values as a special instruction not to set (or
to unset, if nested) the corresponding entry.
In particular, this makes it possible to unset metadata within the
sub-computations of higher-order operations (e.g. branches in
conditionals, loop bodies, etc.). This can be used, for example, to
annotate a conditional but not all the operations in its
branches. That is, the HLO for the following function `f` on a scalar
float argument:
```
def cos(x):
with set_xla_metadata(a=None):
return jnp.cos(x)
@jax.jit
def f(x):
with set_xla_metadata(a="b"):
return jax.lax.cond(x < 0., jnp.sin, cos, x)
```
produces an attribute `a` on the conditional and on the sine, but not
on the cosine.
2025-04-01 20:25:19 -07:00
Ayaka
f139192201
Add OOB checks to jax.numpy array indexing
...
PiperOrigin-RevId: 742927160
2025-04-01 19:17:57 -07:00