mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic] apply_vector_layout C++ rewrite: Fix updateSlice[FromRange] for empty slices
PiperOrigin-RevId: 574665060
This commit is contained in:
parent
e144f71c33
commit
a7f279f5e8
@ -223,11 +223,25 @@ bool incrementIndex(const MutableArrayRef<int64_t> idx,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool sliceIsEmpty(const absl::Span<const int64_t> starts,
|
||||
const absl::Span<const int64_t> limits) {
|
||||
for (auto [s, l] : llvm::zip_equal(starts, limits)) {
|
||||
CHECK_LE(s, l);
|
||||
if (s == l) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// An alternative to xla::Array::UpdateSlice that takes a single value
|
||||
template <typename T>
|
||||
void updateSlice(xla::Array<T> &arr, const T &value,
|
||||
const absl::Span<const int64_t> starts,
|
||||
const absl::Span<const int64_t> limits) {
|
||||
if (sliceIsEmpty(starts, limits)) {
|
||||
return;
|
||||
}
|
||||
SmallVector<int64_t> idx(toArrayRef(starts));
|
||||
do {
|
||||
arr(idx) = value;
|
||||
@ -239,6 +253,9 @@ template <typename T, typename Range>
|
||||
void updateSliceFromRange(xla::Array<T> &arr, Range data,
|
||||
const absl::Span<const int64_t> starts,
|
||||
const absl::Span<const int64_t> limits) {
|
||||
if (sliceIsEmpty(starts, limits)) {
|
||||
return;
|
||||
}
|
||||
SmallVector<int64_t> idx(toArrayRef(starts));
|
||||
auto data_it = data.begin();
|
||||
do {
|
||||
|
Loading…
x
Reference in New Issue
Block a user