From 8a6efa317d2c104ca7905a6a4d6e521a9b9ebe4c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 6 Apr 2025 13:35:12 -0700 Subject: [PATCH] Fix deadlock when computing cached Sharding::type() values. C++ static initialization acquires an internal mutex. It is unsafe to call into Python code while holding that mutex, e.g., see the deadlock in https://gist.github.com/vfdev-5/826ef16c6cbc9f4d85466e8a348c3b5a However, in this case, there's a simpler thing we can do: eagerly initialize the ::type() values during module initialization, rather than on-demand. PiperOrigin-RevId: 744508279 --- jaxlib/xla/BUILD | 1 + jaxlib/xla/sharding.cc | 38 +++++++++++++++++++++++++++++++++++++- jaxlib/xla/sharding.h | 32 ++++++++++++++++---------------- 3 files changed, 54 insertions(+), 17 deletions(-) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 35f344046..8602652cb 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -565,6 +565,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc index b6b58b060..ff1539764 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/xla/sharding.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include #include #include @@ -28,6 +29,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep @@ -242,7 +244,7 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, // TODO(phawkins): this leaks a reference to the check_pspec function. // A better way to fix this would be to move PartitionSpec and this check into // C++. - nb::object* check_pspec = [](){ + nb::object* check_pspec = []() { static absl::Mutex mu; static nb::object* output = nullptr; { @@ -262,6 +264,13 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, (*check_pspec)(mesh_, spec_, manual_axes_); } +/*static*/ PyObject* NamedSharding::type_ = nullptr; + +/*static*/ void NamedSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + SingleDeviceSharding::SingleDeviceSharding(nb::object device, nb::object memory_kind) : Sharding(/*num_devices=*/1), @@ -273,6 +282,13 @@ SingleDeviceSharding::SingleDeviceSharding(nb::object device, CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); } +/*static*/ PyObject* SingleDeviceSharding::type_ = nullptr; + +/*static*/ void SingleDeviceSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + SingleDeviceSharding::SingleDeviceSharding( xla::nb_class_ptr client, xla::ifrt::DeviceListRef device_list, nb::object memory_kind) @@ -295,6 +311,15 @@ PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices, xla::make_nb_class(nb::tuple(flat_devices)); } +/*static*/ PyObject* PmapSharding::type_ = nullptr; + +// /*static*/ nanobind::handle PmapSharding::type() { return type_; } + +/*static*/ void PmapSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding, nb::object memory_kind, nb::object device_list) : Sharding(/*num_devices=*/nb::len(devices.ptr())), @@ -316,6 +341,13 @@ GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding, CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); } +/*static*/ PyObject* GSPMDSharding::type_ = nullptr; + +/*static*/ void GSPMDSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + void RegisterSharding(nb::module_& m) { nb::class_(m, "Sharding").def(nb::init<>()); @@ -334,6 +366,7 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("_internal_device_list", [](const NamedSharding& s) { return xla::ValueOrThrow(s.internal_device_list()); }); + NamedSharding::InitializeType(); nb::class_(m, "SingleDeviceSharding", nb::dynamic_attr()) @@ -343,6 +376,7 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind) .def_prop_ro("_internal_device_list", &SingleDeviceSharding::internal_device_list); + SingleDeviceSharding::InitializeType(); nb::class_(m, "PmapSharding", nb::dynamic_attr()) .def( @@ -357,6 +391,7 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("sharding_spec", &PmapSharding::sharding_spec) .def_prop_ro("_internal_device_list", &PmapSharding::internal_device_list); + PmapSharding::InitializeType(); nb::class_(m, "GSPMDSharding", nb::dynamic_attr()) .def(nb::init(), @@ -372,6 +407,7 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind) .def_prop_ro("_internal_device_list", &GSPMDSharding::internal_device_list); + GSPMDSharding::InitializeType(); } } // namespace jax diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h index 698ff2ca9..4b602bd14 100644 --- a/jaxlib/xla/sharding.h +++ b/jaxlib/xla/sharding.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef JAXLIB_XLA_SHARDING_H_ #define JAXLIB_XLA_SHARDING_H_ +#include + #include #include #include @@ -84,10 +86,8 @@ class NamedSharding : public Sharding { return logical_device_ids_; } - static nanobind::handle type() { - static auto type = nanobind::type(); - return type; - } + static nanobind::handle type() { return type_; } + static void InitializeType(); absl::StatusOr> internal_device_list() const { if (internal_device_list_) { @@ -105,6 +105,7 @@ class NamedSharding : public Sharding { nanobind::object manual_axes_; nanobind::object logical_device_ids_; std::optional> internal_device_list_; + static PyObject* type_; }; class SingleDeviceSharding : public Sharding { @@ -120,10 +121,8 @@ class SingleDeviceSharding : public Sharding { const nanobind::object& device() const { return device_; } const nanobind::object& memory_kind() const { return memory_kind_; } - static nanobind::handle type() { - static auto type = nanobind::type(); - return type; - } + static nanobind::handle type() { return type_; } + static void InitializeType(); xla::nb_class_ptr internal_device_list() const { return internal_device_list_; @@ -133,6 +132,8 @@ class SingleDeviceSharding : public Sharding { nanobind::object device_; nanobind::object memory_kind_; xla::nb_class_ptr internal_device_list_; + + static PyObject* type_; }; // The C++ implementation of jax.PmapSharding in python. It contains a few key @@ -147,10 +148,8 @@ class PmapSharding : public Sharding { const ShardingSpec& sharding_spec() const { return sharding_spec_; } - static nanobind::handle type() { - static auto type = nanobind::type(); - return type; - } + static nanobind::handle type() { return type_; } + static void InitializeType(); xla::nb_class_ptr internal_device_list() const { return internal_device_list_; @@ -160,6 +159,7 @@ class PmapSharding : public Sharding { xla::nb_numpy_ndarray devices_; ShardingSpec sharding_spec_; xla::nb_class_ptr internal_device_list_; + static PyObject* type_; }; class GSPMDSharding : public Sharding { @@ -184,10 +184,8 @@ class GSPMDSharding : public Sharding { return *hash_; } - static nanobind::handle type() { - static auto type = nanobind::type(); - return type; - } + static nanobind::handle type() { return type_; } + static void InitializeType(); const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; } @@ -234,6 +232,8 @@ class GSPMDSharding : public Sharding { nanobind::object memory_kind_; std::optional hash_; xla::nb_class_ptr internal_device_list_; + + static PyObject* type_; }; void RegisterSharding(nanobind::module_& m);