1006 Commits

Author SHA1 Message Date
Ruturaj4
a2d79936df [ROCM] Fix BUILD.bazel library source paths 2024-08-07 09:18:20 -05:00
jax authors
cc9665749f Merge pull request #22901 from ROCm:ci_test_harness_vmap
PiperOrigin-RevId: 660089572
2024-08-06 14:04:57 -07:00
Ruturaj4
707cdd4706 [ROCM] Fix hipsolverSsyevd tests due to align with the rocm behavior. 2024-08-06 14:10:09 -05:00
Dan Foreman-Mackey
23da11b609 Re-land FFI port of GPU LU decomposition after fixing XLA FFI memory leak.
PiperOrigin-RevId: 659867028
2024-08-06 02:13:21 -07:00
Paweł Paruzel
b2a469b361 Port Eigenvalue Decompositions to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 659492696
2024-08-05 03:18:13 -07:00
John Ryan
56ff247c2e Reverts 80560663d3fab4c0c3f87d7c8e52fb9931526dbb
PiperOrigin-RevId: 659334027
2024-08-04 12:11:30 -07:00
Adam Paszke
f85b8e677b [Mosaic TPU] Add support for bf16 reductions
PiperOrigin-RevId: 658787017
2024-08-02 07:42:27 -07:00
Adam Paszke
e88887eda5 [Mosaic TPU] Add a missing reshape in relayout
The fact that src generalizes dst does not mean that they have the same implicit
tile shape (if one has an implicit dim and the other one doesn't, then they will
differ by a singleton dimension).

PiperOrigin-RevId: 658775019
2024-08-02 06:44:31 -07:00
Adam Paszke
959657a489 [Mosaic TPU] Remove special handling of implicit dim in relayout
Now all changes happen inside the dedicated functions.

PiperOrigin-RevId: 658763465
2024-08-02 05:46:26 -07:00
Dan Foreman-Mackey
80560663d3 Enable FFI implementation of GPU Getrf FFI handler.
PiperOrigin-RevId: 658755392
2024-08-02 05:07:02 -07:00
Adam Paszke
99625ff577 [Mosaic TPU] Break out implicit dim changes from relayout
PiperOrigin-RevId: 658752228
2024-08-02 04:50:40 -07:00
Paweł Paruzel
6b0b222a38 Activate LU Decomposition to XLA's FFI
PiperOrigin-RevId: 658721697
2024-08-02 02:22:53 -07:00
Dan Foreman-Mackey
8df0c3a9cc Port Getrf GPU kernel from custom call to FFI.
PiperOrigin-RevId: 658550170
2024-08-01 15:02:25 -07:00
Dan Foreman-Mackey
f20efc630f Move jaxlib GPU handlers to separate build target.
In anticipation of refactoring the jaxlib GPU custom calls into FFI calls, this change moves the implementation of `BlasHandlePool`, `SolverHandlePool`, and `SpSolverHandlePool` into new target.

PiperOrigin-RevId: 658497960
2024-08-01 12:30:04 -07:00
Adam Paszke
0307438c3d [NFC][Mosaic TPU] Separate out retiling from relayout
PiperOrigin-RevId: 658335679
2024-08-01 03:09:15 -07:00
Adam Paszke
0734345279 [NFC][Mosaic TPU] Start breaking up relayout into smaller pieces
We're constantly hitting unimpelmented relayouts, but it's hard to even know what's
in there given the way the code is written. This is the first of a few clean-up CLs
that aims to partition the process into steps with clear responsibilities. It should
help us better understand what's missing.

PiperOrigin-RevId: 658318811
2024-08-01 02:02:09 -07:00
jax authors
a911d76982 Rollback due to internal test failure
PiperOrigin-RevId: 658185213
2024-07-31 16:40:03 -07:00
George Necula
65450d165e Remove forward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th.
Also add backwards compatibility test for the new custom call.

PiperOrigin-RevId: 658011228
2024-07-31 08:10:17 -07:00
Dan Foreman-Mackey
618754d829 Move some common helper functions from lapack_kernels to ffi_helpers.
There were two helper functions for implementing FFI calls that were included directly alongside jaxlib's CPU kernels that will be useful for the GPU kernels as well. This moves those functions into ffi_helpers so that they are accessible from there too.

PiperOrigin-RevId: 658002501
2024-07-31 07:38:33 -07:00
Adam Paszke
9dba6eb16a [Mosaic TPU] Add support for 1D windows
PiperOrigin-RevId: 657976726
2024-07-31 05:58:19 -07:00
Adam Paszke
e0415c1865 [Mosaic TPU] Don't fold the accumulator into matmul if it has multiple uses
PiperOrigin-RevId: 657967724
2024-07-31 05:19:52 -07:00
Peter Hawkins
fd23b8733d Bump minimum SciPy version to 1.10.
SciPy 1.9.0 was released July 29, 2022, which is 24 months ago

PiperOrigin-RevId: 657215038
2024-07-29 08:50:18 -07:00
Dan Foreman-Mackey
ff4e0b1214 Rearrange the LAPACK handler definitions in jaxlib to avoid duplicate handler errors.
When linking the jaxlib `cpu_kernels` target and importing JAX, we currently silently fail to instantiate the CPU backend. This refactor means that we only ever define one version of the handlers.

PiperOrigin-RevId: 657186057
2024-07-29 06:59:44 -07:00
Tomás Longeri
0f834cdf24 [Mosaic TPU] Enable lane broadcast for packed types and offsets outside of first tile, and fix some broadcast infer logic
PiperOrigin-RevId: 656201666
2024-07-25 19:48:20 -07:00
Jevin Jiang
b1b7d0465e [XLA:Mosaic] Support any int type upcast.
Also fixed the int4 unpacking.

PiperOrigin-RevId: 656119043
2024-07-25 15:39:38 -07:00
Sergei Lebedev
5e418f5ab2 Added argument validation to mosaic_gpu_init_tma_desc
This should help with understanding cuTensorMapEncodeTiled failures, since
CUDA doesn't provide any details beyond the error return code.

Note that this change also ensures that TMA descriptors are 64-byte aligned.

PiperOrigin-RevId: 656062820
2024-07-25 13:16:34 -07:00
Paweł Paruzel
ae40c87919 Activate Cholesky Factorization Kernel to XLA's FFI
PiperOrigin-RevId: 655990468
2024-07-25 09:59:28 -07:00
Tomás Longeri
220ec2aa69 [Mosaic TPU] (8,128),-2 -> (8,128) for non-zero and replicated 2nd minor offset
Also fix bug where relayouts for fully replicated source assumed it was a no-op without checking implicit dims

PiperOrigin-RevId: 655746766
2024-07-24 16:58:35 -07:00
Adam Paszke
dbe8f56353 [Mosaic GPU] Strengthen cluster-related tests by covering more cluster shapes
In particular test trivial collectives (over singleton cluster axes), collectives
over more than 2 devices and clusters larger than 8 devices. This uncovered a few
more bugs in the implementation.

PiperOrigin-RevId: 655686102
2024-07-24 13:43:52 -07:00
Paweł Paruzel
54fe6e68a0 Port Triangular Solve to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 655484166
2024-07-24 02:15:41 -07:00
Sharad Vikram
cfd9d8f548 [Pallas/TPU] Allow reading DMA semaphores in Pallas
PiperOrigin-RevId: 655384701
2024-07-23 19:08:45 -07:00
Jevin Jiang
59e944dadf [XLA:Mosaic] Pass rewrite ctx of apply-vector-layout pass to relayout function.
We will implement a more efficient relayout according to the configs in rewrite ctx, such as `hardware_generation`, `max_sublanes_in_scratch` and so on. So it makes sense to change the relayout interface to take ctx (including python bindings). Now we can define rewrite ctx in `apply_vector_layout_test` as well. It makes it easier to test some advanced stuff (eg., mxu_shape change, max_sublanes_in_scratch change for rotate and relayout).

PiperOrigin-RevId: 655350013
2024-07-23 16:50:45 -07:00
Jevin Jiang
7e2107b1ee [XLA:Mosaic] Create apply layout pass with ctx instead of config list.
This cl removes the funcOp from RewriteContext of apply-vector-layout-pass (since only one function is using it) and uses context to create the pass instead of a long list of arguments. We will need to add more args (target's bank counts) to create apply-vector-layout.

PiperOrigin-RevId: 655329321
2024-07-23 15:43:26 -07:00
Bart Chrzaszcz
864178d3a3 #sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.

Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations

The following test:

```py
def test_sdy_lowering(self):
  mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
  np_inp = np.arange(16).reshape(8, 2)
  s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
  arr = jax.device_put(np_inp, s)

  @partial(jax.jit, out_shardings=s)
  def f(x):
    return x * 2

  print(f.lower(arr).as_text())
```

outputs:

```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <"x"=4, "y"=2>
  func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
    %c = stablehlo.constant dense<2> : tensor<i64>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
    %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
    return %1 : tensor<8x2xi64>
  }
}
```

Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.

PiperOrigin-RevId: 655127611
2024-07-23 05:32:06 -07:00
Adam Paszke
a2b2fbf513 [Mosaic GPU] Add early support for block clusters and multicast TMA
PiperOrigin-RevId: 655057490
2024-07-23 00:50:20 -07:00
Tomás Longeri
d350ef779c [Mosaic TPU][apply-vector-layout] Do not broadcast in copy_one_sublane
This affects the (packing, 128) -> (8 * packing, 128) and 32-bit (8, 128),-2 -> (8, 128) retilings:
- No longer always broadcast the first sublane of a vreg before blending, which is usually unnecessary. Rotate instead, unless dst requires replicated offsets in (1, 128) -> (8, 128).
  For (8, 128),-2 -> (8, 128), with our current restrictions, the first vreg always already has the sublane in the right position, so the broadcast is always wasteful.
- Unclear if rotate is always better than broadcast, but it doesn't make sense to broadcast the first vreg yet rotate the others.

This is some cleanup prior to removing some offset restrictions for (8, 128),-2 -> (8, 128)

PiperOrigin-RevId: 654935883
2024-07-22 16:31:27 -07:00
Tomás Longeri
5f18a2e27b [Mosaic TPU] Enable (packing, 128) -> (8 * packing, 128) retiling
PiperOrigin-RevId: 654922099
2024-07-22 15:47:21 -07:00
Tomás Longeri
bf42564172 [Mosaic TPU][NFC] Remove unused toArrayRef for std::array
It's unused, buggy (will return a reference to local copy of array) and `ArrayRef` already has a ctor that takes a `std::array`

PiperOrigin-RevId: 654916697
2024-07-22 15:29:06 -07:00
Tomás Longeri
f4b09234a0 [Mosaic TPU] Set in_bounds for transfer_read used in replicated loads
This is in preparation for integrating changes from MLIR:
2ee5586ac7 (diff-3cbcc8f6c740f2d6e16f5a0c19daf4bb8224ad92d9e430fc10c935587a67dcce)

Also don't pass in `padding` since there is a builder that uses `padding` of zero as default.

PiperOrigin-RevId: 654370142
2024-07-20 16:26:18 -07:00
Jevin Jiang
faf89ab0da [XLA:Mosaic] Simplify the logic in converting dynamic roll to Log(N) static ops.
PiperOrigin-RevId: 654065156
2024-07-19 11:11:22 -07:00
Tomás Longeri
a9772494b2 [Mosaic] Simplify vector.shape_cast rules and cover more cases
- Sublane unfolding was not being checked for non-empty implicit dims e.g. (2, 2, 128, 1) -> (2, 256) would not work
- Noop squeeze/unsqueeze paths in infer-vector-layout, when the source has ImplicitDim::kNone, were forcing native tiling for some reason
- 1D lane squeeze was always assigning bitwidth of 32.
- Maybe others

PiperOrigin-RevId: 653910942
2024-07-19 00:55:48 -07:00
jax authors
378a830322 Add support for multi row shift.
PiperOrigin-RevId: 653395441
2024-07-17 16:19:14 -07:00
Jevin Jiang
63a3e6736c [XLA:Mosaic] Extend support of tpu bitcast with offsets and implicit dim.
* if bitwidth does not change after bitcast:
  - We can bitcast the input with any vector layout.
* if bitwidth changes after bitcast:
  - We can bitcast the input with sublane offset which is a multiple of the ratio of bandwidths.

PiperOrigin-RevId: 653375579
2024-07-17 15:10:27 -07:00
Adam Paszke
a335839ab8 [Mosaic TPU] Update transpose unrolling for new TPUs
PiperOrigin-RevId: 653348218
2024-07-17 13:44:01 -07:00
Tomás Longeri
da02ba196e [Mosaic] Most relayouts should work for any matched implicit dim, or on mismatched but equivalent ones
Also fix bug in (1, 128 * packing) -> (packing, 128) retiling where the part index could be incremented OOB.

Note: Many relayouts might be inefficient for implicit dims. If, for example, implicit dim is kSecondMinor, retiling might blend tiles that are only padding. This also applies to kNone implicit dim with small shapes, however, so any optimizations should be written based on the implicit shape.
PiperOrigin-RevId: 653209744
2024-07-17 06:17:32 -07:00
jax authors
764ec92118 Add support for elementwise op canonicalization in fp32 for older hardware.
PiperOrigin-RevId: 651959463
2024-07-12 19:58:55 -07:00
Jevin Jiang
aa16485457 [XLA:Mosaic] Support memref shapecast.
This cl supports memref shapecast:
1. if tile is (1, 128), we support shapecast on any dim.
2. if shapecast on sublane dim, we only support tile aligned shape.
3. if shapecast on non-tiling dim, we support any shapecast.
4. all other cases would be considered as invalid memref shapecast.

PiperOrigin-RevId: 651924552
2024-07-12 17:05:03 -07:00
Sharad Vikram
7016ca4829 [Mosaic] Strengthen check on return types from RegionOp
PiperOrigin-RevId: 651879359
2024-07-12 13:59:50 -07:00
Paweł Paruzel
5cce394428 Port Householder Product to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 651691430
2024-07-12 01:36:41 -07:00
Sharad Vikram
2cbe6caa50 [Pallas/Mosaic] Add support for returning values from run_scoped
PiperOrigin-RevId: 651600628
2024-07-11 18:37:09 -07:00