8346 Commits

Author SHA1 Message Date
tttc3
b1b56ea0b0 Enable pivoted QR on GPU via MAGMA.
Originally noted in #20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
2025-02-12 16:12:42 +00:00
jax authors
e14466a8fb Merge pull request #26447 from jakevdp:refactor-contractions
PiperOrigin-RevId: 726043463
2025-02-12 07:14:30 -08:00
Benjamin Chetioui
c7199fe8a5 [Pallas/Mosaic GPU] Enable progressive lowering for integer addition.
The helpers `_fragmented_array_to_ir` and `_fragmented_array_from_ir` in
`dialect_lowering.py` have been modified, such that a fragmented array's
signedness no longer appears in its IR representation.

This is because signedness is a reflection of how we make use of the value,
and not an inherent property of it. The appropriate signedness value to use
to reload a fragmented array from IR must be provided by the caller.

PiperOrigin-RevId: 726030853
2025-02-12 06:29:25 -08:00
jax authors
1e2a5770c9 Merge pull request #26455 from gnecula:debug_info_jaxpr_8
PiperOrigin-RevId: 726023315
2025-02-12 06:03:32 -08:00
George Necula
faa0ad6f33 [better_errors] Continue adding debug info to Jaxprs (step 8)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
2025-02-12 14:23:52 +01:00
Benjamin Chetioui
5ad89006c3 [Pallas/Mosaic GPU] Add initial support for warpgroup semantics in lowering.
This will allow us to lower Pallas kernels using the Mosaic GPU dialect, and
in turn to perform layout inference and optimization automatically.

The change contains lowering rules for `get` and `swap` (which are necessary
to get a basic example to run), as well as for `add`.

The new lowering path can be used by specifying the `Warpgroup` thread
semantics as part of `pallas_call`'s compiler params.

PiperOrigin-RevId: 725958027
2025-02-12 01:47:49 -08:00
Roy Frostig
8720a9c0cd docstrings and API reference doc listing for the traced AOT stage 2025-02-11 22:30:50 -08:00
Yash Katariya
675be0121b Add a custom __reduce__ for UnconstrainedSingleton because it can be picked and then loaded back and we need the id of P.UNCONSTRAINED to match before and after loading.
PiperOrigin-RevId: 725874879
2025-02-11 20:05:48 -08:00
Dan Foreman-Mackey
bba09137dc Match output container to result_shape_dtypes in ffi_call.
Previously, ffi_call would always return a list for multiple results, but if the input `result_shape_dtypes` is a tuple, we should return a tuple.

PiperOrigin-RevId: 725834048
2025-02-11 17:33:32 -08:00
Jake VanderPlas
e389b707ba Add public APIs for jax.lax monoidal reductions 2025-02-11 16:00:03 -08:00
jax authors
d0b6c677b0 Merge pull request #26470 from jakevdp:lax-docs
PiperOrigin-RevId: 725804083
2025-02-11 15:58:52 -08:00
Gunhyun Park
6b19bb2091 Allow composites to provide default kwargs with None value
The current behavior will crash upon trying to convert NoneType to an mlir attribute. This allows a composite to have optional attributes that can be omitted when it's not provided. This behavior is similar to how default values in MLIR is not shown in the IR.

PiperOrigin-RevId: 725786442
2025-02-11 15:05:50 -08:00
Jake VanderPlas
e488956092 jax.lax: improve docs for real, imag, complex, conj, and abs. 2025-02-11 14:12:22 -08:00
Marcello Maggioni
6c6b5ec582 [JAX/Pallas] Add has_side_effect parameter to CompilerParams to stop CSE of operations.
Some pallas kernels shouldn't be CSEd even if they share the same inputs.
For example in async pallas scenarios like when you have a kernel starting some DMAs
that are waited in the user of the kernel (to perform async copies) we can't CSE or kernels
might wait multiple times on a DMA that happens only one.

PiperOrigin-RevId: 725752913
2025-02-11 13:33:01 -08:00
Gunhyun Park
7994aa82f8 Delete unused code in _dot_batch_rule
PiperOrigin-RevId: 725725676
2025-02-11 12:16:01 -08:00
Yash Katariya
005c14b4da [sharding_in_types] Error out if the sharding's specs passed to with_sharding_constraint don't refer to Auto axes.
PiperOrigin-RevId: 725679220
2025-02-11 10:16:52 -08:00
jax authors
2c165bffc9 [pallas:triton] Lift dot_general restriction on minimal tile size for a.
PiperOrigin-RevId: 725605869
2025-02-11 06:27:16 -08:00
Dan Foreman-Mackey
c502332ed5 Add "sequential_unrolled" vmap method for callbacks.
Like the `sequential` method, this loops over calls to the callback, but in this case, the loop is unrolled.

PiperOrigin-RevId: 725601366
2025-02-11 06:09:16 -08:00
George Necula
550d1aa187 [better_errors] Continue adding debug info to Jaxprs (step 6)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
2025-02-11 11:28:58 +01:00
Zixuan Jiang
4b1400dbb9 #jax Optimize jax.numpy.take_along_axis along the dimension satisfies
* the dimension is not the one along which to take values
* the dimension size of input tensor is 1
* the dimension size of the indices is not 1

Previously, we create constant zero as the dummy indices, which is redundant. We can squeeze the input tensor and generate the `stablehlo.gather` directly.

In the following example,
```
h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32)
g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8)
q0 = jnp.take_along_axis(h, g, axis=-2)
```
It lowers to the following module before this change,
```
module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) {
    %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32)
    return %0 : tensor<2x3x5x11x13xf32> loc(#loc)
  } loc(#loc)
  func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33)
    %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32)
    %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34)
    %2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35)
    %3 = stablehlo.compare  LT, %0, %2,  SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc35)
    %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32)
    %4 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc36)
    %5 = stablehlo.add %0, %4 : tensor<2x3x5x11x1xi32> loc(#loc36)
    %6 = stablehlo.select %3, %5, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc37)
    %7 = stablehlo.concatenate %1, %6, dim = 4 : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x2xi32> loc(#loc38)
    %c_1 = stablehlo.constant dense<[0, 6]> : tensor<2xi64> loc(#loc39)
    %8 = stablehlo.convert %7 : (tensor<2x3x5x11x2xi32>) -> tensor<2x3x5x11x2xi64> loc(#loc33)
    %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc39)
    %9 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x2xi64> loc(#loc40)
    %10 = stablehlo.compare  GE, %8, %9,  SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc40)
    %11 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<2xi64>) -> tensor<1x1x1x1x2xi64> loc(#loc34)
    %12 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x2xi64>) -> tensor<2x3x5x11x2xi64> loc(#loc41)
    %13 = stablehlo.compare  LE, %8, %12,  SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc41)
    %14 = stablehlo.and %10, %13 : tensor<2x3x5x11x2xi1> loc(#loc42)
    %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43)
    %15 = stablehlo.reduce(%14 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x2xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43)
    %16 = "stablehlo.gather"(%arg0, %8) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [1, 3], operand_batching_dims = [0, 2], start_indices_batching_dims = [0, 2], start_index_map = [1, 3], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 1, 13>}> : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc39)
    %17 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc34)
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc39)
    %18 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc34)
    %19 = stablehlo.select %17, %16, %18 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc37)
    return %19 : tensor<2x3x5x11x13xf32> loc(#loc32)
  }
}
```

With this change, we have
```
module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) {
    %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32)
    return %0 : tensor<2x3x5x11x13xf32> loc(#loc)
  } loc(#loc)
  func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33)
    %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32)
    %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34)
    %2 = stablehlo.compare  LT, %0, %1,  SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc34)
    %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32)
    %3 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35)
    %4 = stablehlo.add %0, %3 : tensor<2x3x5x11x1xi32> loc(#loc35)
    %5 = stablehlo.select %2, %4, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc36)
    %6 = stablehlo.reshape %arg0 : (tensor<2x1x5x7x13xf32>) -> tensor<2x5x7x13xf32> loc(#loc37)
    %c_1 = stablehlo.constant dense<6> : tensor<1xi64> loc(#loc38)
    %7 = stablehlo.convert %5 : (tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi64> loc(#loc33)
    %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc38)
    %8 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x1xi64> loc(#loc39)
    %9 = stablehlo.compare  GE, %7, %8,  SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc39)
    %10 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<1xi64>) -> tensor<1x1x1x1x1xi64> loc(#loc40)
    %11 = stablehlo.broadcast_in_dim %10, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x1xi64>) -> tensor<2x3x5x11x1xi64> loc(#loc41)
    %12 = stablehlo.compare  LE, %7, %11,  SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc41)
    %13 = stablehlo.and %9, %12 : tensor<2x3x5x11x1xi1> loc(#loc42)
    %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43)
    %14 = stablehlo.reduce(%13 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x1xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43)
    %15 = "stablehlo.gather"(%6, %7) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [2], operand_batching_dims = [0, 1], start_indices_batching_dims = [0, 2], start_index_map = [2], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 13>}> : (tensor<2x5x7x13xf32>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc38)
    %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc40)
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc38)
    %17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc40)
    %18 = stablehlo.select %16, %15, %17 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc36)
    return %18 : tensor<2x3x5x11x13xf32> loc(#loc32)
  }
}
```

PiperOrigin-RevId: 725506779
2025-02-11 00:08:46 -08:00
Jake VanderPlas
e6fc7f3e87 refactor: move lax_numpy tensor contractions into their own file 2025-02-10 18:56:18 -08:00
jax authors
ffd3faad72 [TPU[Mosaic] Fix missing sfences in smem DMAs
PiperOrigin-RevId: 725376627
2025-02-10 15:51:35 -08:00
jax authors
b7d012281e Merge pull request #26423 from gnecula:debug_info_jaxpr_7
PiperOrigin-RevId: 725317552
2025-02-10 12:58:26 -08:00
Kevin Gleason
319f7f5a2d [StableHLO] Allow composites to use dtype values in attributes
PiperOrigin-RevId: 725305384
2025-02-10 12:21:56 -08:00
jax authors
6a638ac832 Merge pull request #26426 from NKlug:nklug-linalg-solve-docu
PiperOrigin-RevId: 725296564
2025-02-10 11:57:48 -08:00
jax authors
6bedabd386 [TPU][Pallas][XLA] Add BUILD time codegen tool that turns a pallas kernel into a parameterized kernel loader header that can be utilized anywhere in C++
Next step here is to write a specialization pass that takes the kernel loaded above and binds values to it (already done in prototype/scratch)

PiperOrigin-RevId: 725271468
2025-02-10 10:45:32 -08:00
jax authors
4f979b4496 Merge pull request #26386 from gnecula:debug_info_jaxpr_5
PiperOrigin-RevId: 725228029
2025-02-10 08:41:17 -08:00
Dan Foreman-Mackey
154e4506c0 Some lax.linalg housekeeping.
The main aim here is to clean up lax.linalg to make it a bit easier to maintain and update with new features (e.g. batch partitioning - coming soon!). In this change, I removes some code duplication by consolidate most of the lowering logic into a helper function, and identifying some other common patterns. As part of this, I moved the remaining lowering rules from `jaxlib.lapack` into `lax.linalg`.

PiperOrigin-RevId: 725223882
2025-02-10 08:27:18 -08:00
jax authors
1a8d537728 Merge pull request #26384 from gnecula:debug_info_jaxpr_4
PiperOrigin-RevId: 725210049
2025-02-10 07:42:57 -08:00
jax authors
260a879bbf Merge pull request #26411 from jakevdp:jnp-window-functions
PiperOrigin-RevId: 725195238
2025-02-10 06:46:07 -08:00
jax authors
6740165e4f [Pallas] Add pipeline mode to pltpu
PiperOrigin-RevId: 725133131
2025-02-10 02:36:44 -08:00
Peter Hawkins
cf308a84d9 Use an PartitionSpec.UNCONSTRAINED to represent unconstrained dimensions in ParsedPartitionSpec, rather than None.
This makes PartitionSpec and ParsedPartitionSpec more similar, and fixes some TODOs.

PiperOrigin-RevId: 724927217
2025-02-09 08:45:01 -08:00
George Necula
817b3e5757 [better_errors] Continue adding debug info to Jaxprs (step 7)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
2025-02-09 18:14:33 +02:00
Nikolas Klug
af794721f9 Document behavior of linalg.solve in case the system matrix is singular 2025-02-08 18:09:13 +01:00
George Necula
1e813e1693 [better_errors] Continue adding debug info to Jaxprs (step 4)
This follows after #26078, #26313, #26348, adding `debug_info` to more calls to `lu.wrap_init`.

As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`.

These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
2025-02-08 09:13:55 +02:00
jax authors
289035747e Merge pull request #26407 from jakevdp:printoptions-doc
PiperOrigin-RevId: 724487999
2025-02-07 15:22:52 -08:00
jax authors
a7c20e7fa2 Merge pull request #26406 from jakevdp:lax-docs
PiperOrigin-RevId: 724465108
2025-02-07 14:08:02 -08:00
Yash Katariya
21e1be3320 Don't call get_cur_mesh_sharding if sharding-in-types mode is not enabled
PiperOrigin-RevId: 724461150
2025-02-07 13:55:38 -08:00
Peter Hawkins
f21b0f03b4 Speed up NamedSharding construction.
* Compute the size of a mesh eagerly. We're almost always going to need this, because NamedSharding's constructor asks for it.
* Speed up mesh equality. It's likely we have only one mesh, and the identity equality test will hit. Do it first.
* don't call _prepare_axis_resources in ParsedPartitionSpec construction. This does a bunch of pointless tree flattening and list manipulation but we know we have exactly one PartitionSpec and can directly do the check we need, which is _check_unique_resources.
* only call _check_unique_resources on PartitionSpecs; it's easy to avoid doing it in other cases and then we don't need a bunch of isinstance checks.
* avoid use of collections.Counter when checking for unique resources. collections.Counter has a surprisingly slow isinstance test.

PiperOrigin-RevId: 724431847
2025-02-07 12:20:51 -08:00
Jake VanderPlas
17215177fa refactor: move lax_numpy window functions into their own file 2025-02-07 11:21:38 -08:00
jax authors
c0ba36260e Merge pull request #26377 from mattjj:maintain-mutable-array-sharding
PiperOrigin-RevId: 724405629
2025-02-07 11:04:56 -08:00
jax authors
ec477634f1 Merge pull request #26376 from jakevdp:array-creation
PiperOrigin-RevId: 724399604
2025-02-07 10:48:05 -08:00
Sergei Lebedev
e5058079c9 [pallas:mosaic_gpu] Fixed a bug in how delay_release is handled in emit_pipeline
PiperOrigin-RevId: 724395676
2025-02-07 10:37:21 -08:00
Matthew Johnson
719031c1fd [mutable-arrays] persist shardings through xla computations 2025-02-07 18:33:24 +00:00
jax authors
3b470b9530 Merge pull request #26383 from jakevdp:jnp-sorting
PiperOrigin-RevId: 724381260
2025-02-07 10:00:29 -08:00
Jake VanderPlas
08563842b9 DOC: make clear that printoptions are NumPy aliases 2025-02-07 09:56:52 -08:00
Jake VanderPlas
311e6683e4 jax.lax: better docs for hyperbolic trig functions 2025-02-07 09:33:25 -08:00
Dan Foreman-Mackey
c521bc6205 [xla:python] Add a mechanism for "batch partitioning" of FFI calls.
This is the first in a series of changes to add a simple API for supporting a set of common sharding and partitioning patterns for FFI calls. The high level motivation is that custom calls (including FFI calls) are opaque to the SPMD partitioner, and the only ways to customize the partitioning behavior is to (a) explicitly register an `xla::CustomCallPartitoner` with XLA, or (b) use the `jax.experimental.custom_partitioning` APIs. Option (a) isn't generally practical for most use cases where the FFI handler lives in an external binary. Option (b) is flexible, and supports all common use cases, but it requires embedding Python callbacks in to the HLO, which can lead to issues including cache misses. Furthermore, `custom_partitioning` is overpowered for many use cases, where only (what I will call) "batch partitioning" is supported.

In this case, "batch partitioning" refers to the behavior of many FFI calls where they can be trivially partitioned on some number of (leading) dimensions, with the same call being executed independently on each shard of data. If the data are sharded on non-batch dimensions, partitioning will still re-shard the data to be replicated on the non-batch dimensions. This kind of partitioning logic applies to all the LAPACK/cuSOLVER/etc.-backed linear algebra functions in jaxlib, as well as some external users of `custom_partitioning`.

The approach I'm taking here is to add a new registration function to the XLA client, which let's a user label their FFI call as batch partitionable. Then, when lowering the custom call, the user passes the number of batch dimensions as a frontend attribute, which is then interpreted by the SPMD partitioner.

In parallel with this change, shardy has added support for sharding propagation across custom calls using a string representation that is similar in spirit to this approach, but somewhat more general. However, the shardy implementation still requires a Python callback for the partitioning step, so it doesn't (yet!) solve all of the relevant problems with the `custom_partitioning` approach. Ultimately, it should be possible to have the partitioner parse the shardy sharding rule representation, but I wanted to start with the minimal implementation.

PiperOrigin-RevId: 724367877
2025-02-07 09:14:06 -08:00
Jake VanderPlas
d3b3cd369f refactor: move sorting ops out of lax_numpy 2025-02-07 08:18:04 -08:00
jax authors
e56c7dc502 Merge pull request #26344 from Cjkkkk:disable_head_256_on_bw
PiperOrigin-RevId: 724333455
2025-02-07 07:08:18 -08:00