mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

Before (argument names reversed, oops, fixed in code): ``` name time/op safe_map/num_args:0/arg_lengths:1 1.43µs ± 1% safe_map/num_args:1/arg_lengths:1 1.61µs ± 1% safe_map/num_args:2/arg_lengths:1 1.72µs ± 0% safe_map/num_args:5/arg_lengths:1 2.14µs ± 1% safe_map/num_args:10/arg_lengths:1 2.87µs ± 1% safe_map/num_args:100/arg_lengths:1 15.6µs ± 1% safe_map/num_args:0/arg_lengths:2 1.65µs ± 0% safe_map/num_args:1/arg_lengths:2 1.83µs ± 1% safe_map/num_args:2/arg_lengths:2 1.97µs ± 1% safe_map/num_args:5/arg_lengths:2 2.41µs ± 1% safe_map/num_args:10/arg_lengths:2 3.22µs ± 2% safe_map/num_args:100/arg_lengths:2 17.0µs ± 2% safe_map/num_args:0/arg_lengths:3 1.83µs ± 1% safe_map/num_args:1/arg_lengths:3 2.02µs ± 1% safe_map/num_args:2/arg_lengths:3 2.16µs ± 1% safe_map/num_args:5/arg_lengths:3 2.63µs ± 1% safe_map/num_args:10/arg_lengths:3 3.48µs ± 1% safe_map/num_args:100/arg_lengths:3 18.1µs ± 1% ``` After: ``` name time/op safe_map/num_args:0/arg_lengths:1 409ns ± 1% safe_map/num_args:1/arg_lengths:1 602ns ± 5% safe_map/num_args:2/arg_lengths:1 777ns ± 4% safe_map/num_args:5/arg_lengths:1 1.21µs ± 3% safe_map/num_args:10/arg_lengths:1 1.93µs ± 2% safe_map/num_args:100/arg_lengths:1 14.7µs ± 0% safe_map/num_args:0/arg_lengths:2 451ns ± 1% safe_map/num_args:1/arg_lengths:2 652ns ± 0% safe_map/num_args:2/arg_lengths:2 850ns ± 4% safe_map/num_args:5/arg_lengths:2 1.32µs ± 3% safe_map/num_args:10/arg_lengths:2 2.11µs ± 2% safe_map/num_args:100/arg_lengths:2 16.0µs ± 1% safe_map/num_args:0/arg_lengths:3 496ns ± 1% safe_map/num_args:1/arg_lengths:3 718ns ± 5% safe_map/num_args:2/arg_lengths:3 919ns ± 4% safe_map/num_args:5/arg_lengths:3 1.43µs ± 2% safe_map/num_args:10/arg_lengths:3 2.30µs ± 2% safe_map/num_args:100/arg_lengths:3 17.3µs ± 1% ``` PiperOrigin-RevId: 523263207
124 lines
3.9 KiB
C++
124 lines
3.9 KiB
C++
/* Copyright 2023 The JAX Authors.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include <Python.h>
|
|
|
|
#include "pybind11/pybind11.h"
|
|
#include "absl/cleanup/cleanup.h"
|
|
#include "absl/container/inlined_vector.h"
|
|
|
|
namespace py = pybind11;
|
|
|
|
PyObject* SafeMap(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
|
|
if (nargs < 2) {
|
|
PyErr_SetString(PyExc_TypeError, "safe_map requires at least 2 arguments");
|
|
return nullptr;
|
|
}
|
|
PyObject* fn = args[0];
|
|
absl::InlinedVector<py::object, 4> iterators;
|
|
iterators.reserve(nargs - 1);
|
|
for (Py_ssize_t i = 1; i < nargs; ++i) {
|
|
PyObject* it = PyObject_GetIter(args[i]);
|
|
if (!it) return nullptr;
|
|
iterators.push_back(py::reinterpret_steal<py::object>(it));
|
|
}
|
|
|
|
// Try to use a length hint to estimate how large a list to allocate.
|
|
Py_ssize_t length_hint = PyObject_LengthHint(args[1], 2);
|
|
if (PyErr_Occurred()) {
|
|
PyErr_Clear();
|
|
}
|
|
if (length_hint < 0) {
|
|
length_hint = 2;
|
|
}
|
|
|
|
py::list list(length_hint);
|
|
int n = 0; // Current true size of the list
|
|
|
|
// The arguments we will pass to fn. We allocate space for one more argument
|
|
// than we need at the start of the argument list so we can use
|
|
// PY_VECTORCALL_ARGUMENTS_OFFSET which may speed up the callee.
|
|
absl::InlinedVector<PyObject*, 4> values(nargs, nullptr);
|
|
while (true) {
|
|
absl::Cleanup values_cleanup = [&values]() {
|
|
for (PyObject* v : values) {
|
|
Py_XDECREF(v);
|
|
v = nullptr;
|
|
}
|
|
};
|
|
values[1] = PyIter_Next(iterators[0].ptr());
|
|
if (PyErr_Occurred()) return nullptr;
|
|
|
|
if (values[1]) {
|
|
for (size_t i = 1; i < iterators.size(); ++i) {
|
|
values[i + 1] = PyIter_Next(iterators[i].ptr());
|
|
if (PyErr_Occurred()) return nullptr;
|
|
if (!values[i + 1]) {
|
|
PyErr_SetString(PyExc_ValueError,
|
|
"Length mismatch for arguments to safe_map");
|
|
return nullptr;
|
|
}
|
|
}
|
|
} else {
|
|
// No more elements should be left. Checks the other iterators are
|
|
// exhausted.
|
|
for (size_t i = 1; i < iterators.size(); ++i) {
|
|
values[i + 1] = PyIter_Next(iterators[i].ptr());
|
|
if (PyErr_Occurred()) return nullptr;
|
|
if (values[i + 1]) {
|
|
PyErr_SetString(PyExc_ValueError,
|
|
"Length mismatch for arguments to safe_map");
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
// If the length hint was too large, truncate the list to the true size.
|
|
if (n < length_hint) {
|
|
if (PyList_SetSlice(list.ptr(), n, length_hint, nullptr) < 0) {
|
|
return nullptr;
|
|
}
|
|
}
|
|
return list.release().ptr();
|
|
}
|
|
|
|
// TODO(phawkins): use PyObject_Vectorcall after dropping Python 3.8 support
|
|
py::object out = py::reinterpret_steal<py::object>(_PyObject_Vectorcall(
|
|
fn, &values[1], (nargs - 1) | PY_VECTORCALL_ARGUMENTS_OFFSET,
|
|
/*kwnames=*/nullptr));
|
|
if (PyErr_Occurred()) {
|
|
return nullptr;
|
|
}
|
|
|
|
if (n < length_hint) {
|
|
PyList_SET_ITEM(list.ptr(), n, out.release().ptr());
|
|
} else {
|
|
if (PyList_Append(list.ptr(), out.ptr()) < 0) {
|
|
return nullptr;
|
|
}
|
|
}
|
|
++n;
|
|
}
|
|
}
|
|
|
|
PyMethodDef safe_map_def = {
|
|
"safe_map",
|
|
reinterpret_cast<PyCFunction>(SafeMap),
|
|
METH_FASTCALL,
|
|
};
|
|
|
|
PYBIND11_MODULE(utils, m) {
|
|
m.attr("safe_map") = py::reinterpret_steal<py::object>(
|
|
PyCFunction_NewEx(&safe_map_def, nullptr, nullptr));
|
|
} |