The IFRT `PluginProgram` is simply a wrapper for arbitrary byte-strings: an IFRT backend that recognizes `PluginProgram` can interpret the byte-string in any way it sees fit.
PiperOrigin-RevId: 621258245
This was a little difficult because our current dialect conversion setup assumes 1-1 type conversions.
I think everything works out fine for as long as we never pass memrefs between basic blocks (i.e.
for as long as we never have memrefs as loop carry or return them from conditionals).
TODO: I still need to make sure that the changes to the TPU dialect are backwards-compatible.
I am afraid that the signature change in MemRefSliceOp might not be.
PiperOrigin-RevId: 617755035
The nanobind switch for the GPU callback code means that we are now using the NumPy APIs rather than pybind11's clone of them. It is important to initialize the NumPy APIs before using them in each module.
PiperOrigin-RevId: 613036056
Ideally we would prefer `TypedValue<VectorType>` everywhere possible for static type checking. However, I tried the type for arrays of vregs, `xla::Array<Value>` to `xla::Array<TypedValue<VectorType>>` and ran into issues because MLIR support for arrays/ranges of `TypedValue`s seems lacking.
For example, I can't find a good way to get a `ValueRange` (which many op constructors take) from an array of `TypedValue`s without creating an intermediate vector of `Value`s. Perhaps an unsafe cast if we make the (probably not guaranteed) assumption that `sizeof(TypedValue)` equals `sizeof(Value)`.
Also note that MLIR itself uses untyped `Value`s for ranges of op results and operands even when the op definition declares them to be of a specific type.
PiperOrigin-RevId: 610509743
It was using the `op` variable from the `ExtUIOp` above (because variables declared in initializer of an if statement are available in the else branch).
PiperOrigin-RevId: 610481302
The old `tile_indices` variable was misleading and confusing because it sometimes stored indices (in the static case) and sometimes offsets with respect to the tile (in the dynamic case).
PiperOrigin-RevId: 609457122
This dialect doesn't build on Windows, but we don't support GPUs on Windows anyway, so we can simply exclude it from the build.
CI failures look like this:
```
C:\npm\prefix\bazel.CMD run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=C:\a\jax\jax\jax\dist --jaxlib_git_hash=5f19f7712b485493ac141c44eea3b3eb1ffdfb59 --cpu=AMD64
b"external/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2672: 'mlir::OpState::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2783: 'enable_if<llvm::function_traits<decay<FnT>::type,std::is_class<T>::value>::num_args==1,RetT>::type mlir::OpState::walk(FnT &&)': could not deduce template argument for 'RetT'\r\n with\r\n [\r\n T=decay<FnT>::type\r\n ]\r\nexternal/llvm-project/mlir/include\\mlir/IR/OpDefinition.h(165): note: see declaration of 'mlir::OpState::walk'\r\nexternal/llvm-project/mlir/include\\mlir/IR/PatternMatch.h(357): error C2872: 'detail': ambiguous symbol\r\nexternal/llvm-project/mlir/include\\mlir/Rewrite/FrozenRewritePatternSet.h(15): note: could be 'mlir::detail'\r\nbazel-out/x64_windows-opt/bin/external/triton/include\\triton/Dialect/Triton/IR/Ops.h.inc(5826): note: or 'mlir::triton::detail'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(712): note: see reference to class template instantiation 'mlir::OpRewritePattern<mlir::scf::ForOp>' being compiled\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\n"
output = subprocess.check_output(cmd)
```
PiperOrigin-RevId: 609153322
* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases
PiperOrigin-RevId: 608534008
This allows us to rely on this throughout the code and replace some checks with TPU_ASSERT_*. They have the semantics of an assert and make it clearer that it is an unexpected internal error (instead of unimplemented or invalid user input that we should handle).
Note: the original error messages for some of these checks were using the wrong input names.
PiperOrigin-RevId: 607463728