Add a C++ safe_zip implementation.

Benchmark results on my workstation:
```
name                                 old cpu/op   new cpu/op   delta
safe_zip/arg_lengths:0/num_args:1    1.22µs ± 1%  0.28µs ± 8%  -77.33%  (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:1    1.28µs ± 1%  0.34µs ± 6%  -73.18%  (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:1    1.28µs ± 1%  0.38µs ± 5%  -70.26%  (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:1    1.38µs ± 1%  0.51µs ± 3%  -63.26%  (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:1   1.61µs ± 1%  0.69µs ± 3%  -56.93%  (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:1  5.39µs ± 1%  3.83µs ± 2%  -29.03%  (p=0.008 n=5+5)
safe_zip/arg_lengths:0/num_args:2    1.46µs ± 1%  0.32µs ± 4%  -78.30%  (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:2    1.52µs ± 1%  0.39µs ± 4%  -74.20%  (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:2    1.53µs ± 1%  0.44µs ± 4%  -71.38%  (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:2    1.66µs ± 2%  0.60µs ± 3%  -63.96%  (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:2   1.90µs ± 1%  0.82µs ± 3%  -56.66%  (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:2  6.51µs ± 1%  4.80µs ± 0%  -26.23%  (p=0.016 n=5+4)
safe_zip/arg_lengths:0/num_args:3    1.62µs ± 1%  0.36µs ± 4%  -77.95%  (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:3    1.68µs ± 1%  0.44µs ± 3%  -73.75%  (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:3    1.69µs ± 1%  0.50µs ± 3%  -70.48%  (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:3    1.83µs ± 1%  0.68µs ± 2%  -62.73%  (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:3   2.12µs ± 1%  0.96µs ± 1%  -54.71%  (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:3  7.34µs ± 2%  5.89µs ± 1%  -19.74%  (p=0.008 n=5+5)
```

In addition, improve the length mismatch error for safe_map and define __module__ on both functions.

PiperOrigin-RevId: 523475834
This commit is contained in:
Peter Hawkins 2023-04-11 12:42:30 -07:00 committed by jax authors
parent 466c9a282d
commit 74384e6a87
4 changed files with 210 additions and 37 deletions

View File

@ -919,5 +919,14 @@ def safe_map(state):
while state:
jax.util.safe_map(f, *args)
@google_benchmark.register
@google_benchmark.option.arg_names(['arg_lengths', 'num_args'])
@google_benchmark.option.args_product([[0, 1, 2, 5, 10, 100], [1, 2, 3]])
def safe_zip(state):
args = tuple(list(range(state.range(0))) for _ in range(state.range(1)))
while state:
jax.util.safe_zip(*args)
if __name__ == "__main__":
google_benchmark.main()

View File

@ -37,25 +37,41 @@ T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
# safe_zip cannot yet be fully annotated, so we use a strategy similar
# to that used for builtins.zip in python/typeshed. This supports
# return types matching input types for up to three arguments.
@overload
def safe_zip(__arg1: Iterable[T1]) -> List[Tuple[T1]]: ...
@overload
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[Tuple[T1, T2]]: ...
@overload
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[Tuple[T1, T2, T3]]: ...
@overload
def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[Tuple[Any, ...]]: ...
def safe_zip(*args):
if TYPE_CHECKING:
# safe_zip cannot yet be fully annotated, so we use a strategy similar
# to that used for builtins.zip in python/typeshed. This supports
# return types matching input types for up to three arguments.
@overload
def safe_zip(__arg1: Iterable[T1]) -> List[Tuple[T1]]: ...
@overload
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[Tuple[T1, T2]]: ...
@overload
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[Tuple[T1, T2, T3]]: ...
@overload
def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[Tuple[Any, ...]]: ...
def safe_zip(*args):
args = list(map(list, args))
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
return list(zip(*args))
else:
# TODO(phawkins): remove the hasattr condition after jaxlib 0.4.9 is the
# minimum
if hasattr(jaxlib_utils, 'safe_zip'):
safe_zip = jaxlib_utils.safe_zip
else:
def safe_zip(*args):
args = list(map(list, args))
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
return list(zip(*args))
if TYPE_CHECKING:
# safe_map cannot yet be fully annotated, so we use a strategy similar
# to that used for builtins.map in python/typeshed. This supports

View File

@ -21,6 +21,11 @@ limitations under the License.
namespace py = pybind11;
namespace {
// A variant of map(...) that:
// a) returns a list instead of an iterator, and
// b) checks that the input iterables are of equal length.
PyObject* SafeMap(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
if (nargs < 2) {
PyErr_SetString(PyExc_TypeError, "safe_map requires at least 2 arguments");
@ -66,8 +71,9 @@ PyObject* SafeMap(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
values[i + 1] = PyIter_Next(iterators[i].ptr());
if (PyErr_Occurred()) return nullptr;
if (!values[i + 1]) {
PyErr_SetString(PyExc_ValueError,
"Length mismatch for arguments to safe_map");
PyErr_Format(PyExc_ValueError,
"safe_map() argument %u is shorter than argument 1",
i + 1);
return nullptr;
}
}
@ -78,8 +84,9 @@ PyObject* SafeMap(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
values[i + 1] = PyIter_Next(iterators[i].ptr());
if (PyErr_Occurred()) return nullptr;
if (values[i + 1]) {
PyErr_SetString(PyExc_ValueError,
"Length mismatch for arguments to safe_map");
PyErr_Format(PyExc_ValueError,
"safe_map() argument %u is longer than argument 1",
i + 1);
return nullptr;
}
}
@ -118,7 +125,107 @@ PyMethodDef safe_map_def = {
METH_FASTCALL,
};
PYBIND11_MODULE(utils, m) {
m.attr("safe_map") = py::reinterpret_steal<py::object>(
PyCFunction_NewEx(&safe_map_def, nullptr, nullptr));
// A variant of zip(...) that:
// a) returns a list instead of an iterator, and
// b) checks that the input iterables are of equal length.
// TODO(phawkins): consider replacing this function with
// list(zip(..., strict=True)) once TensorFlow 2.13 is released, which should
// resolve an incompatibility with strict=True and jax2tf.
PyObject* SafeZip(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
if (nargs < 1) {
PyErr_SetString(PyExc_TypeError, "safe_zip requires at least 1 argument");
return nullptr;
}
absl::InlinedVector<py::object, 4> iterators;
iterators.reserve(nargs);
for (Py_ssize_t i = 0; i < nargs; ++i) {
PyObject* it = PyObject_GetIter(args[i]);
if (!it) return nullptr;
iterators.push_back(py::reinterpret_steal<py::object>(it));
}
// Try to use a length hint to estimate how large a list to allocate.
Py_ssize_t length_hint = PyObject_LengthHint(args[0], 2);
if (PyErr_Occurred()) {
PyErr_Clear();
}
if (length_hint < 0) {
length_hint = 2;
}
py::list list(length_hint);
int n = 0; // Current true size of the list
while (true) {
py::object tuple;
py::object v =
py::reinterpret_steal<py::object>(PyIter_Next(iterators[0].ptr()));
if (PyErr_Occurred()) return nullptr;
if (v.ptr()) {
tuple = py::reinterpret_steal<py::object>(PyTuple_New(nargs));
if (!tuple.ptr()) return nullptr;
PyTuple_SET_ITEM(tuple.ptr(), 0, v.release().ptr());
for (size_t i = 1; i < iterators.size(); ++i) {
v = py::reinterpret_steal<py::object>(PyIter_Next(iterators[i].ptr()));
if (PyErr_Occurred()) return nullptr;
if (!v.ptr()) {
PyErr_Format(PyExc_ValueError,
"safe_zip() argument %u is shorter than argument 1",
i + 1);
return nullptr;
}
PyTuple_SET_ITEM(tuple.ptr(), i, v.release().ptr());
}
} else {
// No more elements should be left. Checks the other iterators are
// exhausted.
for (size_t i = 1; i < iterators.size(); ++i) {
v = py::reinterpret_steal<py::object>(PyIter_Next(iterators[i].ptr()));
if (PyErr_Occurred()) return nullptr;
if (v.ptr()) {
PyErr_Format(PyExc_ValueError,
"safe_zip() argument %u is longer than argument 1",
i + 1);
return nullptr;
}
}
// If the length hint was too large, truncate the list to the true size.
if (n < length_hint) {
if (PyList_SetSlice(list.ptr(), n, length_hint, nullptr) < 0) {
return nullptr;
}
}
return list.release().ptr();
}
if (n < length_hint) {
PyList_SET_ITEM(list.ptr(), n, tuple.release().ptr());
} else {
if (PyList_Append(list.ptr(), tuple.ptr()) < 0) {
return nullptr;
}
tuple = py::object();
}
++n;
}
}
PyMethodDef safe_zip_def = {
"safe_zip",
reinterpret_cast<PyCFunction>(SafeZip),
METH_FASTCALL,
};
} // namespace
PYBIND11_MODULE(utils, m) {
py::object module_name = m.attr("__name__");
m.attr("safe_map") = py::reinterpret_steal<py::object>(
PyCFunction_NewEx(&safe_map_def, /*self=*/nullptr, module_name.ptr()));
m.attr("safe_zip") = py::reinterpret_steal<py::object>(
PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr()));
}

View File

@ -20,7 +20,6 @@ from absl.testing import absltest
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import util
from jax._src.lib import version as jaxlib_version
from jax.config import config
from jax._src.util import weakref_lru_cache
@ -123,44 +122,86 @@ class SafeMapTest(jtu.JaxTestCase):
util.safe_map(make_tuple, range(4), range(4, 8)),
)
@unittest.skipIf(jaxlib_version < (0, 4, 9) or
not hasattr(jaxlib_utils, 'safe_map'),
@unittest.skipIf(not hasattr(jaxlib_utils, 'safe_map'),
"requires jaxlib 0.4.9")
def test_safe_map_errors(self):
with self.assertRaises(
TypeError, msg="safe_map requires at least 2 arguments"
with self.assertRaisesRegex(
TypeError, "safe_map requires at least 2 arguments"
):
util.safe_map()
with self.assertRaises(
TypeError, msg="safe_map requires at least 2 arguments"
with self.assertRaisesRegex(
TypeError, "safe_map requires at least 2 arguments"
):
util.safe_map(lambda x: x)
with self.assertRaises(TypeError, msg="'int' object is not callable'"):
with self.assertRaisesRegex(TypeError, "'int' object is not callable"):
util.safe_map(7, range(6))
def error(*args, **kwargs):
raise RuntimeError("hello")
with self.assertRaises(RuntimeError, msg="hello"):
with self.assertRaisesRegex(RuntimeError, "hello"):
util.safe_map(error, range(6))
with self.assertRaises(
ValueError, msg="Length mismatch for arguments to safe_map"
with self.assertRaisesRegex(
ValueError, r"safe_map\(\) argument 2 is longer than argument 1"
):
util.safe_map(operator.add, range(3), range(4))
with self.assertRaises(
ValueError, msg="Length mismatch for arguments to safe_map"
with self.assertRaisesRegex(
ValueError, r"safe_map\(\) argument 2 is shorter than argument 1"
):
util.safe_map(operator.add, range(7), range(2))
with self.assertRaises(
ValueError, msg="Length mismatch for arguments to safe_map"
with self.assertRaisesRegex(
ValueError, r"safe_map\(\) argument 2 is longer than argument 1"
):
util.safe_map(operator.add, (), range(3))
class SafeZipTest(jtu.JaxTestCase):
def test_safe_zip(self):
self.assertEqual([], util.safe_zip([]))
self.assertEqual([], util.safe_zip((), []))
self.assertEqual([], util.safe_zip([], [], []))
self.assertEqual([], util.safe_zip([], iter([]), [], []))
self.assertEqual([(7,)], util.safe_zip((7,)))
self.assertEqual([(0,), (1,), (2,), (3,)], util.safe_zip(range(4)))
self.assertEqual(
[(0, 4), (1, 5), (2, 6), (3, 7)],
util.safe_zip(range(4), range(4, 8)),
)
@unittest.skipIf(not hasattr(jaxlib_utils, 'safe_zip'),
"requires jaxlib 0.4.9")
def test_safe_zip_errors(self):
with self.assertRaisesRegex(
TypeError, "safe_zip requires at least 1 argument"
):
util.safe_zip()
with self.assertRaisesRegex(
TypeError, "'function' object is not iterable"
):
util.safe_zip(lambda x: x)
with self.assertRaisesRegex(
ValueError, r"safe_zip\(\) argument 2 is longer than argument 1"
):
util.safe_zip(range(3), range(4))
with self.assertRaisesRegex(
ValueError, r"safe_zip\(\) argument 2 is shorter than argument 1"
):
util.safe_zip(range(7), range(2))
with self.assertRaisesRegex(
ValueError, r"safe_zip\(\) argument 2 is longer than argument 1"
):
util.safe_zip((), range(3))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())