Updates LLVM usage to match
[c8b5d30f7077](https://github.com/llvm/llvm-project/commit/c8b5d30f7077)

PiperOrigin-RevId: 662906261
This commit is contained in:
jax authors 2024-08-14 07:09:13 -07:00 committed by jax authors
parent 6290cd77fc
commit 807dcb5a06
4 changed files with 8 additions and 20 deletions

View File

@ -1295,9 +1295,7 @@ def reduce_lowering_rule(reduce_fn, type_to_kind, type_to_identity):
kind,
x,
acc,
ir.ArrayAttr.get(
[ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes]
),
axes,
)
return op.result
return _lowering_rule

View File

@ -3578,12 +3578,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
auto acc = cast<TypedValue<VectorType>>(multi_reduction_op.getAcc());
TPU_ASSERT_OP(layouts_out.front().has_value());
const ArrayAttr dim_attrs = multi_reduction_op.getReductionDims();
SmallVector<int64_t> dims;
dims.reserve(dim_attrs.size());
for (const Attribute dim_attr : dim_attrs) {
dims.push_back(cast<IntegerAttr>(dim_attr).getValue().getSExtValue());
}
SmallVector<int64_t> dims(multi_reduction_op.getReductionDims());
std::sort(dims.begin(), dims.end());
// Make sure that the accumulator is a splat of the neutral value

View File

@ -203,8 +203,8 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation,
return success();
} else if (element_type.isBF16()) {
bool reduces_sublanes = false;
for (Attribute dim : op.getReductionDims()) {
if (cast<IntegerAttr>(dim).getInt() == source_ty.getRank() - 2) {
for (int64_t dim : op.getReductionDims()) {
if (dim == source_ty.getRank() - 2) {
reduces_sublanes = true;
}
}
@ -230,7 +230,7 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation,
}
auto new_op = builder.create<vector::MultiDimReductionOp>(
op.getLoc(), new_acc.getType(), op.getKindAttr(), new_source, new_acc,
op.getReductionDims());
DenseI64ArrayAttr::get(builder.getContext(), op.getReductionDims()));
auto new_result = builder.create<arith::TruncFOp>(op.getLoc(), result_ty,
new_op.getResult());
op.replaceAllUsesWith(new_result.getResult());

View File

@ -1277,11 +1277,7 @@ class VectorLayoutInferer {
auto src_ty = op.getSourceVectorType();
auto dst_ty = dyn_cast<VectorType>(op.getDestType());
TPU_CHECK_OP(dst_ty, "only reductions with vector results supported");
SmallVector<int64_t> dims;
dims.reserve(op.getReductionDims().size());
for (Attribute dim_attr : op.getReductionDims()) {
dims.push_back(cast<IntegerAttr>(dim_attr).getInt());
}
llvm::ArrayRef<int64_t> dims = op.getReductionDims();
int64_t src_rank = src_ty.getRank();
auto acc_layout = getLayout(op.getAcc());
TPU_CHECK_OP(is_fully_replicated(acc_layout),
@ -1770,9 +1766,8 @@ class VectorLayoutInferer {
if (auto reduce =
dyn_cast<vector::MultiDimReductionOp>(operand.getOwner())) {
bool reduces_tiled_dims = false;
for (Attribute dim : reduce.getReductionDims()) {
if (cast<IntegerAttr>(dim).getInt() >=
reduce.getSourceVectorType().getRank() - 2) {
for (int64_t dim : reduce.getReductionDims()) {
if (dim >= reduce.getSourceVectorType().getRank() - 2) {
reduces_tiled_dims = true;
break;
}