Fix deadlock when computing cached Sharding::type() values.

C++ static initialization acquires an internal mutex. It is unsafe to call into Python code while holding that mutex, e.g., see the deadlock in https://gist.github.com/vfdev-5/826ef16c6cbc9f4d85466e8a348c3b5a

However, in this case, there's a simpler thing we can do: eagerly initialize the ::type() values during module initialization, rather than on-demand.

PiperOrigin-RevId: 744508279
This commit is contained in:
Peter Hawkins 2025-04-06 13:35:12 -07:00 committed by jax authors
parent 7874d79f56
commit 8a6efa317d
3 changed files with 54 additions and 17 deletions

View File

@ -565,6 +565,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <Python.h>
#include <algorithm>
#include <array>
#include <cstdlib>
#include <optional>
@ -28,6 +29,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "nanobind/nanobind.h"
#include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
@ -242,7 +244,7 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec,
// TODO(phawkins): this leaks a reference to the check_pspec function.
// A better way to fix this would be to move PartitionSpec and this check into
// C++.
nb::object* check_pspec = [](){
nb::object* check_pspec = []() {
static absl::Mutex mu;
static nb::object* output = nullptr;
{
@ -262,6 +264,13 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec,
(*check_pspec)(mesh_, spec_, manual_axes_);
}
/*static*/ PyObject* NamedSharding::type_ = nullptr;
/*static*/ void NamedSharding::InitializeType() {
// Intentionally leaks a reference.
type_ = nanobind::type<NamedSharding>().inc_ref().ptr();
}
SingleDeviceSharding::SingleDeviceSharding(nb::object device,
nb::object memory_kind)
: Sharding(/*num_devices=*/1),
@ -273,6 +282,13 @@ SingleDeviceSharding::SingleDeviceSharding(nb::object device,
CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_);
}
/*static*/ PyObject* SingleDeviceSharding::type_ = nullptr;
/*static*/ void SingleDeviceSharding::InitializeType() {
// Intentionally leaks a reference.
type_ = nanobind::type<SingleDeviceSharding>().inc_ref().ptr();
}
SingleDeviceSharding::SingleDeviceSharding(
xla::nb_class_ptr<xla::PyClient> client,
xla::ifrt::DeviceListRef device_list, nb::object memory_kind)
@ -295,6 +311,15 @@ PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices,
xla::make_nb_class<PyDeviceList>(nb::tuple(flat_devices));
}
/*static*/ PyObject* PmapSharding::type_ = nullptr;
// /*static*/ nanobind::handle PmapSharding::type() { return type_; }
/*static*/ void PmapSharding::InitializeType() {
// Intentionally leaks a reference.
type_ = nanobind::type<PmapSharding>().inc_ref().ptr();
}
GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding,
nb::object memory_kind, nb::object device_list)
: Sharding(/*num_devices=*/nb::len(devices.ptr())),
@ -316,6 +341,13 @@ GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding,
CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_);
}
/*static*/ PyObject* GSPMDSharding::type_ = nullptr;
/*static*/ void GSPMDSharding::InitializeType() {
// Intentionally leaks a reference.
type_ = nanobind::type<GSPMDSharding>().inc_ref().ptr();
}
void RegisterSharding(nb::module_& m) {
nb::class_<Sharding>(m, "Sharding").def(nb::init<>());
@ -334,6 +366,7 @@ void RegisterSharding(nb::module_& m) {
.def_prop_ro("_internal_device_list", [](const NamedSharding& s) {
return xla::ValueOrThrow(s.internal_device_list());
});
NamedSharding::InitializeType();
nb::class_<SingleDeviceSharding, Sharding>(m, "SingleDeviceSharding",
nb::dynamic_attr())
@ -343,6 +376,7 @@ void RegisterSharding(nb::module_& m) {
.def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind)
.def_prop_ro("_internal_device_list",
&SingleDeviceSharding::internal_device_list);
SingleDeviceSharding::InitializeType();
nb::class_<PmapSharding, Sharding>(m, "PmapSharding", nb::dynamic_attr())
.def(
@ -357,6 +391,7 @@ void RegisterSharding(nb::module_& m) {
.def_prop_ro("sharding_spec", &PmapSharding::sharding_spec)
.def_prop_ro("_internal_device_list",
&PmapSharding::internal_device_list);
PmapSharding::InitializeType();
nb::class_<GSPMDSharding, Sharding>(m, "GSPMDSharding", nb::dynamic_attr())
.def(nb::init<nb::sequence, xla::OpSharding, nb::object, nb::object>(),
@ -372,6 +407,7 @@ void RegisterSharding(nb::module_& m) {
.def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind)
.def_prop_ro("_internal_device_list",
&GSPMDSharding::internal_device_list);
GSPMDSharding::InitializeType();
}
} // namespace jax

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef JAXLIB_XLA_SHARDING_H_
#define JAXLIB_XLA_SHARDING_H_
#include <Python.h>
#include <cstddef>
#include <optional>
#include <utility>
@ -84,10 +86,8 @@ class NamedSharding : public Sharding {
return logical_device_ids_;
}
static nanobind::handle type() {
static auto type = nanobind::type<NamedSharding>();
return type;
}
static nanobind::handle type() { return type_; }
static void InitializeType();
absl::StatusOr<xla::nb_class_ptr<PyDeviceList>> internal_device_list() const {
if (internal_device_list_) {
@ -105,6 +105,7 @@ class NamedSharding : public Sharding {
nanobind::object manual_axes_;
nanobind::object logical_device_ids_;
std::optional<xla::nb_class_ptr<PyDeviceList>> internal_device_list_;
static PyObject* type_;
};
class SingleDeviceSharding : public Sharding {
@ -120,10 +121,8 @@ class SingleDeviceSharding : public Sharding {
const nanobind::object& device() const { return device_; }
const nanobind::object& memory_kind() const { return memory_kind_; }
static nanobind::handle type() {
static auto type = nanobind::type<SingleDeviceSharding>();
return type;
}
static nanobind::handle type() { return type_; }
static void InitializeType();
xla::nb_class_ptr<PyDeviceList> internal_device_list() const {
return internal_device_list_;
@ -133,6 +132,8 @@ class SingleDeviceSharding : public Sharding {
nanobind::object device_;
nanobind::object memory_kind_;
xla::nb_class_ptr<PyDeviceList> internal_device_list_;
static PyObject* type_;
};
// The C++ implementation of jax.PmapSharding in python. It contains a few key
@ -147,10 +148,8 @@ class PmapSharding : public Sharding {
const ShardingSpec& sharding_spec() const { return sharding_spec_; }
static nanobind::handle type() {
static auto type = nanobind::type<PmapSharding>();
return type;
}
static nanobind::handle type() { return type_; }
static void InitializeType();
xla::nb_class_ptr<PyDeviceList> internal_device_list() const {
return internal_device_list_;
@ -160,6 +159,7 @@ class PmapSharding : public Sharding {
xla::nb_numpy_ndarray devices_;
ShardingSpec sharding_spec_;
xla::nb_class_ptr<PyDeviceList> internal_device_list_;
static PyObject* type_;
};
class GSPMDSharding : public Sharding {
@ -184,10 +184,8 @@ class GSPMDSharding : public Sharding {
return *hash_;
}
static nanobind::handle type() {
static auto type = nanobind::type<GSPMDSharding>();
return type;
}
static nanobind::handle type() { return type_; }
static void InitializeType();
const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; }
@ -234,6 +232,8 @@ class GSPMDSharding : public Sharding {
nanobind::object memory_kind_;
std::optional<size_t> hash_;
xla::nb_class_ptr<PyDeviceList> internal_device_list_;
static PyObject* type_;
};
void RegisterSharding(nanobind::module_& m);