Fix deadlock when computing cached Sharding::type() values.

C++ static initialization acquires an internal mutex. It is unsafe to call into Python code while holding that mutex, e.g., see the deadlock in https://gist.github.com/vfdev-5/826ef16c6cbc9f4d85466e8a348c3b5a

However, in this case, there's a simpler thing we can do: eagerly initialize the ::type() values during module initialization, rather than on-demand.

PiperOrigin-RevId: 744508279
This commit is contained in:
Peter Hawkins 2025-04-06 13:35:12 -07:00 committed by jax authors
parent 7874d79f56
commit 8a6efa317d
3 changed files with 54 additions and 17 deletions

View File

@ -565,6 +565,7 @@ cc_library(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <Python.h> #include <Python.h>
#include <algorithm>
#include <array> #include <array>
#include <cstdlib> #include <cstdlib>
#include <optional> #include <optional>
@ -28,6 +29,7 @@ limitations under the License.
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "nanobind/nanobind.h" #include "nanobind/nanobind.h"
#include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep
@ -262,6 +264,13 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec,
(*check_pspec)(mesh_, spec_, manual_axes_); (*check_pspec)(mesh_, spec_, manual_axes_);
} }
/*static*/ PyObject* NamedSharding::type_ = nullptr;
/*static*/ void NamedSharding::InitializeType() {
// Intentionally leaks a reference.
type_ = nanobind::type<NamedSharding>().inc_ref().ptr();
}
SingleDeviceSharding::SingleDeviceSharding(nb::object device, SingleDeviceSharding::SingleDeviceSharding(nb::object device,
nb::object memory_kind) nb::object memory_kind)
: Sharding(/*num_devices=*/1), : Sharding(/*num_devices=*/1),
@ -273,6 +282,13 @@ SingleDeviceSharding::SingleDeviceSharding(nb::object device,
CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_);
} }
/*static*/ PyObject* SingleDeviceSharding::type_ = nullptr;
/*static*/ void SingleDeviceSharding::InitializeType() {
// Intentionally leaks a reference.
type_ = nanobind::type<SingleDeviceSharding>().inc_ref().ptr();
}
SingleDeviceSharding::SingleDeviceSharding( SingleDeviceSharding::SingleDeviceSharding(
xla::nb_class_ptr<xla::PyClient> client, xla::nb_class_ptr<xla::PyClient> client,
xla::ifrt::DeviceListRef device_list, nb::object memory_kind) xla::ifrt::DeviceListRef device_list, nb::object memory_kind)
@ -295,6 +311,15 @@ PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices,
xla::make_nb_class<PyDeviceList>(nb::tuple(flat_devices)); xla::make_nb_class<PyDeviceList>(nb::tuple(flat_devices));
} }
/*static*/ PyObject* PmapSharding::type_ = nullptr;
// /*static*/ nanobind::handle PmapSharding::type() { return type_; }
/*static*/ void PmapSharding::InitializeType() {
// Intentionally leaks a reference.
type_ = nanobind::type<PmapSharding>().inc_ref().ptr();
}
GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding, GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding,
nb::object memory_kind, nb::object device_list) nb::object memory_kind, nb::object device_list)
: Sharding(/*num_devices=*/nb::len(devices.ptr())), : Sharding(/*num_devices=*/nb::len(devices.ptr())),
@ -316,6 +341,13 @@ GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding,
CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_);
} }
/*static*/ PyObject* GSPMDSharding::type_ = nullptr;
/*static*/ void GSPMDSharding::InitializeType() {
// Intentionally leaks a reference.
type_ = nanobind::type<GSPMDSharding>().inc_ref().ptr();
}
void RegisterSharding(nb::module_& m) { void RegisterSharding(nb::module_& m) {
nb::class_<Sharding>(m, "Sharding").def(nb::init<>()); nb::class_<Sharding>(m, "Sharding").def(nb::init<>());
@ -334,6 +366,7 @@ void RegisterSharding(nb::module_& m) {
.def_prop_ro("_internal_device_list", [](const NamedSharding& s) { .def_prop_ro("_internal_device_list", [](const NamedSharding& s) {
return xla::ValueOrThrow(s.internal_device_list()); return xla::ValueOrThrow(s.internal_device_list());
}); });
NamedSharding::InitializeType();
nb::class_<SingleDeviceSharding, Sharding>(m, "SingleDeviceSharding", nb::class_<SingleDeviceSharding, Sharding>(m, "SingleDeviceSharding",
nb::dynamic_attr()) nb::dynamic_attr())
@ -343,6 +376,7 @@ void RegisterSharding(nb::module_& m) {
.def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind) .def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind)
.def_prop_ro("_internal_device_list", .def_prop_ro("_internal_device_list",
&SingleDeviceSharding::internal_device_list); &SingleDeviceSharding::internal_device_list);
SingleDeviceSharding::InitializeType();
nb::class_<PmapSharding, Sharding>(m, "PmapSharding", nb::dynamic_attr()) nb::class_<PmapSharding, Sharding>(m, "PmapSharding", nb::dynamic_attr())
.def( .def(
@ -357,6 +391,7 @@ void RegisterSharding(nb::module_& m) {
.def_prop_ro("sharding_spec", &PmapSharding::sharding_spec) .def_prop_ro("sharding_spec", &PmapSharding::sharding_spec)
.def_prop_ro("_internal_device_list", .def_prop_ro("_internal_device_list",
&PmapSharding::internal_device_list); &PmapSharding::internal_device_list);
PmapSharding::InitializeType();
nb::class_<GSPMDSharding, Sharding>(m, "GSPMDSharding", nb::dynamic_attr()) nb::class_<GSPMDSharding, Sharding>(m, "GSPMDSharding", nb::dynamic_attr())
.def(nb::init<nb::sequence, xla::OpSharding, nb::object, nb::object>(), .def(nb::init<nb::sequence, xla::OpSharding, nb::object, nb::object>(),
@ -372,6 +407,7 @@ void RegisterSharding(nb::module_& m) {
.def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind) .def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind)
.def_prop_ro("_internal_device_list", .def_prop_ro("_internal_device_list",
&GSPMDSharding::internal_device_list); &GSPMDSharding::internal_device_list);
GSPMDSharding::InitializeType();
} }
} // namespace jax } // namespace jax

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef JAXLIB_XLA_SHARDING_H_ #ifndef JAXLIB_XLA_SHARDING_H_
#define JAXLIB_XLA_SHARDING_H_ #define JAXLIB_XLA_SHARDING_H_
#include <Python.h>
#include <cstddef> #include <cstddef>
#include <optional> #include <optional>
#include <utility> #include <utility>
@ -84,10 +86,8 @@ class NamedSharding : public Sharding {
return logical_device_ids_; return logical_device_ids_;
} }
static nanobind::handle type() { static nanobind::handle type() { return type_; }
static auto type = nanobind::type<NamedSharding>(); static void InitializeType();
return type;
}
absl::StatusOr<xla::nb_class_ptr<PyDeviceList>> internal_device_list() const { absl::StatusOr<xla::nb_class_ptr<PyDeviceList>> internal_device_list() const {
if (internal_device_list_) { if (internal_device_list_) {
@ -105,6 +105,7 @@ class NamedSharding : public Sharding {
nanobind::object manual_axes_; nanobind::object manual_axes_;
nanobind::object logical_device_ids_; nanobind::object logical_device_ids_;
std::optional<xla::nb_class_ptr<PyDeviceList>> internal_device_list_; std::optional<xla::nb_class_ptr<PyDeviceList>> internal_device_list_;
static PyObject* type_;
}; };
class SingleDeviceSharding : public Sharding { class SingleDeviceSharding : public Sharding {
@ -120,10 +121,8 @@ class SingleDeviceSharding : public Sharding {
const nanobind::object& device() const { return device_; } const nanobind::object& device() const { return device_; }
const nanobind::object& memory_kind() const { return memory_kind_; } const nanobind::object& memory_kind() const { return memory_kind_; }
static nanobind::handle type() { static nanobind::handle type() { return type_; }
static auto type = nanobind::type<SingleDeviceSharding>(); static void InitializeType();
return type;
}
xla::nb_class_ptr<PyDeviceList> internal_device_list() const { xla::nb_class_ptr<PyDeviceList> internal_device_list() const {
return internal_device_list_; return internal_device_list_;
@ -133,6 +132,8 @@ class SingleDeviceSharding : public Sharding {
nanobind::object device_; nanobind::object device_;
nanobind::object memory_kind_; nanobind::object memory_kind_;
xla::nb_class_ptr<PyDeviceList> internal_device_list_; xla::nb_class_ptr<PyDeviceList> internal_device_list_;
static PyObject* type_;
}; };
// The C++ implementation of jax.PmapSharding in python. It contains a few key // The C++ implementation of jax.PmapSharding in python. It contains a few key
@ -147,10 +148,8 @@ class PmapSharding : public Sharding {
const ShardingSpec& sharding_spec() const { return sharding_spec_; } const ShardingSpec& sharding_spec() const { return sharding_spec_; }
static nanobind::handle type() { static nanobind::handle type() { return type_; }
static auto type = nanobind::type<PmapSharding>(); static void InitializeType();
return type;
}
xla::nb_class_ptr<PyDeviceList> internal_device_list() const { xla::nb_class_ptr<PyDeviceList> internal_device_list() const {
return internal_device_list_; return internal_device_list_;
@ -160,6 +159,7 @@ class PmapSharding : public Sharding {
xla::nb_numpy_ndarray devices_; xla::nb_numpy_ndarray devices_;
ShardingSpec sharding_spec_; ShardingSpec sharding_spec_;
xla::nb_class_ptr<PyDeviceList> internal_device_list_; xla::nb_class_ptr<PyDeviceList> internal_device_list_;
static PyObject* type_;
}; };
class GSPMDSharding : public Sharding { class GSPMDSharding : public Sharding {
@ -184,10 +184,8 @@ class GSPMDSharding : public Sharding {
return *hash_; return *hash_;
} }
static nanobind::handle type() { static nanobind::handle type() { return type_; }
static auto type = nanobind::type<GSPMDSharding>(); static void InitializeType();
return type;
}
const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; } const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; }
@ -234,6 +232,8 @@ class GSPMDSharding : public Sharding {
nanobind::object memory_kind_; nanobind::object memory_kind_;
std::optional<size_t> hash_; std::optional<size_t> hash_;
xla::nb_class_ptr<PyDeviceList> internal_device_list_; xla::nb_class_ptr<PyDeviceList> internal_device_list_;
static PyObject* type_;
}; };
void RegisterSharding(nanobind::module_& m); void RegisterSharding(nanobind::module_& m);