mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix deadlock when computing cached Sharding::type() values.
C++ static initialization acquires an internal mutex. It is unsafe to call into Python code while holding that mutex, e.g., see the deadlock in https://gist.github.com/vfdev-5/826ef16c6cbc9f4d85466e8a348c3b5a However, in this case, there's a simpler thing we can do: eagerly initialize the ::type() values during module initialization, rather than on-demand. PiperOrigin-RevId: 744508279
This commit is contained in:
parent
7874d79f56
commit
8a6efa317d
@ -565,6 +565,7 @@ cc_library(
|
|||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:cord",
|
"@com_google_absl//absl/strings:cord",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
@ -28,6 +29,7 @@ limitations under the License.
|
|||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "nanobind/nanobind.h"
|
#include "nanobind/nanobind.h"
|
||||||
#include "nanobind/stl/string.h" // IWYU pragma: keep
|
#include "nanobind/stl/string.h" // IWYU pragma: keep
|
||||||
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
|
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
|
||||||
@ -262,6 +264,13 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec,
|
|||||||
(*check_pspec)(mesh_, spec_, manual_axes_);
|
(*check_pspec)(mesh_, spec_, manual_axes_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*static*/ PyObject* NamedSharding::type_ = nullptr;
|
||||||
|
|
||||||
|
/*static*/ void NamedSharding::InitializeType() {
|
||||||
|
// Intentionally leaks a reference.
|
||||||
|
type_ = nanobind::type<NamedSharding>().inc_ref().ptr();
|
||||||
|
}
|
||||||
|
|
||||||
SingleDeviceSharding::SingleDeviceSharding(nb::object device,
|
SingleDeviceSharding::SingleDeviceSharding(nb::object device,
|
||||||
nb::object memory_kind)
|
nb::object memory_kind)
|
||||||
: Sharding(/*num_devices=*/1),
|
: Sharding(/*num_devices=*/1),
|
||||||
@ -273,6 +282,13 @@ SingleDeviceSharding::SingleDeviceSharding(nb::object device,
|
|||||||
CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_);
|
CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*static*/ PyObject* SingleDeviceSharding::type_ = nullptr;
|
||||||
|
|
||||||
|
/*static*/ void SingleDeviceSharding::InitializeType() {
|
||||||
|
// Intentionally leaks a reference.
|
||||||
|
type_ = nanobind::type<SingleDeviceSharding>().inc_ref().ptr();
|
||||||
|
}
|
||||||
|
|
||||||
SingleDeviceSharding::SingleDeviceSharding(
|
SingleDeviceSharding::SingleDeviceSharding(
|
||||||
xla::nb_class_ptr<xla::PyClient> client,
|
xla::nb_class_ptr<xla::PyClient> client,
|
||||||
xla::ifrt::DeviceListRef device_list, nb::object memory_kind)
|
xla::ifrt::DeviceListRef device_list, nb::object memory_kind)
|
||||||
@ -295,6 +311,15 @@ PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices,
|
|||||||
xla::make_nb_class<PyDeviceList>(nb::tuple(flat_devices));
|
xla::make_nb_class<PyDeviceList>(nb::tuple(flat_devices));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*static*/ PyObject* PmapSharding::type_ = nullptr;
|
||||||
|
|
||||||
|
// /*static*/ nanobind::handle PmapSharding::type() { return type_; }
|
||||||
|
|
||||||
|
/*static*/ void PmapSharding::InitializeType() {
|
||||||
|
// Intentionally leaks a reference.
|
||||||
|
type_ = nanobind::type<PmapSharding>().inc_ref().ptr();
|
||||||
|
}
|
||||||
|
|
||||||
GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding,
|
GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding,
|
||||||
nb::object memory_kind, nb::object device_list)
|
nb::object memory_kind, nb::object device_list)
|
||||||
: Sharding(/*num_devices=*/nb::len(devices.ptr())),
|
: Sharding(/*num_devices=*/nb::len(devices.ptr())),
|
||||||
@ -316,6 +341,13 @@ GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding,
|
|||||||
CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_);
|
CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*static*/ PyObject* GSPMDSharding::type_ = nullptr;
|
||||||
|
|
||||||
|
/*static*/ void GSPMDSharding::InitializeType() {
|
||||||
|
// Intentionally leaks a reference.
|
||||||
|
type_ = nanobind::type<GSPMDSharding>().inc_ref().ptr();
|
||||||
|
}
|
||||||
|
|
||||||
void RegisterSharding(nb::module_& m) {
|
void RegisterSharding(nb::module_& m) {
|
||||||
nb::class_<Sharding>(m, "Sharding").def(nb::init<>());
|
nb::class_<Sharding>(m, "Sharding").def(nb::init<>());
|
||||||
|
|
||||||
@ -334,6 +366,7 @@ void RegisterSharding(nb::module_& m) {
|
|||||||
.def_prop_ro("_internal_device_list", [](const NamedSharding& s) {
|
.def_prop_ro("_internal_device_list", [](const NamedSharding& s) {
|
||||||
return xla::ValueOrThrow(s.internal_device_list());
|
return xla::ValueOrThrow(s.internal_device_list());
|
||||||
});
|
});
|
||||||
|
NamedSharding::InitializeType();
|
||||||
|
|
||||||
nb::class_<SingleDeviceSharding, Sharding>(m, "SingleDeviceSharding",
|
nb::class_<SingleDeviceSharding, Sharding>(m, "SingleDeviceSharding",
|
||||||
nb::dynamic_attr())
|
nb::dynamic_attr())
|
||||||
@ -343,6 +376,7 @@ void RegisterSharding(nb::module_& m) {
|
|||||||
.def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind)
|
.def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind)
|
||||||
.def_prop_ro("_internal_device_list",
|
.def_prop_ro("_internal_device_list",
|
||||||
&SingleDeviceSharding::internal_device_list);
|
&SingleDeviceSharding::internal_device_list);
|
||||||
|
SingleDeviceSharding::InitializeType();
|
||||||
|
|
||||||
nb::class_<PmapSharding, Sharding>(m, "PmapSharding", nb::dynamic_attr())
|
nb::class_<PmapSharding, Sharding>(m, "PmapSharding", nb::dynamic_attr())
|
||||||
.def(
|
.def(
|
||||||
@ -357,6 +391,7 @@ void RegisterSharding(nb::module_& m) {
|
|||||||
.def_prop_ro("sharding_spec", &PmapSharding::sharding_spec)
|
.def_prop_ro("sharding_spec", &PmapSharding::sharding_spec)
|
||||||
.def_prop_ro("_internal_device_list",
|
.def_prop_ro("_internal_device_list",
|
||||||
&PmapSharding::internal_device_list);
|
&PmapSharding::internal_device_list);
|
||||||
|
PmapSharding::InitializeType();
|
||||||
|
|
||||||
nb::class_<GSPMDSharding, Sharding>(m, "GSPMDSharding", nb::dynamic_attr())
|
nb::class_<GSPMDSharding, Sharding>(m, "GSPMDSharding", nb::dynamic_attr())
|
||||||
.def(nb::init<nb::sequence, xla::OpSharding, nb::object, nb::object>(),
|
.def(nb::init<nb::sequence, xla::OpSharding, nb::object, nb::object>(),
|
||||||
@ -372,6 +407,7 @@ void RegisterSharding(nb::module_& m) {
|
|||||||
.def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind)
|
.def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind)
|
||||||
.def_prop_ro("_internal_device_list",
|
.def_prop_ro("_internal_device_list",
|
||||||
&GSPMDSharding::internal_device_list);
|
&GSPMDSharding::internal_device_list);
|
||||||
|
GSPMDSharding::InitializeType();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jax
|
} // namespace jax
|
||||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
|||||||
#ifndef JAXLIB_XLA_SHARDING_H_
|
#ifndef JAXLIB_XLA_SHARDING_H_
|
||||||
#define JAXLIB_XLA_SHARDING_H_
|
#define JAXLIB_XLA_SHARDING_H_
|
||||||
|
|
||||||
|
#include <Python.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -84,10 +86,8 @@ class NamedSharding : public Sharding {
|
|||||||
return logical_device_ids_;
|
return logical_device_ids_;
|
||||||
}
|
}
|
||||||
|
|
||||||
static nanobind::handle type() {
|
static nanobind::handle type() { return type_; }
|
||||||
static auto type = nanobind::type<NamedSharding>();
|
static void InitializeType();
|
||||||
return type;
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<xla::nb_class_ptr<PyDeviceList>> internal_device_list() const {
|
absl::StatusOr<xla::nb_class_ptr<PyDeviceList>> internal_device_list() const {
|
||||||
if (internal_device_list_) {
|
if (internal_device_list_) {
|
||||||
@ -105,6 +105,7 @@ class NamedSharding : public Sharding {
|
|||||||
nanobind::object manual_axes_;
|
nanobind::object manual_axes_;
|
||||||
nanobind::object logical_device_ids_;
|
nanobind::object logical_device_ids_;
|
||||||
std::optional<xla::nb_class_ptr<PyDeviceList>> internal_device_list_;
|
std::optional<xla::nb_class_ptr<PyDeviceList>> internal_device_list_;
|
||||||
|
static PyObject* type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SingleDeviceSharding : public Sharding {
|
class SingleDeviceSharding : public Sharding {
|
||||||
@ -120,10 +121,8 @@ class SingleDeviceSharding : public Sharding {
|
|||||||
const nanobind::object& device() const { return device_; }
|
const nanobind::object& device() const { return device_; }
|
||||||
const nanobind::object& memory_kind() const { return memory_kind_; }
|
const nanobind::object& memory_kind() const { return memory_kind_; }
|
||||||
|
|
||||||
static nanobind::handle type() {
|
static nanobind::handle type() { return type_; }
|
||||||
static auto type = nanobind::type<SingleDeviceSharding>();
|
static void InitializeType();
|
||||||
return type;
|
|
||||||
}
|
|
||||||
|
|
||||||
xla::nb_class_ptr<PyDeviceList> internal_device_list() const {
|
xla::nb_class_ptr<PyDeviceList> internal_device_list() const {
|
||||||
return internal_device_list_;
|
return internal_device_list_;
|
||||||
@ -133,6 +132,8 @@ class SingleDeviceSharding : public Sharding {
|
|||||||
nanobind::object device_;
|
nanobind::object device_;
|
||||||
nanobind::object memory_kind_;
|
nanobind::object memory_kind_;
|
||||||
xla::nb_class_ptr<PyDeviceList> internal_device_list_;
|
xla::nb_class_ptr<PyDeviceList> internal_device_list_;
|
||||||
|
|
||||||
|
static PyObject* type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// The C++ implementation of jax.PmapSharding in python. It contains a few key
|
// The C++ implementation of jax.PmapSharding in python. It contains a few key
|
||||||
@ -147,10 +148,8 @@ class PmapSharding : public Sharding {
|
|||||||
|
|
||||||
const ShardingSpec& sharding_spec() const { return sharding_spec_; }
|
const ShardingSpec& sharding_spec() const { return sharding_spec_; }
|
||||||
|
|
||||||
static nanobind::handle type() {
|
static nanobind::handle type() { return type_; }
|
||||||
static auto type = nanobind::type<PmapSharding>();
|
static void InitializeType();
|
||||||
return type;
|
|
||||||
}
|
|
||||||
|
|
||||||
xla::nb_class_ptr<PyDeviceList> internal_device_list() const {
|
xla::nb_class_ptr<PyDeviceList> internal_device_list() const {
|
||||||
return internal_device_list_;
|
return internal_device_list_;
|
||||||
@ -160,6 +159,7 @@ class PmapSharding : public Sharding {
|
|||||||
xla::nb_numpy_ndarray devices_;
|
xla::nb_numpy_ndarray devices_;
|
||||||
ShardingSpec sharding_spec_;
|
ShardingSpec sharding_spec_;
|
||||||
xla::nb_class_ptr<PyDeviceList> internal_device_list_;
|
xla::nb_class_ptr<PyDeviceList> internal_device_list_;
|
||||||
|
static PyObject* type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class GSPMDSharding : public Sharding {
|
class GSPMDSharding : public Sharding {
|
||||||
@ -184,10 +184,8 @@ class GSPMDSharding : public Sharding {
|
|||||||
return *hash_;
|
return *hash_;
|
||||||
}
|
}
|
||||||
|
|
||||||
static nanobind::handle type() {
|
static nanobind::handle type() { return type_; }
|
||||||
static auto type = nanobind::type<GSPMDSharding>();
|
static void InitializeType();
|
||||||
return type;
|
|
||||||
}
|
|
||||||
|
|
||||||
const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; }
|
const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; }
|
||||||
|
|
||||||
@ -234,6 +232,8 @@ class GSPMDSharding : public Sharding {
|
|||||||
nanobind::object memory_kind_;
|
nanobind::object memory_kind_;
|
||||||
std::optional<size_t> hash_;
|
std::optional<size_t> hash_;
|
||||||
xla::nb_class_ptr<PyDeviceList> internal_device_list_;
|
xla::nb_class_ptr<PyDeviceList> internal_device_list_;
|
||||||
|
|
||||||
|
static PyObject* type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void RegisterSharding(nanobind::module_& m);
|
void RegisterSharding(nanobind::module_& m);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user