[Mosaic] Expose C API for VectorLayout, VRegDataBounds

This is in preparation for Python bindings

PiperOrigin-RevId: 579355000
This commit is contained in:
Tomás Longeri 2023-11-03 18:23:24 -07:00 committed by jax authors
parent 953f4670d8
commit 1c1dd7c8c7
3 changed files with 418 additions and 6 deletions

View File

@ -186,8 +186,12 @@ cc_library(
deps = [
":tpu_dialect",
":tpu_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
],
)

View File

@ -15,13 +15,114 @@ limitations under the License.
#include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h"
#include "mlir/include/mlir/CAPI/Pass.h"
#include "mlir/include/mlir/CAPI/Registration.h"
#include "mlir/include/mlir/CAPI/Support.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include <array>
#include <cstdint>
#include <cstring>
#include <memory>
#include <optional>
#include <utility>
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MemAlloc.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
// TODO(tlongeri): null pointer checks?
namespace {
DEFINE_C_API_PTR_METHODS(MlirTpuVectorLayout, mlir::tpu::VectorLayout);
DEFINE_C_API_PTR_METHODS(MlirTpuVregDataBounds, mlir::tpu::VRegDataBounds);
MlirTpuImplicitDim wrap(mlir::tpu::VectorLayout::ImplicitDim implicit_dim) {
switch (implicit_dim) {
case mlir::tpu::VectorLayout::ImplicitDim::kNone:
return MlirTpuImplicitDimNone;
case mlir::tpu::VectorLayout::ImplicitDim::kMinor:
return MlirTpuImplicitDimMinor;
case mlir::tpu::VectorLayout::ImplicitDim::kSecondMinor:
return MlirTpuImplicitDimSecondMinor;
}
LOG(FATAL) << "Invalid implicit dim (C++)";
}
mlir::tpu::VectorLayout::ImplicitDim unwrap(MlirTpuImplicitDim implicit_dim) {
switch (implicit_dim) {
case MlirTpuImplicitDimNone:
return mlir::tpu::VectorLayout::ImplicitDim::kNone;
case MlirTpuImplicitDimMinor:
return mlir::tpu::VectorLayout::ImplicitDim::kMinor;
case MlirTpuImplicitDimSecondMinor:
return mlir::tpu::VectorLayout::ImplicitDim::kSecondMinor;
}
LOG(FATAL) << "Invalid implicit dim (C)";
}
mlir::tpu::Direction unwrap(MlirTpuDirection direction) {
switch (direction) {
case MlirTpuDirectionSublanes:
return mlir::tpu::Direction::kSublanes;
case MlirTpuImplicitDimMinor:
return mlir::tpu::Direction::kLanes;
case MlirTpuImplicitDimSecondMinor:
return mlir::tpu::Direction::kSubelements;
}
LOG(FATAL) << "Invalid direction (C)";
}
MlirTpuLayoutOffsets wrap(mlir::tpu::LayoutOffsets offsets) {
return {offsets[0].value_or(-1), offsets[1].value_or(-1)};
}
mlir::tpu::LayoutOffsets unwrap(MlirTpuLayoutOffsets offsets) {
auto translateOffset = [](int64_t offset) {
CHECK_GE(offset, -1);
return offset == -1 ? std::nullopt : mlir::tpu::LayoutOffset{offset};
};
return {translateOffset(offsets.sublane), translateOffset(offsets.lane)};
}
std::array<bool, 2> unwrap(MlirTpuBoolTargetTuple arr) {
return {arr.sublane, arr.lane};
}
std::array<int64_t, 2> unwrap(MlirTpuI64TargetTuple arr) {
return {arr.sublane, arr.lane};
}
MlirTpuI64TargetTuple wrap(std::array<int64_t, 2> arr) {
return {arr[0], arr[1]};
}
mlir::OpBuilder mlirTpuInsertionPointToOpBuilder(
MlirTpuInsertionPoint insertion_point) {
mlir::Operation *ref_operation = unwrap(insertion_point.ref_operation);
return ref_operation == nullptr
? mlir::OpBuilder::atBlockEnd(unwrap(insertion_point.block))
: mlir::OpBuilder(ref_operation);
}
// We do not use the names wrap/unwrap for MlirTpuI64ArrayRef because whether
// they should refer to SmallVector or ArrayRef is ambiguous
MlirTpuI64ArrayRef mlirTpuI64ArrayRefFromLlvmSmallVector(
const mlir::SmallVector<int64_t> &vec) {
// TODO(tlongeri): It would be good to steal the buffer from implicit_shape,
// but there are no public member functions for this.
int64_t *ptr =
static_cast<int64_t *>(llvm::safe_malloc(vec.size() * sizeof(int64_t)));
memcpy(ptr, vec.data(), vec.size() * sizeof(int64_t));
return {ptr, vec.size()};
}
llvm::ArrayRef<int64_t> mlirTpuI64ArrayRefToLlvmArrayRef(
MlirTpuI64ArrayRef tpu_array_ref) {
return {tpu_array_ref.ptr, tpu_array_ref.size};
}
} // namespace
extern "C" {
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(TPU, tpu, mlir::tpu::TPUDialect);
@ -51,6 +152,160 @@ void mlirTPUAnalyzePotentialCommunication(MlirOperation op,
*has_custom_barrier = result.second;
}
MlirTpuVectorLayout mlirTpuVectorLayoutCreate(int bitwidth,
MlirTpuLayoutOffsets offsets,
MlirTpuI64TargetTuple tiling,
MlirTpuImplicitDim implicit_dim) {
return wrap(new mlir::tpu::VectorLayout(
bitwidth, unwrap(offsets), unwrap(tiling), unwrap(implicit_dim)));
}
void mlirTpuVectorLayoutDestroy(MlirTpuVectorLayout layout) {
delete unwrap(layout);
}
int mlirTpuVectorLayoutGetBitwidth(MlirTpuVectorLayout layout) {
return unwrap(layout)->bitwidth();
}
MlirTpuLayoutOffsets mlirTpuVectorLayoutGetOffsets(MlirTpuVectorLayout layout) {
return wrap(unwrap(layout)->offsets());
}
MlirTpuI64TargetTuple mlirTpuVectorLayoutGetTiling(MlirTpuVectorLayout layout) {
return wrap(unwrap(layout)->tiling());
}
MlirTpuImplicitDim mlirTpuVectorLayoutGetImplicitDim(
MlirTpuVectorLayout layout) {
return wrap(unwrap(layout)->implicit_dim());
}
int mlirTpuVectorLayoutGetPacking(MlirTpuVectorLayout layout) {
return unwrap(layout)->packing();
}
int mlirTpuVectorLayoutGetLayoutRank(MlirTpuVectorLayout layout) {
return unwrap(layout)->layout_rank();
}
bool mlirTpuVectorLayoutEquals(MlirTpuVectorLayout lhs,
MlirTpuVectorLayout rhs) {
return *unwrap(lhs) == *unwrap(rhs);
}
int64_t mlirTpuVectorLayoutTilesPerVreg(MlirTpuVectorLayout layout,
MlirTpuI64TargetTuple target_shape) {
return unwrap(layout)->tilesPerVreg(unwrap(target_shape));
}
int64_t mlirTpuVectorLayoutSublanesPerTile(MlirTpuVectorLayout layout,
MlirTpuI64TargetTuple target_shape) {
return unwrap(layout)->sublanesPerTile(unwrap(target_shape));
}
MlirTpuI64TargetTuple mlirTpuVectorLayoutVregSlice(
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape) {
return wrap(unwrap(layout)->vregSlice(unwrap(target_shape)));
}
MlirTpuI64ArrayRef mlirTpuVectorLayoutImplicitShape(MlirTpuVectorLayout layout,
MlirTpuI64ArrayRef shape) {
mlir::SmallVector<int64_t> implicit_shape =
unwrap(layout)->implicitShape(mlirTpuI64ArrayRefToLlvmArrayRef(shape));
return mlirTpuI64ArrayRefFromLlvmSmallVector(implicit_shape);
}
MlirTpuI64ArrayRef mlirTpuVectorLayoutTileArrayShape(
MlirTpuVectorLayout layout, MlirTpuI64ArrayRef shape,
MlirTpuI64TargetTuple target_shape) {
mlir::SmallVector<int64_t> tile_array_shape = unwrap(layout)->tileArrayShape(
mlirTpuI64ArrayRefToLlvmArrayRef(shape), unwrap(target_shape));
return mlirTpuI64ArrayRefFromLlvmSmallVector(tile_array_shape);
}
MlirTpuVregDataBounds mlirTpuVectorLayoutTileDataBounds(
MlirTpuVectorLayout layout, MlirContext ctx, int64_t *full_shape,
int64_t *idxs, size_t size, MlirTpuI64TargetTuple target_shape,
MlirTpuBoolTargetTuple allow_replicated) {
std::unique_ptr<mlir::tpu::VRegDataBounds> ptr =
unwrap(layout)->tileDataBounds(
unwrap(ctx), llvm::ArrayRef<int64_t>{full_shape, size},
llvm::ArrayRef<int64_t>{idxs, size}, unwrap(target_shape),
unwrap(allow_replicated));
return wrap(ptr.release());
}
bool mlirTpuVectorLayoutHasNaturalTopology(MlirTpuVectorLayout layout,
MlirTpuI64TargetTuple target_shape) {
return unwrap(layout)->hasNaturalTopology(unwrap(target_shape));
}
bool mlirTpuVectorLayoutHasNativeTiling(MlirTpuVectorLayout layout,
MlirTpuI64TargetTuple target_shape) {
return unwrap(layout)->hasNativeTiling(unwrap(target_shape));
}
bool mlirTpuVectorLayoutGeneralizes(MlirTpuVectorLayout layout,
MlirTpuVectorLayout other,
MlirTpuI64ArrayRef shape,
MlirTpuI64TargetTuple target_shape) {
return unwrap(layout)->generalizes(*unwrap(other),
mlirTpuI64ArrayRefToLlvmArrayRef(shape),
unwrap(target_shape));
}
bool mlirTpuVectorLayoutEquivalentTo(MlirTpuVectorLayout layout,
MlirTpuVectorLayout other,
MlirTpuI64ArrayRef shape,
MlirTpuI64TargetTuple target_shape) {
return unwrap(layout)->equivalentTo(*unwrap(other),
mlirTpuI64ArrayRefToLlvmArrayRef(shape),
unwrap(target_shape));
}
void mlirTpuVregDataBoundsDestroy(MlirTpuVregDataBounds data_bounds) {
delete unwrap(data_bounds);
}
bool mlirTpuVregDataBoundsMaskVariesAlong(MlirTpuVregDataBounds data_bounds,
MlirTpuDirection direction,
MlirTpuI64TargetTuple target_shape) {
return unwrap(data_bounds)
->maskVariesAlong(unwrap(direction), unwrap(target_shape));
}
bool mlirTpuVregDataBoundsIsComplete(MlirTpuVregDataBounds data_bounds,
MlirTpuI64TargetTuple target_shape) {
return unwrap(data_bounds)->isComplete(unwrap(target_shape));
}
MlirValue mlirTpuVregDataBoundsGetVectorMask(
MlirTpuVregDataBounds data_bounds, MlirTpuInsertionPoint insertion_point,
MlirLocation location, int generation, MlirTpuI64TargetTuple target_shape) {
mlir::OpBuilder builder = mlirTpuInsertionPointToOpBuilder(insertion_point);
auto failure_or_mask = unwrap(data_bounds)
->getVectorMask(builder, unwrap(location),
generation, unwrap(target_shape));
if (failed(failure_or_mask)) {
return wrap(mlir::Value());
} else {
return wrap(failure_or_mask.value());
}
}
MlirAttribute mlirTpuVregDataBoundsGetSublaneMask(
MlirTpuVregDataBounds data_bounds, MlirContext ctx,
MlirTpuI64TargetTuple target_shape) {
return wrap(
unwrap(data_bounds)->getSublaneMask(unwrap(ctx), unwrap(target_shape)));
}
}
#include "mlir/CAPI/Pass.h" // IWYU pragma: keep
#include "mlir/CAPI/Support.h" // IWYU pragma: keep
extern "C" {
using namespace mlir::tpu;
#include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_passes.capi.cc.inc"

View File

@ -13,10 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Refer to the corresponding C++ declarations in layout.h and
// apply_vector_layout.h for documentation on the functions in this file
#ifndef JAXLIB_MOSAIC_DIALECT_TPU_INTEGRATIONS_C_TPU_DIALECT_H_
#define JAXLIB_MOSAIC_DIALECT_TPU_INTEGRATIONS_C_TPU_DIALECT_H_
#include "mlir/include/mlir-c/IR.h"
#ifndef __cplusplus
#include <stdbool.h>
#endif
#include <stddef.h>
#include <stdint.h>
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
extern "C" {
@ -35,6 +45,149 @@ mlirTPUTiledLayoutAttrGetTiles(MlirAttribute attr);
MLIR_CAPI_EXPORTED void mlirTPUAnalyzePotentialCommunication(
MlirOperation op, bool* has_communication, bool* has_custom_barrier);
typedef enum MlirTpuImplicitDim {
MlirTpuImplicitDimNone = 0,
MlirTpuImplicitDimMinor = 1,
MlirTpuImplicitDimSecondMinor = 2,
} MlirTpuImplicitDim;
typedef enum MlirTpuDirection {
MlirTpuDirectionSublanes,
MlirTpuDirectionLanes,
MlirTpuDirectionSubelements
} MlirTpuDirection;
// Opaque reference to an owned layout
typedef struct MlirTpuVectorLayout {
void* ptr;
} MlirTpuVectorLayout;
// Opaque reference to owned data bounds
typedef struct MlirTpuVregDataBounds {
void* ptr;
} MlirTpuVregDataBounds;
// mlir::ArrayRef<int64_t> equivalent
// Unlike mlir::ArrayRef, the data may or may not be owned (this should be
// defined by the producer of the struct).
typedef struct MlirTpuI64ArrayRef {
int64_t* ptr;
size_t size;
} MlirTpuI64ArrayRef;
typedef struct MlirTpuLayoutOffsets {
// Use -1 for replicated
int64_t sublane;
int64_t lane;
} MlirTpuLayoutOffsets;
typedef struct MlirTpuI64TargetTuple {
int64_t sublane;
int64_t lane;
} MlirTpuI64TargetTuple;
typedef struct MlirTpuBoolTargetTuple {
bool sublane;
bool lane;
} MlirTpuBoolTargetTuple;
// An insertion point within a block.
// The MLIR C API does not already have a similar struct, unfortunately.
typedef struct MlirTpuInsertionPoint {
MlirBlock block; // Only used when ref_operation is unspecified (null)
MlirOperation ref_operation;
} MlirTpuInsertionPoint;
// Caller owns the returned object and is responsible for calling
// mlirTpuVectorLayoutDestroy
MLIR_CAPI_EXPORTED MlirTpuVectorLayout mlirTpuVectorLayoutCreate(
int bitwidth, MlirTpuLayoutOffsets offsets, MlirTpuI64TargetTuple tiling,
MlirTpuImplicitDim implicit_dim);
MLIR_CAPI_EXPORTED void mlirTpuVectorLayoutDestroy(MlirTpuVectorLayout);
MLIR_CAPI_EXPORTED int mlirTpuVectorLayoutGetBitwidth(
MlirTpuVectorLayout layout);
MLIR_CAPI_EXPORTED MlirTpuLayoutOffsets
mlirTpuVectorLayoutGetOffsets(MlirTpuVectorLayout layout);
MLIR_CAPI_EXPORTED MlirTpuI64TargetTuple
mlirTpuVectorLayoutGetTiling(MlirTpuVectorLayout layout);
MLIR_CAPI_EXPORTED MlirTpuImplicitDim
mlirTpuVectorLayoutGetImplicitDim(MlirTpuVectorLayout layout);
MLIR_CAPI_EXPORTED int mlirTpuVectorLayoutGetPacking(
MlirTpuVectorLayout layout);
MLIR_CAPI_EXPORTED int mlirTpuVectorLayoutGetLayoutRank(
MlirTpuVectorLayout layout);
MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutEquals(MlirTpuVectorLayout lhs,
MlirTpuVectorLayout rhs);
MLIR_CAPI_EXPORTED int64_t mlirTpuVectorLayoutTilesPerVreg(
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape);
MLIR_CAPI_EXPORTED int64_t mlirTpuVectorLayoutSublanesPerTile(
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape);
MLIR_CAPI_EXPORTED MlirTpuI64TargetTuple mlirTpuVectorLayoutVregSlice(
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape);
// Caller is responsible for calling free on the returned pointer
MLIR_CAPI_EXPORTED MlirTpuI64ArrayRef mlirTpuVectorLayoutImplicitShape(
MlirTpuVectorLayout layout, MlirTpuI64ArrayRef shape);
// Caller is responsible for calling free on the returned pointer.
MLIR_CAPI_EXPORTED MlirTpuI64ArrayRef mlirTpuVectorLayoutTileArrayShape(
MlirTpuVectorLayout layout, MlirTpuI64ArrayRef shape,
MlirTpuI64TargetTuple target_shape);
// Caller owns the returned object and is responsible for calling
// mlirTpuVectorLayoutVregDataBoundsDestroy
MLIR_CAPI_EXPORTED MlirTpuVregDataBounds mlirTpuVectorLayoutTileDataBounds(
MlirTpuVectorLayout layout, MlirContext ctx, int64_t* full_shape,
int64_t* idxs, size_t size, MlirTpuI64TargetTuple target_shape,
MlirTpuBoolTargetTuple allow_replicated);
MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutHasNaturalTopology(
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape);
MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutHasNativeTiling(
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape);
// `shape` is optional, pass a shape with a null `ptr` to return true iff the
// "generalizes" relationship applies to all shapes.
MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutGeneralizes(
MlirTpuVectorLayout layout, MlirTpuVectorLayout other,
MlirTpuI64ArrayRef shape, MlirTpuI64TargetTuple target_shape);
// `shape` is optional, pass a shape with a null `ptr` to return true iff the
// "equivalent to" relationship applies to all shapes.
MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutEquivalentTo(
MlirTpuVectorLayout layout, MlirTpuVectorLayout other,
MlirTpuI64ArrayRef shape, MlirTpuI64TargetTuple target_shape);
MLIR_CAPI_EXPORTED void mlirTpuVregDataBoundsDestroy(
MlirTpuVregDataBounds data_bounds);
bool mlirTpuVregDataBoundsMaskVariesAlong(MlirTpuVregDataBounds data_bounds,
MlirTpuDirection direction,
MlirTpuI64TargetTuple target_shape);
bool mlirTpuVregDataBoundsIsComplete(MlirTpuVregDataBounds data_bounds,
MlirTpuI64TargetTuple target_shape);
// Returns null on failure
MlirValue mlirTpuVregDataBoundsGetVectorMask(
MlirTpuVregDataBounds data_bounds, MlirTpuInsertionPoint insertion_point,
MlirLocation location, int generation, MlirTpuI64TargetTuple target_shape);
MlirAttribute mlirTpuVregDataBoundsGetSublaneMask(
MlirTpuVregDataBounds data_bounds, MlirContext ctx,
MlirTpuI64TargetTuple target_shape);
#ifdef __cplusplus
}
#endif