mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic] Expose C API for VectorLayout, VRegDataBounds
This is in preparation for Python bindings PiperOrigin-RevId: 579355000
This commit is contained in:
parent
953f4670d8
commit
1c1dd7c8c7
@ -186,8 +186,12 @@ cc_library(
|
||||
deps = [
|
||||
":tpu_dialect",
|
||||
":tpu_inc_gen",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:CAPIIR",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,13 +15,114 @@ limitations under the License.
|
||||
|
||||
#include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h"
|
||||
|
||||
#include "mlir/include/mlir/CAPI/Pass.h"
|
||||
#include "mlir/include/mlir/CAPI/Registration.h"
|
||||
#include "mlir/include/mlir/CAPI/Support.h"
|
||||
#include "mlir/include/mlir/IR/Attributes.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/MemAlloc.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/CAPI/Wrap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/layout.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
|
||||
// TODO(tlongeri): null pointer checks?
|
||||
|
||||
namespace {
|
||||
DEFINE_C_API_PTR_METHODS(MlirTpuVectorLayout, mlir::tpu::VectorLayout);
|
||||
DEFINE_C_API_PTR_METHODS(MlirTpuVregDataBounds, mlir::tpu::VRegDataBounds);
|
||||
|
||||
MlirTpuImplicitDim wrap(mlir::tpu::VectorLayout::ImplicitDim implicit_dim) {
|
||||
switch (implicit_dim) {
|
||||
case mlir::tpu::VectorLayout::ImplicitDim::kNone:
|
||||
return MlirTpuImplicitDimNone;
|
||||
case mlir::tpu::VectorLayout::ImplicitDim::kMinor:
|
||||
return MlirTpuImplicitDimMinor;
|
||||
case mlir::tpu::VectorLayout::ImplicitDim::kSecondMinor:
|
||||
return MlirTpuImplicitDimSecondMinor;
|
||||
}
|
||||
LOG(FATAL) << "Invalid implicit dim (C++)";
|
||||
}
|
||||
mlir::tpu::VectorLayout::ImplicitDim unwrap(MlirTpuImplicitDim implicit_dim) {
|
||||
switch (implicit_dim) {
|
||||
case MlirTpuImplicitDimNone:
|
||||
return mlir::tpu::VectorLayout::ImplicitDim::kNone;
|
||||
case MlirTpuImplicitDimMinor:
|
||||
return mlir::tpu::VectorLayout::ImplicitDim::kMinor;
|
||||
case MlirTpuImplicitDimSecondMinor:
|
||||
return mlir::tpu::VectorLayout::ImplicitDim::kSecondMinor;
|
||||
}
|
||||
LOG(FATAL) << "Invalid implicit dim (C)";
|
||||
}
|
||||
mlir::tpu::Direction unwrap(MlirTpuDirection direction) {
|
||||
switch (direction) {
|
||||
case MlirTpuDirectionSublanes:
|
||||
return mlir::tpu::Direction::kSublanes;
|
||||
case MlirTpuImplicitDimMinor:
|
||||
return mlir::tpu::Direction::kLanes;
|
||||
case MlirTpuImplicitDimSecondMinor:
|
||||
return mlir::tpu::Direction::kSubelements;
|
||||
}
|
||||
LOG(FATAL) << "Invalid direction (C)";
|
||||
}
|
||||
MlirTpuLayoutOffsets wrap(mlir::tpu::LayoutOffsets offsets) {
|
||||
return {offsets[0].value_or(-1), offsets[1].value_or(-1)};
|
||||
}
|
||||
mlir::tpu::LayoutOffsets unwrap(MlirTpuLayoutOffsets offsets) {
|
||||
auto translateOffset = [](int64_t offset) {
|
||||
CHECK_GE(offset, -1);
|
||||
return offset == -1 ? std::nullopt : mlir::tpu::LayoutOffset{offset};
|
||||
};
|
||||
return {translateOffset(offsets.sublane), translateOffset(offsets.lane)};
|
||||
}
|
||||
std::array<bool, 2> unwrap(MlirTpuBoolTargetTuple arr) {
|
||||
return {arr.sublane, arr.lane};
|
||||
}
|
||||
std::array<int64_t, 2> unwrap(MlirTpuI64TargetTuple arr) {
|
||||
return {arr.sublane, arr.lane};
|
||||
}
|
||||
MlirTpuI64TargetTuple wrap(std::array<int64_t, 2> arr) {
|
||||
return {arr[0], arr[1]};
|
||||
}
|
||||
|
||||
mlir::OpBuilder mlirTpuInsertionPointToOpBuilder(
|
||||
MlirTpuInsertionPoint insertion_point) {
|
||||
mlir::Operation *ref_operation = unwrap(insertion_point.ref_operation);
|
||||
return ref_operation == nullptr
|
||||
? mlir::OpBuilder::atBlockEnd(unwrap(insertion_point.block))
|
||||
: mlir::OpBuilder(ref_operation);
|
||||
}
|
||||
|
||||
// We do not use the names wrap/unwrap for MlirTpuI64ArrayRef because whether
|
||||
// they should refer to SmallVector or ArrayRef is ambiguous
|
||||
MlirTpuI64ArrayRef mlirTpuI64ArrayRefFromLlvmSmallVector(
|
||||
const mlir::SmallVector<int64_t> &vec) {
|
||||
// TODO(tlongeri): It would be good to steal the buffer from implicit_shape,
|
||||
// but there are no public member functions for this.
|
||||
int64_t *ptr =
|
||||
static_cast<int64_t *>(llvm::safe_malloc(vec.size() * sizeof(int64_t)));
|
||||
memcpy(ptr, vec.data(), vec.size() * sizeof(int64_t));
|
||||
return {ptr, vec.size()};
|
||||
}
|
||||
llvm::ArrayRef<int64_t> mlirTpuI64ArrayRefToLlvmArrayRef(
|
||||
MlirTpuI64ArrayRef tpu_array_ref) {
|
||||
return {tpu_array_ref.ptr, tpu_array_ref.size};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(TPU, tpu, mlir::tpu::TPUDialect);
|
||||
@ -51,6 +152,160 @@ void mlirTPUAnalyzePotentialCommunication(MlirOperation op,
|
||||
*has_custom_barrier = result.second;
|
||||
}
|
||||
|
||||
MlirTpuVectorLayout mlirTpuVectorLayoutCreate(int bitwidth,
|
||||
MlirTpuLayoutOffsets offsets,
|
||||
MlirTpuI64TargetTuple tiling,
|
||||
MlirTpuImplicitDim implicit_dim) {
|
||||
return wrap(new mlir::tpu::VectorLayout(
|
||||
bitwidth, unwrap(offsets), unwrap(tiling), unwrap(implicit_dim)));
|
||||
}
|
||||
|
||||
void mlirTpuVectorLayoutDestroy(MlirTpuVectorLayout layout) {
|
||||
delete unwrap(layout);
|
||||
}
|
||||
|
||||
int mlirTpuVectorLayoutGetBitwidth(MlirTpuVectorLayout layout) {
|
||||
return unwrap(layout)->bitwidth();
|
||||
}
|
||||
|
||||
MlirTpuLayoutOffsets mlirTpuVectorLayoutGetOffsets(MlirTpuVectorLayout layout) {
|
||||
return wrap(unwrap(layout)->offsets());
|
||||
}
|
||||
|
||||
MlirTpuI64TargetTuple mlirTpuVectorLayoutGetTiling(MlirTpuVectorLayout layout) {
|
||||
return wrap(unwrap(layout)->tiling());
|
||||
}
|
||||
|
||||
MlirTpuImplicitDim mlirTpuVectorLayoutGetImplicitDim(
|
||||
MlirTpuVectorLayout layout) {
|
||||
return wrap(unwrap(layout)->implicit_dim());
|
||||
}
|
||||
|
||||
int mlirTpuVectorLayoutGetPacking(MlirTpuVectorLayout layout) {
|
||||
return unwrap(layout)->packing();
|
||||
}
|
||||
|
||||
int mlirTpuVectorLayoutGetLayoutRank(MlirTpuVectorLayout layout) {
|
||||
return unwrap(layout)->layout_rank();
|
||||
}
|
||||
|
||||
bool mlirTpuVectorLayoutEquals(MlirTpuVectorLayout lhs,
|
||||
MlirTpuVectorLayout rhs) {
|
||||
return *unwrap(lhs) == *unwrap(rhs);
|
||||
}
|
||||
|
||||
int64_t mlirTpuVectorLayoutTilesPerVreg(MlirTpuVectorLayout layout,
|
||||
MlirTpuI64TargetTuple target_shape) {
|
||||
return unwrap(layout)->tilesPerVreg(unwrap(target_shape));
|
||||
}
|
||||
|
||||
int64_t mlirTpuVectorLayoutSublanesPerTile(MlirTpuVectorLayout layout,
|
||||
MlirTpuI64TargetTuple target_shape) {
|
||||
return unwrap(layout)->sublanesPerTile(unwrap(target_shape));
|
||||
}
|
||||
|
||||
MlirTpuI64TargetTuple mlirTpuVectorLayoutVregSlice(
|
||||
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape) {
|
||||
return wrap(unwrap(layout)->vregSlice(unwrap(target_shape)));
|
||||
}
|
||||
|
||||
MlirTpuI64ArrayRef mlirTpuVectorLayoutImplicitShape(MlirTpuVectorLayout layout,
|
||||
MlirTpuI64ArrayRef shape) {
|
||||
mlir::SmallVector<int64_t> implicit_shape =
|
||||
unwrap(layout)->implicitShape(mlirTpuI64ArrayRefToLlvmArrayRef(shape));
|
||||
return mlirTpuI64ArrayRefFromLlvmSmallVector(implicit_shape);
|
||||
}
|
||||
|
||||
MlirTpuI64ArrayRef mlirTpuVectorLayoutTileArrayShape(
|
||||
MlirTpuVectorLayout layout, MlirTpuI64ArrayRef shape,
|
||||
MlirTpuI64TargetTuple target_shape) {
|
||||
mlir::SmallVector<int64_t> tile_array_shape = unwrap(layout)->tileArrayShape(
|
||||
mlirTpuI64ArrayRefToLlvmArrayRef(shape), unwrap(target_shape));
|
||||
return mlirTpuI64ArrayRefFromLlvmSmallVector(tile_array_shape);
|
||||
}
|
||||
|
||||
MlirTpuVregDataBounds mlirTpuVectorLayoutTileDataBounds(
|
||||
MlirTpuVectorLayout layout, MlirContext ctx, int64_t *full_shape,
|
||||
int64_t *idxs, size_t size, MlirTpuI64TargetTuple target_shape,
|
||||
MlirTpuBoolTargetTuple allow_replicated) {
|
||||
std::unique_ptr<mlir::tpu::VRegDataBounds> ptr =
|
||||
unwrap(layout)->tileDataBounds(
|
||||
unwrap(ctx), llvm::ArrayRef<int64_t>{full_shape, size},
|
||||
llvm::ArrayRef<int64_t>{idxs, size}, unwrap(target_shape),
|
||||
unwrap(allow_replicated));
|
||||
return wrap(ptr.release());
|
||||
}
|
||||
|
||||
bool mlirTpuVectorLayoutHasNaturalTopology(MlirTpuVectorLayout layout,
|
||||
MlirTpuI64TargetTuple target_shape) {
|
||||
return unwrap(layout)->hasNaturalTopology(unwrap(target_shape));
|
||||
}
|
||||
|
||||
bool mlirTpuVectorLayoutHasNativeTiling(MlirTpuVectorLayout layout,
|
||||
MlirTpuI64TargetTuple target_shape) {
|
||||
return unwrap(layout)->hasNativeTiling(unwrap(target_shape));
|
||||
}
|
||||
|
||||
bool mlirTpuVectorLayoutGeneralizes(MlirTpuVectorLayout layout,
|
||||
MlirTpuVectorLayout other,
|
||||
MlirTpuI64ArrayRef shape,
|
||||
MlirTpuI64TargetTuple target_shape) {
|
||||
return unwrap(layout)->generalizes(*unwrap(other),
|
||||
mlirTpuI64ArrayRefToLlvmArrayRef(shape),
|
||||
unwrap(target_shape));
|
||||
}
|
||||
|
||||
bool mlirTpuVectorLayoutEquivalentTo(MlirTpuVectorLayout layout,
|
||||
MlirTpuVectorLayout other,
|
||||
MlirTpuI64ArrayRef shape,
|
||||
MlirTpuI64TargetTuple target_shape) {
|
||||
return unwrap(layout)->equivalentTo(*unwrap(other),
|
||||
mlirTpuI64ArrayRefToLlvmArrayRef(shape),
|
||||
unwrap(target_shape));
|
||||
}
|
||||
|
||||
void mlirTpuVregDataBoundsDestroy(MlirTpuVregDataBounds data_bounds) {
|
||||
delete unwrap(data_bounds);
|
||||
}
|
||||
|
||||
bool mlirTpuVregDataBoundsMaskVariesAlong(MlirTpuVregDataBounds data_bounds,
|
||||
MlirTpuDirection direction,
|
||||
MlirTpuI64TargetTuple target_shape) {
|
||||
return unwrap(data_bounds)
|
||||
->maskVariesAlong(unwrap(direction), unwrap(target_shape));
|
||||
}
|
||||
|
||||
bool mlirTpuVregDataBoundsIsComplete(MlirTpuVregDataBounds data_bounds,
|
||||
MlirTpuI64TargetTuple target_shape) {
|
||||
return unwrap(data_bounds)->isComplete(unwrap(target_shape));
|
||||
}
|
||||
|
||||
MlirValue mlirTpuVregDataBoundsGetVectorMask(
|
||||
MlirTpuVregDataBounds data_bounds, MlirTpuInsertionPoint insertion_point,
|
||||
MlirLocation location, int generation, MlirTpuI64TargetTuple target_shape) {
|
||||
mlir::OpBuilder builder = mlirTpuInsertionPointToOpBuilder(insertion_point);
|
||||
auto failure_or_mask = unwrap(data_bounds)
|
||||
->getVectorMask(builder, unwrap(location),
|
||||
generation, unwrap(target_shape));
|
||||
if (failed(failure_or_mask)) {
|
||||
return wrap(mlir::Value());
|
||||
} else {
|
||||
return wrap(failure_or_mask.value());
|
||||
}
|
||||
}
|
||||
|
||||
MlirAttribute mlirTpuVregDataBoundsGetSublaneMask(
|
||||
MlirTpuVregDataBounds data_bounds, MlirContext ctx,
|
||||
MlirTpuI64TargetTuple target_shape) {
|
||||
return wrap(
|
||||
unwrap(data_bounds)->getSublaneMask(unwrap(ctx), unwrap(target_shape)));
|
||||
}
|
||||
}
|
||||
|
||||
#include "mlir/CAPI/Pass.h" // IWYU pragma: keep
|
||||
#include "mlir/CAPI/Support.h" // IWYU pragma: keep
|
||||
|
||||
extern "C" {
|
||||
using namespace mlir::tpu;
|
||||
|
||||
#include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_passes.capi.cc.inc"
|
||||
|
@ -13,10 +13,20 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Refer to the corresponding C++ declarations in layout.h and
|
||||
// apply_vector_layout.h for documentation on the functions in this file
|
||||
|
||||
#ifndef JAXLIB_MOSAIC_DIALECT_TPU_INTEGRATIONS_C_TPU_DIALECT_H_
|
||||
#define JAXLIB_MOSAIC_DIALECT_TPU_INTEGRATIONS_C_TPU_DIALECT_H_
|
||||
|
||||
#include "mlir/include/mlir-c/IR.h"
|
||||
#ifndef __cplusplus
|
||||
#include <stdbool.h>
|
||||
#endif
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@ -35,6 +45,149 @@ mlirTPUTiledLayoutAttrGetTiles(MlirAttribute attr);
|
||||
MLIR_CAPI_EXPORTED void mlirTPUAnalyzePotentialCommunication(
|
||||
MlirOperation op, bool* has_communication, bool* has_custom_barrier);
|
||||
|
||||
typedef enum MlirTpuImplicitDim {
|
||||
MlirTpuImplicitDimNone = 0,
|
||||
MlirTpuImplicitDimMinor = 1,
|
||||
MlirTpuImplicitDimSecondMinor = 2,
|
||||
} MlirTpuImplicitDim;
|
||||
|
||||
typedef enum MlirTpuDirection {
|
||||
MlirTpuDirectionSublanes,
|
||||
MlirTpuDirectionLanes,
|
||||
MlirTpuDirectionSubelements
|
||||
} MlirTpuDirection;
|
||||
|
||||
// Opaque reference to an owned layout
|
||||
typedef struct MlirTpuVectorLayout {
|
||||
void* ptr;
|
||||
} MlirTpuVectorLayout;
|
||||
|
||||
// Opaque reference to owned data bounds
|
||||
typedef struct MlirTpuVregDataBounds {
|
||||
void* ptr;
|
||||
} MlirTpuVregDataBounds;
|
||||
|
||||
// mlir::ArrayRef<int64_t> equivalent
|
||||
// Unlike mlir::ArrayRef, the data may or may not be owned (this should be
|
||||
// defined by the producer of the struct).
|
||||
typedef struct MlirTpuI64ArrayRef {
|
||||
int64_t* ptr;
|
||||
size_t size;
|
||||
} MlirTpuI64ArrayRef;
|
||||
|
||||
typedef struct MlirTpuLayoutOffsets {
|
||||
// Use -1 for replicated
|
||||
int64_t sublane;
|
||||
int64_t lane;
|
||||
} MlirTpuLayoutOffsets;
|
||||
|
||||
typedef struct MlirTpuI64TargetTuple {
|
||||
int64_t sublane;
|
||||
int64_t lane;
|
||||
} MlirTpuI64TargetTuple;
|
||||
|
||||
typedef struct MlirTpuBoolTargetTuple {
|
||||
bool sublane;
|
||||
bool lane;
|
||||
} MlirTpuBoolTargetTuple;
|
||||
|
||||
// An insertion point within a block.
|
||||
// The MLIR C API does not already have a similar struct, unfortunately.
|
||||
typedef struct MlirTpuInsertionPoint {
|
||||
MlirBlock block; // Only used when ref_operation is unspecified (null)
|
||||
MlirOperation ref_operation;
|
||||
} MlirTpuInsertionPoint;
|
||||
|
||||
// Caller owns the returned object and is responsible for calling
|
||||
// mlirTpuVectorLayoutDestroy
|
||||
MLIR_CAPI_EXPORTED MlirTpuVectorLayout mlirTpuVectorLayoutCreate(
|
||||
int bitwidth, MlirTpuLayoutOffsets offsets, MlirTpuI64TargetTuple tiling,
|
||||
MlirTpuImplicitDim implicit_dim);
|
||||
|
||||
MLIR_CAPI_EXPORTED void mlirTpuVectorLayoutDestroy(MlirTpuVectorLayout);
|
||||
|
||||
MLIR_CAPI_EXPORTED int mlirTpuVectorLayoutGetBitwidth(
|
||||
MlirTpuVectorLayout layout);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTpuLayoutOffsets
|
||||
mlirTpuVectorLayoutGetOffsets(MlirTpuVectorLayout layout);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTpuI64TargetTuple
|
||||
mlirTpuVectorLayoutGetTiling(MlirTpuVectorLayout layout);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTpuImplicitDim
|
||||
mlirTpuVectorLayoutGetImplicitDim(MlirTpuVectorLayout layout);
|
||||
|
||||
MLIR_CAPI_EXPORTED int mlirTpuVectorLayoutGetPacking(
|
||||
MlirTpuVectorLayout layout);
|
||||
|
||||
MLIR_CAPI_EXPORTED int mlirTpuVectorLayoutGetLayoutRank(
|
||||
MlirTpuVectorLayout layout);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutEquals(MlirTpuVectorLayout lhs,
|
||||
MlirTpuVectorLayout rhs);
|
||||
|
||||
MLIR_CAPI_EXPORTED int64_t mlirTpuVectorLayoutTilesPerVreg(
|
||||
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape);
|
||||
|
||||
MLIR_CAPI_EXPORTED int64_t mlirTpuVectorLayoutSublanesPerTile(
|
||||
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTpuI64TargetTuple mlirTpuVectorLayoutVregSlice(
|
||||
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape);
|
||||
|
||||
// Caller is responsible for calling free on the returned pointer
|
||||
MLIR_CAPI_EXPORTED MlirTpuI64ArrayRef mlirTpuVectorLayoutImplicitShape(
|
||||
MlirTpuVectorLayout layout, MlirTpuI64ArrayRef shape);
|
||||
|
||||
// Caller is responsible for calling free on the returned pointer.
|
||||
MLIR_CAPI_EXPORTED MlirTpuI64ArrayRef mlirTpuVectorLayoutTileArrayShape(
|
||||
MlirTpuVectorLayout layout, MlirTpuI64ArrayRef shape,
|
||||
MlirTpuI64TargetTuple target_shape);
|
||||
|
||||
// Caller owns the returned object and is responsible for calling
|
||||
// mlirTpuVectorLayoutVregDataBoundsDestroy
|
||||
MLIR_CAPI_EXPORTED MlirTpuVregDataBounds mlirTpuVectorLayoutTileDataBounds(
|
||||
MlirTpuVectorLayout layout, MlirContext ctx, int64_t* full_shape,
|
||||
int64_t* idxs, size_t size, MlirTpuI64TargetTuple target_shape,
|
||||
MlirTpuBoolTargetTuple allow_replicated);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutHasNaturalTopology(
|
||||
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutHasNativeTiling(
|
||||
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape);
|
||||
|
||||
// `shape` is optional, pass a shape with a null `ptr` to return true iff the
|
||||
// "generalizes" relationship applies to all shapes.
|
||||
MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutGeneralizes(
|
||||
MlirTpuVectorLayout layout, MlirTpuVectorLayout other,
|
||||
MlirTpuI64ArrayRef shape, MlirTpuI64TargetTuple target_shape);
|
||||
|
||||
// `shape` is optional, pass a shape with a null `ptr` to return true iff the
|
||||
// "equivalent to" relationship applies to all shapes.
|
||||
MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutEquivalentTo(
|
||||
MlirTpuVectorLayout layout, MlirTpuVectorLayout other,
|
||||
MlirTpuI64ArrayRef shape, MlirTpuI64TargetTuple target_shape);
|
||||
|
||||
MLIR_CAPI_EXPORTED void mlirTpuVregDataBoundsDestroy(
|
||||
MlirTpuVregDataBounds data_bounds);
|
||||
|
||||
bool mlirTpuVregDataBoundsMaskVariesAlong(MlirTpuVregDataBounds data_bounds,
|
||||
MlirTpuDirection direction,
|
||||
MlirTpuI64TargetTuple target_shape);
|
||||
|
||||
bool mlirTpuVregDataBoundsIsComplete(MlirTpuVregDataBounds data_bounds,
|
||||
MlirTpuI64TargetTuple target_shape);
|
||||
// Returns null on failure
|
||||
MlirValue mlirTpuVregDataBoundsGetVectorMask(
|
||||
MlirTpuVregDataBounds data_bounds, MlirTpuInsertionPoint insertion_point,
|
||||
MlirLocation location, int generation, MlirTpuI64TargetTuple target_shape);
|
||||
|
||||
MlirAttribute mlirTpuVregDataBoundsGetSublaneMask(
|
||||
MlirTpuVregDataBounds data_bounds, MlirContext ctx,
|
||||
MlirTpuI64TargetTuple target_shape);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
Loading…
x
Reference in New Issue
Block a user