This PR is similar to https://github.com/jax-ml/jax/pull/24284
Note that `np.testing.assert_allclose()` is changed to `self.assertAllClose()` because the latter is a wrapper with bfloat16 support.
PiperOrigin-RevId: 688581914
Of course no communication can happen across grid dimensions (unlike over the WG dim),
but we need to be able to launch multiple blocks somehow.
PiperOrigin-RevId: 688488660
It's not enough that we have the physical transpose between the order
of tiled dimensions, we also need the user to explicitly transpose the
logical dimensions. This fixes a shape error that was previously hidden
because the RHS was square.
PiperOrigin-RevId: 687350270
Currently, the error message refers to "last two dimensions" which is confusing for a rank-1 case; furthermore, the error does not match the check in the code.
PiperOrigin-RevId: 686520781
- core_map is like a shard_map but it takes in no inputs and outputs
- we can use it in Pallas to generalize mapping a function over the cores of a chip (e.g. TensorCores in a TPU or SMs in a GPU)
- we specify how the function will be mapped over the device with a `mesh` object. This is also a convenient mechanism for picking the backend for pallas to target
PiperOrigin-RevId: 686036101
The approach here is to add a new notion to jax, for ragged_prop. Ragged prop is useful for computing the dynamism/raggedness of an output, given a set of inputs. In the limit, if we decide that this is a useful property to have in jax as a first class citizen, we could fold the raggedness into the type system. At the moment, however, it is just a small set of rules implemented per op.
PiperOrigin-RevId: 685827096
It was previously possible to pass `swizzle` both directly and via `transforms`.
This change eliminates the ambiguity at a slight downgrade to ergonomics.
PiperOrigin-RevId: 684797980
On TPU, instructions differentiate between vectors and scalars, and the corresponding lowering paths are different. Existing Pallas tests only test vector version of operations, but not the scalar version of them. This PR adds tests for scalar elementwise operations.
The structure of the test is similar to the vector version of the tests above.
PiperOrigin-RevId: 684569107
Fixes https://github.com/jax-ml/jax/issues/23972.
In Pallas, we use `i32` for both `jnp.int32` and `jnp.uint32`, but we need to choose the correct operation (e.g. `arith.extUI` vs `arith.extSI`) or the correct attribute (e.g. `sle` vs `ule` for `arith::CmpIOp`).
In this particular issue, we need to use attributes like `ule` for `jnp.uint32`, but it's currently lowered to attributes for `jnp.int32` such as `sle`.
This PR fixes this issue by distinguishing the attributes to use for signed and unsigned types.
PiperOrigin-RevId: 684065893
This is useful so that we don't have to block on the WGMMA immediately after it runs.
`delay_release=n` means that the input/output buffers will not be mutated by the system
for at least `n` sequential steps following the one when they were kernel arguments.
PiperOrigin-RevId: 683629935
This PR is similar to https://github.com/jax-ml/jax/pull/23814.
Background: We run tests on both 32-bit and 64-bit environments. Currently, when the tests encounters 64-bit dtypes on 32-bit environments, it enters into a local 64-bit environment using `stack.enter_context(config.enable_x64(True))`. This is not necessary since we also run the same tests on 64-bit environments. This PR makes those test skipped on 32-bit environments.
PiperOrigin-RevId: 683405197
I also deprecated `jax.experimental.pallas.gpu` in favor of
`jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU
backend.
PiperOrigin-RevId: 683119193
A barrier must be indexed via `.at` and not directly. I wish we could emit
an instructive error for the latter case, but I couldn't find a good place
to put it.
PiperOrigin-RevId: 681857034
I decided to
* split `async_copy_p` into multiple primitives to avoid having extra
control flow in the lowering rule;
* drop the `async_*` prefix from `async_copy_p` and the corresponding
APIs, because the names felt a bit too long otherwise.
Note that barriers cannot be sliced at the moment. I will address that in
a follow up CL.
PiperOrigin-RevId: 681793650
Otherwise, we can simply pass it in as an argument, but we can avoid updating it
since it will always remain constant. Both programs have equivalent semantics,
but this one can be optimized better since it makes it more apparent that the
cond does not actually modify a ref.
PiperOrigin-RevId: 681482148
This enables us to avoid spurious copies in the cases outlined in [the async operations design note](https://jax.readthedocs.io/en/latest/pallas/async_note.html) but not in general, since JAX and/or XLA could introduce copies because we have value semantics. For a proper solution, we need to introduce some notion of buffer semantics to XLA/HLO and preserve it through the lowering of stateful JAX (maybe by avoiding state discharge altogether).
PiperOrigin-RevId: 681206784