jnp.vectorize: support None arguments

This commit is contained in:
Jake VanderPlas 2023-11-08 11:46:47 -08:00
parent ba77626934
commit a9452b98a3
2 changed files with 25 additions and 1 deletions

View File

@ -263,7 +263,6 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
error_context = ("on vectorized function with excluded={!r} and "
"signature={!r}".format(excluded, signature))
excluded_func, args = _apply_excluded(pyfunc, excluded, args)
args = tuple(map(jnp.asarray, args))
if signature is not None:
input_core_dims, output_core_dims = _parse_gufunc_signature(signature)
@ -271,6 +270,15 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
input_core_dims = [()] * len(args)
output_core_dims = None
none_args = {i for i, arg in enumerate(args) if arg is None}
if any(none_args):
if any(input_core_dims[i] != () for i in none_args):
raise ValueError(f"Cannot pass None at locations {none_args} with {signature=}")
excluded_func, args = _apply_excluded(excluded_func, none_args, args)
input_core_dims = [dim for i, dim in enumerate(input_core_dims) if i not in none_args]
args = tuple(map(jnp.asarray, args))
broadcast_shape, dim_sizes = _parse_input_dimensions(
args, input_core_dims, error_context)

View File

@ -225,6 +225,22 @@ class VectorizeTest(jtu.JaxTestCase):
self.assertAllClose(xx[1], x)
self.assertIsInstance(xx, tuple)
def test_none_arg(self):
f = jnp.vectorize(lambda x, y: x if y is None else x + y)
x = jnp.arange(10)
self.assertAllClose(f(x, None), x)
y = jnp.arange(10, 20)
self.assertAllClose(f(x, y), x + y)
def test_none_arg_bad_signature(self):
f = jnp.vectorize(lambda x, y: x if y is None else x + y,
signature='(k),(k)->(k)')
args = jnp.arange(10), None
msg = r"Cannot pass None at locations \{1\} with signature='\(k\),\(k\)->\(k\)'"
with self.assertRaisesRegex(ValueError, msg):
f(*args)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())