This type is unused by JAX, so there is no replacement.
(JAX does have an internal PaddingType enum in lax, but it is not present in any APIs, as best I can tell.)
PiperOrigin-RevId: 684451556
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.
We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.
PiperOrigin-RevId: 684447186
* Separate MXU size to MXU contracting size and MXU non-contracting size.
* Rename tile to group for MXU shaped tiling since tile is overused in Mosaic.
PiperOrigin-RevId: 684116306
Fixes https://github.com/jax-ml/jax/issues/23972.
In Pallas, we use `i32` for both `jnp.int32` and `jnp.uint32`, but we need to choose the correct operation (e.g. `arith.extUI` vs `arith.extSI`) or the correct attribute (e.g. `sle` vs `ule` for `arith::CmpIOp`).
In this particular issue, we need to use attributes like `ule` for `jnp.uint32`, but it's currently lowered to attributes for `jnp.int32` such as `sle`.
This PR fixes this issue by distinguishing the attributes to use for signed and unsigned types.
PiperOrigin-RevId: 684065893
Even when the total size of manual axes is 1, and we can skip creating the `ManualComputationOp`, we need to have the body of what was supposed to be the `shard_map` operate under this new context.
PiperOrigin-RevId: 684055903
1) input layout is AUTO and output layout is not AUTO (i.e. default or concrete)
2) input layout is not AUTO (i.e. default or concrete) and output layout is AUTO
This is because there is a conflict in such cases and almost always leads to the wrong layout being chosen by the compiler. For example, let's talk about (1): since input layout is AUTO and output layout is default and since they are aliased, XLA will end up choose default layout for input which is not what you want in majority of the cases.
Erroring is best in such cases and the user can mark the input layout to be default if they want to do that.
The correct choice is to always make both of them AUTO since you want the compiler to choose the best possible layout instead of choosing the input or output layout if the other one is AUTO.
PiperOrigin-RevId: 683688470
There is currently an issue with the Mosaic compiler that prevents emitting code that returns semaphores in the presence of the grid argument.
PiperOrigin-RevId: 683681627
A reshape function that does fold/unfold by touching minimal number of
dimensions to potentially circumvent issues with strided memrefs.
PiperOrigin-RevId: 683663541
This is useful so that we don't have to block on the WGMMA immediately after it runs.
`delay_release=n` means that the input/output buffers will not be mutated by the system
for at least `n` sequential steps following the one when they were kernel arguments.
PiperOrigin-RevId: 683629935
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future.
See https://github.com/google/jax/issues/20385.
PiperOrigin-RevId: 683564340
This PR is similar to https://github.com/jax-ml/jax/pull/23814.
Background: We run tests on both 32-bit and 64-bit environments. Currently, when the tests encounters 64-bit dtypes on 32-bit environments, it enters into a local 64-bit environment using `stack.enter_context(config.enable_x64(True))`. This is not necessary since we also run the same tests on 64-bit environments. This PR makes those test skipped on 32-bit environments.
PiperOrigin-RevId: 683405197