mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Integrate LLVM at llvm/llvm-project@c8b5d30f70
Updates LLVM usage to match [c8b5d30f7077](https://github.com/llvm/llvm-project/commit/c8b5d30f7077) PiperOrigin-RevId: 662906261
This commit is contained in:
parent
6290cd77fc
commit
807dcb5a06
@ -1295,9 +1295,7 @@ def reduce_lowering_rule(reduce_fn, type_to_kind, type_to_identity):
|
||||
kind,
|
||||
x,
|
||||
acc,
|
||||
ir.ArrayAttr.get(
|
||||
[ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes]
|
||||
),
|
||||
axes,
|
||||
)
|
||||
return op.result
|
||||
return _lowering_rule
|
||||
|
@ -3578,12 +3578,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
|
||||
auto acc = cast<TypedValue<VectorType>>(multi_reduction_op.getAcc());
|
||||
TPU_ASSERT_OP(layouts_out.front().has_value());
|
||||
|
||||
const ArrayAttr dim_attrs = multi_reduction_op.getReductionDims();
|
||||
SmallVector<int64_t> dims;
|
||||
dims.reserve(dim_attrs.size());
|
||||
for (const Attribute dim_attr : dim_attrs) {
|
||||
dims.push_back(cast<IntegerAttr>(dim_attr).getValue().getSExtValue());
|
||||
}
|
||||
SmallVector<int64_t> dims(multi_reduction_op.getReductionDims());
|
||||
std::sort(dims.begin(), dims.end());
|
||||
|
||||
// Make sure that the accumulator is a splat of the neutral value
|
||||
|
@ -203,8 +203,8 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation,
|
||||
return success();
|
||||
} else if (element_type.isBF16()) {
|
||||
bool reduces_sublanes = false;
|
||||
for (Attribute dim : op.getReductionDims()) {
|
||||
if (cast<IntegerAttr>(dim).getInt() == source_ty.getRank() - 2) {
|
||||
for (int64_t dim : op.getReductionDims()) {
|
||||
if (dim == source_ty.getRank() - 2) {
|
||||
reduces_sublanes = true;
|
||||
}
|
||||
}
|
||||
@ -230,7 +230,7 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation,
|
||||
}
|
||||
auto new_op = builder.create<vector::MultiDimReductionOp>(
|
||||
op.getLoc(), new_acc.getType(), op.getKindAttr(), new_source, new_acc,
|
||||
op.getReductionDims());
|
||||
DenseI64ArrayAttr::get(builder.getContext(), op.getReductionDims()));
|
||||
auto new_result = builder.create<arith::TruncFOp>(op.getLoc(), result_ty,
|
||||
new_op.getResult());
|
||||
op.replaceAllUsesWith(new_result.getResult());
|
||||
|
@ -1277,11 +1277,7 @@ class VectorLayoutInferer {
|
||||
auto src_ty = op.getSourceVectorType();
|
||||
auto dst_ty = dyn_cast<VectorType>(op.getDestType());
|
||||
TPU_CHECK_OP(dst_ty, "only reductions with vector results supported");
|
||||
SmallVector<int64_t> dims;
|
||||
dims.reserve(op.getReductionDims().size());
|
||||
for (Attribute dim_attr : op.getReductionDims()) {
|
||||
dims.push_back(cast<IntegerAttr>(dim_attr).getInt());
|
||||
}
|
||||
llvm::ArrayRef<int64_t> dims = op.getReductionDims();
|
||||
int64_t src_rank = src_ty.getRank();
|
||||
auto acc_layout = getLayout(op.getAcc());
|
||||
TPU_CHECK_OP(is_fully_replicated(acc_layout),
|
||||
@ -1770,9 +1766,8 @@ class VectorLayoutInferer {
|
||||
if (auto reduce =
|
||||
dyn_cast<vector::MultiDimReductionOp>(operand.getOwner())) {
|
||||
bool reduces_tiled_dims = false;
|
||||
for (Attribute dim : reduce.getReductionDims()) {
|
||||
if (cast<IntegerAttr>(dim).getInt() >=
|
||||
reduce.getSourceVectorType().getRank() - 2) {
|
||||
for (int64_t dim : reduce.getReductionDims()) {
|
||||
if (dim >= reduce.getSourceVectorType().getRank() - 2) {
|
||||
reduces_tiled_dims = true;
|
||||
break;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user