diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index e2376a457..1eb4b03fe 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -102,7 +102,7 @@ class LoweringRuleContext: @dataclasses.dataclass class LoweringResult: - """Keeps pybind11 objects alive.""" + """Keeps python objects alive.""" module: ir.Module grid: tuple[int, ...] diff --git a/jaxlib/absl_status_casters.h b/jaxlib/absl_status_casters.h index 1ed3c0a0a..39e4b6c35 100644 --- a/jaxlib/absl_status_casters.h +++ b/jaxlib/absl_status_casters.h @@ -27,11 +27,6 @@ namespace jax { // // Failing statuses become Python exceptions; OK Status() becomes None. // -// Given there can be only a single global pybind11 type_caster for the -// `absl::Status` type, and given XLA wants a custom exception being raised, -// we use a dedicated helper to implement this feature without relying on a -// global `type_caster`. -// // For example: // // - Functions without arguments: diff --git a/jaxlib/kernel_helpers.h b/jaxlib/kernel_helpers.h index 33eaf8a1d..dac0355fb 100644 --- a/jaxlib/kernel_helpers.h +++ b/jaxlib/kernel_helpers.h @@ -25,9 +25,9 @@ limitations under the License. namespace jax { -// See kernel_pybind11_helpers.h for info on descriptor objects. We separate out -// the functionality that doesn't require pybind11 for building CUDA libraries, -// since older versions nvcc don't seem to be able to compile pybind11. +// See kernel_nanobind_helpers.h for info on descriptor objects. We separate out +// the functionality that doesn't require nanobind for building CUDA libraries, +// since older versions nvcc don't seem to be able to compile nanobind. // Packs a descriptor object into a byte string. template diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 0b94f9d1d..e8443b2ac 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -158,6 +158,7 @@ py_extension( "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi", "@llvm-project//mlir:CAPIIRHeaders", "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", + "@nanobind", ], ) @@ -177,8 +178,10 @@ py_extension( "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", - "@pybind11", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", + "@nanobind", + "@xla//xla/python:nb_numpy", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -203,8 +206,8 @@ pybind_extension( ":jaxlib_mlir_capi_shared_library", "//jaxlib/triton:triton_dialect_capi_headers", "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", - "@pybind11", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", + "@nanobind", ], ) @@ -266,9 +269,9 @@ py_extension( "@llvm-project//mlir:CAPISCFHeaders", "@llvm-project//mlir:CAPITransformsHeaders", "@llvm-project//mlir:CAPIVectorHeaders", - "@llvm-project//mlir:MLIRBindingsPythonHeaders", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", - "@pybind11", + "@nanobind", "@shardy//shardy/integrations/c:sdy_capi_headers", ], ) diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index 7204bbaa1..d3009be21 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -13,17 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// clang-format: off -// pybind11 must be included before mlir/Bindings/Python/PybindAdaptors.h, -// otherwise this code will not build on Windows. -#include "pybind11/pybind11.h" -// clang-format: on - #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" // IWYU pragma: keep +#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep +#include "nanobind/nanobind.h" #include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h" -PYBIND11_MODULE(_mosaic_gpu_ext, m, py::mod_gil_not_used()) { +namespace nb = nanobind; + +NB_MODULE(_mosaic_gpu_ext, m) { m.def( "register_dialect", [](MlirContext context, bool load) { @@ -33,5 +30,5 @@ PYBIND11_MODULE(_mosaic_gpu_ext, m, py::mod_gil_not_used()) { mlirDialectHandleLoadDialect(dialect, context); } }, - py::arg("context"), py::arg("load") = true); + nb::arg("context"), nb::arg("load") = true); } diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 06caabb30..715dfbc59 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -1,6 +1,6 @@ // Registers MLIR dialects used by JAX. // This module is called by mlir/__init__.py during initialization. -#include +#include #include "mlir-c/Dialect/Arith.h" #include "mlir-c/Dialect/Func.h" @@ -13,17 +13,17 @@ #include "mlir-c/Dialect/SCF.h" #include "mlir-c/Dialect/Vector.h" #include "mlir-c/Transforms.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "shardy/integrations/c/passes.h" -namespace py = pybind11; +namespace nb = nanobind; #define REGISTER_DIALECT(name) \ MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ mlirDialectHandleInsertDialect(name##_dialect, registry) -PYBIND11_MODULE(register_jax_dialects, m, py::mod_gil_not_used()) { +NB_MODULE(register_jax_dialects, m) { m.doc() = "Registers upstream MLIR dialects used by JAX."; m.def("register_dialects", [](MlirDialectRegistry registry) { diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index 3061cd399..a15d2f4f6 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -27,12 +27,6 @@ limitations under the License. #include #include -// clang-format: off -// pybind11 must be included before mlir/Bindings/Python/PybindAdaptors.h, -// otherwise this code will not build on Windows. -#include "pybind11/pybind11.h" -// clang-format: on - #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" @@ -46,19 +40,24 @@ limitations under the License. #include "mlir-c/Dialect/Func.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" -#include "pybind11/attr.h" -#include "pybind11/cast.h" -#include "pybind11/detail/common.h" -#include "pybind11/numpy.h" -#include "pybind11/pytypes.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/python/lib/core/numpy.h" // TODO(tlongeri): Can I add my own return type annotations to functions? // TODO(tlongeri): I don't understand why MLIR uses the C API to implement // Python bindings. Do we have a reason to do that? +namespace nb = nanobind; + namespace { constexpr const char LAYOUT_DEFS_MODULE[] = "jax.jaxlib.mosaic.python.layout_defs"; @@ -75,20 +74,21 @@ constexpr MlirTpuI64TargetTuple TARGET_SHAPE{8, 128}; class NotImplementedException : public std::runtime_error { using runtime_error::runtime_error; }; + } // namespace template <> -struct py::detail::type_caster { - PYBIND11_TYPE_CASTER(MlirTpuImplicitDim, const_name("ImplicitDim | None")); +struct nb::detail::type_caster { + NB_TYPE_CASTER(MlirTpuImplicitDim, const_name("ImplicitDim | None")); - bool load(handle src, bool) { + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept { if (src.is_none()) { value = MlirTpuImplicitDimNone; return true; } auto implicit_dim_cls = - py::module_::import(LAYOUT_DEFS_MODULE).attr("ImplicitDim"); - if (!py::isinstance(src, implicit_dim_cls)) { + nb::module_::import_(LAYOUT_DEFS_MODULE).attr("ImplicitDim"); + if (!nb::isinstance(src, implicit_dim_cls)) { return false; } if (src.is(implicit_dim_cls.attr("MINOR"))) { @@ -96,36 +96,36 @@ struct py::detail::type_caster { } else if (src.is(implicit_dim_cls.attr("SECOND_MINOR"))) { value = MlirTpuImplicitDimSecondMinor; } else { - throw py::value_error(); + return false; } return true; } - static handle cast(MlirTpuImplicitDim implicit_dim, - return_value_policy /* policy */, handle /* parent */) { + static handle from_cpp(MlirTpuImplicitDim implicit_dim, rv_policy policy, + cleanup_list* cleanup) noexcept { auto implicit_dim_cls = - py::module_::import(LAYOUT_DEFS_MODULE).attr("ImplicitDim"); + nb::module_::import_(LAYOUT_DEFS_MODULE).attr("ImplicitDim"); switch (implicit_dim) { case MlirTpuImplicitDimNone: - return py::none().release(); + return nb::none().release(); case MlirTpuImplicitDimMinor: - return static_cast(implicit_dim_cls.attr("MINOR")) + return static_cast(implicit_dim_cls.attr("MINOR")) .release(); case MlirTpuImplicitDimSecondMinor: - return static_cast(implicit_dim_cls.attr("SECOND_MINOR")) + return static_cast(implicit_dim_cls.attr("SECOND_MINOR")) .release(); } } }; template <> -struct py::detail::type_caster { - PYBIND11_TYPE_CASTER(MlirTpuDirection, const_name("Direction")); +struct nb::detail::type_caster { + NB_TYPE_CASTER(MlirTpuDirection, const_name("Direction")); - bool load(handle src, bool) { + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept { auto direction_cls = - py::module_::import(LAYOUT_DEFS_MODULE).attr("Direction"); - if (!py::isinstance(src, direction_cls)) { + nb::module_::import_(LAYOUT_DEFS_MODULE).attr("Direction"); + if (!nb::isinstance(src, direction_cls)) { return false; } if (src.is(direction_cls.attr("LANES"))) { @@ -135,26 +135,28 @@ struct py::detail::type_caster { } else if (src.is(direction_cls.attr("SUBELEMENTS"))) { value = MlirTpuDirectionSubelements; } else { - throw py::value_error(); + return false; } return true; } - static handle cast(MlirTpuDirection direction, - return_value_policy /* policy */, handle /* parent */) { + static handle from_cpp(MlirTpuDirection direction, rv_policy /* policy */, + cleanup_list* /* cleanup */) noexcept { auto direction_cls = - py::module_::import(LAYOUT_DEFS_MODULE).attr("ImplicitDim"); + nb::module_::import_(LAYOUT_DEFS_MODULE).attr("ImplicitDim"); switch (direction) { case MlirTpuDirectionLanes: - return static_cast(direction_cls.attr("LANES")).release(); + return static_cast(direction_cls.attr("LANES")).release(); case MlirTpuDirectionSublanes: - return static_cast(direction_cls.attr("SUBLANES")) + return static_cast(direction_cls.attr("SUBLANES")) .release(); case MlirTpuDirectionSubelements: - return static_cast(direction_cls.attr("SUBELEMENTS")) + return static_cast(direction_cls.attr("SUBELEMENTS")) .release(); default: - throw py::value_error(); + PyErr_Format(PyExc_ValueError, "Invalid MlirTpuDirection: %d", + static_cast(direction)); + return nb::handle(); } } }; @@ -163,9 +165,9 @@ namespace { // Handler for use with MLIR C API print functions. The 2nd parameter is an // opaque pointer to "user data" that should always be a string. void printToString(MlirStringRef c_mlir_str, void* opaque_string) { - std::string* str = static_cast(opaque_string); - CHECK(str != nullptr); - str->append(c_mlir_str.data, c_mlir_str.length); + std::string* str = static_cast(opaque_string); + CHECK(str != nullptr); + str->append(c_mlir_str.data, c_mlir_str.length); } class DiagnosticCapture { @@ -217,141 +219,134 @@ class DiagnosticCapture { const MlirDiagnosticHandlerID id_; }; -template -class Holder {}; - -// Holder class for MlirTpuVectorLayout, to deal properly with destruction. -// TODO(tlongeri): It would be nice to not have a seemingly unnecessary -// "pointer-to-pointer" (MlirTpuVectorLayout is basically an opaque pointer). -// But I'm not sure if that's possible since pybind expects get() to return a -// true pointer type. -template <> -class Holder { - public: - Holder(MlirTpuVectorLayout layout) : ptr(new MlirTpuVectorLayout(layout)) {} - Holder(MlirTpuVectorLayout* layout) : ptr(layout) {} - Holder(Holder&& other) = default; - ~Holder() { mlirTpuVectorLayoutDestroy(*ptr); } - MlirTpuVectorLayout* get() { return ptr.get(); } - - private: - std::unique_ptr ptr; -}; } // namespace -PYBIND11_DECLARE_HOLDER_TYPE(T, Holder); - namespace { -py::object toPyLayoutOffset(int64_t offset) { +nb::object toPyLayoutOffset(int64_t offset) { CHECK_GE(offset, -1); if (offset == -1) { - return py::module_::import(LAYOUT_DEFS_MODULE).attr("REPLICATED"); + return nb::module_::import_(LAYOUT_DEFS_MODULE).attr("REPLICATED"); } else { - return py::int_(offset); + return nb::int_(offset); } } // TODO(tlongeri): Would `type_caster`s let me avoid defining all of these // to/from functions? -int64_t offsetFromPyOffset(py::object py_offset) { - if (py::isinstance(py_offset)) { - int64_t offset = py::cast(py_offset); +int64_t offsetFromPyOffset(nb::object py_offset) { + if (nb::isinstance(py_offset)) { + int64_t offset = nb::cast(py_offset); if (offset < 0) { - throw py::value_error("Invalid py layout offset"); + throw nb::value_error("Invalid py layout offset"); } return offset; } else if (py_offset.equal( - py::module_::import(LAYOUT_DEFS_MODULE).attr("REPLICATED"))) { + nb::module_::import_(LAYOUT_DEFS_MODULE).attr("REPLICATED"))) { return -1; } else { - throw py::type_error("Invalid layout offset type"); + throw nb::type_error("Invalid layout offset type"); } } template -llvm::SmallVector sequenceToSmallVector(py::sequence seq) { - return llvm::map_to_vector( - seq, [](py::handle handle) { return py::cast(handle); }); +llvm::SmallVector sequenceToSmallVector(nb::sequence seq) { + llvm::SmallVector out; + out.reserve(nb::len(seq)); + for (nb::handle elem : seq) { + out.push_back(nb::cast(elem)); + } + return out; } -py::tuple toPyTuple(const int64_t* data, size_t count) { - py::tuple tuple(count); +nb::tuple toPyTuple(const int64_t* data, size_t count) { + nb::tuple tuple = nb::steal(PyTuple_New(count)); for (size_t i = 0; i < count; ++i) { - tuple[i] = data[i]; + PyTuple_SET_ITEM(tuple.ptr(), i, nb::int_(data[i]).release().ptr()); } return tuple; } -py::tuple toPyTuple(MlirTpuI64TargetTuple tuple) { - return py::make_tuple(tuple.sublane, tuple.lane); +nb::tuple toPyTuple(MlirTpuI64TargetTuple tuple) { + return nb::make_tuple(tuple.sublane, tuple.lane); } // Unwraps the current default insertion point // ValueError is raised if default insertion point is not set MlirTpuInsertionPoint getDefaultInsertionPoint() { - py::object insertion_point = - py::module_::import(IR_MODULE).attr("InsertionPoint").attr("current"); - py::object ref_operation = insertion_point.attr("ref_operation"); - return {py::cast(insertion_point.attr("block")), + nb::object insertion_point = + nb::module_::import_(IR_MODULE).attr("InsertionPoint").attr("current"); + nb::object ref_operation = insertion_point.attr("ref_operation"); + return {nb::cast(insertion_point.attr("block")), ref_operation.is_none() ? MlirOperation{nullptr} - : py::cast(insertion_point.attr("ref_operation"))}; + : nb::cast(insertion_point.attr("ref_operation"))}; } // Unwraps the current default location // ValueError is raised if default location is not set MlirLocation getDefaultLocation() { - return py::cast( - py::module_::import(IR_MODULE).attr("Location").attr("current")); + return nb::cast( + nb::module_::import_(IR_MODULE).attr("Location").attr("current")); } // Unwraps the current default MLIR context // ValueError is raised if default context is not set MlirContext getDefaultContext() { - return py::cast( - py::module_::import(IR_MODULE).attr("Context").attr("current")); + return nb::cast( + nb::module_::import_(IR_MODULE).attr("Context").attr("current")); } +struct PyTpuVectorLayout { + PyTpuVectorLayout(MlirTpuVectorLayout layout) : layout(layout) {} + ~PyTpuVectorLayout() { mlirTpuVectorLayoutDestroy(layout); } + PyTpuVectorLayout(const PyTpuVectorLayout&) = delete; + PyTpuVectorLayout& operator=(const PyTpuVectorLayout&) = delete; + + MlirTpuVectorLayout layout; +}; + } // namespace -PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { +NB_MODULE(_tpu_ext, m) { + tsl::ImportNumpy(); mlirRegisterTPUPasses(); // Register all passes on load. mlirTpuRegisterMosaicSerdePass(); - py::class_(m, "ApplyVectorLayoutCtx", - py::module_local()) - .def(py::init([](int hardware_generation, py::tuple target_shape, - py::tuple mxu_shape, int max_sublanes_in_scratch) { - if (target_shape.size() != 2) { - throw py::value_error("target_shape should be of length 2"); - } - if (mxu_shape.size() != 2) { - throw py::value_error("mxu_shape should be of length 2"); - } - return MlirTpuApplyVectorLayoutContext{ - .hardware_generation = hardware_generation, - .target_shape = {target_shape[0].cast(), - target_shape[1].cast()}, - .mxu_shape = {mxu_shape[0].cast(), - mxu_shape[1].cast()}, - .max_sublanes_in_scratch = max_sublanes_in_scratch}; - }), - py::arg("hardware_generation") = -1, - py::arg("target_shape") = toPyTuple(TARGET_SHAPE), - py::arg("mxu_shape") = py::make_tuple(128, 128), - py::arg("max_sublanes_in_scratch") = 0); + nb::class_(m, "ApplyVectorLayoutCtx") + .def( + "__init__", + [](MlirTpuApplyVectorLayoutContext* self, int hardware_generation, + nb::tuple target_shape, nb::tuple mxu_shape, + int max_sublanes_in_scratch) { + if (target_shape.size() != 2) { + throw nb::value_error("target_shape should be of length 2"); + } + if (mxu_shape.size() != 2) { + throw nb::value_error("mxu_shape should be of length 2"); + } + new (self) MlirTpuApplyVectorLayoutContext{ + .hardware_generation = hardware_generation, + .target_shape = {nb::cast(target_shape[0]), + nb::cast(target_shape[1])}, + .mxu_shape = {nb::cast(mxu_shape[0]), + nb::cast(mxu_shape[1])}, + .max_sublanes_in_scratch = max_sublanes_in_scratch}; + }, + nb::arg("hardware_generation") = -1, + nb::arg("target_shape") = toPyTuple(TARGET_SHAPE), + nb::arg("mxu_shape") = nb::make_tuple(128, 128), + nb::arg("max_sublanes_in_scratch") = 0); - py::class_(m, "VRegDataBounds", py::module_local()) + nb::class_(m, "VRegDataBounds") .def("mask_varies_along", [](MlirTpuVregDataBounds self, MlirTpuDirection direction) { return mlirTpuVregDataBoundsMaskVariesAlong(self, direction, TARGET_SHAPE); }) - .def_property_readonly("complete", - [](MlirTpuVregDataBounds self) { - return mlirTpuVregDataBoundsIsComplete( - self, TARGET_SHAPE); - }) + .def_prop_ro("complete", + [](MlirTpuVregDataBounds self) { + return mlirTpuVregDataBoundsIsComplete(self, TARGET_SHAPE); + }) .def("get_vector_mask", [](MlirTpuVregDataBounds self, int generation) { // TODO: Does this work? Test in Python @@ -370,63 +365,79 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { // TODO(tlongeri): More precise argument type annotations. There currently // seems to be no way to define your own? - py::class_>( - m, "VectorLayout", py::module_local()) - .def(py::init([](int bitwidth, py::tuple offsets, py::tuple tiling, - MlirTpuImplicitDim implicit_dim) { - if (offsets.size() != 2) { - throw py::value_error("Offsets should be of length 2"); - } - if (tiling.size() != 2) { - throw py::value_error("Tiling should be of length 2"); - } - MlirTpuVectorLayout layout = mlirTpuVectorLayoutCreate( - bitwidth, - {offsetFromPyOffset(offsets[0]), - offsetFromPyOffset(offsets[1])}, - {tiling[0].cast(), tiling[1].cast()}, - implicit_dim); - if (!mlirTpuVectorLayoutIsValid(layout, TARGET_SHAPE)) { - throw py::value_error("Layout not valid for target shape"); + nb::class_(m, "VectorLayout") + .def( + "__init__", + [](PyTpuVectorLayout* self, int bitwidth, nb::tuple offsets, + nb::tuple tiling, MlirTpuImplicitDim implicit_dim) { + if (offsets.size() != 2) { + throw nb::value_error("Offsets should be of length 2"); } - return layout; - }), - py::arg("bitwidth"), py::arg("offsets"), py::arg("tiling"), - py::arg("implicit_dim")) - .def_property_readonly("bitwidth", mlirTpuVectorLayoutGetBitwidth, - "The bitwidth of the stored values.") - .def_property_readonly( + if (tiling.size() != 2) { + throw nb::value_error("Tiling should be of length 2"); + } + MlirTpuVectorLayout layout = mlirTpuVectorLayoutCreate( + bitwidth, + {offsetFromPyOffset(offsets[0]), + offsetFromPyOffset(offsets[1])}, + {nb::cast(tiling[0]), nb::cast(tiling[1])}, + implicit_dim); + if (!mlirTpuVectorLayoutIsValid(layout, TARGET_SHAPE)) { + throw nb::value_error("Layout not valid for target shape"); + } + new (self) PyTpuVectorLayout(layout); + }, + nb::arg("bitwidth"), nb::arg("offsets"), nb::arg("tiling"), + nb::arg("implicit_dim").none()) + .def_prop_ro( + "bitwidth", + [](const PyTpuVectorLayout& self) { + return mlirTpuVectorLayoutGetBitwidth(self.layout); + }, + "The bitwidth of the stored values.") + .def_prop_ro( "offsets", - [](MlirTpuVectorLayout self) { - MlirTpuLayoutOffsets offsets = mlirTpuVectorLayoutGetOffsets(self); - return py::make_tuple(toPyLayoutOffset(offsets.sublane), + [](const PyTpuVectorLayout& self) { + MlirTpuLayoutOffsets offsets = + mlirTpuVectorLayoutGetOffsets(self.layout); + return nb::make_tuple(toPyLayoutOffset(offsets.sublane), toPyLayoutOffset(offsets.lane)); }, "The coordinates of the first valid element. If an offset is " "REPLICATED, then any offset is valid as the value does not vary " "across sublanes or lanes respectively.") - .def_property_readonly( + .def_prop_ro( "tiling", - [](MlirTpuVectorLayout self) { - return toPyTuple(mlirTpuVectorLayoutGetTiling(self)); + [](const PyTpuVectorLayout& self) { + return toPyTuple(mlirTpuVectorLayoutGetTiling(self.layout)); }, "The tiling used to lay out values (see the XLA docs). For values of " "bitwidth < 32, an implicit (32 // bitwidth, 1) tiling is appended " "to the one specified as an attribute.") - .def_property_readonly( - "implicit_dim", mlirTpuVectorLayoutGetImplicitDim, + .def_prop_ro( + "implicit_dim", + [](const PyTpuVectorLayout& self) { + return mlirTpuVectorLayoutGetImplicitDim(self.layout); + }, "If specified, the value has an implicit dim inserted in either " "minormost or second minormost position.") - .def_property_readonly( - "packing", mlirTpuVectorLayoutGetPacking, + .def_prop_ro( + "packing", + [](const PyTpuVectorLayout& self) { + return mlirTpuVectorLayoutGetPacking(self.layout); + }, "Returns the number of values stored in a vreg entry.") - .def_property_readonly( - "layout_rank", mlirTpuVectorLayoutGetLayoutRank, + .def_prop_ro( + "layout_rank", + [](const PyTpuVectorLayout& self) { + return mlirTpuVectorLayoutGetLayoutRank(self.layout); + }, "The number of minormost dimensions tiled by this layout.") - .def_property_readonly( + .def_prop_ro( "has_natural_topology", - [](MlirTpuVectorLayout self) { - return mlirTpuVectorLayoutHasNaturalTopology(self, TARGET_SHAPE); + [](const PyTpuVectorLayout& self) { + return mlirTpuVectorLayoutHasNaturalTopology(self.layout, + TARGET_SHAPE); }, "True, if every vector register has a layout without jumps.\n" "\n" @@ -434,33 +445,35 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { "always leads to a contiguous traversal of the (second) minormost " "dimension of data. This is only true for 32-bit types, since " "narrower types use two level tiling.") - .def_property_readonly( + .def_prop_ro( "has_native_tiling", - [](MlirTpuVectorLayout self) { - return mlirTpuVectorLayoutHasNativeTiling(self, TARGET_SHAPE); + [](const PyTpuVectorLayout& self) { + return mlirTpuVectorLayoutHasNativeTiling(self.layout, + TARGET_SHAPE); }, "True, if every vector register has a natural \"packed\" topology.\n" "\n" "This is equivalent to has_natural_topology for 32-bit types, but " "generalizes it to narrower values with packed layouts too.") - .def_property_readonly( + .def_prop_ro( "tiles_per_vreg", - [](MlirTpuVectorLayout self) { - return mlirTpuVectorLayoutTilesPerVreg(self, TARGET_SHAPE); + [](const PyTpuVectorLayout& self) { + return mlirTpuVectorLayoutTilesPerVreg(self.layout, TARGET_SHAPE); }, "How many tiles fit in each vector register.") - .def_property_readonly( + .def_prop_ro( "sublanes_per_tile", - [](MlirTpuVectorLayout self) { - return mlirTpuVectorLayoutSublanesPerTile(self, TARGET_SHAPE); + [](const PyTpuVectorLayout& self) { + return mlirTpuVectorLayoutSublanesPerTile(self.layout, + TARGET_SHAPE); }, "The number of sublanes necessary to store each tile.") - .def_property_readonly( + .def_prop_ro( "vreg_slice", - [](MlirTpuVectorLayout self) { + [](const PyTpuVectorLayout& self) { MlirTpuI64TargetTuple vreg_slice = - mlirTpuVectorLayoutVregSlice(self, TARGET_SHAPE); - return py::module_::import(LAYOUT_DEFS_MODULE) + mlirTpuVectorLayoutVregSlice(self.layout, TARGET_SHAPE); + return nb::module_::import_(LAYOUT_DEFS_MODULE) .attr("TargetTuple")(vreg_slice.sublane, vreg_slice.lane); }, "Returns the size of a window contained in a single vreg.\n" @@ -469,34 +482,34 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { "rows, so only the minormost dimension can increase.") .def( "implicit_shape", - [](MlirTpuVectorLayout self, py::sequence shape) { + [](const PyTpuVectorLayout& self, nb::sequence shape) { llvm::SmallVector implicit_shape_vec = sequenceToSmallVector(shape); MlirTpuI64ArrayRef implicit_shape = mlirTpuVectorLayoutImplicitShape( - self, + self.layout, {implicit_shape_vec.data(), implicit_shape_vec.size()}); - py::tuple ret = toPyTuple(implicit_shape.ptr, implicit_shape.size); + nb::tuple ret = toPyTuple(implicit_shape.ptr, implicit_shape.size); free(implicit_shape.ptr); return ret; }, - py::arg("shape")) + nb::arg("shape")) .def( "tile_array_shape", - [](MlirTpuVectorLayout self, py::sequence shape) { + [](const PyTpuVectorLayout& self, nb::sequence shape) { llvm::SmallVector tile_array_shape_vec = sequenceToSmallVector(shape); MlirTpuI64ArrayRef tile_array_shape = mlirTpuVectorLayoutTileArrayShape( - self, + self.layout, {tile_array_shape_vec.data(), tile_array_shape_vec.size()}, TARGET_SHAPE); - py::tuple ret = + nb::tuple ret = toPyTuple(tile_array_shape.ptr, tile_array_shape.size); free(tile_array_shape.ptr); return ret; }, - py::arg("shape"), + nb::arg("shape"), "Returns the shape of an ndarray of vregs needed to represent a " "value.\n" "\n" @@ -511,34 +524,33 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { " shape: The shape of the ndarray to tile.") .def( "tile_data_bounds", - [](MlirTpuVectorLayout self, py::sequence shape, py::sequence ixs, - std::variant allow_replicated) { + [](const PyTpuVectorLayout& self, nb::sequence shape, + nb::sequence ixs, std::variant allow_replicated) { llvm::SmallVector shape_vec = sequenceToSmallVector(shape); llvm::SmallVector ixs_vec = sequenceToSmallVector(ixs); if (shape_vec.size() != ixs_vec.size()) { - throw py::value_error( + throw nb::value_error( "Expected shape and ixs to have the same size"); } return std::visit( [&](auto ar) { if constexpr (std::is_same_v) { return mlirTpuVectorLayoutTileDataBounds( - self, getDefaultContext(), shape_vec.data(), + self.layout, getDefaultContext(), shape_vec.data(), ixs_vec.data(), shape_vec.size(), TARGET_SHAPE, {ar, ar}); } else { return mlirTpuVectorLayoutTileDataBounds( - self, getDefaultContext(), shape_vec.data(), + self.layout, getDefaultContext(), shape_vec.data(), ixs_vec.data(), shape_vec.size(), TARGET_SHAPE, - {ar[0].template cast(), - ar[1].template cast()}); + {nb::cast(ar[0]), nb::cast(ar[1])}); } }, allow_replicated); }, - py::arg("shape"), py::arg("ixs"), py::arg("allow_replicated") = false, + nb::arg("shape"), nb::arg("ixs"), nb::arg("allow_replicated") = false, "Returns the bounds of the given tile that hold useful data.\n" "\n" "Arguments:\n" @@ -556,19 +568,19 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { "within the tile selected by idx.") .def( "generalizes", - [](MlirTpuVectorLayout self, MlirTpuVectorLayout other, - std::optional shape) { + [](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other, + std::optional shape) { if (shape) { llvm::SmallVector shape_vec = sequenceToSmallVector(*shape); return mlirTpuVectorLayoutGeneralizes( - self, other, {shape_vec.data(), shape_vec.size()}, - TARGET_SHAPE); + self.layout, other.layout, + {shape_vec.data(), shape_vec.size()}, TARGET_SHAPE); } - return mlirTpuVectorLayoutGeneralizes(self, other, {nullptr, 0}, - TARGET_SHAPE); + return mlirTpuVectorLayoutGeneralizes(self.layout, other.layout, + {nullptr, 0}, TARGET_SHAPE); }, - py::arg("other"), py::arg("shape") = std::nullopt, + nb::arg("other"), nb::arg("shape").none() = std::nullopt, "Returns True if the other layout is a special case of this one.\n" "\n" "In here, other is considered \"a special case\" when the set of " @@ -595,19 +607,19 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { "does not hold the other way around for some shapes.") .def( "equivalent_to", - [](MlirTpuVectorLayout self, MlirTpuVectorLayout other, - std::optional shape) { + [](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other, + std::optional shape) { if (shape) { llvm::SmallVector shape_vec = sequenceToSmallVector(*shape); return mlirTpuVectorLayoutEquivalentTo( - self, other, {shape_vec.data(), shape_vec.size()}, - TARGET_SHAPE); + self.layout, other.layout, + {shape_vec.data(), shape_vec.size()}, TARGET_SHAPE); } - return mlirTpuVectorLayoutEquivalentTo(self, other, {nullptr, 0}, - TARGET_SHAPE); + return mlirTpuVectorLayoutEquivalentTo(self.layout, other.layout, + {nullptr, 0}, TARGET_SHAPE); }, - py::arg("other"), py::arg("shape") = std::nullopt, + nb::arg("other"), nb::arg("shape").none() = std::nullopt, "Returns True if the two layouts are equivalent.\n" "\n" "That is, when all potential vector entries where the value can be " @@ -619,55 +631,65 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { " shape: An optional shape of the vector to which both layouts " "apply. More layouts are considered equivalent when the shape is " "specified. Also see the docstring of the generalizes method.") - .def("__eq__", mlirTpuVectorLayoutEquals) - .def("__repr__", - [](MlirTpuVectorLayout self) { - std::string str; - mlirTpuVectorLayoutPrint(self, printToString, &str); - return str; - }); + .def("__eq__", + [](const PyTpuVectorLayout& self, const PyTpuVectorLayout& other) { + return mlirTpuVectorLayoutEquals(self.layout, other.layout); + }) + .def("__repr__", [](const PyTpuVectorLayout& self) { + std::string str; + mlirTpuVectorLayoutPrint(self.layout, printToString, &str); + return str; + }); // TODO(tlongeri): Can we make the first parameter a VectorType? m.def("assemble", - [](const MlirType ty, MlirTpuVectorLayout layout, - // TODO(tlongeri): Remove py::array::c_style, I only added it because - // I couldn't find a simple way to iterate over array data, but it - // causes yet another unnecessary copy. - py::array_t np_arr) -> MlirOperation { + [](const MlirType ty, const PyTpuVectorLayout& layout, + nb::object np_arr_obj) -> MlirOperation { + // TODO(tlongeri): Remove nb::array::c_style, I only added it because + // I couldn't find a simple way to iterate over array data, but it + // causes yet another unnecessary copy. + xla::nb_numpy_ndarray np_arr = xla::nb_numpy_ndarray::ensure( + np_arr_obj, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED); if (!mlirTypeIsAVector(ty)) { - throw py::type_error("Expected vector type"); + throw nb::type_error("Expected vector type"); } llvm::SmallVector vals(np_arr.size()); for (int64_t i = 0; i < np_arr.size(); ++i) { - vals.data()[i] = py::cast(py::handle(np_arr.data()[i])); + vals.data()[i] = nb::cast(nb::handle( + reinterpret_cast(np_arr.data())[i])); } llvm::SmallVector shape(np_arr.ndim()); for (int64_t i = 0; i < np_arr.ndim(); ++i) { shape.data()[i] = np_arr.shape()[i]; } return mlirTpuAssemble( - getDefaultInsertionPoint(), ty, layout, + getDefaultInsertionPoint(), ty, layout.layout, MlirTpuValueArray{MlirTpuI64ArrayRef{shape.data(), shape.size()}, vals.data()}, TARGET_SHAPE); }); - m.def("disassemble", [](MlirTpuVectorLayout layout, MlirValue val) { + m.def("disassemble", [](const PyTpuVectorLayout& layout, MlirValue val) { DiagnosticCapture diag_capture(getDefaultContext()); - MlirTpuValueArray val_arr = mlirTpuDisassemble(getDefaultInsertionPoint(), - layout, val, TARGET_SHAPE); + MlirTpuValueArray val_arr = mlirTpuDisassemble( + getDefaultInsertionPoint(), layout.layout, val, TARGET_SHAPE); if (val_arr.vals == nullptr) { diag_capture.throwIfError(); - throw py::value_error("Failed to disassemble"); + throw nb::value_error("Failed to disassemble"); } - py::array_t np_vals( - llvm::ArrayRef{val_arr.shape.ptr, val_arr.shape.size}); + xla::nb_numpy_ndarray np_vals( + /*dtype=*/xla::nb_dtype("O"), + /*shape=*/ + absl::Span(val_arr.shape.ptr, val_arr.shape.size), + /*strides=*/std::nullopt); for (ssize_t i = 0; i < np_vals.size(); ++i) { - np_vals.mutable_data()[i] = py::cast(val_arr.vals[i]).release().ptr(); + reinterpret_cast(np_vals.mutable_data())[i] = + nb::cast(val_arr.vals[i]).release().ptr(); } free(val_arr.shape.ptr); free(val_arr.vals); return np_vals; }); + m.def("apply_layout_op", [](int hardware_generation, const MlirOperation c_op) { DiagnosticCapture diag_capture(getDefaultContext()); @@ -678,25 +700,27 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { throw std::runtime_error("applyLayoutOp failed"); } }); - m.def("relayout", - [](MlirValue v, MlirTpuVectorLayout src, MlirTpuVectorLayout dst, - MlirTpuApplyVectorLayoutContext apply_layout_ctx) { - DiagnosticCapture diag_capture(getDefaultContext()); - MlirValue new_v = mlirTpuRelayout(getDefaultInsertionPoint(), v, src, - dst, apply_layout_ctx); - if (new_v.ptr == nullptr) { - diag_capture.throwIfError(); - throw py::value_error("Failed to relayout"); - } - return new_v; - }); - py::register_exception_translator([](std::exception_ptr p) { - try { - if (p) std::rethrow_exception(p); - } catch (const NotImplementedException& e) { - PyErr_SetString(PyExc_NotImplementedError, e.what()); + m.def("relayout", [](MlirValue v, const PyTpuVectorLayout& src, + const PyTpuVectorLayout& dst, + MlirTpuApplyVectorLayoutContext apply_layout_ctx) { + DiagnosticCapture diag_capture(getDefaultContext()); + MlirValue new_v = mlirTpuRelayout(getDefaultInsertionPoint(), v, src.layout, + dst.layout, apply_layout_ctx); + if (new_v.ptr == nullptr) { + diag_capture.throwIfError(); + throw nb::value_error("Failed to relayout"); } + return new_v; }); + nb::register_exception_translator( + [](const std::exception_ptr& p, void*) { + try { + if (p) std::rethrow_exception(p); + } catch (const NotImplementedException& e) { + PyErr_SetString(PyExc_NotImplementedError, e.what()); + } + }, + nullptr); m.def( "register_dialect", @@ -707,21 +731,25 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { mlirDialectHandleLoadDialect(dialect, context); } }, - py::arg("context"), py::arg("load") = true); + nb::arg("context"), nb::arg("load") = true); m.def("private_is_tiled_layout", [](MlirAttribute attr) { return mlirTPUAttributeIsATiledLayoutAttr(attr); }); - m.def("private_get_tiles", [](MlirAttribute attr) -> py::object { + m.def("private_get_tiles", [](MlirAttribute attr) -> nb::object { MlirAttribute encoded_tiles = mlirTPUTiledLayoutAttrGetTiles(attr); - py::tuple py_tiles(mlirArrayAttrGetNumElements(encoded_tiles)); + nb::tuple py_tiles = nb::steal( + PyTuple_New(mlirArrayAttrGetNumElements(encoded_tiles))); for (intptr_t i = 0; i < mlirArrayAttrGetNumElements(encoded_tiles); ++i) { MlirAttribute tile = mlirArrayAttrGetElement(encoded_tiles, i); - py::tuple py_tile(mlirDenseArrayGetNumElements(tile)); + nb::tuple py_tile = + nb::steal(PyTuple_New(mlirDenseArrayGetNumElements(tile))); for (intptr_t j = 0; j < mlirDenseArrayGetNumElements(tile); ++j) { - py_tile[j] = mlirDenseI64ArrayGetElement(tile, j); + PyTuple_SET_ITEM( + py_tile.ptr(), j, + nb::cast(mlirDenseI64ArrayGetElement(tile, j)).release().ptr()); } - py_tiles[i] = py_tile; + PyTuple_SET_ITEM(py_tiles.ptr(), i, py_tile.release().ptr()); } return py_tiles; }); @@ -737,7 +765,7 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { m.def("private_replace_all_uses_with", [](MlirOperation op, std::vector vals) { if (vals.size() != mlirOperationGetNumResults(op)) { - throw py::value_error("length mismatch in replace_all_uses_with"); + throw nb::value_error("length mismatch in replace_all_uses_with"); } for (int i = 0; i < vals.size(); ++i) { mlirValueReplaceAllUsesOfWith(mlirOperationGetResult(op, i), vals[i]); @@ -747,7 +775,7 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { [](MlirValue old, MlirValue new_val, MlirOperation except) { for (intptr_t i = 0; i < mlirOperationGetNumOperands(except); ++i) { if (mlirValueEqual(mlirOperationGetOperand(except, i), new_val)) { - throw py::value_error("new val already used in except"); + throw nb::value_error("new val already used in except"); } } mlirValueReplaceAllUsesOfWith(old, new_val); @@ -785,7 +813,7 @@ PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { }); m.def("private_move_all_regions", [](MlirOperation src, MlirOperation dst) { if (mlirOperationGetNumRegions(src) != mlirOperationGetNumRegions(dst)) { - throw py::value_error( + throw nb::value_error( "Region counts do not match in src operation and dst operations"); } for (intptr_t i = 0; i < mlirOperationGetNumRegions(src); ++i) { diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index e02e4f3d8..2a13c40d9 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -16,13 +16,13 @@ limitations under the License. #include #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" -#include "pybind11/detail/common.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "nanobind/nanobind.h" #include "jaxlib/triton/triton_dialect_capi.h" -namespace py = pybind11; +namespace nb = nanobind; -PYBIND11_MODULE(_triton_ext, m, py::mod_gil_not_used()) { +NB_MODULE(_triton_ext, m) { // // Dialects. // @@ -36,20 +36,20 @@ PYBIND11_MODULE(_triton_ext, m, py::mod_gil_not_used()) { mlirDialectHandleLoadDialect(dialect, context); } }, - py::arg("context"), py::arg("load") = true); + nb::arg("context"), nb::arg("load") = true); // // Types. // - mlir::python::adaptors::mlir_type_subclass(m, "PointerType", + mlir::python::nanobind_adaptors::mlir_type_subclass(m, "PointerType", mlirTritonIsAPointer) .def_classmethod( "get", - [](py::object cls, MlirType pointee_type, int64_t address_space) { + [](nb::object cls, MlirType pointee_type, int64_t address_space) { return cls(mlirTritonPointerTypeGet(pointee_type, address_space)); }, - py::arg("cls"), py::arg("pointee_type"), py::arg("address_space"), + nb::arg("cls"), nb::arg("pointee_type"), nb::arg("address_space"), "Creates a PointerType type.") .def_property_readonly("pointee_type", [](MlirType self) { return mlirTritonPointerTypeGetPointeeType(self);