800 Commits

Author SHA1 Message Date
jax authors
fb55d59143 This CL introduces 'PluginProgram' in IFRT and exposes this in python via xla_client.compile_ifrt_program().
The IFRT `PluginProgram` is simply a wrapper for arbitrary byte-strings: an IFRT backend that recognizes `PluginProgram` can interpret the byte-string in any way it sees fit.

PiperOrigin-RevId: 621258245
2024-04-02 12:20:35 -07:00
David Dunleavy
aade591fdf Move tsl/python to xla/tsl/python
PiperOrigin-RevId: 620320903
2024-03-29 13:15:21 -07:00
Peter Hawkins
478cfa9944 Add an upper bound on JAX's CUDNN version constraint.
Major releases of CUDNN break ABI compatibility, so we cannot allow new major versions.

PiperOrigin-RevId: 620030416
2024-03-28 13:00:36 -07:00
Jevin Jiang
67f4f6032a [XLA:Mosaic] Remove duplicate headers in debug assert insertion.
PiperOrigin-RevId: 619801919
2024-03-27 23:14:05 -07:00
Michael Hudgins
023930decf Fix some load orderings for buildifier
PiperOrigin-RevId: 619575196
2024-03-27 10:28:57 -07:00
jax authors
0be07e6aec Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

Reverts 910a31d7b7510e3375718ab1ea0d38df7bd2c0d5

PiperOrigin-RevId: 618911489
2024-03-25 11:46:39 -07:00
Sandeep Dasgupta
6ffd55c405 Fixing StableHLO python dependencies on stablehlo:reference_api
PiperOrigin-RevId: 618294054
2024-03-22 14:52:06 -07:00
jax authors
910a31d7b7 Reverts bed4f65438a62777ed100ecec2b0eb3f7cf87a0e
PiperOrigin-RevId: 618249855
2024-03-22 12:10:53 -07:00
jax authors
bed4f65438 Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

PiperOrigin-RevId: 618195554
2024-03-22 09:05:39 -07:00
Tomás Longeri
7f7e0c00df [Mosaic] Support left shifting relayouts
PiperOrigin-RevId: 618008857
2024-03-21 17:20:30 -07:00
jax authors
2848cda34c Merge pull request #20341 from ROCm:rocm_add_hipStreamWaitEvent
PiperOrigin-RevId: 617893634
2024-03-21 10:41:38 -07:00
Adam Paszke
7d431ad33b Add support for slicing dynamically-shaped memrefs + DMAs between them
This was a little difficult because our current dialect conversion setup assumes 1-1 type conversions.
I think everything works out fine for as long as we never pass memrefs between basic blocks (i.e.
for as long as we never have memrefs as loop carry or return them from conditionals).

TODO: I still need to make sure that the changes to the TPU dialect are backwards-compatible.
I am afraid that the signature change in MemRefSliceOp might not be.
PiperOrigin-RevId: 617755035
2024-03-21 00:56:51 -07:00
Rahul Batra
8575055571 [ROCm]: Add missing hipStreamWaitEvent API call 2024-03-20 16:58:21 +00:00
jax authors
df9cefabc1 jaxlib: Add ifrt_proxy.pyi to build_wheel.py.
PiperOrigin-RevId: 617275734
2024-03-19 13:27:39 -07:00
Peter Hawkins
c2bbf9c577 Remove some code to support older CUDA and CUSPARSE versions.
The minimum CUDA version supported by JAX is CUDA 11.8, which ships with CUSPARSE 11.7.5.

PiperOrigin-RevId: 616892230
2024-03-18 11:25:03 -07:00
Jevin Jiang
7578e10ce3 [XLA:Mosaic] Support dynamic indices in strided load/store.
PiperOrigin-RevId: 615931990
2024-03-14 16:02:08 -07:00
Jevin Jiang
30208fa9cc [XLA:Mosaic] Support strided load/store memref with arbitrary shape as long as last dim size is 128 and dtype is 32bit.
PiperOrigin-RevId: 614862128
2024-03-11 18:22:11 -07:00
Sergei Lebedev
778933dfda Removed inspect.signature() call from jaxlib.triton.dialect.ScanOp
PiperOrigin-RevId: 614772594
2024-03-11 13:30:41 -07:00
Goran Flegar
53364b438c Integrate Triton up to [bfb8e413](bfb8e413b0)
PiperOrigin-RevId: 614740360
2024-03-11 11:43:46 -07:00
Jevin Jiang
75f2f7510f [XLA:Mosaic] Support input offset (replicated, 0) in shapecast.
PiperOrigin-RevId: 613340933
2024-03-06 14:26:00 -08:00
Andrey Portnoy
dcb58bb540 Include <cstdint> in files where it is used 2024-03-06 11:58:15 -05:00
Peter Hawkins
1a193ea189 Fix segfault in cuda_plugin_extension.
The nanobind switch for the GPU callback code means that we are now using the NumPy APIs rather than pybind11's clone of them. It is important to initialize the NumPy APIs before using them in each module.

PiperOrigin-RevId: 613036056
2024-03-05 18:31:50 -08:00
Jevin Jiang
05f54b665c [XLA:Mosaic] Use different MXU shape based on the target
PiperOrigin-RevId: 612906617
2024-03-05 11:14:24 -08:00
Peter Hawkins
feda85dff3 Replace references to xla/python/status_casters.h with xla/pjrt/status_casters.h, which its current home.
PiperOrigin-RevId: 612578488
2024-03-04 14:11:01 -08:00
jax authors
81363cefd7 Merge pull request #19808 from Micky774:cc_check
PiperOrigin-RevId: 612463272
2024-03-04 08:40:13 -08:00
jax authors
7514d5c7aa [triton] Add clustering support and test
PiperOrigin-RevId: 612417957
2024-03-04 05:51:10 -08:00
Meekail Zain
9fff9aeb69 Update 2024-03-03 19:57:26 +00:00
Eugene Zhulenev
1ae2022918 [jax-triton] Do not capture jax-triton calls that require autotuning
PiperOrigin-RevId: 611823473
2024-03-01 10:28:47 -08:00
jax authors
32bb3b0613 Use $(RULEDIR) to avoid an implicit dependency on output_to_genfiles.
PiperOrigin-RevId: 611652089
2024-02-29 17:40:18 -08:00
David Dunleavy
be3e39ad3b Move tsl/cuda to xla/tsl/cuda
PiperOrigin-RevId: 610550833
2024-02-26 15:45:10 -08:00
Tomás Longeri
57e34e1a2c [Mosaic][NFC] Use TypedValue<VectorType> instead of Value for applicable arguments/return values in disassemble and relayout
Ideally we would prefer `TypedValue<VectorType>` everywhere possible for static type checking. However, I tried the type for arrays of vregs, `xla::Array<Value>` to `xla::Array<TypedValue<VectorType>>` and ran into issues because MLIR support for arrays/ranges of `TypedValue`s seems lacking.

For example, I can't find a good way to get a `ValueRange` (which many op constructors take) from an array of `TypedValue`s without creating an intermediate vector of `Value`s. Perhaps an unsafe cast if we make the (probably not guaranteed) assumption that `sizeof(TypedValue)` equals `sizeof(Value)`.

Also note that MLIR itself uses untyped `Value`s for ranges of op results and operands even when the op definition declares them to be of a specific type.

PiperOrigin-RevId: 610509743
2024-02-26 13:34:58 -08:00
Tomás Longeri
2f882ad092 [Mosaic][infer-vector-layout] Fix crash with TPU_CHECK_OP
It was using the `op` variable from the `ExtUIOp` above (because variables declared in initializer of an if statement are available in the else branch).

PiperOrigin-RevId: 610481302
2024-02-26 11:58:51 -08:00
Adam Paszke
516b75dc24 Add pl.num_programs to make it easier to query the dynamic grid size
The new function can be used both in the kernel body and in the block specs.

PiperOrigin-RevId: 610391119
2024-02-26 06:39:03 -08:00
Tomás Longeri
61aa7e89aa [Mosaic] Fix bug in divisibility check in infer_vector_layout load and store rules
PiperOrigin-RevId: 609876232
2024-02-23 17:16:53 -08:00
Tomás Longeri
75cdef7626 [Mosaic][NFC] Prefer mlir aliases for llvm classes/functions within mlir namespace for consistency
(also fix a missing cstdint header to fix linter error)

PiperOrigin-RevId: 609826731
2024-02-23 13:48:03 -08:00
Tomás Longeri
8a43140c2e [Mosaic][apply_vector_layout][NFC] Use LLVM_UNLIKELY in TPU_ASSERT_* macros
PiperOrigin-RevId: 609805325
2024-02-23 12:27:49 -08:00
Eugene Zhulenev
3a69b80774 [jax-triton] Synchronize autotuning stream with a main one
PiperOrigin-RevId: 609792049
2024-02-23 11:42:30 -08:00
Jevin Jiang
f5c0021071 [XLA:Mosaic] Unify ext/trunc in infer vector layout.
PiperOrigin-RevId: 609765653
2024-02-23 10:19:26 -08:00
Tomás Longeri
c9eaca2282 [Mosaic] In apply_vector_layout, verify VectorLayout invariants that depend on target shape when loading them.
VectorLayout::verify was unused.
PiperOrigin-RevId: 609754730
2024-02-23 09:40:57 -08:00
Jevin Jiang
8d6bb0197b [XLA:Mosaic] Support broadcast scalar with a narrower type.
PiperOrigin-RevId: 609475719
2024-02-22 13:22:17 -08:00
Tomás Longeri
8172c067a4 [Mosaic][NFC] Replace tile_indices variable with tile_offsets with more consistent semantics
The old `tile_indices` variable was misleading and confusing because it sometimes stored indices (in the static case) and sometimes offsets with respect to the tile (in the dynamic case).

PiperOrigin-RevId: 609457122
2024-02-22 12:17:20 -08:00
Jevin Jiang
1fcb84dc90 [XLA:Mosaic] Support broadcast one row with padded tiling.
PiperOrigin-RevId: 609435269
2024-02-22 11:13:28 -08:00
Peter Hawkins
ef40b85c8b Don't build the Triton MLIR dialect on Windows
This dialect doesn't build on Windows, but we don't support GPUs on Windows anyway, so we can simply exclude it from the build.

CI failures look like this:
```
C:\npm\prefix\bazel.CMD run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=C:\a\jax\jax\jax\dist --jaxlib_git_hash=5f19f7712b485493ac141c44eea3b3eb1ffdfb59 --cpu=AMD64
b"external/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2672: 'mlir::OpState::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2783: 'enable_if<llvm::function_traits<decay<FnT>::type,std::is_class<T>::value>::num_args==1,RetT>::type mlir::OpState::walk(FnT &&)': could not deduce template argument for 'RetT'\r\n        with\r\n        [\r\n            T=decay<FnT>::type\r\n        ]\r\nexternal/llvm-project/mlir/include\\mlir/IR/OpDefinition.h(165): note: see declaration of 'mlir::OpState::walk'\r\nexternal/llvm-project/mlir/include\\mlir/IR/PatternMatch.h(357): error C2872: 'detail': ambiguous symbol\r\nexternal/llvm-project/mlir/include\\mlir/Rewrite/FrozenRewritePatternSet.h(15): note: could be 'mlir::detail'\r\nbazel-out/x64_windows-opt/bin/external/triton/include\\triton/Dialect/Triton/IR/Ops.h.inc(5826): note: or       'mlir::triton::detail'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(712): note: see reference to class template instantiation 'mlir::OpRewritePattern<mlir::scf::ForOp>' being compiled\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\n"
    output = subprocess.check_output(cmd)
```
PiperOrigin-RevId: 609153322
2024-02-21 16:02:54 -08:00
jax authors
16b29a6930 Merge pull request #19288 from pearu:pearu/int32-overflow
PiperOrigin-RevId: 608701959
2024-02-20 12:43:16 -08:00
Sergei Lebedev
37f313ab22 Fixed internal CI builds
* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases

PiperOrigin-RevId: 608534008
2024-02-20 02:42:14 -08:00
Pearu Peterson
3fa1033ac1 Prevent silent overflow in lapack worker size calculations.
Add -fexceptions to building lapack_kernels
2024-02-20 11:04:06 +02:00
Tomás Longeri
14474acf76 [Mosaic] Fix mistake in error message
PiperOrigin-RevId: 607700109
2024-02-16 08:38:30 -08:00
Tomás Longeri
243e7edc56 [Mosaic] In apply_vector_layout.cc, check layout validity when reading the attribute
This allows us to rely on this throughout the code and replace some checks with TPU_ASSERT_*. They have the semantics of an assert and make it clearer that it is an unexpected internal error (instead of unimplemented or invalid user input that we should handle).

Note: the original error messages for some of these checks were using the wrong input names.
PiperOrigin-RevId: 607463728
2024-02-15 14:51:45 -08:00
Jevin Jiang
a37f2d4a09 [XLA:Mosaic] Prevent generating rotate op if both shift and stride are zeros.
PiperOrigin-RevId: 607433547
2024-02-15 13:19:10 -08:00
Tomás Longeri
72c5aea161 [Mosaic] In apply_vector_layout, prefer returning failure over CHECKs and add macros for this
Rationale is that it's easier to debug from Python.

PiperOrigin-RevId: 607426243
2024-02-15 12:54:33 -08:00