mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
jnp.vectorize: support None arguments
This commit is contained in:
parent
ba77626934
commit
a9452b98a3
@ -263,7 +263,6 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
|
||||
error_context = ("on vectorized function with excluded={!r} and "
|
||||
"signature={!r}".format(excluded, signature))
|
||||
excluded_func, args = _apply_excluded(pyfunc, excluded, args)
|
||||
args = tuple(map(jnp.asarray, args))
|
||||
|
||||
if signature is not None:
|
||||
input_core_dims, output_core_dims = _parse_gufunc_signature(signature)
|
||||
@ -271,6 +270,15 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
|
||||
input_core_dims = [()] * len(args)
|
||||
output_core_dims = None
|
||||
|
||||
none_args = {i for i, arg in enumerate(args) if arg is None}
|
||||
if any(none_args):
|
||||
if any(input_core_dims[i] != () for i in none_args):
|
||||
raise ValueError(f"Cannot pass None at locations {none_args} with {signature=}")
|
||||
excluded_func, args = _apply_excluded(excluded_func, none_args, args)
|
||||
input_core_dims = [dim for i, dim in enumerate(input_core_dims) if i not in none_args]
|
||||
|
||||
args = tuple(map(jnp.asarray, args))
|
||||
|
||||
broadcast_shape, dim_sizes = _parse_input_dimensions(
|
||||
args, input_core_dims, error_context)
|
||||
|
||||
|
@ -225,6 +225,22 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(xx[1], x)
|
||||
self.assertIsInstance(xx, tuple)
|
||||
|
||||
def test_none_arg(self):
|
||||
f = jnp.vectorize(lambda x, y: x if y is None else x + y)
|
||||
x = jnp.arange(10)
|
||||
self.assertAllClose(f(x, None), x)
|
||||
|
||||
y = jnp.arange(10, 20)
|
||||
self.assertAllClose(f(x, y), x + y)
|
||||
|
||||
def test_none_arg_bad_signature(self):
|
||||
f = jnp.vectorize(lambda x, y: x if y is None else x + y,
|
||||
signature='(k),(k)->(k)')
|
||||
args = jnp.arange(10), None
|
||||
msg = r"Cannot pass None at locations \{1\} with signature='\(k\),\(k\)->\(k\)'"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
f(*args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user