mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Fix deadlock when computing cached Sharding::type() values.
C++ static initialization acquires an internal mutex. It is unsafe to call into Python code while holding that mutex, e.g., see the deadlock in https://gist.github.com/vfdev-5/826ef16c6cbc9f4d85466e8a348c3b5a However, in this case, there's a simpler thing we can do: eagerly initialize the ::type() values during module initialization, rather than on-demand. PiperOrigin-RevId: 744508279
This commit is contained in:
parent
7874d79f56
commit
8a6efa317d
@ -565,6 +565,7 @@ cc_library(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:cord",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstdlib>
|
||||
#include <optional>
|
||||
@ -28,6 +29,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "nanobind/stl/string.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
|
||||
@ -242,7 +244,7 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec,
|
||||
// TODO(phawkins): this leaks a reference to the check_pspec function.
|
||||
// A better way to fix this would be to move PartitionSpec and this check into
|
||||
// C++.
|
||||
nb::object* check_pspec = [](){
|
||||
nb::object* check_pspec = []() {
|
||||
static absl::Mutex mu;
|
||||
static nb::object* output = nullptr;
|
||||
{
|
||||
@ -262,6 +264,13 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec,
|
||||
(*check_pspec)(mesh_, spec_, manual_axes_);
|
||||
}
|
||||
|
||||
/*static*/ PyObject* NamedSharding::type_ = nullptr;
|
||||
|
||||
/*static*/ void NamedSharding::InitializeType() {
|
||||
// Intentionally leaks a reference.
|
||||
type_ = nanobind::type<NamedSharding>().inc_ref().ptr();
|
||||
}
|
||||
|
||||
SingleDeviceSharding::SingleDeviceSharding(nb::object device,
|
||||
nb::object memory_kind)
|
||||
: Sharding(/*num_devices=*/1),
|
||||
@ -273,6 +282,13 @@ SingleDeviceSharding::SingleDeviceSharding(nb::object device,
|
||||
CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_);
|
||||
}
|
||||
|
||||
/*static*/ PyObject* SingleDeviceSharding::type_ = nullptr;
|
||||
|
||||
/*static*/ void SingleDeviceSharding::InitializeType() {
|
||||
// Intentionally leaks a reference.
|
||||
type_ = nanobind::type<SingleDeviceSharding>().inc_ref().ptr();
|
||||
}
|
||||
|
||||
SingleDeviceSharding::SingleDeviceSharding(
|
||||
xla::nb_class_ptr<xla::PyClient> client,
|
||||
xla::ifrt::DeviceListRef device_list, nb::object memory_kind)
|
||||
@ -295,6 +311,15 @@ PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices,
|
||||
xla::make_nb_class<PyDeviceList>(nb::tuple(flat_devices));
|
||||
}
|
||||
|
||||
/*static*/ PyObject* PmapSharding::type_ = nullptr;
|
||||
|
||||
// /*static*/ nanobind::handle PmapSharding::type() { return type_; }
|
||||
|
||||
/*static*/ void PmapSharding::InitializeType() {
|
||||
// Intentionally leaks a reference.
|
||||
type_ = nanobind::type<PmapSharding>().inc_ref().ptr();
|
||||
}
|
||||
|
||||
GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding,
|
||||
nb::object memory_kind, nb::object device_list)
|
||||
: Sharding(/*num_devices=*/nb::len(devices.ptr())),
|
||||
@ -316,6 +341,13 @@ GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding,
|
||||
CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_);
|
||||
}
|
||||
|
||||
/*static*/ PyObject* GSPMDSharding::type_ = nullptr;
|
||||
|
||||
/*static*/ void GSPMDSharding::InitializeType() {
|
||||
// Intentionally leaks a reference.
|
||||
type_ = nanobind::type<GSPMDSharding>().inc_ref().ptr();
|
||||
}
|
||||
|
||||
void RegisterSharding(nb::module_& m) {
|
||||
nb::class_<Sharding>(m, "Sharding").def(nb::init<>());
|
||||
|
||||
@ -334,6 +366,7 @@ void RegisterSharding(nb::module_& m) {
|
||||
.def_prop_ro("_internal_device_list", [](const NamedSharding& s) {
|
||||
return xla::ValueOrThrow(s.internal_device_list());
|
||||
});
|
||||
NamedSharding::InitializeType();
|
||||
|
||||
nb::class_<SingleDeviceSharding, Sharding>(m, "SingleDeviceSharding",
|
||||
nb::dynamic_attr())
|
||||
@ -343,6 +376,7 @@ void RegisterSharding(nb::module_& m) {
|
||||
.def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind)
|
||||
.def_prop_ro("_internal_device_list",
|
||||
&SingleDeviceSharding::internal_device_list);
|
||||
SingleDeviceSharding::InitializeType();
|
||||
|
||||
nb::class_<PmapSharding, Sharding>(m, "PmapSharding", nb::dynamic_attr())
|
||||
.def(
|
||||
@ -357,6 +391,7 @@ void RegisterSharding(nb::module_& m) {
|
||||
.def_prop_ro("sharding_spec", &PmapSharding::sharding_spec)
|
||||
.def_prop_ro("_internal_device_list",
|
||||
&PmapSharding::internal_device_list);
|
||||
PmapSharding::InitializeType();
|
||||
|
||||
nb::class_<GSPMDSharding, Sharding>(m, "GSPMDSharding", nb::dynamic_attr())
|
||||
.def(nb::init<nb::sequence, xla::OpSharding, nb::object, nb::object>(),
|
||||
@ -372,6 +407,7 @@ void RegisterSharding(nb::module_& m) {
|
||||
.def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind)
|
||||
.def_prop_ro("_internal_device_list",
|
||||
&GSPMDSharding::internal_device_list);
|
||||
GSPMDSharding::InitializeType();
|
||||
}
|
||||
|
||||
} // namespace jax
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef JAXLIB_XLA_SHARDING_H_
|
||||
#define JAXLIB_XLA_SHARDING_H_
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
@ -84,10 +86,8 @@ class NamedSharding : public Sharding {
|
||||
return logical_device_ids_;
|
||||
}
|
||||
|
||||
static nanobind::handle type() {
|
||||
static auto type = nanobind::type<NamedSharding>();
|
||||
return type;
|
||||
}
|
||||
static nanobind::handle type() { return type_; }
|
||||
static void InitializeType();
|
||||
|
||||
absl::StatusOr<xla::nb_class_ptr<PyDeviceList>> internal_device_list() const {
|
||||
if (internal_device_list_) {
|
||||
@ -105,6 +105,7 @@ class NamedSharding : public Sharding {
|
||||
nanobind::object manual_axes_;
|
||||
nanobind::object logical_device_ids_;
|
||||
std::optional<xla::nb_class_ptr<PyDeviceList>> internal_device_list_;
|
||||
static PyObject* type_;
|
||||
};
|
||||
|
||||
class SingleDeviceSharding : public Sharding {
|
||||
@ -120,10 +121,8 @@ class SingleDeviceSharding : public Sharding {
|
||||
const nanobind::object& device() const { return device_; }
|
||||
const nanobind::object& memory_kind() const { return memory_kind_; }
|
||||
|
||||
static nanobind::handle type() {
|
||||
static auto type = nanobind::type<SingleDeviceSharding>();
|
||||
return type;
|
||||
}
|
||||
static nanobind::handle type() { return type_; }
|
||||
static void InitializeType();
|
||||
|
||||
xla::nb_class_ptr<PyDeviceList> internal_device_list() const {
|
||||
return internal_device_list_;
|
||||
@ -133,6 +132,8 @@ class SingleDeviceSharding : public Sharding {
|
||||
nanobind::object device_;
|
||||
nanobind::object memory_kind_;
|
||||
xla::nb_class_ptr<PyDeviceList> internal_device_list_;
|
||||
|
||||
static PyObject* type_;
|
||||
};
|
||||
|
||||
// The C++ implementation of jax.PmapSharding in python. It contains a few key
|
||||
@ -147,10 +148,8 @@ class PmapSharding : public Sharding {
|
||||
|
||||
const ShardingSpec& sharding_spec() const { return sharding_spec_; }
|
||||
|
||||
static nanobind::handle type() {
|
||||
static auto type = nanobind::type<PmapSharding>();
|
||||
return type;
|
||||
}
|
||||
static nanobind::handle type() { return type_; }
|
||||
static void InitializeType();
|
||||
|
||||
xla::nb_class_ptr<PyDeviceList> internal_device_list() const {
|
||||
return internal_device_list_;
|
||||
@ -160,6 +159,7 @@ class PmapSharding : public Sharding {
|
||||
xla::nb_numpy_ndarray devices_;
|
||||
ShardingSpec sharding_spec_;
|
||||
xla::nb_class_ptr<PyDeviceList> internal_device_list_;
|
||||
static PyObject* type_;
|
||||
};
|
||||
|
||||
class GSPMDSharding : public Sharding {
|
||||
@ -184,10 +184,8 @@ class GSPMDSharding : public Sharding {
|
||||
return *hash_;
|
||||
}
|
||||
|
||||
static nanobind::handle type() {
|
||||
static auto type = nanobind::type<GSPMDSharding>();
|
||||
return type;
|
||||
}
|
||||
static nanobind::handle type() { return type_; }
|
||||
static void InitializeType();
|
||||
|
||||
const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; }
|
||||
|
||||
@ -234,6 +232,8 @@ class GSPMDSharding : public Sharding {
|
||||
nanobind::object memory_kind_;
|
||||
std::optional<size_t> hash_;
|
||||
xla::nb_class_ptr<PyDeviceList> internal_device_list_;
|
||||
|
||||
static PyObject* type_;
|
||||
};
|
||||
|
||||
void RegisterSharding(nanobind::module_& m);
|
||||
|
Loading…
x
Reference in New Issue
Block a user