40 Commits

Author SHA1 Message Date
Peter Hawkins
27e19239ca Fix triton capi_objects target to depend on MLIR CAPIIRObjects bazel
target.

"...Objects" targets should only depend on other "...Objects" targets in
MLIR land. Don't mix them.
2024-09-06 01:06:27 +00:00
Sergei Lebedev
f3b91b2042 Export PointerType and register_dialect from jaxlib.triton.dialect
The `... as ...` form tells the type checker that the name is exported.
See #7570.

PiperOrigin-RevId: 671318047
2024-09-05 04:15:32 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Sergei Lebedev
778933dfda Removed inspect.signature() call from jaxlib.triton.dialect.ScanOp
PiperOrigin-RevId: 614772594
2024-03-11 13:30:41 -07:00
Goran Flegar
53364b438c Integrate Triton up to [bfb8e413](bfb8e413b0)
PiperOrigin-RevId: 614740360
2024-03-11 11:43:46 -07:00
jax authors
32bb3b0613 Use $(RULEDIR) to avoid an implicit dependency on output_to_genfiles.
PiperOrigin-RevId: 611652089
2024-02-29 17:40:18 -08:00
Peter Hawkins
ef40b85c8b Don't build the Triton MLIR dialect on Windows
This dialect doesn't build on Windows, but we don't support GPUs on Windows anyway, so we can simply exclude it from the build.

CI failures look like this:
```
C:\npm\prefix\bazel.CMD run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=C:\a\jax\jax\jax\dist --jaxlib_git_hash=5f19f7712b485493ac141c44eea3b3eb1ffdfb59 --cpu=AMD64
b"external/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2672: 'mlir::OpState::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2783: 'enable_if<llvm::function_traits<decay<FnT>::type,std::is_class<T>::value>::num_args==1,RetT>::type mlir::OpState::walk(FnT &&)': could not deduce template argument for 'RetT'\r\n        with\r\n        [\r\n            T=decay<FnT>::type\r\n        ]\r\nexternal/llvm-project/mlir/include\\mlir/IR/OpDefinition.h(165): note: see declaration of 'mlir::OpState::walk'\r\nexternal/llvm-project/mlir/include\\mlir/IR/PatternMatch.h(357): error C2872: 'detail': ambiguous symbol\r\nexternal/llvm-project/mlir/include\\mlir/Rewrite/FrozenRewritePatternSet.h(15): note: could be 'mlir::detail'\r\nbazel-out/x64_windows-opt/bin/external/triton/include\\triton/Dialect/Triton/IR/Ops.h.inc(5826): note: or       'mlir::triton::detail'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(712): note: see reference to class template instantiation 'mlir::OpRewritePattern<mlir::scf::ForOp>' being compiled\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\n"
    output = subprocess.check_output(cmd)
```
PiperOrigin-RevId: 609153322
2024-02-21 16:02:54 -08:00
Sergei Lebedev
37f313ab22 Fixed internal CI builds
* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases

PiperOrigin-RevId: 608534008
2024-02-20 02:42:14 -08:00
Sergei Lebedev
881436240e Inlined triton.compat
We no longer need a compatibility layer, since Pallas does not use any Triton
IR building APIs.

PiperOrigin-RevId: 606948415
2024-02-14 05:23:15 -08:00
Sergei Lebedev
cee8bf7030 Changed lower_jaxpr_to_triton_ir to use ir.Value instead of triton.compat APIs
PiperOrigin-RevId: 606603607
2024-02-13 06:40:51 -08:00
Sergei Lebedev
6a7d1dceff Added ir.Value-based versions of load and store in triton.compat
PiperOrigin-RevId: 606597830
2024-02-13 06:13:31 -08:00
Sergei Lebedev
0ddf37145e Added ir.Value-based versions of some triton.compat APIs
PiperOrigin-RevId: 606581304
2024-02-13 05:00:57 -08:00
Sergei Lebedev
7dd887dc84 Updated the Triton lowering of jnp.transpose
Triton no longer restricts its transpose to just two axes.

PiperOrigin-RevId: 606188599
2024-02-12 02:42:33 -08:00
Sergei Lebedev
8a8601fb53 Moved atomic_* operations out of the Triton compatibility layer
PiperOrigin-RevId: 605536620
2024-02-08 23:21:25 -08:00
Sergei Lebedev
4c505f8bac Inlined tl.core._to_tensor
PiperOrigin-RevId: 605270480
2024-02-08 04:05:38 -08:00
Sergei Lebedev
5e2e609a9b _triton_ext no longer links in MLIR C APIs
I re-used the same trick we do for the TPU dialect. Specifically, _triton_ext no longer depends on :triton_dialect_capi. Instead

* we include Triton dialect C bindings into :jaxlib_mlir_capi_objects
* and _triton_ext depends on :jaxlib_mlir_capi_objects and a header-only cc_library providing Triton dialect C bindings

This is a fork of #19680 with a few internal-only fixes.

PiperOrigin-RevId: 604929377
2024-02-07 03:39:29 -08:00
Sergei Lebedev
01b995e01f Add MLIR_CAPI_EXPORTED to Triton dialect C bindings
PiperOrigin-RevId: 604912455
2024-02-07 02:14:50 -08:00
Jake VanderPlas
4d5c557264 jaxlib/triton: ensure python 3.9 compatibility 2024-02-06 11:39:51 -08:00
Sergei Lebedev
9e94e6ef71 Fixed a typo in min/max Triton lowering rules
PiperOrigin-RevId: 604424404
2024-02-05 14:06:09 -08:00
Sergei Lebedev
0596c804cb Inlined all builder.create_* calls
The lowering code now directly calls into the auto-generated MLIR bindings
for the respective dialects.

PiperOrigin-RevId: 604260837
2024-02-05 03:00:33 -08:00
Sergei Lebedev
c77d45d511 Migrated tl.semantic.cast to lower directly to MLIR
PiperOrigin-RevId: 604255461
2024-02-05 02:34:58 -08:00
Sergei Lebedev
06d3280890 Inlined some Triton-specific abstractions
* tensor is now just a container for dtype/shape with no extra methods;
* constexpr is not used;
* all APIs assume that the arguments are tensors and do no ->tensor conversion.

PiperOrigin-RevId: 603894852
2024-02-03 00:56:47 -08:00
Sergei Lebedev
cda40ece87 Removed wrap_with_builder and tensor.to from the Triton compatibility layer
The only API using it was cast.

I also added a test which covers int1->int8 casts.

PiperOrigin-RevId: 603771797
2024-02-02 13:23:52 -08:00
Sergei Lebedev
28eff4f9b8 Migrated dot to lower directly to Triton IR
PiperOrigin-RevId: 603768074
2024-02-02 13:09:25 -08:00
Sergei Lebedev
5867a05cdd Migrated store/load to lower directly to Triton IR
PiperOrigin-RevId: 603764118
2024-02-02 12:53:42 -08:00
Sergei Lebedev
e1ea936fc1 Added a custom lowering rule for pow which special-cases weak dtypes
PiperOrigin-RevId: 603635095
2024-02-02 03:06:43 -08:00
Sergei Lebedev
d9f42c56b8 Fixed a few ir.Value type-tensor dtype mismatches in the Pallas lowering code on GPU
PiperOrigin-RevId: 603632044
2024-02-02 02:52:08 -08:00
Sergei Lebedev
9e76e380cc Temporarily switch triton.compat to use Triton APIs for math and semantic operations
This is only meant as a short-term fix to unblock internal users.

PiperOrigin-RevId: 602707085
2024-01-30 06:30:22 -08:00
Goran Flegar
66308c30ad Integrate Triton up to [9f816a7b](9f816a7b98)
PiperOrigin-RevId: 602641874
2024-01-30 01:16:11 -08:00
Sergei Lebedev
fad3e749a1 Migrated remaining operations from the math namespace to lower directly to Triton IR
PiperOrigin-RevId: 602390761
2024-01-29 08:10:03 -08:00
Sergei Lebedev
07f8f700ca Migrated atomic operations to lower directly to Triton IR
PiperOrigin-RevId: 602384705
2024-01-29 07:45:31 -08:00
Sergei Lebedev
cc5f565b89 Ported a subset of binary operations to lower directly to Triton IR
PiperOrigin-RevId: 601806008
2024-01-26 10:57:01 -08:00
Sergei Lebedev
273cb27047 compat.tensor __*__ methods no longer do implicit broadcasting
This change makes it simpler to lower binary operations to Triton IR
bypassing Triton Python bindings.

PiperOrigin-RevId: 601796719
2024-01-26 10:13:51 -08:00
Sergei Lebedev
f15cad4651 Lower a subset of math primitives directly to Triton IR
Note that all primitives are now lowered to libdevice calls. Previously,
some of them were lowered to the MLIR arith dialect, and some to libdevice
calls, without any apparent reason for doing so.

PiperOrigin-RevId: 601259707
2024-01-24 15:55:09 -08:00
Sergei Lebedev
af49b01e1f Migrated a subset of triton.compat to directly use IR builders
PiperOrigin-RevId: 598826331
2024-01-16 06:46:08 -08:00
Sergei Lebedev
5b7a0d9c91 Pallas now uses MLIR Python builders to lower to Triton IR
This allows us to drop a dependency on the Triton Python package in the future,
and delegate ->ptx compilation to XLA.

PiperOrigin-RevId: 597640756
2024-01-11 13:33:26 -08:00
Sergei Lebedev
6174145386 Removed the Triton dependency from the BUILD file
PiperOrigin-RevId: 597336551
2024-01-10 13:17:47 -08:00
Sergei Lebedev
ba10775eda Added a compatibility overlay for Triton Python APIs
Follow up changes will gradually re-implement these APIs using the MLIR
builders added in google/jax#19159.

PiperOrigin-RevId: 597023799
2024-01-09 13:13:56 -08:00
Sergei Lebedev
f219482212 The Triton MLIR bindings now include auto-generated wrappers for enums
PiperOrigin-RevId: 596873541
2024-01-09 03:00:47 -08:00
Sergei Lebedev
e6c890171b Generate Python bindings for the Triton MLIR dialect
The bindings are not yet included in the jaxlib wheel. I will do that in a
follow up PR.

PiperOrigin-RevId: 595174466
2024-01-02 11:55:05 -08:00