The helpers `_fragmented_array_to_ir` and `_fragmented_array_from_ir` in
`dialect_lowering.py` have been modified, such that a fragmented array's
signedness no longer appears in its IR representation.
This is because signedness is a reflection of how we make use of the value,
and not an inherent property of it. The appropriate signedness value to use
to reload a fragmented array from IR must be provided by the caller.
PiperOrigin-RevId: 726030853
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
This will allow us to lower Pallas kernels using the Mosaic GPU dialect, and
in turn to perform layout inference and optimization automatically.
The change contains lowering rules for `get` and `swap` (which are necessary
to get a basic example to run), as well as for `add`.
The new lowering path can be used by specifying the `Warpgroup` thread
semantics as part of `pallas_call`'s compiler params.
PiperOrigin-RevId: 725958027
Previously, ffi_call would always return a list for multiple results, but if the input `result_shape_dtypes` is a tuple, we should return a tuple.
PiperOrigin-RevId: 725834048
The current behavior will crash upon trying to convert NoneType to an mlir attribute. This allows a composite to have optional attributes that can be omitted when it's not provided. This behavior is similar to how default values in MLIR is not shown in the IR.
PiperOrigin-RevId: 725786442
Some pallas kernels shouldn't be CSEd even if they share the same inputs.
For example in async pallas scenarios like when you have a kernel starting some DMAs
that are waited in the user of the kernel (to perform async copies) we can't CSE or kernels
might wait multiple times on a DMA that happens only one.
PiperOrigin-RevId: 725752913
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
Next step here is to write a specialization pass that takes the kernel loaded above and binds values to it (already done in prototype/scratch)
PiperOrigin-RevId: 725271468
The main aim here is to clean up lax.linalg to make it a bit easier to maintain and update with new features (e.g. batch partitioning - coming soon!). In this change, I removes some code duplication by consolidate most of the lowering logic into a helper function, and identifying some other common patterns. As part of this, I moved the remaining lowering rules from `jaxlib.lapack` into `lax.linalg`.
PiperOrigin-RevId: 725223882
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
This follows after #26078, #26313, #26348, adding `debug_info` to more calls to `lu.wrap_init`.
As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`.
These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
* Compute the size of a mesh eagerly. We're almost always going to need this, because NamedSharding's constructor asks for it.
* Speed up mesh equality. It's likely we have only one mesh, and the identity equality test will hit. Do it first.
* don't call _prepare_axis_resources in ParsedPartitionSpec construction. This does a bunch of pointless tree flattening and list manipulation but we know we have exactly one PartitionSpec and can directly do the check we need, which is _check_unique_resources.
* only call _check_unique_resources on PartitionSpecs; it's easy to avoid doing it in other cases and then we don't need a bunch of isinstance checks.
* avoid use of collections.Counter when checking for unique resources. collections.Counter has a surprisingly slow isinstance test.
PiperOrigin-RevId: 724431847
This is the first in a series of changes to add a simple API for supporting a set of common sharding and partitioning patterns for FFI calls. The high level motivation is that custom calls (including FFI calls) are opaque to the SPMD partitioner, and the only ways to customize the partitioning behavior is to (a) explicitly register an `xla::CustomCallPartitoner` with XLA, or (b) use the `jax.experimental.custom_partitioning` APIs. Option (a) isn't generally practical for most use cases where the FFI handler lives in an external binary. Option (b) is flexible, and supports all common use cases, but it requires embedding Python callbacks in to the HLO, which can lead to issues including cache misses. Furthermore, `custom_partitioning` is overpowered for many use cases, where only (what I will call) "batch partitioning" is supported.
In this case, "batch partitioning" refers to the behavior of many FFI calls where they can be trivially partitioned on some number of (leading) dimensions, with the same call being executed independently on each shard of data. If the data are sharded on non-batch dimensions, partitioning will still re-shard the data to be replicated on the non-batch dimensions. This kind of partitioning logic applies to all the LAPACK/cuSOLVER/etc.-backed linear algebra functions in jaxlib, as well as some external users of `custom_partitioning`.
The approach I'm taking here is to add a new registration function to the XLA client, which let's a user label their FFI call as batch partitionable. Then, when lowering the custom call, the user passes the number of batch dimensions as a frontend attribute, which is then interpreted by the SPMD partitioner.
In parallel with this change, shardy has added support for sharding propagation across custom calls using a string representation that is similar in spirit to this approach, but somewhat more general. However, the shardy implementation still requires a Python callback for the partitioning step, so it doesn't (yet!) solve all of the relevant problems with the `custom_partitioning` approach. Ultimately, it should be possible to have the partitioner parse the shardy sharding rule representation, but I wanted to start with the minimal implementation.
PiperOrigin-RevId: 724367877