diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 627577f41..a65ce5268 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -919,5 +919,14 @@ def safe_map(state): while state: jax.util.safe_map(f, *args) +@google_benchmark.register +@google_benchmark.option.arg_names(['arg_lengths', 'num_args']) +@google_benchmark.option.args_product([[0, 1, 2, 5, 10, 100], [1, 2, 3]]) +def safe_zip(state): + args = tuple(list(range(state.range(0))) for _ in range(state.range(1))) + while state: + jax.util.safe_zip(*args) + + if __name__ == "__main__": google_benchmark.main() diff --git a/jax/_src/util.py b/jax/_src/util.py index aa1e804ed..b814ce308 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -37,24 +37,40 @@ T1 = TypeVar("T1") T2 = TypeVar("T2") T3 = TypeVar("T3") -# safe_zip cannot yet be fully annotated, so we use a strategy similar -# to that used for builtins.zip in python/typeshed. This supports -# return types matching input types for up to three arguments. -@overload -def safe_zip(__arg1: Iterable[T1]) -> List[Tuple[T1]]: ... -@overload -def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[Tuple[T1, T2]]: ... -@overload -def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[Tuple[T1, T2, T3]]: ... -@overload -def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[Tuple[Any, ...]]: ... -def safe_zip(*args): - args = list(map(list, args)) - n = len(args[0]) - for arg in args[1:]: - assert len(arg) == n, f'length mismatch: {list(map(len, args))}' - return list(zip(*args)) +if TYPE_CHECKING: + # safe_zip cannot yet be fully annotated, so we use a strategy similar + # to that used for builtins.zip in python/typeshed. This supports + # return types matching input types for up to three arguments. + @overload + def safe_zip(__arg1: Iterable[T1]) -> List[Tuple[T1]]: ... + @overload + def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[Tuple[T1, T2]]: ... + @overload + def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[Tuple[T1, T2, T3]]: ... + @overload + def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[Tuple[Any, ...]]: ... + + def safe_zip(*args): + args = list(map(list, args)) + n = len(args[0]) + for arg in args[1:]: + assert len(arg) == n, f'length mismatch: {list(map(len, args))}' + return list(zip(*args)) + +else: + # TODO(phawkins): remove the hasattr condition after jaxlib 0.4.9 is the + # minimum + if hasattr(jaxlib_utils, 'safe_zip'): + safe_zip = jaxlib_utils.safe_zip + else: + def safe_zip(*args): + args = list(map(list, args)) + n = len(args[0]) + for arg in args[1:]: + assert len(arg) == n, f'length mismatch: {list(map(len, args))}' + return list(zip(*args)) + if TYPE_CHECKING: # safe_map cannot yet be fully annotated, so we use a strategy similar diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc index 54192d26f..c2a2414d6 100644 --- a/jaxlib/utils.cc +++ b/jaxlib/utils.cc @@ -21,6 +21,11 @@ limitations under the License. namespace py = pybind11; +namespace { + +// A variant of map(...) that: +// a) returns a list instead of an iterator, and +// b) checks that the input iterables are of equal length. PyObject* SafeMap(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { if (nargs < 2) { PyErr_SetString(PyExc_TypeError, "safe_map requires at least 2 arguments"); @@ -66,8 +71,9 @@ PyObject* SafeMap(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { values[i + 1] = PyIter_Next(iterators[i].ptr()); if (PyErr_Occurred()) return nullptr; if (!values[i + 1]) { - PyErr_SetString(PyExc_ValueError, - "Length mismatch for arguments to safe_map"); + PyErr_Format(PyExc_ValueError, + "safe_map() argument %u is shorter than argument 1", + i + 1); return nullptr; } } @@ -78,8 +84,9 @@ PyObject* SafeMap(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { values[i + 1] = PyIter_Next(iterators[i].ptr()); if (PyErr_Occurred()) return nullptr; if (values[i + 1]) { - PyErr_SetString(PyExc_ValueError, - "Length mismatch for arguments to safe_map"); + PyErr_Format(PyExc_ValueError, + "safe_map() argument %u is longer than argument 1", + i + 1); return nullptr; } } @@ -118,7 +125,107 @@ PyMethodDef safe_map_def = { METH_FASTCALL, }; +// A variant of zip(...) that: +// a) returns a list instead of an iterator, and +// b) checks that the input iterables are of equal length. +// TODO(phawkins): consider replacing this function with +// list(zip(..., strict=True)) once TensorFlow 2.13 is released, which should +// resolve an incompatibility with strict=True and jax2tf. +PyObject* SafeZip(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { + if (nargs < 1) { + PyErr_SetString(PyExc_TypeError, "safe_zip requires at least 1 argument"); + return nullptr; + } + absl::InlinedVector iterators; + iterators.reserve(nargs); + for (Py_ssize_t i = 0; i < nargs; ++i) { + PyObject* it = PyObject_GetIter(args[i]); + if (!it) return nullptr; + iterators.push_back(py::reinterpret_steal(it)); + } + + // Try to use a length hint to estimate how large a list to allocate. + Py_ssize_t length_hint = PyObject_LengthHint(args[0], 2); + if (PyErr_Occurred()) { + PyErr_Clear(); + } + if (length_hint < 0) { + length_hint = 2; + } + + py::list list(length_hint); + int n = 0; // Current true size of the list + + while (true) { + py::object tuple; + py::object v = + py::reinterpret_steal(PyIter_Next(iterators[0].ptr())); + if (PyErr_Occurred()) return nullptr; + + if (v.ptr()) { + tuple = py::reinterpret_steal(PyTuple_New(nargs)); + if (!tuple.ptr()) return nullptr; + + PyTuple_SET_ITEM(tuple.ptr(), 0, v.release().ptr()); + for (size_t i = 1; i < iterators.size(); ++i) { + v = py::reinterpret_steal(PyIter_Next(iterators[i].ptr())); + if (PyErr_Occurred()) return nullptr; + if (!v.ptr()) { + PyErr_Format(PyExc_ValueError, + "safe_zip() argument %u is shorter than argument 1", + i + 1); + return nullptr; + } + PyTuple_SET_ITEM(tuple.ptr(), i, v.release().ptr()); + } + } else { + // No more elements should be left. Checks the other iterators are + // exhausted. + for (size_t i = 1; i < iterators.size(); ++i) { + v = py::reinterpret_steal(PyIter_Next(iterators[i].ptr())); + if (PyErr_Occurred()) return nullptr; + if (v.ptr()) { + PyErr_Format(PyExc_ValueError, + "safe_zip() argument %u is longer than argument 1", + i + 1); + return nullptr; + } + } + + // If the length hint was too large, truncate the list to the true size. + if (n < length_hint) { + if (PyList_SetSlice(list.ptr(), n, length_hint, nullptr) < 0) { + return nullptr; + } + } + return list.release().ptr(); + } + + if (n < length_hint) { + PyList_SET_ITEM(list.ptr(), n, tuple.release().ptr()); + } else { + if (PyList_Append(list.ptr(), tuple.ptr()) < 0) { + return nullptr; + } + tuple = py::object(); + } + ++n; + } +} + +PyMethodDef safe_zip_def = { + "safe_zip", + reinterpret_cast(SafeZip), + METH_FASTCALL, +}; + +} // namespace + + PYBIND11_MODULE(utils, m) { + py::object module_name = m.attr("__name__"); m.attr("safe_map") = py::reinterpret_steal( - PyCFunction_NewEx(&safe_map_def, nullptr, nullptr)); + PyCFunction_NewEx(&safe_map_def, /*self=*/nullptr, module_name.ptr())); + m.attr("safe_zip") = py::reinterpret_steal( + PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr())); } \ No newline at end of file diff --git a/tests/util_test.py b/tests/util_test.py index 5b0960b82..c317b9498 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -20,7 +20,6 @@ from absl.testing import absltest from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import util -from jax._src.lib import version as jaxlib_version from jax.config import config from jax._src.util import weakref_lru_cache @@ -123,44 +122,86 @@ class SafeMapTest(jtu.JaxTestCase): util.safe_map(make_tuple, range(4), range(4, 8)), ) - @unittest.skipIf(jaxlib_version < (0, 4, 9) or - not hasattr(jaxlib_utils, 'safe_map'), + @unittest.skipIf(not hasattr(jaxlib_utils, 'safe_map'), "requires jaxlib 0.4.9") def test_safe_map_errors(self): - with self.assertRaises( - TypeError, msg="safe_map requires at least 2 arguments" + with self.assertRaisesRegex( + TypeError, "safe_map requires at least 2 arguments" ): util.safe_map() - with self.assertRaises( - TypeError, msg="safe_map requires at least 2 arguments" + with self.assertRaisesRegex( + TypeError, "safe_map requires at least 2 arguments" ): util.safe_map(lambda x: x) - with self.assertRaises(TypeError, msg="'int' object is not callable'"): + with self.assertRaisesRegex(TypeError, "'int' object is not callable"): util.safe_map(7, range(6)) def error(*args, **kwargs): raise RuntimeError("hello") - with self.assertRaises(RuntimeError, msg="hello"): + with self.assertRaisesRegex(RuntimeError, "hello"): util.safe_map(error, range(6)) - with self.assertRaises( - ValueError, msg="Length mismatch for arguments to safe_map" + with self.assertRaisesRegex( + ValueError, r"safe_map\(\) argument 2 is longer than argument 1" ): util.safe_map(operator.add, range(3), range(4)) - with self.assertRaises( - ValueError, msg="Length mismatch for arguments to safe_map" + with self.assertRaisesRegex( + ValueError, r"safe_map\(\) argument 2 is shorter than argument 1" ): util.safe_map(operator.add, range(7), range(2)) - with self.assertRaises( - ValueError, msg="Length mismatch for arguments to safe_map" + with self.assertRaisesRegex( + ValueError, r"safe_map\(\) argument 2 is longer than argument 1" ): util.safe_map(operator.add, (), range(3)) +class SafeZipTest(jtu.JaxTestCase): + + def test_safe_zip(self): + self.assertEqual([], util.safe_zip([])) + self.assertEqual([], util.safe_zip((), [])) + self.assertEqual([], util.safe_zip([], [], [])) + self.assertEqual([], util.safe_zip([], iter([]), [], [])) + self.assertEqual([(7,)], util.safe_zip((7,))) + self.assertEqual([(0,), (1,), (2,), (3,)], util.safe_zip(range(4))) + self.assertEqual( + [(0, 4), (1, 5), (2, 6), (3, 7)], + util.safe_zip(range(4), range(4, 8)), + ) + + @unittest.skipIf(not hasattr(jaxlib_utils, 'safe_zip'), + "requires jaxlib 0.4.9") + def test_safe_zip_errors(self): + with self.assertRaisesRegex( + TypeError, "safe_zip requires at least 1 argument" + ): + util.safe_zip() + + with self.assertRaisesRegex( + TypeError, "'function' object is not iterable" + ): + util.safe_zip(lambda x: x) + + with self.assertRaisesRegex( + ValueError, r"safe_zip\(\) argument 2 is longer than argument 1" + ): + util.safe_zip(range(3), range(4)) + + with self.assertRaisesRegex( + ValueError, r"safe_zip\(\) argument 2 is shorter than argument 1" + ): + util.safe_zip(range(7), range(2)) + + with self.assertRaisesRegex( + ValueError, r"safe_zip\(\) argument 2 is longer than argument 1" + ): + util.safe_zip((), range(3)) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())