rocm_jax/jaxlib/xla/py_values.cc
Alex Pivovarov bb515aa74f Address previous FP8-related TODOs in jaxlib/XLA.
The ml_dtype requirement in JAX was updated to version 0.5.0+ (on Mar 20, 2025) - commit 4b7ead4

This update allows us to address previous FP8-related TODOs in jaxlib/XLA.

PiperOrigin-RevId: 744943824
2025-04-07 20:01:53 -07:00

760 lines
32 KiB
C++

/* Copyright 2020 The JAX Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/xla/py_values.h"
#include <Python.h>
#include <cstdint>
#include <exception>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "nanobind/nanobind.h"
#include "nanobind/stl/complex.h" // IWYU pragma: keep
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
#include "jaxlib/xla/py_array.h"
#include "jaxlib/xla/python_ref_manager.h"
#include "jaxlib/xla/sharding.h"
#include "xla/primitive_util.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/memory.h"
#include "xla/python/ifrt/shape.h"
#include "xla/python/ifrt/sharding.h"
#include "xla/python/nb_numpy.h"
#include "xla/python/pjrt_ifrt/pjrt_dtype.h"
#include "xla/python/types.h"
#include "xla/shape.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/python/lib/core/numpy.h"
#include "xla/types.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/ml_dtypes.h"
#include "tsl/profiler/lib/traceme.h"
namespace nb = nanobind;
namespace xla {
namespace {
absl::StatusOr<std::vector<absl::Cord>> StringDTypeArrayToCords(
PyArrayObject* py_array_obj) {
if (PyArray_SIZE(py_array_obj) == 0) {
return absl::InvalidArgumentError("empty numpy array");
}
std::vector<absl::Cord> cords;
cords.reserve(PyArray_SIZE(py_array_obj));
auto iter =
nb::steal(PyArray_IterNew(reinterpret_cast<PyObject*>(py_array_obj)));
while (PyArray_ITER_NOTDONE(iter.ptr())) {
auto* iter_data = PyArray_ITER_DATA(iter.ptr());
auto* item = PyArray_GETITEM(py_array_obj, static_cast<char*>(iter_data));
if (!item) {
return absl::InternalError(
"Failed to get elements out of the ndarray iter.");
}
Py_ssize_t len;
auto str = PyUnicode_AsUTF8AndSize(item, &len);
cords.push_back(absl::Cord(absl::string_view(str, len)));
PyArray_ITER_NEXT(iter.ptr());
}
return cords;
}
using DevicePutFunc = std::function<absl::StatusOr<DevicePutResultFn>(
nb::handle, ifrt::Client*, ifrt::Device*, const DevicePutOptions& options,
ifrt::MemoryKind to_memory_kind)>;
template <typename T, typename SquashedT>
absl::StatusOr<DevicePutResultFn> HandlePythonScalar(
nb::handle obj, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
T value;
try {
value = nb::cast<T>(obj);
} catch (const std::exception& e) {
return InvalidArgument(
"Unable to convert Python scalar to %s. This most likely means the "
"value (%s) overflows the range of the type.",
PrimitiveType_Name(primitive_util::NativeToPrimitiveType<T>()),
nb::cast<absl::string_view>(nb::repr(obj)));
}
std::variant<T, SquashedT> data;
Shape shape;
PrimitiveType type;
if (std::is_same<T, SquashedT>() || !options.squash_64bit_types) {
data.template emplace<0>(value);
type = primitive_util::NativeToPrimitiveType<T>();
} else {
// TODO(phawkins): we should check for overflow here, e.g., because of bugs
// like https://github.com/google/jax/issues/2006
data.template emplace<1>(static_cast<SquashedT>(value));
type = primitive_util::NativeToPrimitiveType<SquashedT>();
}
return [client, data, type, to_device, to_memory_kind,
options]() -> absl::StatusOr<DevicePutResult> {
const void* ptr = std::visit(
[](const auto& v) { return static_cast<const void*>(&v); }, data);
TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type));
// TODO(yashkatariya): Plumb sharding or memory_kind here.
TF_ASSIGN_OR_RETURN(
auto ifrt_array,
client->MakeArrayFromHostBuffer(
ptr, ifrt_dtype, /*shape=*/ifrt::Shape({}), /*byte_strides=*/{},
ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind),
ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall,
/*on_done_with_host_buffer=*/{}, options.ifrt_user_context));
return DevicePutResult(std::move(ifrt_array), /*weak_type=*/true);
};
}
absl::StatusOr<DevicePutResultFn> HandlePythonInt(
nb::handle obj, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
PrimitiveType type;
std::variant<int64_t, int32_t> data;
if (options.squash_64bit_types) {
try {
data.emplace<1>(nb::cast<int32_t>(obj));
} catch (const std::exception& e) {
return InvalidArgument(
"Unable to convert Python scalar to %s. This most likely means the "
"value (%s) overflows the range of the type.",
PrimitiveType_Name(primitive_util::NativeToPrimitiveType<int32_t>()),
nb::cast<absl::string_view>(nb::repr(obj)));
}
type = S32;
} else {
try {
data.emplace<0>(nb::cast<int64_t>(obj));
} catch (const std::exception& e) {
return InvalidArgument(
"Unable to convert Python scalar to %s. This most likely means the "
"value (%s) overflows the range of the type.",
PrimitiveType_Name(primitive_util::NativeToPrimitiveType<int64_t>()),
nb::cast<absl::string_view>(nb::repr(obj)));
}
type = S64;
}
return [client, data, type, to_device, to_memory_kind,
options]() -> absl::StatusOr<DevicePutResult> {
const void* ptr = std::visit(
[](const auto& v) { return static_cast<const void*>(&v); }, data);
TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type));
// TODO(yashkatariya): Plumb sharding or memory_kind here.
TF_ASSIGN_OR_RETURN(
auto ifrt_array,
client->MakeArrayFromHostBuffer(
ptr, ifrt_dtype, /*shape=*/xla::ifrt::Shape({}),
/*byte_strides=*/{},
ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind),
ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall,
/*on_done_with_host_buffer=*/nullptr, options.ifrt_user_context));
return DevicePutResult(std::move(ifrt_array), /*weak_type=*/true);
};
}
template <typename T, typename SquashedT = T>
absl::StatusOr<DevicePutResultFn> HandleNumpyScalar(
nb::handle h, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
std::variant<T, SquashedT, void*> data;
PrimitiveType type;
// For extension types, ScalarAsCtype returns a pointer to the data.
if (std::is_same<T, xla::s2>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = S2;
} else if (std::is_same<T, xla::s4>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = S4;
} else if (std::is_same<T, xla::u2>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = U2;
} else if (std::is_same<T, xla::u4>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = U4;
} else if (std::is_same<T, bfloat16>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = BF16;
} else if (std::is_same<T, tsl::float4_e2m1fn>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = F4E2M1FN;
} else if (std::is_same<T, tsl::float8_e3m4>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = F8E3M4;
} else if (std::is_same<T, tsl::float8_e4m3>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = F8E4M3;
} else if (std::is_same<T, tsl::float8_e4m3fn>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = F8E4M3FN;
} else if (std::is_same<T, tsl::float8_e4m3b11fnuz>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = F8E4M3B11FNUZ;
} else if (std::is_same<T, tsl::float8_e5m2>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = F8E5M2;
} else if (std::is_same<T, tsl::float8_e4m3fnuz>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = F8E4M3FNUZ;
} else if (std::is_same<T, tsl::float8_e5m2fnuz>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = F8E5M2FNUZ;
} else if (std::is_same<T, tsl::float8_e8m0fnu>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = F8E8M0FNU;
} else if (std::is_same<T, SquashedT>() || !options.squash_64bit_types) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<0>());
type = primitive_util::NativeToPrimitiveType<T>();
} else {
T value;
PyArray_ScalarAsCtype(h.ptr(), &value);
data.template emplace<1>(static_cast<SquashedT>(value));
type = primitive_util::NativeToPrimitiveType<SquashedT>();
}
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref;
if (data.index() == 2) {
py_buffer_ref =
GlobalPyRefManager()->ManageReference(nb::cast<nb::object>(h));
}
return [client, data, py_buffer_ref, type, to_device, options,
to_memory_kind]() mutable -> absl::StatusOr<DevicePutResult> {
const void* ptr = std::visit(
[](const auto& v) -> const void* {
if constexpr (std::is_same_v<std::decay_t<decltype(v)>, void*>) {
return v;
} else {
return static_cast<const void*>(&v);
}
},
data);
TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type));
// TODO(yashkatariya): Plumb sharding or memory_kind here.
TF_ASSIGN_OR_RETURN(
auto ifrt_array,
client->MakeArrayFromHostBuffer(
ptr, ifrt_dtype, /*shape=*/xla::ifrt::Shape({}),
/*byte_strides=*/{},
ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind),
ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall,
/*on_done_with_host_buffer=*/
[py_buffer_ref = std::move(
py_buffer_ref)]() { /* keeps py_buffer_ref alive */ },
options.ifrt_user_context));
return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false);
};
}
absl::StatusOr<DevicePutResultFn> HandleStringNumpyArray(
nb::handle h, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
xla::nb_numpy_ndarray array = nb::cast<xla::nb_numpy_ndarray>(h);
auto py_array_obj = reinterpret_cast<PyArrayObject*>(array.ptr());
TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj));
// Assemble all the parameters of MakeArrayFromHostBuffer
void* data = cords.data();
// Make an explicit copy of the shape elements so we won't run into complex
// endianness and precision issues that might arise if we reinterpret-casted
// from npy_intp, that can be just 32 bits-wide in some environments
// such as macos_arm64 to const int64_t* that must be 64 bits-wide.
ifrt::Shape::Dimensions dims;
dims.reserve(array.ndim());
for (int i = 0; i < array.ndim(); ++i) {
dims.push_back(array.shape(i));
}
ifrt::Shape shape(std::move(dims));
std::shared_ptr<xla::ifrt::Sharding> sharding =
xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind);
auto on_done_with_host_buffer = [cords = std::move(cords)] {};
return [client, data = data, shape = std::move(shape),
sharding = std::move(sharding),
on_done_with_host_buffer = std::move(on_done_with_host_buffer),
options]() mutable -> absl::StatusOr<DevicePutResult> {
TF_ASSIGN_OR_RETURN(
auto ifrt_array,
client->MakeArrayFromHostBuffer(
data, ifrt::DType(ifrt::DType::kString), std::move(shape),
/*byte_strides=*/std::nullopt, std::move(sharding),
ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes,
std::move(on_done_with_host_buffer), options.ifrt_user_context));
return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false);
};
}
absl::StatusOr<DevicePutResultFn> HandleNumpyArray(
nb::handle h, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
xla::nb_numpy_ndarray array = nb::cast<xla::nb_numpy_ndarray>(h);
// String numpy arrays require substantially different processing.
if (array.dtype().char_() == (int)'T' || array.dtype().kind() == 'T') {
return HandleStringNumpyArray(h, client, to_device, options,
to_memory_kind);
}
TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype()));
PrimitiveType squashed_type;
if (options.squash_64bit_types) {
squashed_type = Squash64BitTypes(type);
if (squashed_type != type) {
TF_ASSIGN_OR_RETURN(xla::nb_dtype squashed_dtype,
PrimitiveTypeToNbDtype(squashed_type));
array = nb::steal<xla::nb_numpy_ndarray>(PyArray_CastToType(
reinterpret_cast<PyArrayObject*>(array.ptr()),
reinterpret_cast<PyArray_Descr*>(squashed_dtype.release().ptr()),
/*fortran=*/0));
}
} else {
squashed_type = type;
}
absl::InlinedVector<int64_t, 4> dims(array.ndim());
absl::InlinedVector<int64_t, 4> byte_strides(array.ndim());
for (int i = 0; i < array.ndim(); ++i) {
dims[i] = array.shape(i);
byte_strides[i] = array.strides(i);
}
const void* data = array.data();
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
GlobalPyRefManager()->ManageReference(std::move(array));
return [client, data, squashed_type, dims = std::move(dims),
byte_strides = std::move(byte_strides),
py_buffer_ref = std::move(py_buffer_ref), options, to_device,
to_memory_kind]() mutable -> absl::StatusOr<DevicePutResult> {
TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(squashed_type));
ifrt::Client::HostBufferSemantics host_buffer_semantics =
ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall;
std::function<void()> on_done_with_host_buffer;
if (options.allow_zero_copy) {
on_done_with_host_buffer =
[py_buffer_ref{
std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ };
host_buffer_semantics =
ifrt::Client::HostBufferSemantics::kImmutableZeroCopy;
}
TF_ASSIGN_OR_RETURN(
auto ifrt_array,
client->MakeArrayFromHostBuffer(
data, ifrt_dtype, ifrt::Shape(dims), byte_strides,
xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind),
host_buffer_semantics, std::move(on_done_with_host_buffer),
options.ifrt_user_context));
return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false);
};
}
absl::StatusOr<DevicePutResultFn> HandlePyArray(
nb::handle obj, ifrt::Client* client, ifrt::Device* to_device,
const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) {
auto py_array = nb::borrow<PyArray>(obj);
// We only allow single device case for PyArray in device put.
if (py_array.num_shards() != 1) {
return InvalidArgument(
"device_put expects an array with exactly one shard, got an array with "
"with %d shards.",
py_array.num_shards());
}
ifrt::Array* ifrt_array = py_array.ifrt_array();
if (ifrt_array == nullptr) {
return InvalidArgument("Array has been deleted.");
}
// Fallback to python for non-matching clients or pmap sharding.
if (py_array.sharding().type().ptr() == jax::PmapSharding::type().ptr() ||
ifrt_array->sharding().devices()->devices().front()->client() !=
to_device->client()) {
return HandleNumpyArray(obj.attr("_value"), client, to_device, options,
to_memory_kind);
}
if (ifrt_array->sharding().devices()->devices().front() == to_device &&
options.allow_zero_copy &&
(!to_memory_kind.memory_kind().has_value() ||
!ifrt_array->sharding().memory_kind().memory_kind().has_value() ||
ifrt_array->sharding().memory_kind() == to_memory_kind)) {
DevicePutResult result(tsl::FormRef(ifrt_array), py_array.weak_type(),
/*owning_pybuffer=*/nb::borrow<nb::object>(obj));
return [result = std::move(result)]() mutable { return std::move(result); };
} else {
return [ifrt_array = tsl::FormRef(ifrt_array), to_device, to_memory_kind,
owning_pybuffer = py_array.weak_type(),
allow_zero_copy = options.allow_zero_copy]() mutable
-> absl::StatusOr<DevicePutResult> {
auto* ifrt_client = ifrt_array->client();
TF_ASSIGN_OR_RETURN(
auto copied_ifrt_arrays,
ifrt_client->CopyArrays(
absl::MakeSpan(&ifrt_array, 1),
ifrt_client->MakeDeviceList({to_device}), to_memory_kind,
allow_zero_copy ? ifrt::ArrayCopySemantics::kReuseInput
: ifrt::ArrayCopySemantics::kAlwaysCopy));
return DevicePutResult(std::move(copied_ifrt_arrays[0]),
std::move(owning_pybuffer));
};
}
}
} // namespace
absl::StatusOr<DevicePutResultFn> DevicePut(nb::handle arg,
ifrt::Client* client,
ifrt::Device* to_device,
const DevicePutOptions& options,
ifrt::MemoryKind to_memory_kind) {
tsl::profiler::TraceMe traceme("DevicePut");
static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
[] {
auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
// Python scalar types.
static_assert(sizeof(bool) == 1,
"Conversion code assumes bool is 1 byte");
(*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] =
HandlePythonScalar<bool, bool>;
(*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = HandlePythonInt;
(*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] =
HandlePythonScalar<double, float>;
(*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] =
HandlePythonScalar<complex128, complex64>;
(*p)[reinterpret_cast<PyObject*>(&PyArray_Type)] = HandleNumpyArray;
// Numpy scalar types. For some of them, we share the handler with
// Python types (np_int64, np_float64, np_complex128).
(*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar<bool>;
(*p)[dtypes.np_int4.ptr()] = HandleNumpyScalar<xla::s4>;
if (dtypes.np_int2.has_value()) {
(*p)[dtypes.np_int2->ptr()] = HandleNumpyScalar<xla::s2>;
}
(*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar<int8_t>;
(*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar<int16_t>;
(*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar<int32_t>;
(*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar<int64_t, int32_t>;
if (dtypes.np_uint2.has_value()) {
(*p)[dtypes.np_uint2->ptr()] = HandleNumpyScalar<xla::u2>;
}
(*p)[dtypes.np_uint4.ptr()] = HandleNumpyScalar<xla::u4>;
(*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar<uint8_t>;
(*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar<uint16_t>;
(*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar<uint32_t>;
(*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar<uint64_t, uint32_t>;
if (dtypes.np_float4_e2m1fn.has_value()) {
(*p)[dtypes.np_float4_e2m1fn->ptr()] =
HandleNumpyScalar<tsl::float4_e2m1fn>;
}
if (dtypes.np_float8_e3m4.has_value()) {
(*p)[dtypes.np_float8_e3m4->ptr()] =
HandleNumpyScalar<tsl::float8_e3m4>;
}
if (dtypes.np_float8_e4m3.has_value()) {
(*p)[dtypes.np_float8_e4m3->ptr()] =
HandleNumpyScalar<tsl::float8_e4m3>;
}
(*p)[dtypes.np_float8_e4m3fn.ptr()] =
HandleNumpyScalar<tsl::float8_e4m3fn>;
(*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] =
HandleNumpyScalar<tsl::float8_e4m3b11fnuz>;
(*p)[dtypes.np_float8_e5m2.ptr()] = HandleNumpyScalar<tsl::float8_e5m2>;
(*p)[dtypes.np_float8_e4m3fnuz.ptr()] =
HandleNumpyScalar<tsl::float8_e4m3fnuz>;
(*p)[dtypes.np_float8_e5m2fnuz.ptr()] =
HandleNumpyScalar<tsl::float8_e5m2fnuz>;
if (dtypes.np_float8_e8m0fnu.has_value()) {
(*p)[dtypes.np_float8_e8m0fnu->ptr()] =
HandleNumpyScalar<tsl::float8_e8m0fnu>;
}
(*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar<bfloat16>;
(*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar<half>;
(*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar<float>;
(*p)[dtypes.np_float64.ptr()] = HandleNumpyScalar<double, float>;
(*p)[dtypes.np_complex64.ptr()] = HandleNumpyScalar<complex64>;
(*p)[dtypes.np_complex128.ptr()] =
HandleNumpyScalar<complex128, complex64>;
static_assert(sizeof(long long) == sizeof(int64_t), // NOLINT
"long long must be the same size as int64_t");
(*p)[dtypes.np_longlong.ptr()] = HandleNumpyScalar<int64_t, int32_t>;
static_assert(sizeof(int) == sizeof(int32_t),
"int must be the same size as int32_t");
(*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar<int32_t>;
return p;
}();
if (arg.type().ptr() == PyArray::type().ptr()) {
auto array = nb::borrow<PyArray>(arg);
return HandlePyArray(arg, client, to_device, options, to_memory_kind);
}
auto res = handlers->find(arg.type().ptr());
if (res == handlers->end()) {
for (auto base_class : arg.type().attr("__mro__")) {
res = handlers->find(base_class.ptr());
if (res != handlers->end()) {
return res->second(arg, client, to_device, options, to_memory_kind);
}
}
return InvalidArgument(
"%s", absl::StrCat(
"Not supported: The C++ jax jit execution path, only accepts "
"DeviceArray, Numpy arrays scalars of supported types "
"(see implementation), or Python scalars. Got type ",
nb::cast<absl::string_view>(nb::str(arg.type()))));
}
return res->second(arg, client, to_device, options, to_memory_kind);
}
bool IsFloat0(xla::nb_numpy_ndarray arg) {
static const auto* dtypes_module =
new nb::module_(nb::module_::import_("jax.dtypes"));
static const auto* float0_dtype =
new nb::handle(dtypes_module->attr("float0"));
return float0_dtype->is(arg.attr("dtype"));
}
std::string PyArgSignature::DebugString() const {
std::string result = "";
if (weak_type) {
absl::StrAppend(&result, "weak_");
}
absl::StrAppend(&result, xla::PrimitiveType_Name(dtype));
absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]");
return result;
}
using ToPyArgSignatureHandler =
std::function<absl::StatusOr<PyArgSignature>(nb::handle, bool)>;
absl::StatusOr<PyArgSignature> PyArgSignatureOfValue(nb::handle arg,
bool jax_enable_x64) {
static const absl::flat_hash_map<PyObject*, ToPyArgSignatureHandler>* const
handlers = [] {
auto p = new absl::flat_hash_map<PyObject*, ToPyArgSignatureHandler>();
const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
// The 4 Python native types.
ToPyArgSignatureHandler bool_handler =
[](nb::handle, bool) -> absl::StatusOr<PyArgSignature> {
return PyArgSignature(PrimitiveType::PRED, {}, true);
};
ToPyArgSignatureHandler int_handler =
[](nb::handle h,
bool jax_enable_x64) -> absl::StatusOr<PyArgSignature> {
// TODO(phawkins): we should consider checking for integer overflow.
if (jax_enable_x64) {
return PyArgSignature(PrimitiveType::S64, {}, true);
} else {
return PyArgSignature(PrimitiveType::S32, {}, true);
}
};
ToPyArgSignatureHandler float_handler =
[&dtypes](nb::handle h,
bool jax_enable_x64) -> absl::StatusOr<PyArgSignature> {
// Only Python native types has a True weak_type.
bool weak_type = !nb::isinstance(h, dtypes.np_float64);
if (jax_enable_x64) {
return PyArgSignature(PrimitiveType::F64, {}, weak_type);
} else {
return PyArgSignature(PrimitiveType::F32, {}, weak_type);
}
};
ToPyArgSignatureHandler complex_handler =
[&dtypes](nb::handle h,
bool jax_enable_x64) -> absl::StatusOr<PyArgSignature> {
// Note that this branch is also taken for np.complex128:
// isinstance(np.complex128(3), complex) returns True
// isinstance(np.complex64(3), complex) returns False
bool weak_type = !nb::isinstance(h, dtypes.np_complex128);
if (jax_enable_x64) {
return PyArgSignature(PrimitiveType::C128, {}, weak_type);
} else {
return PyArgSignature(PrimitiveType::C64, {}, weak_type);
}
};
(*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] = bool_handler;
(*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = int_handler;
(*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] = float_handler;
(*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] = complex_handler;
ToPyArgSignatureHandler numpy_handler =
[](nb::handle h,
bool jax_enable_x64) -> absl::StatusOr<PyArgSignature> {
xla::nb_numpy_ndarray numpy_array =
nb::cast<xla::nb_numpy_ndarray>(h);
TF_ASSIGN_OR_RETURN(PrimitiveType dtype,
DtypeToPrimitiveType(numpy_array.dtype()));
if (!jax_enable_x64) {
dtype = Squash64BitTypes(dtype);
}
// We use reinterpret_cast<> to defend against environments where
// ssize_t may not be precisely the same type as int64_t, even if it
// is the same size (long vs long long).
static_assert(sizeof(int64_t) == sizeof(ssize_t),
"Code assumes ssize_t is the same as int64_t");
return PyArgSignature(
dtype,
absl::MakeConstSpan(
reinterpret_cast<const int64_t*>(numpy_array.shape()),
numpy_array.ndim()),
/*weak_type=*/false);
};
(*p)[reinterpret_cast<PyObject*>(&PyArray_Type)] = numpy_handler;
ToPyArgSignatureHandler np_uint64_handler =
[](nb::handle h,
bool jax_enable_x64) -> absl::StatusOr<PyArgSignature> {
if (jax_enable_x64) {
return PyArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false);
} else {
return PyArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false);
}
};
ToPyArgSignatureHandler np_int_handler =
[](nb::handle h,
bool jax_enable_x64) -> absl::StatusOr<PyArgSignature> {
if (jax_enable_x64) {
return PyArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false);
} else {
return PyArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false);
}
};
ToPyArgSignatureHandler numpy_array_handler =
[](nb::handle h,
bool jax_enable_x64) -> absl::StatusOr<PyArgSignature> {
// This block deals with all numpy scalar types, except for int64_dt,
// float64_dt and complex128_dt which are taken care of in previous if
// blocks.
TF_ASSIGN_OR_RETURN(auto dtype,
DtypeToPrimitiveType(h.attr("dtype")));
return PyArgSignature(dtype, {}, /*weak_type=*/false);
};
// This block deals with all numpy scalar types, except for int64_dt,
// float64_dt and complex128_dt which are taken care of in previous if
// blocks.
(*p)[dtypes.np_bool.ptr()] = numpy_array_handler;
(*p)[dtypes.np_int4.ptr()] = numpy_array_handler;
(*p)[dtypes.np_int8.ptr()] = numpy_array_handler;
(*p)[dtypes.np_int16.ptr()] = numpy_array_handler;
(*p)[dtypes.np_int32.ptr()] = numpy_array_handler;
(*p)[dtypes.np_int64.ptr()] = np_int_handler;
(*p)[dtypes.np_uint4.ptr()] = numpy_array_handler;
(*p)[dtypes.np_uint8.ptr()] = numpy_array_handler;
(*p)[dtypes.np_uint16.ptr()] = numpy_array_handler;
(*p)[dtypes.np_uint32.ptr()] = numpy_array_handler;
(*p)[dtypes.np_uint64.ptr()] = np_uint64_handler;
// TODO(upwind): Explore if we can remove std::optional for these types
// in xla/python/types.h and xla/python/types.cc
if (dtypes.np_float4_e2m1fn.has_value()) {
(*p)[dtypes.np_float4_e2m1fn->ptr()] = numpy_array_handler;
}
if (dtypes.np_float8_e3m4.has_value()) {
(*p)[dtypes.np_float8_e3m4->ptr()] = numpy_array_handler;
}
if (dtypes.np_float8_e4m3.has_value()) {
(*p)[dtypes.np_float8_e4m3->ptr()] = numpy_array_handler;
}
(*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler;
(*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler;
(*p)[dtypes.np_float8_e4m3fnuz.ptr()] = numpy_array_handler;
(*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler;
(*p)[dtypes.np_float8_e5m2fnuz.ptr()] = numpy_array_handler;
if (dtypes.np_float8_e8m0fnu.has_value()) {
(*p)[dtypes.np_float8_e8m0fnu->ptr()] = numpy_array_handler;
}
(*p)[dtypes.np_float16.ptr()] = numpy_array_handler;
(*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler;
(*p)[dtypes.np_float32.ptr()] = numpy_array_handler;
(*p)[dtypes.np_float64.ptr()] = float_handler;
(*p)[dtypes.np_complex64.ptr()] = numpy_array_handler;
(*p)[dtypes.np_complex128.ptr()] = complex_handler;
(*p)[dtypes.np_longlong.ptr()] = np_int_handler;
(*p)[dtypes.np_intc.ptr()] = numpy_array_handler;
return p;
}();
if (arg.type().ptr() == PyArray::type().ptr()) {
auto array = nb::borrow<PyArray>(arg);
ifrt::Array* ifrt_array = array.ifrt_array();
if (ifrt_array == nullptr) {
return xla::InvalidArgument("Array has been deleted.");
}
TF_ASSIGN_OR_RETURN(auto primitive_type,
ifrt::ToPrimitiveType(ifrt_array->dtype()));
return PyArgSignature(primitive_type, array.shape(), array.weak_type());
}
auto res = handlers->find(arg.type().ptr());
if (res == handlers->end()) {
// We attempt to look at the MRO classes
for (auto base_class : arg.type().attr("__mro__")) {
res = handlers->find(base_class.ptr());
if (res != handlers->end()) {
return res->second(arg, jax_enable_x64);
}
}
return InvalidArgument(
"%s",
absl::StrCat("Not supported: The C++ ToPyArgSignature only accepts "
"Buffer/DeviceArray, Numpy "
"arrays scalars of supported types "
"(see implementation), or Python scalars. Got type ",
nb::cast<absl::string_view>(nb::str(arg.type()))));
}
return res->second(arg, jax_enable_x64);
}
} // namespace xla