mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a C++ safe_zip implementation.
Benchmark results on my workstation: ``` name old cpu/op new cpu/op delta safe_zip/arg_lengths:0/num_args:1 1.22µs ± 1% 0.28µs ± 8% -77.33% (p=0.008 n=5+5) safe_zip/arg_lengths:1/num_args:1 1.28µs ± 1% 0.34µs ± 6% -73.18% (p=0.008 n=5+5) safe_zip/arg_lengths:2/num_args:1 1.28µs ± 1% 0.38µs ± 5% -70.26% (p=0.008 n=5+5) safe_zip/arg_lengths:5/num_args:1 1.38µs ± 1% 0.51µs ± 3% -63.26% (p=0.008 n=5+5) safe_zip/arg_lengths:10/num_args:1 1.61µs ± 1% 0.69µs ± 3% -56.93% (p=0.008 n=5+5) safe_zip/arg_lengths:100/num_args:1 5.39µs ± 1% 3.83µs ± 2% -29.03% (p=0.008 n=5+5) safe_zip/arg_lengths:0/num_args:2 1.46µs ± 1% 0.32µs ± 4% -78.30% (p=0.008 n=5+5) safe_zip/arg_lengths:1/num_args:2 1.52µs ± 1% 0.39µs ± 4% -74.20% (p=0.008 n=5+5) safe_zip/arg_lengths:2/num_args:2 1.53µs ± 1% 0.44µs ± 4% -71.38% (p=0.008 n=5+5) safe_zip/arg_lengths:5/num_args:2 1.66µs ± 2% 0.60µs ± 3% -63.96% (p=0.008 n=5+5) safe_zip/arg_lengths:10/num_args:2 1.90µs ± 1% 0.82µs ± 3% -56.66% (p=0.008 n=5+5) safe_zip/arg_lengths:100/num_args:2 6.51µs ± 1% 4.80µs ± 0% -26.23% (p=0.016 n=5+4) safe_zip/arg_lengths:0/num_args:3 1.62µs ± 1% 0.36µs ± 4% -77.95% (p=0.008 n=5+5) safe_zip/arg_lengths:1/num_args:3 1.68µs ± 1% 0.44µs ± 3% -73.75% (p=0.008 n=5+5) safe_zip/arg_lengths:2/num_args:3 1.69µs ± 1% 0.50µs ± 3% -70.48% (p=0.008 n=5+5) safe_zip/arg_lengths:5/num_args:3 1.83µs ± 1% 0.68µs ± 2% -62.73% (p=0.008 n=5+5) safe_zip/arg_lengths:10/num_args:3 2.12µs ± 1% 0.96µs ± 1% -54.71% (p=0.008 n=5+5) safe_zip/arg_lengths:100/num_args:3 7.34µs ± 2% 5.89µs ± 1% -19.74% (p=0.008 n=5+5) ``` In addition, improve the length mismatch error for safe_map and define __module__ on both functions. PiperOrigin-RevId: 523475834
This commit is contained in:
parent
466c9a282d
commit
74384e6a87
@ -919,5 +919,14 @@ def safe_map(state):
|
||||
while state:
|
||||
jax.util.safe_map(f, *args)
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['arg_lengths', 'num_args'])
|
||||
@google_benchmark.option.args_product([[0, 1, 2, 5, 10, 100], [1, 2, 3]])
|
||||
def safe_zip(state):
|
||||
args = tuple(list(range(state.range(0))) for _ in range(state.range(1)))
|
||||
while state:
|
||||
jax.util.safe_zip(*args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
google_benchmark.main()
|
||||
|
@ -37,25 +37,41 @@ T1 = TypeVar("T1")
|
||||
T2 = TypeVar("T2")
|
||||
T3 = TypeVar("T3")
|
||||
|
||||
# safe_zip cannot yet be fully annotated, so we use a strategy similar
|
||||
# to that used for builtins.zip in python/typeshed. This supports
|
||||
# return types matching input types for up to three arguments.
|
||||
@overload
|
||||
def safe_zip(__arg1: Iterable[T1]) -> List[Tuple[T1]]: ...
|
||||
@overload
|
||||
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[Tuple[T1, T2]]: ...
|
||||
@overload
|
||||
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[Tuple[T1, T2, T3]]: ...
|
||||
@overload
|
||||
def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[Tuple[Any, ...]]: ...
|
||||
|
||||
def safe_zip(*args):
|
||||
if TYPE_CHECKING:
|
||||
# safe_zip cannot yet be fully annotated, so we use a strategy similar
|
||||
# to that used for builtins.zip in python/typeshed. This supports
|
||||
# return types matching input types for up to three arguments.
|
||||
@overload
|
||||
def safe_zip(__arg1: Iterable[T1]) -> List[Tuple[T1]]: ...
|
||||
@overload
|
||||
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[Tuple[T1, T2]]: ...
|
||||
@overload
|
||||
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[Tuple[T1, T2, T3]]: ...
|
||||
@overload
|
||||
def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[Tuple[Any, ...]]: ...
|
||||
|
||||
def safe_zip(*args):
|
||||
args = list(map(list, args))
|
||||
n = len(args[0])
|
||||
for arg in args[1:]:
|
||||
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
||||
return list(zip(*args))
|
||||
|
||||
else:
|
||||
# TODO(phawkins): remove the hasattr condition after jaxlib 0.4.9 is the
|
||||
# minimum
|
||||
if hasattr(jaxlib_utils, 'safe_zip'):
|
||||
safe_zip = jaxlib_utils.safe_zip
|
||||
else:
|
||||
def safe_zip(*args):
|
||||
args = list(map(list, args))
|
||||
n = len(args[0])
|
||||
for arg in args[1:]:
|
||||
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
||||
return list(zip(*args))
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# safe_map cannot yet be fully annotated, so we use a strategy similar
|
||||
# to that used for builtins.map in python/typeshed. This supports
|
||||
|
121
jaxlib/utils.cc
121
jaxlib/utils.cc
@ -21,6 +21,11 @@ limitations under the License.
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace {
|
||||
|
||||
// A variant of map(...) that:
|
||||
// a) returns a list instead of an iterator, and
|
||||
// b) checks that the input iterables are of equal length.
|
||||
PyObject* SafeMap(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
|
||||
if (nargs < 2) {
|
||||
PyErr_SetString(PyExc_TypeError, "safe_map requires at least 2 arguments");
|
||||
@ -66,8 +71,9 @@ PyObject* SafeMap(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
|
||||
values[i + 1] = PyIter_Next(iterators[i].ptr());
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
if (!values[i + 1]) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Length mismatch for arguments to safe_map");
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"safe_map() argument %u is shorter than argument 1",
|
||||
i + 1);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
@ -78,8 +84,9 @@ PyObject* SafeMap(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
|
||||
values[i + 1] = PyIter_Next(iterators[i].ptr());
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
if (values[i + 1]) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Length mismatch for arguments to safe_map");
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"safe_map() argument %u is longer than argument 1",
|
||||
i + 1);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
@ -118,7 +125,107 @@ PyMethodDef safe_map_def = {
|
||||
METH_FASTCALL,
|
||||
};
|
||||
|
||||
PYBIND11_MODULE(utils, m) {
|
||||
m.attr("safe_map") = py::reinterpret_steal<py::object>(
|
||||
PyCFunction_NewEx(&safe_map_def, nullptr, nullptr));
|
||||
// A variant of zip(...) that:
|
||||
// a) returns a list instead of an iterator, and
|
||||
// b) checks that the input iterables are of equal length.
|
||||
// TODO(phawkins): consider replacing this function with
|
||||
// list(zip(..., strict=True)) once TensorFlow 2.13 is released, which should
|
||||
// resolve an incompatibility with strict=True and jax2tf.
|
||||
PyObject* SafeZip(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
|
||||
if (nargs < 1) {
|
||||
PyErr_SetString(PyExc_TypeError, "safe_zip requires at least 1 argument");
|
||||
return nullptr;
|
||||
}
|
||||
absl::InlinedVector<py::object, 4> iterators;
|
||||
iterators.reserve(nargs);
|
||||
for (Py_ssize_t i = 0; i < nargs; ++i) {
|
||||
PyObject* it = PyObject_GetIter(args[i]);
|
||||
if (!it) return nullptr;
|
||||
iterators.push_back(py::reinterpret_steal<py::object>(it));
|
||||
}
|
||||
|
||||
// Try to use a length hint to estimate how large a list to allocate.
|
||||
Py_ssize_t length_hint = PyObject_LengthHint(args[0], 2);
|
||||
if (PyErr_Occurred()) {
|
||||
PyErr_Clear();
|
||||
}
|
||||
if (length_hint < 0) {
|
||||
length_hint = 2;
|
||||
}
|
||||
|
||||
py::list list(length_hint);
|
||||
int n = 0; // Current true size of the list
|
||||
|
||||
while (true) {
|
||||
py::object tuple;
|
||||
py::object v =
|
||||
py::reinterpret_steal<py::object>(PyIter_Next(iterators[0].ptr()));
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
|
||||
if (v.ptr()) {
|
||||
tuple = py::reinterpret_steal<py::object>(PyTuple_New(nargs));
|
||||
if (!tuple.ptr()) return nullptr;
|
||||
|
||||
PyTuple_SET_ITEM(tuple.ptr(), 0, v.release().ptr());
|
||||
for (size_t i = 1; i < iterators.size(); ++i) {
|
||||
v = py::reinterpret_steal<py::object>(PyIter_Next(iterators[i].ptr()));
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
if (!v.ptr()) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"safe_zip() argument %u is shorter than argument 1",
|
||||
i + 1);
|
||||
return nullptr;
|
||||
}
|
||||
PyTuple_SET_ITEM(tuple.ptr(), i, v.release().ptr());
|
||||
}
|
||||
} else {
|
||||
// No more elements should be left. Checks the other iterators are
|
||||
// exhausted.
|
||||
for (size_t i = 1; i < iterators.size(); ++i) {
|
||||
v = py::reinterpret_steal<py::object>(PyIter_Next(iterators[i].ptr()));
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
if (v.ptr()) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"safe_zip() argument %u is longer than argument 1",
|
||||
i + 1);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// If the length hint was too large, truncate the list to the true size.
|
||||
if (n < length_hint) {
|
||||
if (PyList_SetSlice(list.ptr(), n, length_hint, nullptr) < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return list.release().ptr();
|
||||
}
|
||||
|
||||
if (n < length_hint) {
|
||||
PyList_SET_ITEM(list.ptr(), n, tuple.release().ptr());
|
||||
} else {
|
||||
if (PyList_Append(list.ptr(), tuple.ptr()) < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
tuple = py::object();
|
||||
}
|
||||
++n;
|
||||
}
|
||||
}
|
||||
|
||||
PyMethodDef safe_zip_def = {
|
||||
"safe_zip",
|
||||
reinterpret_cast<PyCFunction>(SafeZip),
|
||||
METH_FASTCALL,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
PYBIND11_MODULE(utils, m) {
|
||||
py::object module_name = m.attr("__name__");
|
||||
m.attr("safe_map") = py::reinterpret_steal<py::object>(
|
||||
PyCFunction_NewEx(&safe_map_def, /*self=*/nullptr, module_name.ptr()));
|
||||
m.attr("safe_zip") = py::reinterpret_steal<py::object>(
|
||||
PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr()));
|
||||
}
|
@ -20,7 +20,6 @@ from absl.testing import absltest
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
|
||||
from jax.config import config
|
||||
from jax._src.util import weakref_lru_cache
|
||||
@ -123,44 +122,86 @@ class SafeMapTest(jtu.JaxTestCase):
|
||||
util.safe_map(make_tuple, range(4), range(4, 8)),
|
||||
)
|
||||
|
||||
@unittest.skipIf(jaxlib_version < (0, 4, 9) or
|
||||
not hasattr(jaxlib_utils, 'safe_map'),
|
||||
@unittest.skipIf(not hasattr(jaxlib_utils, 'safe_map'),
|
||||
"requires jaxlib 0.4.9")
|
||||
def test_safe_map_errors(self):
|
||||
with self.assertRaises(
|
||||
TypeError, msg="safe_map requires at least 2 arguments"
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "safe_map requires at least 2 arguments"
|
||||
):
|
||||
util.safe_map()
|
||||
|
||||
with self.assertRaises(
|
||||
TypeError, msg="safe_map requires at least 2 arguments"
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "safe_map requires at least 2 arguments"
|
||||
):
|
||||
util.safe_map(lambda x: x)
|
||||
|
||||
with self.assertRaises(TypeError, msg="'int' object is not callable'"):
|
||||
with self.assertRaisesRegex(TypeError, "'int' object is not callable"):
|
||||
util.safe_map(7, range(6))
|
||||
|
||||
def error(*args, **kwargs):
|
||||
raise RuntimeError("hello")
|
||||
|
||||
with self.assertRaises(RuntimeError, msg="hello"):
|
||||
with self.assertRaisesRegex(RuntimeError, "hello"):
|
||||
util.safe_map(error, range(6))
|
||||
|
||||
with self.assertRaises(
|
||||
ValueError, msg="Length mismatch for arguments to safe_map"
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"safe_map\(\) argument 2 is longer than argument 1"
|
||||
):
|
||||
util.safe_map(operator.add, range(3), range(4))
|
||||
|
||||
with self.assertRaises(
|
||||
ValueError, msg="Length mismatch for arguments to safe_map"
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"safe_map\(\) argument 2 is shorter than argument 1"
|
||||
):
|
||||
util.safe_map(operator.add, range(7), range(2))
|
||||
|
||||
with self.assertRaises(
|
||||
ValueError, msg="Length mismatch for arguments to safe_map"
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"safe_map\(\) argument 2 is longer than argument 1"
|
||||
):
|
||||
util.safe_map(operator.add, (), range(3))
|
||||
|
||||
|
||||
class SafeZipTest(jtu.JaxTestCase):
|
||||
|
||||
def test_safe_zip(self):
|
||||
self.assertEqual([], util.safe_zip([]))
|
||||
self.assertEqual([], util.safe_zip((), []))
|
||||
self.assertEqual([], util.safe_zip([], [], []))
|
||||
self.assertEqual([], util.safe_zip([], iter([]), [], []))
|
||||
self.assertEqual([(7,)], util.safe_zip((7,)))
|
||||
self.assertEqual([(0,), (1,), (2,), (3,)], util.safe_zip(range(4)))
|
||||
self.assertEqual(
|
||||
[(0, 4), (1, 5), (2, 6), (3, 7)],
|
||||
util.safe_zip(range(4), range(4, 8)),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not hasattr(jaxlib_utils, 'safe_zip'),
|
||||
"requires jaxlib 0.4.9")
|
||||
def test_safe_zip_errors(self):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "safe_zip requires at least 1 argument"
|
||||
):
|
||||
util.safe_zip()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "'function' object is not iterable"
|
||||
):
|
||||
util.safe_zip(lambda x: x)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"safe_zip\(\) argument 2 is longer than argument 1"
|
||||
):
|
||||
util.safe_zip(range(3), range(4))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"safe_zip\(\) argument 2 is shorter than argument 1"
|
||||
):
|
||||
util.safe_zip(range(7), range(2))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"safe_zip\(\) argument 2 is longer than argument 1"
|
||||
):
|
||||
util.safe_zip((), range(3))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user