Migrate JAX MLIR Python dialect extensions to nanobind.

Now that https://github.com/llvm/llvm-project/pull/117922 has landed upstream, we can work towards removing our last uses of pybind11.

PiperOrigin-RevId: 705872751
This commit is contained in:
Peter Hawkins 2024-12-13 07:07:44 -08:00 committed by jax authors
parent 5a3fa500b5
commit 64eae324ee
8 changed files with 307 additions and 284 deletions

View File

@ -102,7 +102,7 @@ class LoweringRuleContext:
@dataclasses.dataclass
class LoweringResult:
"""Keeps pybind11 objects alive."""
"""Keeps python objects alive."""
module: ir.Module
grid: tuple[int, ...]

View File

@ -27,11 +27,6 @@ namespace jax {
//
// Failing statuses become Python exceptions; OK Status() becomes None.
//
// Given there can be only a single global pybind11 type_caster for the
// `absl::Status` type, and given XLA wants a custom exception being raised,
// we use a dedicated helper to implement this feature without relying on a
// global `type_caster`.
//
// For example:
//
// - Functions without arguments:

View File

@ -25,9 +25,9 @@ limitations under the License.
namespace jax {
// See kernel_pybind11_helpers.h for info on descriptor objects. We separate out
// the functionality that doesn't require pybind11 for building CUDA libraries,
// since older versions nvcc don't seem to be able to compile pybind11.
// See kernel_nanobind_helpers.h for info on descriptor objects. We separate out
// the functionality that doesn't require nanobind for building CUDA libraries,
// since older versions nvcc don't seem to be able to compile nanobind.
// Packs a descriptor object into a byte string.
template <typename T>

View File

@ -158,6 +158,7 @@ py_extension(
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
"@nanobind",
],
)
@ -177,8 +178,10 @@ py_extension(
"@com_google_absl//absl/log:check",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
"@nanobind",
"@xla//xla/python:nb_numpy",
"@xla//xla/tsl/python/lib/core:numpy",
],
)
@ -203,8 +206,8 @@ pybind_extension(
":jaxlib_mlir_capi_shared_library",
"//jaxlib/triton:triton_dialect_capi_headers",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
"@nanobind",
],
)
@ -266,9 +269,9 @@ py_extension(
"@llvm-project//mlir:CAPISCFHeaders",
"@llvm-project//mlir:CAPITransformsHeaders",
"@llvm-project//mlir:CAPIVectorHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@local_config_python//:headers",
"@pybind11",
"@nanobind",
"@shardy//shardy/integrations/c:sdy_capi_headers",
],
)

View File

@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// clang-format: off
// pybind11 must be included before mlir/Bindings/Python/PybindAdaptors.h,
// otherwise this code will not build on Windows.
#include "pybind11/pybind11.h"
// clang-format: on
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h" // IWYU pragma: keep
#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep
#include "nanobind/nanobind.h"
#include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h"
PYBIND11_MODULE(_mosaic_gpu_ext, m, py::mod_gil_not_used()) {
namespace nb = nanobind;
NB_MODULE(_mosaic_gpu_ext, m) {
m.def(
"register_dialect",
[](MlirContext context, bool load) {
@ -33,5 +30,5 @@ PYBIND11_MODULE(_mosaic_gpu_ext, m, py::mod_gil_not_used()) {
mlirDialectHandleLoadDialect(dialect, context);
}
},
py::arg("context"), py::arg("load") = true);
nb::arg("context"), nb::arg("load") = true);
}

View File

@ -1,6 +1,6 @@
// Registers MLIR dialects used by JAX.
// This module is called by mlir/__init__.py during initialization.
#include <pybind11/pybind11.h>
#include <nanobind/nanobind.h>
#include "mlir-c/Dialect/Arith.h"
#include "mlir-c/Dialect/Func.h"
@ -13,17 +13,17 @@
#include "mlir-c/Dialect/SCF.h"
#include "mlir-c/Dialect/Vector.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "shardy/integrations/c/passes.h"
namespace py = pybind11;
namespace nb = nanobind;
#define REGISTER_DIALECT(name) \
MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \
mlirDialectHandleInsertDialect(name##_dialect, registry)
PYBIND11_MODULE(register_jax_dialects, m, py::mod_gil_not_used()) {
NB_MODULE(register_jax_dialects, m) {
m.doc() = "Registers upstream MLIR dialects used by JAX.";
m.def("register_dialects", [](MlirDialectRegistry registry) {

View File

@ -27,12 +27,6 @@ limitations under the License.
#include <variant>
#include <vector>
// clang-format: off
// pybind11 must be included before mlir/Bindings/Python/PybindAdaptors.h,
// otherwise this code will not build on Windows.
#include "pybind11/pybind11.h"
// clang-format: on
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
@ -46,19 +40,24 @@ limitations under the License.
#include "mlir-c/Dialect/Func.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "pybind11/attr.h"
#include "pybind11/cast.h"
#include "pybind11/detail/common.h"
#include "pybind11/numpy.h"
#include "pybind11/pytypes.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep
#include "nanobind/nanobind.h"
#include "nanobind/stl/optional.h" // IWYU pragma: keep
#include "nanobind/stl/pair.h" // IWYU pragma: keep
#include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/variant.h" // IWYU pragma: keep
#include "nanobind/stl/vector.h" // IWYU pragma: keep
#include "absl/log/check.h"
#include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h"
#include "xla/python/nb_numpy.h"
#include "xla/tsl/python/lib/core/numpy.h"
// TODO(tlongeri): Can I add my own return type annotations to functions?
// TODO(tlongeri): I don't understand why MLIR uses the C API to implement
// Python bindings. Do we have a reason to do that?
namespace nb = nanobind;
namespace {
constexpr const char LAYOUT_DEFS_MODULE[] =
"jax.jaxlib.mosaic.python.layout_defs";
@ -75,20 +74,21 @@ constexpr MlirTpuI64TargetTuple TARGET_SHAPE{8, 128};
class NotImplementedException : public std::runtime_error {
using runtime_error::runtime_error;
};
} // namespace
template <>
struct py::detail::type_caster<MlirTpuImplicitDim> {
PYBIND11_TYPE_CASTER(MlirTpuImplicitDim, const_name("ImplicitDim | None"));
struct nb::detail::type_caster<MlirTpuImplicitDim> {
NB_TYPE_CASTER(MlirTpuImplicitDim, const_name("ImplicitDim | None"));
bool load(handle src, bool) {
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
if (src.is_none()) {
value = MlirTpuImplicitDimNone;
return true;
}
auto implicit_dim_cls =
py::module_::import(LAYOUT_DEFS_MODULE).attr("ImplicitDim");
if (!py::isinstance(src, implicit_dim_cls)) {
nb::module_::import_(LAYOUT_DEFS_MODULE).attr("ImplicitDim");
if (!nb::isinstance(src, implicit_dim_cls)) {
return false;
}
if (src.is(implicit_dim_cls.attr("MINOR"))) {
@ -96,36 +96,36 @@ struct py::detail::type_caster<MlirTpuImplicitDim> {
} else if (src.is(implicit_dim_cls.attr("SECOND_MINOR"))) {
value = MlirTpuImplicitDimSecondMinor;
} else {
throw py::value_error();
return false;
}
return true;
}
static handle cast(MlirTpuImplicitDim implicit_dim,
return_value_policy /* policy */, handle /* parent */) {
static handle from_cpp(MlirTpuImplicitDim implicit_dim, rv_policy policy,
cleanup_list* cleanup) noexcept {
auto implicit_dim_cls =
py::module_::import(LAYOUT_DEFS_MODULE).attr("ImplicitDim");
nb::module_::import_(LAYOUT_DEFS_MODULE).attr("ImplicitDim");
switch (implicit_dim) {
case MlirTpuImplicitDimNone:
return py::none().release();
return nb::none().release();
case MlirTpuImplicitDimMinor:
return static_cast<py::object>(implicit_dim_cls.attr("MINOR"))
return static_cast<nb::object>(implicit_dim_cls.attr("MINOR"))
.release();
case MlirTpuImplicitDimSecondMinor:
return static_cast<py::object>(implicit_dim_cls.attr("SECOND_MINOR"))
return static_cast<nb::object>(implicit_dim_cls.attr("SECOND_MINOR"))
.release();
}
}
};
template <>
struct py::detail::type_caster<MlirTpuDirection> {
PYBIND11_TYPE_CASTER(MlirTpuDirection, const_name("Direction"));
struct nb::detail::type_caster<MlirTpuDirection> {
NB_TYPE_CASTER(MlirTpuDirection, const_name("Direction"));
bool load(handle src, bool) {
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
auto direction_cls =
py::module_::import(LAYOUT_DEFS_MODULE).attr("Direction");
if (!py::isinstance(src, direction_cls)) {
nb::module_::import_(LAYOUT_DEFS_MODULE).attr("Direction");
if (!nb::isinstance(src, direction_cls)) {
return false;
}
if (src.is(direction_cls.attr("LANES"))) {
@ -135,26 +135,28 @@ struct py::detail::type_caster<MlirTpuDirection> {
} else if (src.is(direction_cls.attr("SUBELEMENTS"))) {
value = MlirTpuDirectionSubelements;
} else {
throw py::value_error();
return false;
}
return true;
}
static handle cast(MlirTpuDirection direction,
return_value_policy /* policy */, handle /* parent */) {
static handle from_cpp(MlirTpuDirection direction, rv_policy /* policy */,
cleanup_list* /* cleanup */) noexcept {
auto direction_cls =
py::module_::import(LAYOUT_DEFS_MODULE).attr("ImplicitDim");
nb::module_::import_(LAYOUT_DEFS_MODULE).attr("ImplicitDim");
switch (direction) {
case MlirTpuDirectionLanes:
return static_cast<py::object>(direction_cls.attr("LANES")).release();
return static_cast<nb::object>(direction_cls.attr("LANES")).release();
case MlirTpuDirectionSublanes:
return static_cast<py::object>(direction_cls.attr("SUBLANES"))
return static_cast<nb::object>(direction_cls.attr("SUBLANES"))
.release();
case MlirTpuDirectionSubelements:
return static_cast<py::object>(direction_cls.attr("SUBELEMENTS"))
return static_cast<nb::object>(direction_cls.attr("SUBELEMENTS"))
.release();
default:
throw py::value_error();
PyErr_Format(PyExc_ValueError, "Invalid MlirTpuDirection: %d",
static_cast<int>(direction));
return nb::handle();
}
}
};
@ -163,9 +165,9 @@ namespace {
// Handler for use with MLIR C API print functions. The 2nd parameter is an
// opaque pointer to "user data" that should always be a string.
void printToString(MlirStringRef c_mlir_str, void* opaque_string) {
std::string* str = static_cast<std::string*>(opaque_string);
CHECK(str != nullptr);
str->append(c_mlir_str.data, c_mlir_str.length);
std::string* str = static_cast<std::string*>(opaque_string);
CHECK(str != nullptr);
str->append(c_mlir_str.data, c_mlir_str.length);
}
class DiagnosticCapture {
@ -217,141 +219,134 @@ class DiagnosticCapture {
const MlirDiagnosticHandlerID id_;
};
template <typename T>
class Holder {};
// Holder class for MlirTpuVectorLayout, to deal properly with destruction.
// TODO(tlongeri): It would be nice to not have a seemingly unnecessary
// "pointer-to-pointer" (MlirTpuVectorLayout is basically an opaque pointer).
// But I'm not sure if that's possible since pybind expects get() to return a
// true pointer type.
template <>
class Holder<MlirTpuVectorLayout> {
public:
Holder(MlirTpuVectorLayout layout) : ptr(new MlirTpuVectorLayout(layout)) {}
Holder(MlirTpuVectorLayout* layout) : ptr(layout) {}
Holder(Holder<MlirTpuVectorLayout>&& other) = default;
~Holder() { mlirTpuVectorLayoutDestroy(*ptr); }
MlirTpuVectorLayout* get() { return ptr.get(); }
private:
std::unique_ptr<MlirTpuVectorLayout> ptr;
};
} // namespace
PYBIND11_DECLARE_HOLDER_TYPE(T, Holder<T>);
namespace {
py::object toPyLayoutOffset(int64_t offset) {
nb::object toPyLayoutOffset(int64_t offset) {
CHECK_GE(offset, -1);
if (offset == -1) {
return py::module_::import(LAYOUT_DEFS_MODULE).attr("REPLICATED");
return nb::module_::import_(LAYOUT_DEFS_MODULE).attr("REPLICATED");
} else {
return py::int_(offset);
return nb::int_(offset);
}
}
// TODO(tlongeri): Would `type_caster`s let me avoid defining all of these
// to/from functions?
int64_t offsetFromPyOffset(py::object py_offset) {
if (py::isinstance<py::int_>(py_offset)) {
int64_t offset = py::cast<py::int_>(py_offset);
int64_t offsetFromPyOffset(nb::object py_offset) {
if (nb::isinstance<nb::int_>(py_offset)) {
int64_t offset = nb::cast<int64_t>(py_offset);
if (offset < 0) {
throw py::value_error("Invalid py layout offset");
throw nb::value_error("Invalid py layout offset");
}
return offset;
} else if (py_offset.equal(
py::module_::import(LAYOUT_DEFS_MODULE).attr("REPLICATED"))) {
nb::module_::import_(LAYOUT_DEFS_MODULE).attr("REPLICATED"))) {
return -1;
} else {
throw py::type_error("Invalid layout offset type");
throw nb::type_error("Invalid layout offset type");
}
}
template <typename T>
llvm::SmallVector<T> sequenceToSmallVector(py::sequence seq) {
return llvm::map_to_vector(
seq, [](py::handle handle) { return py::cast<T>(handle); });
llvm::SmallVector<T> sequenceToSmallVector(nb::sequence seq) {
llvm::SmallVector<T> out;
out.reserve(nb::len(seq));
for (nb::handle elem : seq) {
out.push_back(nb::cast<T>(elem));
}
return out;
}
py::tuple toPyTuple(const int64_t* data, size_t count) {
py::tuple tuple(count);
nb::tuple toPyTuple(const int64_t* data, size_t count) {
nb::tuple tuple = nb::steal<nb::tuple>(PyTuple_New(count));
for (size_t i = 0; i < count; ++i) {
tuple[i] = data[i];
PyTuple_SET_ITEM(tuple.ptr(), i, nb::int_(data[i]).release().ptr());
}
return tuple;
}
py::tuple toPyTuple(MlirTpuI64TargetTuple tuple) {
return py::make_tuple(tuple.sublane, tuple.lane);
nb::tuple toPyTuple(MlirTpuI64TargetTuple tuple) {
return nb::make_tuple(tuple.sublane, tuple.lane);
}
// Unwraps the current default insertion point
// ValueError is raised if default insertion point is not set
MlirTpuInsertionPoint getDefaultInsertionPoint() {
py::object insertion_point =
py::module_::import(IR_MODULE).attr("InsertionPoint").attr("current");
py::object ref_operation = insertion_point.attr("ref_operation");
return {py::cast<MlirBlock>(insertion_point.attr("block")),
nb::object insertion_point =
nb::module_::import_(IR_MODULE).attr("InsertionPoint").attr("current");
nb::object ref_operation = insertion_point.attr("ref_operation");
return {nb::cast<MlirBlock>(insertion_point.attr("block")),
ref_operation.is_none()
? MlirOperation{nullptr}
: py::cast<MlirOperation>(insertion_point.attr("ref_operation"))};
: nb::cast<MlirOperation>(insertion_point.attr("ref_operation"))};
}
// Unwraps the current default location
// ValueError is raised if default location is not set
MlirLocation getDefaultLocation() {
return py::cast<MlirLocation>(
py::module_::import(IR_MODULE).attr("Location").attr("current"));
return nb::cast<MlirLocation>(
nb::module_::import_(IR_MODULE).attr("Location").attr("current"));
}
// Unwraps the current default MLIR context
// ValueError is raised if default context is not set
MlirContext getDefaultContext() {
return py::cast<MlirContext>(
py::module_::import(IR_MODULE).attr("Context").attr("current"));
return nb::cast<MlirContext>(
nb::module_::import_(IR_MODULE).attr("Context").attr("current"));
}
struct PyTpuVectorLayout {
PyTpuVectorLayout(MlirTpuVectorLayout layout) : layout(layout) {}
~PyTpuVectorLayout() { mlirTpuVectorLayoutDestroy(layout); }
PyTpuVectorLayout(const PyTpuVectorLayout&) = delete;
PyTpuVectorLayout& operator=(const PyTpuVectorLayout&) = delete;
MlirTpuVectorLayout layout;
};
} // namespace
PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
NB_MODULE(_tpu_ext, m) {
tsl::ImportNumpy();
mlirRegisterTPUPasses(); // Register all passes on load.
mlirTpuRegisterMosaicSerdePass();
py::class_<MlirTpuApplyVectorLayoutContext>(m, "ApplyVectorLayoutCtx",
py::module_local())
.def(py::init([](int hardware_generation, py::tuple target_shape,
py::tuple mxu_shape, int max_sublanes_in_scratch) {
if (target_shape.size() != 2) {
throw py::value_error("target_shape should be of length 2");
}
if (mxu_shape.size() != 2) {
throw py::value_error("mxu_shape should be of length 2");
}
return MlirTpuApplyVectorLayoutContext{
.hardware_generation = hardware_generation,
.target_shape = {target_shape[0].cast<int64_t>(),
target_shape[1].cast<int64_t>()},
.mxu_shape = {mxu_shape[0].cast<int64_t>(),
mxu_shape[1].cast<int64_t>()},
.max_sublanes_in_scratch = max_sublanes_in_scratch};
}),
py::arg("hardware_generation") = -1,
py::arg("target_shape") = toPyTuple(TARGET_SHAPE),
py::arg("mxu_shape") = py::make_tuple(128, 128),
py::arg("max_sublanes_in_scratch") = 0);
nb::class_<MlirTpuApplyVectorLayoutContext>(m, "ApplyVectorLayoutCtx")
.def(
"__init__",
[](MlirTpuApplyVectorLayoutContext* self, int hardware_generation,
nb::tuple target_shape, nb::tuple mxu_shape,
int max_sublanes_in_scratch) {
if (target_shape.size() != 2) {
throw nb::value_error("target_shape should be of length 2");
}
if (mxu_shape.size() != 2) {
throw nb::value_error("mxu_shape should be of length 2");
}
new (self) MlirTpuApplyVectorLayoutContext{
.hardware_generation = hardware_generation,
.target_shape = {nb::cast<int64_t>(target_shape[0]),
nb::cast<int64_t>(target_shape[1])},
.mxu_shape = {nb::cast<int64_t>(mxu_shape[0]),
nb::cast<int64_t>(mxu_shape[1])},
.max_sublanes_in_scratch = max_sublanes_in_scratch};
},
nb::arg("hardware_generation") = -1,
nb::arg("target_shape") = toPyTuple(TARGET_SHAPE),
nb::arg("mxu_shape") = nb::make_tuple(128, 128),
nb::arg("max_sublanes_in_scratch") = 0);
py::class_<MlirTpuVregDataBounds>(m, "VRegDataBounds", py::module_local())
nb::class_<MlirTpuVregDataBounds>(m, "VRegDataBounds")
.def("mask_varies_along",
[](MlirTpuVregDataBounds self, MlirTpuDirection direction) {
return mlirTpuVregDataBoundsMaskVariesAlong(self, direction,
TARGET_SHAPE);
})
.def_property_readonly("complete",
[](MlirTpuVregDataBounds self) {
return mlirTpuVregDataBoundsIsComplete(
self, TARGET_SHAPE);
})
.def_prop_ro("complete",
[](MlirTpuVregDataBounds self) {
return mlirTpuVregDataBoundsIsComplete(self, TARGET_SHAPE);
})
.def("get_vector_mask",
[](MlirTpuVregDataBounds self, int generation) {
// TODO: Does this work? Test in Python
@ -370,63 +365,79 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
// TODO(tlongeri): More precise argument type annotations. There currently
// seems to be no way to define your own?
py::class_<MlirTpuVectorLayout, Holder<MlirTpuVectorLayout>>(
m, "VectorLayout", py::module_local())
.def(py::init([](int bitwidth, py::tuple offsets, py::tuple tiling,
MlirTpuImplicitDim implicit_dim) {
if (offsets.size() != 2) {
throw py::value_error("Offsets should be of length 2");
}
if (tiling.size() != 2) {
throw py::value_error("Tiling should be of length 2");
}
MlirTpuVectorLayout layout = mlirTpuVectorLayoutCreate(
bitwidth,
{offsetFromPyOffset(offsets[0]),
offsetFromPyOffset(offsets[1])},
{tiling[0].cast<int64_t>(), tiling[1].cast<int64_t>()},
implicit_dim);
if (!mlirTpuVectorLayoutIsValid(layout, TARGET_SHAPE)) {
throw py::value_error("Layout not valid for target shape");
nb::class_<PyTpuVectorLayout>(m, "VectorLayout")
.def(
"__init__",
[](PyTpuVectorLayout* self, int bitwidth, nb::tuple offsets,
nb::tuple tiling, MlirTpuImplicitDim implicit_dim) {
if (offsets.size() != 2) {
throw nb::value_error("Offsets should be of length 2");
}
return layout;
}),
py::arg("bitwidth"), py::arg("offsets"), py::arg("tiling"),
py::arg("implicit_dim"))
.def_property_readonly("bitwidth", mlirTpuVectorLayoutGetBitwidth,
"The bitwidth of the stored values.")
.def_property_readonly(
if (tiling.size() != 2) {
throw nb::value_error("Tiling should be of length 2");
}
MlirTpuVectorLayout layout = mlirTpuVectorLayoutCreate(
bitwidth,
{offsetFromPyOffset(offsets[0]),
offsetFromPyOffset(offsets[1])},
{nb::cast<int64_t>(tiling[0]), nb::cast<int64_t>(tiling[1])},
implicit_dim);
if (!mlirTpuVectorLayoutIsValid(layout, TARGET_SHAPE)) {
throw nb::value_error("Layout not valid for target shape");
}
new (self) PyTpuVectorLayout(layout);
},
nb::arg("bitwidth"), nb::arg("offsets"), nb::arg("tiling"),
nb::arg("implicit_dim").none())
.def_prop_ro(
"bitwidth",
[](const PyTpuVectorLayout& self) {
return mlirTpuVectorLayoutGetBitwidth(self.layout);
},
"The bitwidth of the stored values.")
.def_prop_ro(
"offsets",
[](MlirTpuVectorLayout self) {
MlirTpuLayoutOffsets offsets = mlirTpuVectorLayoutGetOffsets(self);
return py::make_tuple(toPyLayoutOffset(offsets.sublane),
[](const PyTpuVectorLayout& self) {
MlirTpuLayoutOffsets offsets =
mlirTpuVectorLayoutGetOffsets(self.layout);
return nb::make_tuple(toPyLayoutOffset(offsets.sublane),
toPyLayoutOffset(offsets.lane));
},
"The coordinates of the first valid element. If an offset is "
"REPLICATED, then any offset is valid as the value does not vary "
"across sublanes or lanes respectively.")
.def_property_readonly(
.def_prop_ro(
"tiling",
[](MlirTpuVectorLayout self) {
return toPyTuple(mlirTpuVectorLayoutGetTiling(self));
[](const PyTpuVectorLayout& self) {
return toPyTuple(mlirTpuVectorLayoutGetTiling(self.layout));
},
"The tiling used to lay out values (see the XLA docs). For values of "
"bitwidth < 32, an implicit (32 // bitwidth, 1) tiling is appended "
"to the one specified as an attribute.")
.def_property_readonly(
"implicit_dim", mlirTpuVectorLayoutGetImplicitDim,
.def_prop_ro(
"implicit_dim",
[](const PyTpuVectorLayout& self) {
return mlirTpuVectorLayoutGetImplicitDim(self.layout);
},
"If specified, the value has an implicit dim inserted in either "
"minormost or second minormost position.")
.def_property_readonly(
"packing", mlirTpuVectorLayoutGetPacking,
.def_prop_ro(
"packing",
[](const PyTpuVectorLayout& self) {
return mlirTpuVectorLayoutGetPacking(self.layout);
},
"Returns the number of values stored in a vreg entry.")
.def_property_readonly(
"layout_rank", mlirTpuVectorLayoutGetLayoutRank,
.def_prop_ro(
"layout_rank",
[](const PyTpuVectorLayout& self) {
return mlirTpuVectorLayoutGetLayoutRank(self.layout);
},
"The number of minormost dimensions tiled by this layout.")
.def_property_readonly(
.def_prop_ro(
"has_natural_topology",
[](MlirTpuVectorLayout self) {
return mlirTpuVectorLayoutHasNaturalTopology(self, TARGET_SHAPE);
[](const PyTpuVectorLayout& self) {
return mlirTpuVectorLayoutHasNaturalTopology(self.layout,
TARGET_SHAPE);
},
"True, if every vector register has a layout without jumps.\n"
"\n"
@ -434,33 +445,35 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
"always leads to a contiguous traversal of the (second) minormost "
"dimension of data. This is only true for 32-bit types, since "
"narrower types use two level tiling.")
.def_property_readonly(
.def_prop_ro(
"has_native_tiling",
[](MlirTpuVectorLayout self) {
return mlirTpuVectorLayoutHasNativeTiling(self, TARGET_SHAPE);
[](const PyTpuVectorLayout& self) {
return mlirTpuVectorLayoutHasNativeTiling(self.layout,
TARGET_SHAPE);
},
"True, if every vector register has a natural \"packed\" topology.\n"
"\n"
"This is equivalent to has_natural_topology for 32-bit types, but "
"generalizes it to narrower values with packed layouts too.")
.def_property_readonly(
.def_prop_ro(
"tiles_per_vreg",
[](MlirTpuVectorLayout self) {
return mlirTpuVectorLayoutTilesPerVreg(self, TARGET_SHAPE);
[](const PyTpuVectorLayout& self) {
return mlirTpuVectorLayoutTilesPerVreg(self.layout, TARGET_SHAPE);
},
"How many tiles fit in each vector register.")
.def_property_readonly(
.def_prop_ro(
"sublanes_per_tile",
[](MlirTpuVectorLayout self) {
return mlirTpuVectorLayoutSublanesPerTile(self, TARGET_SHAPE);
[](const PyTpuVectorLayout& self) {
return mlirTpuVectorLayoutSublanesPerTile(self.layout,
TARGET_SHAPE);
},
"The number of sublanes necessary to store each tile.")
.def_property_readonly(
.def_prop_ro(
"vreg_slice",
[](MlirTpuVectorLayout self) {
[](const PyTpuVectorLayout& self) {
MlirTpuI64TargetTuple vreg_slice =
mlirTpuVectorLayoutVregSlice(self, TARGET_SHAPE);
return py::module_::import(LAYOUT_DEFS_MODULE)
mlirTpuVectorLayoutVregSlice(self.layout, TARGET_SHAPE);
return nb::module_::import_(LAYOUT_DEFS_MODULE)
.attr("TargetTuple")(vreg_slice.sublane, vreg_slice.lane);
},
"Returns the size of a window contained in a single vreg.\n"
@ -469,34 +482,34 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
"rows, so only the minormost dimension can increase.")
.def(
"implicit_shape",
[](MlirTpuVectorLayout self, py::sequence shape) {
[](const PyTpuVectorLayout& self, nb::sequence shape) {
llvm::SmallVector<int64_t> implicit_shape_vec =
sequenceToSmallVector<int64_t>(shape);
MlirTpuI64ArrayRef implicit_shape =
mlirTpuVectorLayoutImplicitShape(
self,
self.layout,
{implicit_shape_vec.data(), implicit_shape_vec.size()});
py::tuple ret = toPyTuple(implicit_shape.ptr, implicit_shape.size);
nb::tuple ret = toPyTuple(implicit_shape.ptr, implicit_shape.size);
free(implicit_shape.ptr);
return ret;
},
py::arg("shape"))
nb::arg("shape"))
.def(
"tile_array_shape",
[](MlirTpuVectorLayout self, py::sequence shape) {
[](const PyTpuVectorLayout& self, nb::sequence shape) {
llvm::SmallVector<int64_t> tile_array_shape_vec =
sequenceToSmallVector<int64_t>(shape);
MlirTpuI64ArrayRef tile_array_shape =
mlirTpuVectorLayoutTileArrayShape(
self,
self.layout,
{tile_array_shape_vec.data(), tile_array_shape_vec.size()},
TARGET_SHAPE);
py::tuple ret =
nb::tuple ret =
toPyTuple(tile_array_shape.ptr, tile_array_shape.size);
free(tile_array_shape.ptr);
return ret;
},
py::arg("shape"),
nb::arg("shape"),
"Returns the shape of an ndarray of vregs needed to represent a "
"value.\n"
"\n"
@ -511,34 +524,33 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
" shape: The shape of the ndarray to tile.")
.def(
"tile_data_bounds",
[](MlirTpuVectorLayout self, py::sequence shape, py::sequence ixs,
std::variant<bool, py::tuple> allow_replicated) {
[](const PyTpuVectorLayout& self, nb::sequence shape,
nb::sequence ixs, std::variant<bool, nb::tuple> allow_replicated) {
llvm::SmallVector<int64_t> shape_vec =
sequenceToSmallVector<int64_t>(shape);
llvm::SmallVector<int64_t> ixs_vec =
sequenceToSmallVector<int64_t>(ixs);
if (shape_vec.size() != ixs_vec.size()) {
throw py::value_error(
throw nb::value_error(
"Expected shape and ixs to have the same size");
}
return std::visit(
[&](auto ar) {
if constexpr (std::is_same_v<decltype(ar), bool>) {
return mlirTpuVectorLayoutTileDataBounds(
self, getDefaultContext(), shape_vec.data(),
self.layout, getDefaultContext(), shape_vec.data(),
ixs_vec.data(), shape_vec.size(), TARGET_SHAPE,
{ar, ar});
} else {
return mlirTpuVectorLayoutTileDataBounds(
self, getDefaultContext(), shape_vec.data(),
self.layout, getDefaultContext(), shape_vec.data(),
ixs_vec.data(), shape_vec.size(), TARGET_SHAPE,
{ar[0].template cast<bool>(),
ar[1].template cast<bool>()});
{nb::cast<bool>(ar[0]), nb::cast<bool>(ar[1])});
}
},
allow_replicated);
},
py::arg("shape"), py::arg("ixs"), py::arg("allow_replicated") = false,
nb::arg("shape"), nb::arg("ixs"), nb::arg("allow_replicated") = false,
"Returns the bounds of the given tile that hold useful data.\n"
"\n"
"Arguments:\n"
@ -556,19 +568,19 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
"within the tile selected by idx.")
.def(
"generalizes",
[](MlirTpuVectorLayout self, MlirTpuVectorLayout other,
std::optional<py::sequence> shape) {
[](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other,
std::optional<nb::sequence> shape) {
if (shape) {
llvm::SmallVector<int64_t> shape_vec =
sequenceToSmallVector<int64_t>(*shape);
return mlirTpuVectorLayoutGeneralizes(
self, other, {shape_vec.data(), shape_vec.size()},
TARGET_SHAPE);
self.layout, other.layout,
{shape_vec.data(), shape_vec.size()}, TARGET_SHAPE);
}
return mlirTpuVectorLayoutGeneralizes(self, other, {nullptr, 0},
TARGET_SHAPE);
return mlirTpuVectorLayoutGeneralizes(self.layout, other.layout,
{nullptr, 0}, TARGET_SHAPE);
},
py::arg("other"), py::arg("shape") = std::nullopt,
nb::arg("other"), nb::arg("shape").none() = std::nullopt,
"Returns True if the other layout is a special case of this one.\n"
"\n"
"In here, other is considered \"a special case\" when the set of "
@ -595,19 +607,19 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
"does not hold the other way around for some shapes.")
.def(
"equivalent_to",
[](MlirTpuVectorLayout self, MlirTpuVectorLayout other,
std::optional<py::sequence> shape) {
[](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other,
std::optional<nb::sequence> shape) {
if (shape) {
llvm::SmallVector<int64_t> shape_vec =
sequenceToSmallVector<int64_t>(*shape);
return mlirTpuVectorLayoutEquivalentTo(
self, other, {shape_vec.data(), shape_vec.size()},
TARGET_SHAPE);
self.layout, other.layout,
{shape_vec.data(), shape_vec.size()}, TARGET_SHAPE);
}
return mlirTpuVectorLayoutEquivalentTo(self, other, {nullptr, 0},
TARGET_SHAPE);
return mlirTpuVectorLayoutEquivalentTo(self.layout, other.layout,
{nullptr, 0}, TARGET_SHAPE);
},
py::arg("other"), py::arg("shape") = std::nullopt,
nb::arg("other"), nb::arg("shape").none() = std::nullopt,
"Returns True if the two layouts are equivalent.\n"
"\n"
"That is, when all potential vector entries where the value can be "
@ -619,55 +631,65 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
" shape: An optional shape of the vector to which both layouts "
"apply. More layouts are considered equivalent when the shape is "
"specified. Also see the docstring of the generalizes method.")
.def("__eq__", mlirTpuVectorLayoutEquals)
.def("__repr__",
[](MlirTpuVectorLayout self) {
std::string str;
mlirTpuVectorLayoutPrint(self, printToString, &str);
return str;
});
.def("__eq__",
[](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other) {
return mlirTpuVectorLayoutEquals(self.layout, other.layout);
})
.def("__repr__", [](const PyTpuVectorLayout& self) {
std::string str;
mlirTpuVectorLayoutPrint(self.layout, printToString, &str);
return str;
});
// TODO(tlongeri): Can we make the first parameter a VectorType?
m.def("assemble",
[](const MlirType ty, MlirTpuVectorLayout layout,
// TODO(tlongeri): Remove py::array::c_style, I only added it because
// I couldn't find a simple way to iterate over array data, but it
// causes yet another unnecessary copy.
py::array_t<PyObject*, py::array::c_style> np_arr) -> MlirOperation {
[](const MlirType ty, const PyTpuVectorLayout& layout,
nb::object np_arr_obj) -> MlirOperation {
// TODO(tlongeri): Remove nb::array::c_style, I only added it because
// I couldn't find a simple way to iterate over array data, but it
// causes yet another unnecessary copy.
xla::nb_numpy_ndarray np_arr = xla::nb_numpy_ndarray::ensure(
np_arr_obj, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED);
if (!mlirTypeIsAVector(ty)) {
throw py::type_error("Expected vector type");
throw nb::type_error("Expected vector type");
}
llvm::SmallVector<MlirValue> vals(np_arr.size());
for (int64_t i = 0; i < np_arr.size(); ++i) {
vals.data()[i] = py::cast<MlirValue>(py::handle(np_arr.data()[i]));
vals.data()[i] = nb::cast<MlirValue>(nb::handle(
reinterpret_cast<PyObject* const*>(np_arr.data())[i]));
}
llvm::SmallVector<int64_t> shape(np_arr.ndim());
for (int64_t i = 0; i < np_arr.ndim(); ++i) {
shape.data()[i] = np_arr.shape()[i];
}
return mlirTpuAssemble(
getDefaultInsertionPoint(), ty, layout,
getDefaultInsertionPoint(), ty, layout.layout,
MlirTpuValueArray{MlirTpuI64ArrayRef{shape.data(), shape.size()},
vals.data()},
TARGET_SHAPE);
});
m.def("disassemble", [](MlirTpuVectorLayout layout, MlirValue val) {
m.def("disassemble", [](const PyTpuVectorLayout& layout, MlirValue val) {
DiagnosticCapture diag_capture(getDefaultContext());
MlirTpuValueArray val_arr = mlirTpuDisassemble(getDefaultInsertionPoint(),
layout, val, TARGET_SHAPE);
MlirTpuValueArray val_arr = mlirTpuDisassemble(
getDefaultInsertionPoint(), layout.layout, val, TARGET_SHAPE);
if (val_arr.vals == nullptr) {
diag_capture.throwIfError();
throw py::value_error("Failed to disassemble");
throw nb::value_error("Failed to disassemble");
}
py::array_t<PyObject*> np_vals(
llvm::ArrayRef<int64_t>{val_arr.shape.ptr, val_arr.shape.size});
xla::nb_numpy_ndarray np_vals(
/*dtype=*/xla::nb_dtype("O"),
/*shape=*/
absl::Span<int64_t const>(val_arr.shape.ptr, val_arr.shape.size),
/*strides=*/std::nullopt);
for (ssize_t i = 0; i < np_vals.size(); ++i) {
np_vals.mutable_data()[i] = py::cast(val_arr.vals[i]).release().ptr();
reinterpret_cast<PyObject**>(np_vals.mutable_data())[i] =
nb::cast(val_arr.vals[i]).release().ptr();
}
free(val_arr.shape.ptr);
free(val_arr.vals);
return np_vals;
});
m.def("apply_layout_op",
[](int hardware_generation, const MlirOperation c_op) {
DiagnosticCapture diag_capture(getDefaultContext());
@ -678,25 +700,27 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
throw std::runtime_error("applyLayoutOp failed");
}
});
m.def("relayout",
[](MlirValue v, MlirTpuVectorLayout src, MlirTpuVectorLayout dst,
MlirTpuApplyVectorLayoutContext apply_layout_ctx) {
DiagnosticCapture diag_capture(getDefaultContext());
MlirValue new_v = mlirTpuRelayout(getDefaultInsertionPoint(), v, src,
dst, apply_layout_ctx);
if (new_v.ptr == nullptr) {
diag_capture.throwIfError();
throw py::value_error("Failed to relayout");
}
return new_v;
});
py::register_exception_translator([](std::exception_ptr p) {
try {
if (p) std::rethrow_exception(p);
} catch (const NotImplementedException& e) {
PyErr_SetString(PyExc_NotImplementedError, e.what());
m.def("relayout", [](MlirValue v, const PyTpuVectorLayout& src,
const PyTpuVectorLayout& dst,
MlirTpuApplyVectorLayoutContext apply_layout_ctx) {
DiagnosticCapture diag_capture(getDefaultContext());
MlirValue new_v = mlirTpuRelayout(getDefaultInsertionPoint(), v, src.layout,
dst.layout, apply_layout_ctx);
if (new_v.ptr == nullptr) {
diag_capture.throwIfError();
throw nb::value_error("Failed to relayout");
}
return new_v;
});
nb::register_exception_translator(
[](const std::exception_ptr& p, void*) {
try {
if (p) std::rethrow_exception(p);
} catch (const NotImplementedException& e) {
PyErr_SetString(PyExc_NotImplementedError, e.what());
}
},
nullptr);
m.def(
"register_dialect",
@ -707,21 +731,25 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
mlirDialectHandleLoadDialect(dialect, context);
}
},
py::arg("context"), py::arg("load") = true);
nb::arg("context"), nb::arg("load") = true);
m.def("private_is_tiled_layout", [](MlirAttribute attr) {
return mlirTPUAttributeIsATiledLayoutAttr(attr);
});
m.def("private_get_tiles", [](MlirAttribute attr) -> py::object {
m.def("private_get_tiles", [](MlirAttribute attr) -> nb::object {
MlirAttribute encoded_tiles = mlirTPUTiledLayoutAttrGetTiles(attr);
py::tuple py_tiles(mlirArrayAttrGetNumElements(encoded_tiles));
nb::tuple py_tiles = nb::steal<nb::tuple>(
PyTuple_New(mlirArrayAttrGetNumElements(encoded_tiles)));
for (intptr_t i = 0; i < mlirArrayAttrGetNumElements(encoded_tiles); ++i) {
MlirAttribute tile = mlirArrayAttrGetElement(encoded_tiles, i);
py::tuple py_tile(mlirDenseArrayGetNumElements(tile));
nb::tuple py_tile =
nb::steal<nb::tuple>(PyTuple_New(mlirDenseArrayGetNumElements(tile)));
for (intptr_t j = 0; j < mlirDenseArrayGetNumElements(tile); ++j) {
py_tile[j] = mlirDenseI64ArrayGetElement(tile, j);
PyTuple_SET_ITEM(
py_tile.ptr(), j,
nb::cast(mlirDenseI64ArrayGetElement(tile, j)).release().ptr());
}
py_tiles[i] = py_tile;
PyTuple_SET_ITEM(py_tiles.ptr(), i, py_tile.release().ptr());
}
return py_tiles;
});
@ -737,7 +765,7 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
m.def("private_replace_all_uses_with", [](MlirOperation op,
std::vector<MlirValue> vals) {
if (vals.size() != mlirOperationGetNumResults(op)) {
throw py::value_error("length mismatch in replace_all_uses_with");
throw nb::value_error("length mismatch in replace_all_uses_with");
}
for (int i = 0; i < vals.size(); ++i) {
mlirValueReplaceAllUsesOfWith(mlirOperationGetResult(op, i), vals[i]);
@ -747,7 +775,7 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
[](MlirValue old, MlirValue new_val, MlirOperation except) {
for (intptr_t i = 0; i < mlirOperationGetNumOperands(except); ++i) {
if (mlirValueEqual(mlirOperationGetOperand(except, i), new_val)) {
throw py::value_error("new val already used in except");
throw nb::value_error("new val already used in except");
}
}
mlirValueReplaceAllUsesOfWith(old, new_val);
@ -785,7 +813,7 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
});
m.def("private_move_all_regions", [](MlirOperation src, MlirOperation dst) {
if (mlirOperationGetNumRegions(src) != mlirOperationGetNumRegions(dst)) {
throw py::value_error(
throw nb::value_error(
"Region counts do not match in src operation and dst operations");
}
for (intptr_t i = 0; i < mlirOperationGetNumRegions(src); ++i) {

View File

@ -16,13 +16,13 @@ limitations under the License.
#include <optional>
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "pybind11/detail/common.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
#include "jaxlib/triton/triton_dialect_capi.h"
namespace py = pybind11;
namespace nb = nanobind;
PYBIND11_MODULE(_triton_ext, m, py::mod_gil_not_used()) {
NB_MODULE(_triton_ext, m) {
//
// Dialects.
//
@ -36,20 +36,20 @@ PYBIND11_MODULE(_triton_ext, m, py::mod_gil_not_used()) {
mlirDialectHandleLoadDialect(dialect, context);
}
},
py::arg("context"), py::arg("load") = true);
nb::arg("context"), nb::arg("load") = true);
//
// Types.
//
mlir::python::adaptors::mlir_type_subclass(m, "PointerType",
mlir::python::nanobind_adaptors::mlir_type_subclass(m, "PointerType",
mlirTritonIsAPointer)
.def_classmethod(
"get",
[](py::object cls, MlirType pointee_type, int64_t address_space) {
[](nb::object cls, MlirType pointee_type, int64_t address_space) {
return cls(mlirTritonPointerTypeGet(pointee_type, address_space));
},
py::arg("cls"), py::arg("pointee_type"), py::arg("address_space"),
nb::arg("cls"), nb::arg("pointee_type"), nb::arg("address_space"),
"Creates a PointerType type.")
.def_property_readonly("pointee_type", [](MlirType self) {
return mlirTritonPointerTypeGetPointeeType(self);