[Mosaic][NFC] Factor out xla-array related utils in a separate file.

Also added tests.

PiperOrigin-RevId: 721424194
This commit is contained in:
Tzu-Wei Sung 2025-01-30 09:49:03 -08:00 committed by jax authors
parent bb951136e9
commit d4758b6d5e
5 changed files with 252 additions and 82 deletions

View File

@ -39,6 +39,7 @@ py_library(
cc_library(
name = "tpu_dialect",
srcs = [
"dialect/tpu/array_util.cc",
"dialect/tpu/layout.cc",
"dialect/tpu/tpu_dialect.cc",
"dialect/tpu/tpu_ops.cc",
@ -49,6 +50,7 @@ cc_library(
"dialect/tpu/transforms/*.cc",
]),
hdrs = [
"dialect/tpu/array_util.h",
"dialect/tpu/layout.h",
"dialect/tpu/tpu_dialect.h",
"dialect/tpu/util.h",
@ -250,6 +252,17 @@ cc_test(
],
)
cc_test(
name = "array_util_test",
srcs = ["dialect/tpu/array_util_test.cc"],
deps = [
":tpu_dialect",
"//testing/base/public:gunit_main",
"@llvm-project//mlir:Support",
"@xla//xla:array",
],
)
filegroup(
name = "extension_srcs",
srcs = [

View File

@ -0,0 +1,54 @@
/* Copyright 2025 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/mosaic/dialect/tpu/array_util.h"
#include <cstdint>
#include "absl/log/check.h"
#include "absl/types/span.h"
#include "llvm/include/llvm/ADT/STLExtras.h"
#include "mlir/include/mlir/Support/LLVM.h"
namespace mlir::tpu::internal {
bool sliceIsEmpty(const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
for (auto [s, l] : llvm::zip_equal(starts, limits)) {
CHECK_LE(s, l);
if (s == l) {
return true;
}
}
return false;
}
bool incrementSliceIndex(const MutableArrayRef<int64_t> idx,
const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
const int64_t nd = idx.size();
CHECK_EQ(nd, starts.size());
CHECK_EQ(nd, limits.size());
for (int64_t i = nd - 1; i >= 0; --i) {
++idx[i];
if (idx[i] < limits[i]) {
return true;
}
idx[i] = starts[i];
}
return false;
}
} // namespace mlir::tpu::internal

View File

@ -0,0 +1,102 @@
/* Copyright 2025 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_ARRAY_UTIL_H_
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_ARRAY_UTIL_H_
#include <cstdint>
#include "absl/log/check.h"
#include "absl/types/span.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
#include "xla/array.h"
namespace mlir::tpu {
namespace internal {
// Returns true if the slice is empty.
// `starts` and `limits` must be the same length.
bool sliceIsEmpty(absl::Span<const int64_t> starts,
absl::Span<const int64_t> limits);
// Increments the slice index.
// Returns true if the slice index is in bounds.
// `idx`, `starts` and `limits` must be the same length.
bool incrementSliceIndex(MutableArrayRef<int64_t> idx,
absl::Span<const int64_t> starts,
absl::Span<const int64_t> limits);
} // namespace internal
template <typename T>
ArrayRef<T> XlaArrayToFlatArrayRef(const xla::Array<T> &arr) {
return ArrayRef<T>(arr.data(), arr.num_elements());
}
template <typename T, typename Range>
xla::Array<T> XlaArrayFromShapeAndValues(ArrayRef<int64_t> sizes, Range vals) {
// TODO(tlongeri): is there no way to avoid default initialization in the
// constructor?
xla::Array<T> arr(sizes);
arr.SetValues(vals);
return arr;
}
// An alternative to `xla::Array::UpdateSlice` that takes a single value.
template <typename T>
void updateSlice(xla::Array<T> &arr, const T &value,
const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
if (internal::sliceIsEmpty(starts, limits)) {
return;
}
SmallVector<int64_t> idx(toArrayRef(starts));
do {
arr(idx) = value;
} while (internal::incrementSliceIndex(idx, starts, limits));
}
// An alternative to `xla::Array::UpdateSlice` that takes a range of data.
template <typename T, typename Range>
void updateSliceFromRange(xla::Array<T> &arr, Range data,
const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
if (internal::sliceIsEmpty(starts, limits)) {
return;
}
SmallVector<int64_t> idx(toArrayRef(starts));
auto in_bounds = [&] {
for (int64_t i = 0; i < idx.size(); ++i) {
if (idx[i] >= arr.dim(i)) {
return false;
}
}
return true;
};
auto data_it = data.begin();
do {
if (in_bounds()) {
arr(idx) = *data_it;
}
++data_it;
} while (internal::incrementSliceIndex(idx, starts, limits));
CHECK(data_it == data.end());
}
} // namespace mlir::tpu
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_ARRAY_UTIL_H_

View File

@ -0,0 +1,82 @@
/* Copyright 2025 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/mosaic/dialect/tpu/array_util.h"
#include <cstdint>
#include <initializer_list>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "mlir/include/mlir/Support/LLVM.h"
#include "xla/array.h"
namespace mlir::tpu {
namespace {
using ::testing::Address;
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::StrEq;
TEST(ArrayUtilTest, XlaArrayToFlatArrayRef) {
xla::Array<int32_t> arr({2, 3}, 0);
arr.FillIota(0);
ArrayRef<int32_t> ref = XlaArrayToFlatArrayRef(arr);
ASSERT_EQ(ref.size(), arr.num_elements());
EXPECT_THAT(ref, ElementsAre(0, 1, 2, 3, 4, 5));
// Make sure it's not a copy but a view.
int* ptr = arr.begin();
for (int i = 0; i < ref.size() && ptr != arr.end(); ++i, ++ptr) {
EXPECT_THAT(ref[i], Address(Eq(ptr)));
}
}
TEST(ArrayUtilTest, XlaArrayFromShapeAndValues) {
xla::Array<int32_t> arr = XlaArrayFromShapeAndValues<int32_t>(
{2, 3}, std::initializer_list<int32_t>{0, 1, 2, 3, 4, 5});
EXPECT_THAT(arr.ToString(), StrEq(R"([[0, 1, 2],
[3, 4, 5]])"));
}
TEST(ArrayUtilTest, UpdateSlice) {
xla::Array<int32_t> arr({4, 5}, 0);
updateSlice(arr, 1, {1, 1}, {3, 4});
EXPECT_THAT(arr.ToString(), StrEq(R"([[0, 0, 0, 0, 0],
[0, 1, 1, 1, 0],
[0, 1, 1, 1, 0],
[0, 0, 0, 0, 0]])"));
}
TEST(ArrayUtilTest, UpdateSliceFromRange) {
xla::Array<int32_t> arr({4, 5}, 0);
updateSliceFromRange(arr, std::initializer_list<int32_t>{1, 2, 3, 4, 5, 6},
{1, 1}, {3, 4});
EXPECT_THAT(arr.ToString(), StrEq(R"([[0, 0, 0, 0, 0],
[0, 1, 2, 3, 0],
[0, 4, 5, 6, 0],
[0, 0, 0, 0, 0]])"));
}
} // namespace
} // namespace mlir::tpu

View File

@ -60,6 +60,7 @@
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/OperationSupport.h"
#include "jaxlib/mosaic/dialect/tpu/array_util.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h"
@ -181,36 +182,6 @@ SmallVector<xla::Array<Value>> split(const xla::Array<Value> &vregs, int axis) {
return chunks;
};
template <typename T>
ArrayRef<T> XlaArrayToFlatArrayRef(xla::Array<T> xla_array) {
return ArrayRef<T>(xla_array.data(), xla_array.num_elements());
}
template <typename T, typename Range>
xla::Array<T> XlaArrayFromShapeAndValues(ArrayRef<int64_t> sizes, Range vals) {
// TODO(tlongeri): is there no way to avoid default initialization in the
// constructor?
xla::Array<T> arr(sizes);
arr.SetValues(vals);
return arr;
}
bool incrementSliceIndex(const MutableArrayRef<int64_t> idx,
const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
const int64_t nd = idx.size();
CHECK_EQ(nd, starts.size());
CHECK_EQ(nd, limits.size());
for (int64_t i = nd - 1; i >= 0; --i) {
++idx[i];
if (idx[i] < limits[i]) {
return true;
}
idx[i] = starts[i];
}
return false;
}
bool incrementIndex(const MutableArrayRef<int64_t> idx,
const absl::Span<const int64_t> limits) {
const int64_t nd = idx.size();
@ -225,58 +196,6 @@ bool incrementIndex(const MutableArrayRef<int64_t> idx,
return false;
}
bool sliceIsEmpty(const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
for (auto [s, l] : llvm::zip_equal(starts, limits)) {
CHECK_LE(s, l);
if (s == l) {
return true;
}
}
return false;
}
// An alternative to xla::Array::UpdateSlice that takes a single value
template <typename T>
void updateSlice(xla::Array<T> &arr, const T &value,
const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
if (sliceIsEmpty(starts, limits)) {
return;
}
SmallVector<int64_t> idx(toArrayRef(starts));
do {
arr(idx) = value;
} while (incrementSliceIndex(idx, starts, limits));
}
// An alternative to xla::Array::UpdateSlice that takes a range of data
template <typename T, typename Range>
void updateSliceFromRange(xla::Array<T> &arr, Range data,
const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
if (sliceIsEmpty(starts, limits)) {
return;
}
SmallVector<int64_t> idx(toArrayRef(starts));
auto in_bounds = [&] {
for (int64_t i = 0; i < idx.size(); ++i) {
if (idx[i] >= arr.dim(i)) {
return false;
}
}
return true;
};
auto data_it = data.begin();
do {
if (in_bounds()) {
arr(idx) = *data_it;
}
++data_it;
} while (incrementSliceIndex(idx, starts, limits));
CHECK(data_it == data.end());
}
FailureOr<int64_t> getIntConst(Value v, bool silent = false) {
if (auto constant_op = v.getDefiningOp<arith::ConstantOp>()) {
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_op.getValue())) {