9119 Commits

Author SHA1 Message Date
Matt Bahr
81abbac536 add pascal matrix 2025-03-26 02:11:03 +00:00
Daniel Suo
3a593219d4 [jaxlib:cpu] Cleaning up after callback FFI refactor.
PiperOrigin-RevId: 740547947
2025-03-25 17:41:53 -07:00
Yash Katariya
cc51412019 [sharding_in_types] Add out_sharding to jax.random.normal.
Drop into `Auto` mode inside for implementation.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 740538785
2025-03-25 17:12:39 -07:00
Yash Katariya
f1a9241187 Add standard_insert_broadcasts to all traceables in lax.py and checks in abstract_eval rules of those primitives.
PiperOrigin-RevId: 740536031
2025-03-25 17:03:18 -07:00
Yash Katariya
087a38988c [sharding_in_types] Add out_sharding to jax.random.uniform.
Drop into `Auto` mode inside for implementation.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 740529498
2025-03-25 16:42:19 -07:00
Yash Katariya
289fa625e5 [sharding_in_types] Add fold_in support
PiperOrigin-RevId: 740505750
2025-03-25 15:29:32 -07:00
Yash Katariya
ed75189c92 [sharding_in_types] Add support for rng_bit_generator
PiperOrigin-RevId: 740492876
2025-03-25 14:48:27 -07:00
Jacob Burnim
ec06156655 [Pallas] A few fixes for TPU interpret mode:
- Actually de-allocate buffers after a pl.run_scoped.

 - Periodically run an explicit garbage collection after
   de-allocating buffers.

 - Add no-op implementations for a few internal/testing mosaic primitives
   (prng_seed_p, prng_random_bits_p, assume_p, random_p).
2025-03-25 14:47:33 -07:00
jax authors
b3f63dad9d Merge pull request #27394 from jakevdp:likers-jax-array
PiperOrigin-RevId: 740487903
2025-03-25 14:34:42 -07:00
Sergei Lebedev
0a53c9aad2 [pallas:mosaic_gpu] Updated the tests to use plgpu.kernel
It leads to much more compact kernel definitions, just look at the diff!
The combination of `pl.core_map` and `pl.run_state` is too noisy to easily
follow the kernel logic.

PiperOrigin-RevId: 740479934
2025-03-25 14:11:15 -07:00
Jake VanderPlas
85150471e2 Support __jax_array__ in jnp.full_like & co 2025-03-25 13:45:54 -07:00
Dimitar (Mitko) Asenov
8c44b277be [Mosaic GPU] Add warpgroup lowering for BarrierArrive in Pallas.
PiperOrigin-RevId: 740466565
2025-03-25 13:34:10 -07:00
jax authors
6144b371f7 Merge pull request #27355 from jlperla:identity
PiperOrigin-RevId: 740422195
2025-03-25 11:29:55 -07:00
jax authors
650ced5717 Merge pull request #26673 from nvcastet:split_distributed_gpu_pallas_1
PiperOrigin-RevId: 740417766
2025-03-25 11:17:36 -07:00
jax authors
664598f359 Merge pull request #27313 from jburnim:jburnim_pallas_interpret_mode6
PiperOrigin-RevId: 740381115
2025-03-25 09:43:16 -07:00
Sergei Lebedev
a9266a1521 [pallas:mosaic_gpu] PallasCallTest now runs all tests under both Lane and WG thread semantics
PiperOrigin-RevId: 740371195
2025-03-25 09:10:43 -07:00
Nicolas Castet
8260ab3291 Address review comments 2025-03-25 10:20:11 -05:00
Tori Baker
a7d46e6acc Integrate Triton up to [cdb53266](cdb53266e6)
PiperOrigin-RevId: 740345806
2025-03-25 07:49:41 -07:00
Daniel Suo
411450b8b8 Fix Jax XLA FFI callback handlers for OSS GPU.
OSS Jax builds for GPU backends split `jaxlib` into three wheels and since we cannot expect a stable C++ ABI among the shared libraries, we refactor to ensure:

1. C++ objects are not created/consumed by different shared libraries.
2. Static objects are declared and defined appropriately.

This PR:

1. Migrates Jax XLA FFI callback handlers from XLA's Internal FFI API to the [External FFI API](https://github.com/openxla/xla/tree/main/xla/ffi#xla-ffi-external-vs-internal-apis). Note that we update both CPU and GPU handlers because we cannot mix Internal and External APIs.
2. Updates how FFI GPU handlers are registered, now analogous to how the original GPU custom call was registered.
3. Adds an `xla::ffi::ExecutionContext` member to `ifrt::PjRtLoadedExectuable` holding opaque pointers to callbacks.
4. Updates Jax `callback.py` to call the new FFI callback handlers.

PiperOrigin-RevId: 740327296
2025-03-25 06:42:05 -07:00
Dimitar (Mitko) Asenov
ad7550de6d [Mosaic GPU] Add warpgroup lowering for SetMaxRegisters in Pallas.
PiperOrigin-RevId: 740318556
2025-03-25 06:09:06 -07:00
jax authors
e1f7fc9d6e Merge pull request #27398 from mattjj:cristian-attrs
PiperOrigin-RevId: 740312383
2025-03-25 05:51:04 -07:00
Dimitar (Mitko) Asenov
ca30ce6919 [Mosaic GPU] Add warpgroup lowering for AxisIndex in Pallas.
PiperOrigin-RevId: 740280136
2025-03-25 03:41:14 -07:00
Matthew Johnson
b4922df220 [attrs] allow setattr on a previously non-existant attr
Before this change, we handled attrs for initial-style primitives like jit/scan
like this:
1. the traceable would form a jaxpr and see what attrs were touched (by
   jax_getattr or jax_setattr),
2. for each such attr, the traceable would do jax_getattr to get the current
   value, tree-flatten, pass the flat valuesinto the (pure) bind, get the new
   values out, tree-unflatten, then jax_setattr the result.

That approach would error if the function called `jax_setattr` to set a
previously non-existant attr. That is, this would work:

```python
from jax.experimental.attrs import jax_setattr
class Thing: ...
thing = Thing()
jax_setattr(thing, 'x', 1.0)
```
but it wouldn't work under a `jax.jit`.

This commit makes the same code work under a jit. We just
1. in partial_eval.py's `to_jaxpr`, ensure attrs added during jaxpr formation
   are deleted, using a special sentinel value `dne_sentinel` to indicate the
   attribute initially did not exist before tracing;
2. in pjit.py's `_get_states`, when reading initial attr values before the
   pjit_p bind, if the attribute does not exist we don't try to read it and
   instead just use `dne_sentinel` as the value, which is a convenient empty
   pytree;
3. in pjit.py's `_attr_token` for jit caching, when forming the cache key based
   on the current attr states, we map attrs that don't exist to `dne_sentinel`
   (rather than just erroring when the attr doesn't exist, as before).

In short, we use a special value to indicate "does not exist".

If `jax_getattr` supported the 'default' argument, the code would be a little
cleaner since we could avoid the `if hasattr` stuff. And that's probably a
useful feature to have anyway. We can add that in a follow-up.

This PR only makes setattr-to-nonexistant-attr work with jit. We'll add scan
etc in follow-ups.
2025-03-25 03:17:11 +00:00
Yash Katariya
c1904dc7eb Update the docstring to mesh to use computation follows data and jax.jit APIs. Fixes https://github.com/jax-ml/jax/issues/27390
PiperOrigin-RevId: 740104692
2025-03-24 16:07:12 -07:00
Gleb Pobudzey
777d8f2740 [Mosaic GPU] Adding pallas bindings to broadcast over the leading dimension and load a ref into WGMMAColFragLayout format.
PiperOrigin-RevId: 740068368
2025-03-24 14:17:12 -07:00
jax authors
89c7403d61 Merge pull request #27376 from jakevdp:jtu-type-annotations
PiperOrigin-RevId: 740051013
2025-03-24 13:30:33 -07:00
jax authors
24d76ee1a4 Merge pull request #27363 from hawkinsp:pp
PiperOrigin-RevId: 740050893
2025-03-24 13:28:37 -07:00
Jake VanderPlas
7e235e3aee jax.test_util: improve type annotations 2025-03-24 12:43:28 -07:00
Peter Hawkins
13862ec10b Small cleanup to pretty-printer.
Kidger's reimplementation of this code notes that the break mode and indent are unused in the _fits function (851379b8f5/wadler_lindig/_wadler_lindig.py (L166)).
We can make the same optimization here.
2025-03-24 15:29:29 -04:00
Sergei Lebedev
92f231e875 Delay the unflattening in jnp.array
Reverts 53e8eac7134a13c1d28de673e7e3a23b4a837aed

PiperOrigin-RevId: 740012608
2025-03-24 11:32:23 -07:00
Gleb Pobudzey
a2f22cc1de [Mosaic GPU] Adding a primitive to load from memrefs *with* a specified layout.
PiperOrigin-RevId: 739995908
2025-03-24 10:47:25 -07:00
Chris Jones
198d7bb9c2 [pallas] Add support for split into any power-of-two equal parts in Triton lowering.
PiperOrigin-RevId: 739968019
2025-03-24 09:30:37 -07:00
Jacob Burnim
b6b5d95239 [Pallas] In TPU interpret mode, add initial barrier for kernels without one. 2025-03-24 08:38:02 -07:00
jax authors
43ae8d70be Merge pull request #27337 from jburnim:jburnim_interpret_mode7
PiperOrigin-RevId: 739946581
2025-03-24 08:23:55 -07:00
Chris Jones
a2475a66c5 [pallas] Add support for split (into two equal parts) in Triton lowering.
PiperOrigin-RevId: 739855323
2025-03-24 02:06:59 -07:00
Chris Jones
5b0a767d83 [jax] Add ndim and size properties to TransformedRef.
Without these implementations, `ndim` and `size` were retrieved from the underlying, non-transformed reference and were inconsistent with `TransformedRef.shape`.

PiperOrigin-RevId: 739802491
2025-03-23 21:33:51 -07:00
Jesse Perla
5d79df7e67 Add identity activation
Fix typo
2025-03-23 15:13:17 -07:00
Matthew Johnson
a092df90ba fix a linearize-of-remat-of-while_loop-fixpoint bug
We were using the original unknown-carries-in rather than the fixpoint-updated ones.
2025-03-23 03:50:55 +00:00
Chris Jones
74977938d8 [pallas] Add support for DotAlgorithmPreset.BF16_BF16_F32_X{6,9} in Triton lowering.
PiperOrigin-RevId: 739400359
2025-03-21 22:33:40 -07:00
Chris Jones
396e389001 [pallas] Add _zeros[_like] and _ones[_like] utility functions in Triton lowering.
PiperOrigin-RevId: 739395754
2025-03-21 22:03:32 -07:00
Peter Hawkins
55e408471c [JAX] [XLA:Python] Migrate xla_extension and its type stubs into jaxlib.
Future changes will migrate many of its dependent modules.

PiperOrigin-RevId: 739361786
2025-03-21 18:52:54 -07:00
Praveen Narayanan
2692c5ff98 Lower lax.ragged_dot_general to chlo.ragged_dot in some cases on tpu.
PiperOrigin-RevId: 739348011
2025-03-21 17:36:32 -07:00
Nicolas Castet
6b7744581b [Pallas] [1/3] Move communication primitives from mosaic to core 2025-03-21 16:34:52 -05:00
Peter Hawkins
e71bcde543 Remove some long-stale version guards.
PiperOrigin-RevId: 739279729
2025-03-21 13:23:26 -07:00
Brian Zhao
53e8eac713 Reverts be5713309521d5cf0d2252b9c8f1d38ab50952d1
PiperOrigin-RevId: 739258607
2025-03-21 12:12:45 -07:00
Krishna Haridasan
e23069b39c Allow forcing pallas forward compatibility for some backends
PiperOrigin-RevId: 739249745
2025-03-21 11:42:55 -07:00
Ayaka
7dd78d97fa Add support for configurable error checking categories
PiperOrigin-RevId: 739234594
2025-03-21 10:53:34 -07:00
Jacob Burnim
37b5066d5b [Pallas] Fixes scalar prefetch in TPU interpret mode. 2025-03-21 10:45:57 -07:00
Yash Katariya
3163fbaac4 Add varying manual axes rules to mul_p and convert_element_type_p. There are 2 things that need to be added:
1. At the lax level, before we bind the primitive, we need to insert pbroadcasts if some inputs are varying. This is equivalent to the rewrite rules that shard_map has.

2. In abstract_eval rules of primitives, we need to check if all inputs are varying across the same mesh axes and then add the `varying_manual_axes` to the output ShapedArray.

This in turn requires us to support `pbroadcast2` and `psum2` primitives in shard_map.py. These primitives don't need to insert any pbroadcasts (equivalent to `no_rewrite` in shard_map) but need to do checks and update the output aval in their abstract_eval rules.

* pbroadcast_p: Union the existing aval.varying_manual_axes + axes (passed to pbroadcast) to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is empty.

* psum2_p: Remove the named axes from aval.varying_manual_axes to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is NOT empty.

Majority of the primitives should use the standard_insert_pbroadcast and standard_vma_rule and I'll add those in the follow up CLs to other primitives

PiperOrigin-RevId: 739225392
2025-03-21 10:26:18 -07:00
Ayaka
3bf2eea259 Add AOT support for error checking
PiperOrigin-RevId: 739222389
2025-03-21 10:17:36 -07:00