- core_map is like a shard_map but it takes in no inputs and outputs
- we can use it in Pallas to generalize mapping a function over the cores of a chip (e.g. TensorCores in a TPU or SMs in a GPU)
- we specify how the function will be mapped over the device with a `mesh` object. This is also a convenient mechanism for picking the backend for pallas to target
PiperOrigin-RevId: 686036101
This is because for Shardy, GSPMDSharding doesn't work, so `device_put` on a mesh with different device order needs `NamedSharding` support. Bonus is that the logic is now simplified wrt the previous version in `_different_device_order_reshard`.
This will also allow us to remove OpSharding usage in other projects which require such kind of permutation capabilities.
PiperOrigin-RevId: 685925636
The approach here is to add a new notion to jax, for ragged_prop. Ragged prop is useful for computing the dynamism/raggedness of an output, given a set of inputs. In the limit, if we decide that this is a useful property to have in jax as a first class citizen, we could fold the raggedness into the type system. At the moment, however, it is just a small set of rules implemented per op.
PiperOrigin-RevId: 685827096
Cases where we error
* batch dimensions not having consistent sharding (ignore None)
* contracting dimensions not having consistent sharding (ignore None)
* lhs.mesh != rhs.mesh
* if batch dimension and tensor dimension sharding match -> Error
PiperOrigin-RevId: 684983567
It was previously possible to pass `swizzle` both directly and via `transforms`.
This change eliminates the ambiguity at a slight downgrade to ergonomics.
PiperOrigin-RevId: 684797980
On TPU, instructions differentiate between vectors and scalars, and the corresponding lowering paths are different. Existing Pallas tests only test vector version of operations, but not the scalar version of them. This PR adds tests for scalar elementwise operations.
The structure of the test is similar to the vector version of the tests above.
PiperOrigin-RevId: 684569107
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.
We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.
PiperOrigin-RevId: 684447186
Fixes https://github.com/jax-ml/jax/issues/23972.
In Pallas, we use `i32` for both `jnp.int32` and `jnp.uint32`, but we need to choose the correct operation (e.g. `arith.extUI` vs `arith.extSI`) or the correct attribute (e.g. `sle` vs `ule` for `arith::CmpIOp`).
In this particular issue, we need to use attributes like `ule` for `jnp.uint32`, but it's currently lowered to attributes for `jnp.int32` such as `sle`.
This PR fixes this issue by distinguishing the attributes to use for signed and unsigned types.
PiperOrigin-RevId: 684065893
1) input layout is AUTO and output layout is not AUTO (i.e. default or concrete)
2) input layout is not AUTO (i.e. default or concrete) and output layout is AUTO
This is because there is a conflict in such cases and almost always leads to the wrong layout being chosen by the compiler. For example, let's talk about (1): since input layout is AUTO and output layout is default and since they are aliased, XLA will end up choose default layout for input which is not what you want in majority of the cases.
Erroring is best in such cases and the user can mark the input layout to be default if they want to do that.
The correct choice is to always make both of them AUTO since you want the compiler to choose the best possible layout instead of choosing the input or output layout if the other one is AUTO.
PiperOrigin-RevId: 683688470
A reshape function that does fold/unfold by touching minimal number of
dimensions to potentially circumvent issues with strided memrefs.
PiperOrigin-RevId: 683663541
This is useful so that we don't have to block on the WGMMA immediately after it runs.
`delay_release=n` means that the input/output buffers will not be mutated by the system
for at least `n` sequential steps following the one when they were kernel arguments.
PiperOrigin-RevId: 683629935
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future.
See https://github.com/google/jax/issues/20385.
PiperOrigin-RevId: 683564340
This PR is similar to https://github.com/jax-ml/jax/pull/23814.
Background: We run tests on both 32-bit and 64-bit environments. Currently, when the tests encounters 64-bit dtypes on 32-bit environments, it enters into a local 64-bit environment using `stack.enter_context(config.enable_x64(True))`. This is not necessary since we also run the same tests on 64-bit environments. This PR makes those test skipped on 32-bit environments.
PiperOrigin-RevId: 683405197
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
I also deprecated `jax.experimental.pallas.gpu` in favor of
`jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU
backend.
PiperOrigin-RevId: 683119193