mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Migrate JAX MLIR Python dialect extensions to nanobind.
Now that https://github.com/llvm/llvm-project/pull/117922 has landed upstream, we can work towards removing our last uses of pybind11. PiperOrigin-RevId: 705872751
This commit is contained in:
parent
5a3fa500b5
commit
64eae324ee
@ -102,7 +102,7 @@ class LoweringRuleContext:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoweringResult:
|
||||
"""Keeps pybind11 objects alive."""
|
||||
"""Keeps python objects alive."""
|
||||
|
||||
module: ir.Module
|
||||
grid: tuple[int, ...]
|
||||
|
@ -27,11 +27,6 @@ namespace jax {
|
||||
//
|
||||
// Failing statuses become Python exceptions; OK Status() becomes None.
|
||||
//
|
||||
// Given there can be only a single global pybind11 type_caster for the
|
||||
// `absl::Status` type, and given XLA wants a custom exception being raised,
|
||||
// we use a dedicated helper to implement this feature without relying on a
|
||||
// global `type_caster`.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// - Functions without arguments:
|
||||
|
@ -25,9 +25,9 @@ limitations under the License.
|
||||
|
||||
namespace jax {
|
||||
|
||||
// See kernel_pybind11_helpers.h for info on descriptor objects. We separate out
|
||||
// the functionality that doesn't require pybind11 for building CUDA libraries,
|
||||
// since older versions nvcc don't seem to be able to compile pybind11.
|
||||
// See kernel_nanobind_helpers.h for info on descriptor objects. We separate out
|
||||
// the functionality that doesn't require nanobind for building CUDA libraries,
|
||||
// since older versions nvcc don't seem to be able to compile nanobind.
|
||||
|
||||
// Packs a descriptor object into a byte string.
|
||||
template <typename T>
|
||||
|
@ -158,6 +158,7 @@ py_extension(
|
||||
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -177,8 +178,10 @@ py_extension(
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
|
||||
"@nanobind",
|
||||
"@xla//xla/python:nb_numpy",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -203,8 +206,8 @@ pybind_extension(
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"//jaxlib/triton:triton_dialect_capi_headers",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -266,9 +269,9 @@ py_extension(
|
||||
"@llvm-project//mlir:CAPISCFHeaders",
|
||||
"@llvm-project//mlir:CAPITransformsHeaders",
|
||||
"@llvm-project//mlir:CAPIVectorHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@local_config_python//:headers",
|
||||
"@pybind11",
|
||||
"@nanobind",
|
||||
"@shardy//shardy/integrations/c:sdy_capi_headers",
|
||||
],
|
||||
)
|
||||
|
@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// clang-format: off
|
||||
// pybind11 must be included before mlir/Bindings/Python/PybindAdaptors.h,
|
||||
// otherwise this code will not build on Windows.
|
||||
#include "pybind11/pybind11.h"
|
||||
// clang-format: on
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h" // IWYU pragma: keep
|
||||
#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h"
|
||||
|
||||
PYBIND11_MODULE(_mosaic_gpu_ext, m, py::mod_gil_not_used()) {
|
||||
namespace nb = nanobind;
|
||||
|
||||
NB_MODULE(_mosaic_gpu_ext, m) {
|
||||
m.def(
|
||||
"register_dialect",
|
||||
[](MlirContext context, bool load) {
|
||||
@ -33,5 +30,5 @@ PYBIND11_MODULE(_mosaic_gpu_ext, m, py::mod_gil_not_used()) {
|
||||
mlirDialectHandleLoadDialect(dialect, context);
|
||||
}
|
||||
},
|
||||
py::arg("context"), py::arg("load") = true);
|
||||
nb::arg("context"), nb::arg("load") = true);
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
// Registers MLIR dialects used by JAX.
|
||||
// This module is called by mlir/__init__.py during initialization.
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#include "mlir-c/Dialect/Arith.h"
|
||||
#include "mlir-c/Dialect/Func.h"
|
||||
@ -13,17 +13,17 @@
|
||||
#include "mlir-c/Dialect/SCF.h"
|
||||
#include "mlir-c/Dialect/Vector.h"
|
||||
#include "mlir-c/Transforms.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "mlir/Bindings/Python/NanobindAdaptors.h"
|
||||
#include "shardy/integrations/c/passes.h"
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace nb = nanobind;
|
||||
|
||||
#define REGISTER_DIALECT(name) \
|
||||
MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \
|
||||
mlirDialectHandleInsertDialect(name##_dialect, registry)
|
||||
|
||||
PYBIND11_MODULE(register_jax_dialects, m, py::mod_gil_not_used()) {
|
||||
NB_MODULE(register_jax_dialects, m) {
|
||||
m.doc() = "Registers upstream MLIR dialects used by JAX.";
|
||||
|
||||
m.def("register_dialects", [](MlirDialectRegistry registry) {
|
||||
|
@ -27,12 +27,6 @@ limitations under the License.
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
// clang-format: off
|
||||
// pybind11 must be included before mlir/Bindings/Python/PybindAdaptors.h,
|
||||
// otherwise this code will not build on Windows.
|
||||
#include "pybind11/pybind11.h"
|
||||
// clang-format: on
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/SmallVectorExtras.h"
|
||||
@ -46,19 +40,24 @@ limitations under the License.
|
||||
#include "mlir-c/Dialect/Func.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "pybind11/attr.h"
|
||||
#include "pybind11/cast.h"
|
||||
#include "pybind11/detail/common.h"
|
||||
#include "pybind11/numpy.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "nanobind/stl/optional.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/pair.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/string.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/variant.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/vector.h" // IWYU pragma: keep
|
||||
#include "absl/log/check.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h"
|
||||
#include "xla/python/nb_numpy.h"
|
||||
#include "xla/tsl/python/lib/core/numpy.h"
|
||||
|
||||
// TODO(tlongeri): Can I add my own return type annotations to functions?
|
||||
// TODO(tlongeri): I don't understand why MLIR uses the C API to implement
|
||||
// Python bindings. Do we have a reason to do that?
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
namespace {
|
||||
constexpr const char LAYOUT_DEFS_MODULE[] =
|
||||
"jax.jaxlib.mosaic.python.layout_defs";
|
||||
@ -75,20 +74,21 @@ constexpr MlirTpuI64TargetTuple TARGET_SHAPE{8, 128};
|
||||
class NotImplementedException : public std::runtime_error {
|
||||
using runtime_error::runtime_error;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
struct py::detail::type_caster<MlirTpuImplicitDim> {
|
||||
PYBIND11_TYPE_CASTER(MlirTpuImplicitDim, const_name("ImplicitDim | None"));
|
||||
struct nb::detail::type_caster<MlirTpuImplicitDim> {
|
||||
NB_TYPE_CASTER(MlirTpuImplicitDim, const_name("ImplicitDim | None"));
|
||||
|
||||
bool load(handle src, bool) {
|
||||
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
|
||||
if (src.is_none()) {
|
||||
value = MlirTpuImplicitDimNone;
|
||||
return true;
|
||||
}
|
||||
auto implicit_dim_cls =
|
||||
py::module_::import(LAYOUT_DEFS_MODULE).attr("ImplicitDim");
|
||||
if (!py::isinstance(src, implicit_dim_cls)) {
|
||||
nb::module_::import_(LAYOUT_DEFS_MODULE).attr("ImplicitDim");
|
||||
if (!nb::isinstance(src, implicit_dim_cls)) {
|
||||
return false;
|
||||
}
|
||||
if (src.is(implicit_dim_cls.attr("MINOR"))) {
|
||||
@ -96,36 +96,36 @@ struct py::detail::type_caster<MlirTpuImplicitDim> {
|
||||
} else if (src.is(implicit_dim_cls.attr("SECOND_MINOR"))) {
|
||||
value = MlirTpuImplicitDimSecondMinor;
|
||||
} else {
|
||||
throw py::value_error();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static handle cast(MlirTpuImplicitDim implicit_dim,
|
||||
return_value_policy /* policy */, handle /* parent */) {
|
||||
static handle from_cpp(MlirTpuImplicitDim implicit_dim, rv_policy policy,
|
||||
cleanup_list* cleanup) noexcept {
|
||||
auto implicit_dim_cls =
|
||||
py::module_::import(LAYOUT_DEFS_MODULE).attr("ImplicitDim");
|
||||
nb::module_::import_(LAYOUT_DEFS_MODULE).attr("ImplicitDim");
|
||||
switch (implicit_dim) {
|
||||
case MlirTpuImplicitDimNone:
|
||||
return py::none().release();
|
||||
return nb::none().release();
|
||||
case MlirTpuImplicitDimMinor:
|
||||
return static_cast<py::object>(implicit_dim_cls.attr("MINOR"))
|
||||
return static_cast<nb::object>(implicit_dim_cls.attr("MINOR"))
|
||||
.release();
|
||||
case MlirTpuImplicitDimSecondMinor:
|
||||
return static_cast<py::object>(implicit_dim_cls.attr("SECOND_MINOR"))
|
||||
return static_cast<nb::object>(implicit_dim_cls.attr("SECOND_MINOR"))
|
||||
.release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct py::detail::type_caster<MlirTpuDirection> {
|
||||
PYBIND11_TYPE_CASTER(MlirTpuDirection, const_name("Direction"));
|
||||
struct nb::detail::type_caster<MlirTpuDirection> {
|
||||
NB_TYPE_CASTER(MlirTpuDirection, const_name("Direction"));
|
||||
|
||||
bool load(handle src, bool) {
|
||||
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
|
||||
auto direction_cls =
|
||||
py::module_::import(LAYOUT_DEFS_MODULE).attr("Direction");
|
||||
if (!py::isinstance(src, direction_cls)) {
|
||||
nb::module_::import_(LAYOUT_DEFS_MODULE).attr("Direction");
|
||||
if (!nb::isinstance(src, direction_cls)) {
|
||||
return false;
|
||||
}
|
||||
if (src.is(direction_cls.attr("LANES"))) {
|
||||
@ -135,26 +135,28 @@ struct py::detail::type_caster<MlirTpuDirection> {
|
||||
} else if (src.is(direction_cls.attr("SUBELEMENTS"))) {
|
||||
value = MlirTpuDirectionSubelements;
|
||||
} else {
|
||||
throw py::value_error();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static handle cast(MlirTpuDirection direction,
|
||||
return_value_policy /* policy */, handle /* parent */) {
|
||||
static handle from_cpp(MlirTpuDirection direction, rv_policy /* policy */,
|
||||
cleanup_list* /* cleanup */) noexcept {
|
||||
auto direction_cls =
|
||||
py::module_::import(LAYOUT_DEFS_MODULE).attr("ImplicitDim");
|
||||
nb::module_::import_(LAYOUT_DEFS_MODULE).attr("ImplicitDim");
|
||||
switch (direction) {
|
||||
case MlirTpuDirectionLanes:
|
||||
return static_cast<py::object>(direction_cls.attr("LANES")).release();
|
||||
return static_cast<nb::object>(direction_cls.attr("LANES")).release();
|
||||
case MlirTpuDirectionSublanes:
|
||||
return static_cast<py::object>(direction_cls.attr("SUBLANES"))
|
||||
return static_cast<nb::object>(direction_cls.attr("SUBLANES"))
|
||||
.release();
|
||||
case MlirTpuDirectionSubelements:
|
||||
return static_cast<py::object>(direction_cls.attr("SUBELEMENTS"))
|
||||
return static_cast<nb::object>(direction_cls.attr("SUBELEMENTS"))
|
||||
.release();
|
||||
default:
|
||||
throw py::value_error();
|
||||
PyErr_Format(PyExc_ValueError, "Invalid MlirTpuDirection: %d",
|
||||
static_cast<int>(direction));
|
||||
return nb::handle();
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -163,9 +165,9 @@ namespace {
|
||||
// Handler for use with MLIR C API print functions. The 2nd parameter is an
|
||||
// opaque pointer to "user data" that should always be a string.
|
||||
void printToString(MlirStringRef c_mlir_str, void* opaque_string) {
|
||||
std::string* str = static_cast<std::string*>(opaque_string);
|
||||
CHECK(str != nullptr);
|
||||
str->append(c_mlir_str.data, c_mlir_str.length);
|
||||
std::string* str = static_cast<std::string*>(opaque_string);
|
||||
CHECK(str != nullptr);
|
||||
str->append(c_mlir_str.data, c_mlir_str.length);
|
||||
}
|
||||
|
||||
class DiagnosticCapture {
|
||||
@ -217,141 +219,134 @@ class DiagnosticCapture {
|
||||
const MlirDiagnosticHandlerID id_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Holder {};
|
||||
|
||||
// Holder class for MlirTpuVectorLayout, to deal properly with destruction.
|
||||
// TODO(tlongeri): It would be nice to not have a seemingly unnecessary
|
||||
// "pointer-to-pointer" (MlirTpuVectorLayout is basically an opaque pointer).
|
||||
// But I'm not sure if that's possible since pybind expects get() to return a
|
||||
// true pointer type.
|
||||
template <>
|
||||
class Holder<MlirTpuVectorLayout> {
|
||||
public:
|
||||
Holder(MlirTpuVectorLayout layout) : ptr(new MlirTpuVectorLayout(layout)) {}
|
||||
Holder(MlirTpuVectorLayout* layout) : ptr(layout) {}
|
||||
Holder(Holder<MlirTpuVectorLayout>&& other) = default;
|
||||
~Holder() { mlirTpuVectorLayoutDestroy(*ptr); }
|
||||
MlirTpuVectorLayout* get() { return ptr.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<MlirTpuVectorLayout> ptr;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
PYBIND11_DECLARE_HOLDER_TYPE(T, Holder<T>);
|
||||
|
||||
namespace {
|
||||
py::object toPyLayoutOffset(int64_t offset) {
|
||||
nb::object toPyLayoutOffset(int64_t offset) {
|
||||
CHECK_GE(offset, -1);
|
||||
if (offset == -1) {
|
||||
return py::module_::import(LAYOUT_DEFS_MODULE).attr("REPLICATED");
|
||||
return nb::module_::import_(LAYOUT_DEFS_MODULE).attr("REPLICATED");
|
||||
} else {
|
||||
return py::int_(offset);
|
||||
return nb::int_(offset);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(tlongeri): Would `type_caster`s let me avoid defining all of these
|
||||
// to/from functions?
|
||||
int64_t offsetFromPyOffset(py::object py_offset) {
|
||||
if (py::isinstance<py::int_>(py_offset)) {
|
||||
int64_t offset = py::cast<py::int_>(py_offset);
|
||||
int64_t offsetFromPyOffset(nb::object py_offset) {
|
||||
if (nb::isinstance<nb::int_>(py_offset)) {
|
||||
int64_t offset = nb::cast<int64_t>(py_offset);
|
||||
if (offset < 0) {
|
||||
throw py::value_error("Invalid py layout offset");
|
||||
throw nb::value_error("Invalid py layout offset");
|
||||
}
|
||||
return offset;
|
||||
} else if (py_offset.equal(
|
||||
py::module_::import(LAYOUT_DEFS_MODULE).attr("REPLICATED"))) {
|
||||
nb::module_::import_(LAYOUT_DEFS_MODULE).attr("REPLICATED"))) {
|
||||
return -1;
|
||||
} else {
|
||||
throw py::type_error("Invalid layout offset type");
|
||||
throw nb::type_error("Invalid layout offset type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
llvm::SmallVector<T> sequenceToSmallVector(py::sequence seq) {
|
||||
return llvm::map_to_vector(
|
||||
seq, [](py::handle handle) { return py::cast<T>(handle); });
|
||||
llvm::SmallVector<T> sequenceToSmallVector(nb::sequence seq) {
|
||||
llvm::SmallVector<T> out;
|
||||
out.reserve(nb::len(seq));
|
||||
for (nb::handle elem : seq) {
|
||||
out.push_back(nb::cast<T>(elem));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
py::tuple toPyTuple(const int64_t* data, size_t count) {
|
||||
py::tuple tuple(count);
|
||||
nb::tuple toPyTuple(const int64_t* data, size_t count) {
|
||||
nb::tuple tuple = nb::steal<nb::tuple>(PyTuple_New(count));
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
tuple[i] = data[i];
|
||||
PyTuple_SET_ITEM(tuple.ptr(), i, nb::int_(data[i]).release().ptr());
|
||||
}
|
||||
return tuple;
|
||||
}
|
||||
|
||||
py::tuple toPyTuple(MlirTpuI64TargetTuple tuple) {
|
||||
return py::make_tuple(tuple.sublane, tuple.lane);
|
||||
nb::tuple toPyTuple(MlirTpuI64TargetTuple tuple) {
|
||||
return nb::make_tuple(tuple.sublane, tuple.lane);
|
||||
}
|
||||
|
||||
// Unwraps the current default insertion point
|
||||
// ValueError is raised if default insertion point is not set
|
||||
MlirTpuInsertionPoint getDefaultInsertionPoint() {
|
||||
py::object insertion_point =
|
||||
py::module_::import(IR_MODULE).attr("InsertionPoint").attr("current");
|
||||
py::object ref_operation = insertion_point.attr("ref_operation");
|
||||
return {py::cast<MlirBlock>(insertion_point.attr("block")),
|
||||
nb::object insertion_point =
|
||||
nb::module_::import_(IR_MODULE).attr("InsertionPoint").attr("current");
|
||||
nb::object ref_operation = insertion_point.attr("ref_operation");
|
||||
return {nb::cast<MlirBlock>(insertion_point.attr("block")),
|
||||
ref_operation.is_none()
|
||||
? MlirOperation{nullptr}
|
||||
: py::cast<MlirOperation>(insertion_point.attr("ref_operation"))};
|
||||
: nb::cast<MlirOperation>(insertion_point.attr("ref_operation"))};
|
||||
}
|
||||
|
||||
// Unwraps the current default location
|
||||
// ValueError is raised if default location is not set
|
||||
MlirLocation getDefaultLocation() {
|
||||
return py::cast<MlirLocation>(
|
||||
py::module_::import(IR_MODULE).attr("Location").attr("current"));
|
||||
return nb::cast<MlirLocation>(
|
||||
nb::module_::import_(IR_MODULE).attr("Location").attr("current"));
|
||||
}
|
||||
|
||||
// Unwraps the current default MLIR context
|
||||
// ValueError is raised if default context is not set
|
||||
MlirContext getDefaultContext() {
|
||||
return py::cast<MlirContext>(
|
||||
py::module_::import(IR_MODULE).attr("Context").attr("current"));
|
||||
return nb::cast<MlirContext>(
|
||||
nb::module_::import_(IR_MODULE).attr("Context").attr("current"));
|
||||
}
|
||||
|
||||
struct PyTpuVectorLayout {
|
||||
PyTpuVectorLayout(MlirTpuVectorLayout layout) : layout(layout) {}
|
||||
~PyTpuVectorLayout() { mlirTpuVectorLayoutDestroy(layout); }
|
||||
PyTpuVectorLayout(const PyTpuVectorLayout&) = delete;
|
||||
PyTpuVectorLayout& operator=(const PyTpuVectorLayout&) = delete;
|
||||
|
||||
MlirTpuVectorLayout layout;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
|
||||
NB_MODULE(_tpu_ext, m) {
|
||||
tsl::ImportNumpy();
|
||||
mlirRegisterTPUPasses(); // Register all passes on load.
|
||||
mlirTpuRegisterMosaicSerdePass();
|
||||
|
||||
py::class_<MlirTpuApplyVectorLayoutContext>(m, "ApplyVectorLayoutCtx",
|
||||
py::module_local())
|
||||
.def(py::init([](int hardware_generation, py::tuple target_shape,
|
||||
py::tuple mxu_shape, int max_sublanes_in_scratch) {
|
||||
if (target_shape.size() != 2) {
|
||||
throw py::value_error("target_shape should be of length 2");
|
||||
}
|
||||
if (mxu_shape.size() != 2) {
|
||||
throw py::value_error("mxu_shape should be of length 2");
|
||||
}
|
||||
return MlirTpuApplyVectorLayoutContext{
|
||||
.hardware_generation = hardware_generation,
|
||||
.target_shape = {target_shape[0].cast<int64_t>(),
|
||||
target_shape[1].cast<int64_t>()},
|
||||
.mxu_shape = {mxu_shape[0].cast<int64_t>(),
|
||||
mxu_shape[1].cast<int64_t>()},
|
||||
.max_sublanes_in_scratch = max_sublanes_in_scratch};
|
||||
}),
|
||||
py::arg("hardware_generation") = -1,
|
||||
py::arg("target_shape") = toPyTuple(TARGET_SHAPE),
|
||||
py::arg("mxu_shape") = py::make_tuple(128, 128),
|
||||
py::arg("max_sublanes_in_scratch") = 0);
|
||||
nb::class_<MlirTpuApplyVectorLayoutContext>(m, "ApplyVectorLayoutCtx")
|
||||
.def(
|
||||
"__init__",
|
||||
[](MlirTpuApplyVectorLayoutContext* self, int hardware_generation,
|
||||
nb::tuple target_shape, nb::tuple mxu_shape,
|
||||
int max_sublanes_in_scratch) {
|
||||
if (target_shape.size() != 2) {
|
||||
throw nb::value_error("target_shape should be of length 2");
|
||||
}
|
||||
if (mxu_shape.size() != 2) {
|
||||
throw nb::value_error("mxu_shape should be of length 2");
|
||||
}
|
||||
new (self) MlirTpuApplyVectorLayoutContext{
|
||||
.hardware_generation = hardware_generation,
|
||||
.target_shape = {nb::cast<int64_t>(target_shape[0]),
|
||||
nb::cast<int64_t>(target_shape[1])},
|
||||
.mxu_shape = {nb::cast<int64_t>(mxu_shape[0]),
|
||||
nb::cast<int64_t>(mxu_shape[1])},
|
||||
.max_sublanes_in_scratch = max_sublanes_in_scratch};
|
||||
},
|
||||
nb::arg("hardware_generation") = -1,
|
||||
nb::arg("target_shape") = toPyTuple(TARGET_SHAPE),
|
||||
nb::arg("mxu_shape") = nb::make_tuple(128, 128),
|
||||
nb::arg("max_sublanes_in_scratch") = 0);
|
||||
|
||||
py::class_<MlirTpuVregDataBounds>(m, "VRegDataBounds", py::module_local())
|
||||
nb::class_<MlirTpuVregDataBounds>(m, "VRegDataBounds")
|
||||
.def("mask_varies_along",
|
||||
[](MlirTpuVregDataBounds self, MlirTpuDirection direction) {
|
||||
return mlirTpuVregDataBoundsMaskVariesAlong(self, direction,
|
||||
TARGET_SHAPE);
|
||||
})
|
||||
.def_property_readonly("complete",
|
||||
[](MlirTpuVregDataBounds self) {
|
||||
return mlirTpuVregDataBoundsIsComplete(
|
||||
self, TARGET_SHAPE);
|
||||
})
|
||||
.def_prop_ro("complete",
|
||||
[](MlirTpuVregDataBounds self) {
|
||||
return mlirTpuVregDataBoundsIsComplete(self, TARGET_SHAPE);
|
||||
})
|
||||
.def("get_vector_mask",
|
||||
[](MlirTpuVregDataBounds self, int generation) {
|
||||
// TODO: Does this work? Test in Python
|
||||
@ -370,63 +365,79 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
|
||||
|
||||
// TODO(tlongeri): More precise argument type annotations. There currently
|
||||
// seems to be no way to define your own?
|
||||
py::class_<MlirTpuVectorLayout, Holder<MlirTpuVectorLayout>>(
|
||||
m, "VectorLayout", py::module_local())
|
||||
.def(py::init([](int bitwidth, py::tuple offsets, py::tuple tiling,
|
||||
MlirTpuImplicitDim implicit_dim) {
|
||||
if (offsets.size() != 2) {
|
||||
throw py::value_error("Offsets should be of length 2");
|
||||
}
|
||||
if (tiling.size() != 2) {
|
||||
throw py::value_error("Tiling should be of length 2");
|
||||
}
|
||||
MlirTpuVectorLayout layout = mlirTpuVectorLayoutCreate(
|
||||
bitwidth,
|
||||
{offsetFromPyOffset(offsets[0]),
|
||||
offsetFromPyOffset(offsets[1])},
|
||||
{tiling[0].cast<int64_t>(), tiling[1].cast<int64_t>()},
|
||||
implicit_dim);
|
||||
if (!mlirTpuVectorLayoutIsValid(layout, TARGET_SHAPE)) {
|
||||
throw py::value_error("Layout not valid for target shape");
|
||||
nb::class_<PyTpuVectorLayout>(m, "VectorLayout")
|
||||
.def(
|
||||
"__init__",
|
||||
[](PyTpuVectorLayout* self, int bitwidth, nb::tuple offsets,
|
||||
nb::tuple tiling, MlirTpuImplicitDim implicit_dim) {
|
||||
if (offsets.size() != 2) {
|
||||
throw nb::value_error("Offsets should be of length 2");
|
||||
}
|
||||
return layout;
|
||||
}),
|
||||
py::arg("bitwidth"), py::arg("offsets"), py::arg("tiling"),
|
||||
py::arg("implicit_dim"))
|
||||
.def_property_readonly("bitwidth", mlirTpuVectorLayoutGetBitwidth,
|
||||
"The bitwidth of the stored values.")
|
||||
.def_property_readonly(
|
||||
if (tiling.size() != 2) {
|
||||
throw nb::value_error("Tiling should be of length 2");
|
||||
}
|
||||
MlirTpuVectorLayout layout = mlirTpuVectorLayoutCreate(
|
||||
bitwidth,
|
||||
{offsetFromPyOffset(offsets[0]),
|
||||
offsetFromPyOffset(offsets[1])},
|
||||
{nb::cast<int64_t>(tiling[0]), nb::cast<int64_t>(tiling[1])},
|
||||
implicit_dim);
|
||||
if (!mlirTpuVectorLayoutIsValid(layout, TARGET_SHAPE)) {
|
||||
throw nb::value_error("Layout not valid for target shape");
|
||||
}
|
||||
new (self) PyTpuVectorLayout(layout);
|
||||
},
|
||||
nb::arg("bitwidth"), nb::arg("offsets"), nb::arg("tiling"),
|
||||
nb::arg("implicit_dim").none())
|
||||
.def_prop_ro(
|
||||
"bitwidth",
|
||||
[](const PyTpuVectorLayout& self) {
|
||||
return mlirTpuVectorLayoutGetBitwidth(self.layout);
|
||||
},
|
||||
"The bitwidth of the stored values.")
|
||||
.def_prop_ro(
|
||||
"offsets",
|
||||
[](MlirTpuVectorLayout self) {
|
||||
MlirTpuLayoutOffsets offsets = mlirTpuVectorLayoutGetOffsets(self);
|
||||
return py::make_tuple(toPyLayoutOffset(offsets.sublane),
|
||||
[](const PyTpuVectorLayout& self) {
|
||||
MlirTpuLayoutOffsets offsets =
|
||||
mlirTpuVectorLayoutGetOffsets(self.layout);
|
||||
return nb::make_tuple(toPyLayoutOffset(offsets.sublane),
|
||||
toPyLayoutOffset(offsets.lane));
|
||||
},
|
||||
"The coordinates of the first valid element. If an offset is "
|
||||
"REPLICATED, then any offset is valid as the value does not vary "
|
||||
"across sublanes or lanes respectively.")
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"tiling",
|
||||
[](MlirTpuVectorLayout self) {
|
||||
return toPyTuple(mlirTpuVectorLayoutGetTiling(self));
|
||||
[](const PyTpuVectorLayout& self) {
|
||||
return toPyTuple(mlirTpuVectorLayoutGetTiling(self.layout));
|
||||
},
|
||||
"The tiling used to lay out values (see the XLA docs). For values of "
|
||||
"bitwidth < 32, an implicit (32 // bitwidth, 1) tiling is appended "
|
||||
"to the one specified as an attribute.")
|
||||
.def_property_readonly(
|
||||
"implicit_dim", mlirTpuVectorLayoutGetImplicitDim,
|
||||
.def_prop_ro(
|
||||
"implicit_dim",
|
||||
[](const PyTpuVectorLayout& self) {
|
||||
return mlirTpuVectorLayoutGetImplicitDim(self.layout);
|
||||
},
|
||||
"If specified, the value has an implicit dim inserted in either "
|
||||
"minormost or second minormost position.")
|
||||
.def_property_readonly(
|
||||
"packing", mlirTpuVectorLayoutGetPacking,
|
||||
.def_prop_ro(
|
||||
"packing",
|
||||
[](const PyTpuVectorLayout& self) {
|
||||
return mlirTpuVectorLayoutGetPacking(self.layout);
|
||||
},
|
||||
"Returns the number of values stored in a vreg entry.")
|
||||
.def_property_readonly(
|
||||
"layout_rank", mlirTpuVectorLayoutGetLayoutRank,
|
||||
.def_prop_ro(
|
||||
"layout_rank",
|
||||
[](const PyTpuVectorLayout& self) {
|
||||
return mlirTpuVectorLayoutGetLayoutRank(self.layout);
|
||||
},
|
||||
"The number of minormost dimensions tiled by this layout.")
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"has_natural_topology",
|
||||
[](MlirTpuVectorLayout self) {
|
||||
return mlirTpuVectorLayoutHasNaturalTopology(self, TARGET_SHAPE);
|
||||
[](const PyTpuVectorLayout& self) {
|
||||
return mlirTpuVectorLayoutHasNaturalTopology(self.layout,
|
||||
TARGET_SHAPE);
|
||||
},
|
||||
"True, if every vector register has a layout without jumps.\n"
|
||||
"\n"
|
||||
@ -434,33 +445,35 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
|
||||
"always leads to a contiguous traversal of the (second) minormost "
|
||||
"dimension of data. This is only true for 32-bit types, since "
|
||||
"narrower types use two level tiling.")
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"has_native_tiling",
|
||||
[](MlirTpuVectorLayout self) {
|
||||
return mlirTpuVectorLayoutHasNativeTiling(self, TARGET_SHAPE);
|
||||
[](const PyTpuVectorLayout& self) {
|
||||
return mlirTpuVectorLayoutHasNativeTiling(self.layout,
|
||||
TARGET_SHAPE);
|
||||
},
|
||||
"True, if every vector register has a natural \"packed\" topology.\n"
|
||||
"\n"
|
||||
"This is equivalent to has_natural_topology for 32-bit types, but "
|
||||
"generalizes it to narrower values with packed layouts too.")
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"tiles_per_vreg",
|
||||
[](MlirTpuVectorLayout self) {
|
||||
return mlirTpuVectorLayoutTilesPerVreg(self, TARGET_SHAPE);
|
||||
[](const PyTpuVectorLayout& self) {
|
||||
return mlirTpuVectorLayoutTilesPerVreg(self.layout, TARGET_SHAPE);
|
||||
},
|
||||
"How many tiles fit in each vector register.")
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"sublanes_per_tile",
|
||||
[](MlirTpuVectorLayout self) {
|
||||
return mlirTpuVectorLayoutSublanesPerTile(self, TARGET_SHAPE);
|
||||
[](const PyTpuVectorLayout& self) {
|
||||
return mlirTpuVectorLayoutSublanesPerTile(self.layout,
|
||||
TARGET_SHAPE);
|
||||
},
|
||||
"The number of sublanes necessary to store each tile.")
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"vreg_slice",
|
||||
[](MlirTpuVectorLayout self) {
|
||||
[](const PyTpuVectorLayout& self) {
|
||||
MlirTpuI64TargetTuple vreg_slice =
|
||||
mlirTpuVectorLayoutVregSlice(self, TARGET_SHAPE);
|
||||
return py::module_::import(LAYOUT_DEFS_MODULE)
|
||||
mlirTpuVectorLayoutVregSlice(self.layout, TARGET_SHAPE);
|
||||
return nb::module_::import_(LAYOUT_DEFS_MODULE)
|
||||
.attr("TargetTuple")(vreg_slice.sublane, vreg_slice.lane);
|
||||
},
|
||||
"Returns the size of a window contained in a single vreg.\n"
|
||||
@ -469,34 +482,34 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
|
||||
"rows, so only the minormost dimension can increase.")
|
||||
.def(
|
||||
"implicit_shape",
|
||||
[](MlirTpuVectorLayout self, py::sequence shape) {
|
||||
[](const PyTpuVectorLayout& self, nb::sequence shape) {
|
||||
llvm::SmallVector<int64_t> implicit_shape_vec =
|
||||
sequenceToSmallVector<int64_t>(shape);
|
||||
MlirTpuI64ArrayRef implicit_shape =
|
||||
mlirTpuVectorLayoutImplicitShape(
|
||||
self,
|
||||
self.layout,
|
||||
{implicit_shape_vec.data(), implicit_shape_vec.size()});
|
||||
py::tuple ret = toPyTuple(implicit_shape.ptr, implicit_shape.size);
|
||||
nb::tuple ret = toPyTuple(implicit_shape.ptr, implicit_shape.size);
|
||||
free(implicit_shape.ptr);
|
||||
return ret;
|
||||
},
|
||||
py::arg("shape"))
|
||||
nb::arg("shape"))
|
||||
.def(
|
||||
"tile_array_shape",
|
||||
[](MlirTpuVectorLayout self, py::sequence shape) {
|
||||
[](const PyTpuVectorLayout& self, nb::sequence shape) {
|
||||
llvm::SmallVector<int64_t> tile_array_shape_vec =
|
||||
sequenceToSmallVector<int64_t>(shape);
|
||||
MlirTpuI64ArrayRef tile_array_shape =
|
||||
mlirTpuVectorLayoutTileArrayShape(
|
||||
self,
|
||||
self.layout,
|
||||
{tile_array_shape_vec.data(), tile_array_shape_vec.size()},
|
||||
TARGET_SHAPE);
|
||||
py::tuple ret =
|
||||
nb::tuple ret =
|
||||
toPyTuple(tile_array_shape.ptr, tile_array_shape.size);
|
||||
free(tile_array_shape.ptr);
|
||||
return ret;
|
||||
},
|
||||
py::arg("shape"),
|
||||
nb::arg("shape"),
|
||||
"Returns the shape of an ndarray of vregs needed to represent a "
|
||||
"value.\n"
|
||||
"\n"
|
||||
@ -511,34 +524,33 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
|
||||
" shape: The shape of the ndarray to tile.")
|
||||
.def(
|
||||
"tile_data_bounds",
|
||||
[](MlirTpuVectorLayout self, py::sequence shape, py::sequence ixs,
|
||||
std::variant<bool, py::tuple> allow_replicated) {
|
||||
[](const PyTpuVectorLayout& self, nb::sequence shape,
|
||||
nb::sequence ixs, std::variant<bool, nb::tuple> allow_replicated) {
|
||||
llvm::SmallVector<int64_t> shape_vec =
|
||||
sequenceToSmallVector<int64_t>(shape);
|
||||
llvm::SmallVector<int64_t> ixs_vec =
|
||||
sequenceToSmallVector<int64_t>(ixs);
|
||||
if (shape_vec.size() != ixs_vec.size()) {
|
||||
throw py::value_error(
|
||||
throw nb::value_error(
|
||||
"Expected shape and ixs to have the same size");
|
||||
}
|
||||
return std::visit(
|
||||
[&](auto ar) {
|
||||
if constexpr (std::is_same_v<decltype(ar), bool>) {
|
||||
return mlirTpuVectorLayoutTileDataBounds(
|
||||
self, getDefaultContext(), shape_vec.data(),
|
||||
self.layout, getDefaultContext(), shape_vec.data(),
|
||||
ixs_vec.data(), shape_vec.size(), TARGET_SHAPE,
|
||||
{ar, ar});
|
||||
} else {
|
||||
return mlirTpuVectorLayoutTileDataBounds(
|
||||
self, getDefaultContext(), shape_vec.data(),
|
||||
self.layout, getDefaultContext(), shape_vec.data(),
|
||||
ixs_vec.data(), shape_vec.size(), TARGET_SHAPE,
|
||||
{ar[0].template cast<bool>(),
|
||||
ar[1].template cast<bool>()});
|
||||
{nb::cast<bool>(ar[0]), nb::cast<bool>(ar[1])});
|
||||
}
|
||||
},
|
||||
allow_replicated);
|
||||
},
|
||||
py::arg("shape"), py::arg("ixs"), py::arg("allow_replicated") = false,
|
||||
nb::arg("shape"), nb::arg("ixs"), nb::arg("allow_replicated") = false,
|
||||
"Returns the bounds of the given tile that hold useful data.\n"
|
||||
"\n"
|
||||
"Arguments:\n"
|
||||
@ -556,19 +568,19 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
|
||||
"within the tile selected by idx.")
|
||||
.def(
|
||||
"generalizes",
|
||||
[](MlirTpuVectorLayout self, MlirTpuVectorLayout other,
|
||||
std::optional<py::sequence> shape) {
|
||||
[](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other,
|
||||
std::optional<nb::sequence> shape) {
|
||||
if (shape) {
|
||||
llvm::SmallVector<int64_t> shape_vec =
|
||||
sequenceToSmallVector<int64_t>(*shape);
|
||||
return mlirTpuVectorLayoutGeneralizes(
|
||||
self, other, {shape_vec.data(), shape_vec.size()},
|
||||
TARGET_SHAPE);
|
||||
self.layout, other.layout,
|
||||
{shape_vec.data(), shape_vec.size()}, TARGET_SHAPE);
|
||||
}
|
||||
return mlirTpuVectorLayoutGeneralizes(self, other, {nullptr, 0},
|
||||
TARGET_SHAPE);
|
||||
return mlirTpuVectorLayoutGeneralizes(self.layout, other.layout,
|
||||
{nullptr, 0}, TARGET_SHAPE);
|
||||
},
|
||||
py::arg("other"), py::arg("shape") = std::nullopt,
|
||||
nb::arg("other"), nb::arg("shape").none() = std::nullopt,
|
||||
"Returns True if the other layout is a special case of this one.\n"
|
||||
"\n"
|
||||
"In here, other is considered \"a special case\" when the set of "
|
||||
@ -595,19 +607,19 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
|
||||
"does not hold the other way around for some shapes.")
|
||||
.def(
|
||||
"equivalent_to",
|
||||
[](MlirTpuVectorLayout self, MlirTpuVectorLayout other,
|
||||
std::optional<py::sequence> shape) {
|
||||
[](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other,
|
||||
std::optional<nb::sequence> shape) {
|
||||
if (shape) {
|
||||
llvm::SmallVector<int64_t> shape_vec =
|
||||
sequenceToSmallVector<int64_t>(*shape);
|
||||
return mlirTpuVectorLayoutEquivalentTo(
|
||||
self, other, {shape_vec.data(), shape_vec.size()},
|
||||
TARGET_SHAPE);
|
||||
self.layout, other.layout,
|
||||
{shape_vec.data(), shape_vec.size()}, TARGET_SHAPE);
|
||||
}
|
||||
return mlirTpuVectorLayoutEquivalentTo(self, other, {nullptr, 0},
|
||||
TARGET_SHAPE);
|
||||
return mlirTpuVectorLayoutEquivalentTo(self.layout, other.layout,
|
||||
{nullptr, 0}, TARGET_SHAPE);
|
||||
},
|
||||
py::arg("other"), py::arg("shape") = std::nullopt,
|
||||
nb::arg("other"), nb::arg("shape").none() = std::nullopt,
|
||||
"Returns True if the two layouts are equivalent.\n"
|
||||
"\n"
|
||||
"That is, when all potential vector entries where the value can be "
|
||||
@ -619,55 +631,65 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
|
||||
" shape: An optional shape of the vector to which both layouts "
|
||||
"apply. More layouts are considered equivalent when the shape is "
|
||||
"specified. Also see the docstring of the generalizes method.")
|
||||
.def("__eq__", mlirTpuVectorLayoutEquals)
|
||||
.def("__repr__",
|
||||
[](MlirTpuVectorLayout self) {
|
||||
std::string str;
|
||||
mlirTpuVectorLayoutPrint(self, printToString, &str);
|
||||
return str;
|
||||
});
|
||||
.def("__eq__",
|
||||
[](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other) {
|
||||
return mlirTpuVectorLayoutEquals(self.layout, other.layout);
|
||||
})
|
||||
.def("__repr__", [](const PyTpuVectorLayout& self) {
|
||||
std::string str;
|
||||
mlirTpuVectorLayoutPrint(self.layout, printToString, &str);
|
||||
return str;
|
||||
});
|
||||
|
||||
// TODO(tlongeri): Can we make the first parameter a VectorType?
|
||||
m.def("assemble",
|
||||
[](const MlirType ty, MlirTpuVectorLayout layout,
|
||||
// TODO(tlongeri): Remove py::array::c_style, I only added it because
|
||||
// I couldn't find a simple way to iterate over array data, but it
|
||||
// causes yet another unnecessary copy.
|
||||
py::array_t<PyObject*, py::array::c_style> np_arr) -> MlirOperation {
|
||||
[](const MlirType ty, const PyTpuVectorLayout& layout,
|
||||
nb::object np_arr_obj) -> MlirOperation {
|
||||
// TODO(tlongeri): Remove nb::array::c_style, I only added it because
|
||||
// I couldn't find a simple way to iterate over array data, but it
|
||||
// causes yet another unnecessary copy.
|
||||
xla::nb_numpy_ndarray np_arr = xla::nb_numpy_ndarray::ensure(
|
||||
np_arr_obj, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED);
|
||||
if (!mlirTypeIsAVector(ty)) {
|
||||
throw py::type_error("Expected vector type");
|
||||
throw nb::type_error("Expected vector type");
|
||||
}
|
||||
llvm::SmallVector<MlirValue> vals(np_arr.size());
|
||||
for (int64_t i = 0; i < np_arr.size(); ++i) {
|
||||
vals.data()[i] = py::cast<MlirValue>(py::handle(np_arr.data()[i]));
|
||||
vals.data()[i] = nb::cast<MlirValue>(nb::handle(
|
||||
reinterpret_cast<PyObject* const*>(np_arr.data())[i]));
|
||||
}
|
||||
llvm::SmallVector<int64_t> shape(np_arr.ndim());
|
||||
for (int64_t i = 0; i < np_arr.ndim(); ++i) {
|
||||
shape.data()[i] = np_arr.shape()[i];
|
||||
}
|
||||
return mlirTpuAssemble(
|
||||
getDefaultInsertionPoint(), ty, layout,
|
||||
getDefaultInsertionPoint(), ty, layout.layout,
|
||||
MlirTpuValueArray{MlirTpuI64ArrayRef{shape.data(), shape.size()},
|
||||
vals.data()},
|
||||
TARGET_SHAPE);
|
||||
});
|
||||
m.def("disassemble", [](MlirTpuVectorLayout layout, MlirValue val) {
|
||||
m.def("disassemble", [](const PyTpuVectorLayout& layout, MlirValue val) {
|
||||
DiagnosticCapture diag_capture(getDefaultContext());
|
||||
MlirTpuValueArray val_arr = mlirTpuDisassemble(getDefaultInsertionPoint(),
|
||||
layout, val, TARGET_SHAPE);
|
||||
MlirTpuValueArray val_arr = mlirTpuDisassemble(
|
||||
getDefaultInsertionPoint(), layout.layout, val, TARGET_SHAPE);
|
||||
if (val_arr.vals == nullptr) {
|
||||
diag_capture.throwIfError();
|
||||
throw py::value_error("Failed to disassemble");
|
||||
throw nb::value_error("Failed to disassemble");
|
||||
}
|
||||
py::array_t<PyObject*> np_vals(
|
||||
llvm::ArrayRef<int64_t>{val_arr.shape.ptr, val_arr.shape.size});
|
||||
xla::nb_numpy_ndarray np_vals(
|
||||
/*dtype=*/xla::nb_dtype("O"),
|
||||
/*shape=*/
|
||||
absl::Span<int64_t const>(val_arr.shape.ptr, val_arr.shape.size),
|
||||
/*strides=*/std::nullopt);
|
||||
for (ssize_t i = 0; i < np_vals.size(); ++i) {
|
||||
np_vals.mutable_data()[i] = py::cast(val_arr.vals[i]).release().ptr();
|
||||
reinterpret_cast<PyObject**>(np_vals.mutable_data())[i] =
|
||||
nb::cast(val_arr.vals[i]).release().ptr();
|
||||
}
|
||||
free(val_arr.shape.ptr);
|
||||
free(val_arr.vals);
|
||||
return np_vals;
|
||||
});
|
||||
|
||||
m.def("apply_layout_op",
|
||||
[](int hardware_generation, const MlirOperation c_op) {
|
||||
DiagnosticCapture diag_capture(getDefaultContext());
|
||||
@ -678,25 +700,27 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
|
||||
throw std::runtime_error("applyLayoutOp failed");
|
||||
}
|
||||
});
|
||||
m.def("relayout",
|
||||
[](MlirValue v, MlirTpuVectorLayout src, MlirTpuVectorLayout dst,
|
||||
MlirTpuApplyVectorLayoutContext apply_layout_ctx) {
|
||||
DiagnosticCapture diag_capture(getDefaultContext());
|
||||
MlirValue new_v = mlirTpuRelayout(getDefaultInsertionPoint(), v, src,
|
||||
dst, apply_layout_ctx);
|
||||
if (new_v.ptr == nullptr) {
|
||||
diag_capture.throwIfError();
|
||||
throw py::value_error("Failed to relayout");
|
||||
}
|
||||
return new_v;
|
||||
});
|
||||
py::register_exception_translator([](std::exception_ptr p) {
|
||||
try {
|
||||
if (p) std::rethrow_exception(p);
|
||||
} catch (const NotImplementedException& e) {
|
||||
PyErr_SetString(PyExc_NotImplementedError, e.what());
|
||||
m.def("relayout", [](MlirValue v, const PyTpuVectorLayout& src,
|
||||
const PyTpuVectorLayout& dst,
|
||||
MlirTpuApplyVectorLayoutContext apply_layout_ctx) {
|
||||
DiagnosticCapture diag_capture(getDefaultContext());
|
||||
MlirValue new_v = mlirTpuRelayout(getDefaultInsertionPoint(), v, src.layout,
|
||||
dst.layout, apply_layout_ctx);
|
||||
if (new_v.ptr == nullptr) {
|
||||
diag_capture.throwIfError();
|
||||
throw nb::value_error("Failed to relayout");
|
||||
}
|
||||
return new_v;
|
||||
});
|
||||
nb::register_exception_translator(
|
||||
[](const std::exception_ptr& p, void*) {
|
||||
try {
|
||||
if (p) std::rethrow_exception(p);
|
||||
} catch (const NotImplementedException& e) {
|
||||
PyErr_SetString(PyExc_NotImplementedError, e.what());
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
|
||||
m.def(
|
||||
"register_dialect",
|
||||
@ -707,21 +731,25 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
|
||||
mlirDialectHandleLoadDialect(dialect, context);
|
||||
}
|
||||
},
|
||||
py::arg("context"), py::arg("load") = true);
|
||||
nb::arg("context"), nb::arg("load") = true);
|
||||
|
||||
m.def("private_is_tiled_layout", [](MlirAttribute attr) {
|
||||
return mlirTPUAttributeIsATiledLayoutAttr(attr);
|
||||
});
|
||||
m.def("private_get_tiles", [](MlirAttribute attr) -> py::object {
|
||||
m.def("private_get_tiles", [](MlirAttribute attr) -> nb::object {
|
||||
MlirAttribute encoded_tiles = mlirTPUTiledLayoutAttrGetTiles(attr);
|
||||
py::tuple py_tiles(mlirArrayAttrGetNumElements(encoded_tiles));
|
||||
nb::tuple py_tiles = nb::steal<nb::tuple>(
|
||||
PyTuple_New(mlirArrayAttrGetNumElements(encoded_tiles)));
|
||||
for (intptr_t i = 0; i < mlirArrayAttrGetNumElements(encoded_tiles); ++i) {
|
||||
MlirAttribute tile = mlirArrayAttrGetElement(encoded_tiles, i);
|
||||
py::tuple py_tile(mlirDenseArrayGetNumElements(tile));
|
||||
nb::tuple py_tile =
|
||||
nb::steal<nb::tuple>(PyTuple_New(mlirDenseArrayGetNumElements(tile)));
|
||||
for (intptr_t j = 0; j < mlirDenseArrayGetNumElements(tile); ++j) {
|
||||
py_tile[j] = mlirDenseI64ArrayGetElement(tile, j);
|
||||
PyTuple_SET_ITEM(
|
||||
py_tile.ptr(), j,
|
||||
nb::cast(mlirDenseI64ArrayGetElement(tile, j)).release().ptr());
|
||||
}
|
||||
py_tiles[i] = py_tile;
|
||||
PyTuple_SET_ITEM(py_tiles.ptr(), i, py_tile.release().ptr());
|
||||
}
|
||||
return py_tiles;
|
||||
});
|
||||
@ -737,7 +765,7 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
|
||||
m.def("private_replace_all_uses_with", [](MlirOperation op,
|
||||
std::vector<MlirValue> vals) {
|
||||
if (vals.size() != mlirOperationGetNumResults(op)) {
|
||||
throw py::value_error("length mismatch in replace_all_uses_with");
|
||||
throw nb::value_error("length mismatch in replace_all_uses_with");
|
||||
}
|
||||
for (int i = 0; i < vals.size(); ++i) {
|
||||
mlirValueReplaceAllUsesOfWith(mlirOperationGetResult(op, i), vals[i]);
|
||||
@ -747,7 +775,7 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
|
||||
[](MlirValue old, MlirValue new_val, MlirOperation except) {
|
||||
for (intptr_t i = 0; i < mlirOperationGetNumOperands(except); ++i) {
|
||||
if (mlirValueEqual(mlirOperationGetOperand(except, i), new_val)) {
|
||||
throw py::value_error("new val already used in except");
|
||||
throw nb::value_error("new val already used in except");
|
||||
}
|
||||
}
|
||||
mlirValueReplaceAllUsesOfWith(old, new_val);
|
||||
@ -785,7 +813,7 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
|
||||
});
|
||||
m.def("private_move_all_regions", [](MlirOperation src, MlirOperation dst) {
|
||||
if (mlirOperationGetNumRegions(src) != mlirOperationGetNumRegions(dst)) {
|
||||
throw py::value_error(
|
||||
throw nb::value_error(
|
||||
"Region counts do not match in src operation and dst operations");
|
||||
}
|
||||
for (intptr_t i = 0; i < mlirOperationGetNumRegions(src); ++i) {
|
||||
|
@ -16,13 +16,13 @@ limitations under the License.
|
||||
#include <optional>
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "pybind11/detail/common.h"
|
||||
#include "mlir/Bindings/Python/NanobindAdaptors.h"
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "jaxlib/triton/triton_dialect_capi.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace nb = nanobind;
|
||||
|
||||
PYBIND11_MODULE(_triton_ext, m, py::mod_gil_not_used()) {
|
||||
NB_MODULE(_triton_ext, m) {
|
||||
//
|
||||
// Dialects.
|
||||
//
|
||||
@ -36,20 +36,20 @@ PYBIND11_MODULE(_triton_ext, m, py::mod_gil_not_used()) {
|
||||
mlirDialectHandleLoadDialect(dialect, context);
|
||||
}
|
||||
},
|
||||
py::arg("context"), py::arg("load") = true);
|
||||
nb::arg("context"), nb::arg("load") = true);
|
||||
|
||||
//
|
||||
// Types.
|
||||
//
|
||||
|
||||
mlir::python::adaptors::mlir_type_subclass(m, "PointerType",
|
||||
mlir::python::nanobind_adaptors::mlir_type_subclass(m, "PointerType",
|
||||
mlirTritonIsAPointer)
|
||||
.def_classmethod(
|
||||
"get",
|
||||
[](py::object cls, MlirType pointee_type, int64_t address_space) {
|
||||
[](nb::object cls, MlirType pointee_type, int64_t address_space) {
|
||||
return cls(mlirTritonPointerTypeGet(pointee_type, address_space));
|
||||
},
|
||||
py::arg("cls"), py::arg("pointee_type"), py::arg("address_space"),
|
||||
nb::arg("cls"), nb::arg("pointee_type"), nb::arg("address_space"),
|
||||
"Creates a PointerType type.")
|
||||
.def_property_readonly("pointee_type", [](MlirType self) {
|
||||
return mlirTritonPointerTypeGetPointeeType(self);
|
||||
|
Loading…
x
Reference in New Issue
Block a user