rocm_jax/tests/api_util_test.py
Dan Foreman-Mackey 69fc8bb419 Consolidate handling of input argument resolution in custom_* APIs.
This is a partial re-land of https://github.com/google/jax/pull/22869 with some updates to ensure that it doesn't break existing uses of `custom_vmap`.

Previously, using a `custom_jvp` or `custom_vjp` with a primal function that has keyword-only arguments would result in a type error, even if these arguments weren't passed by the caller. I believe that this check is actually slightly stricter than it needed to be, as discovered when adding a similar check to `custom_vmap`. Instead, I think that it is sufficient to check that the caller hasn't _passed_ any keyword-only arguments.

The previous behavior in `custom_vmap` was even harsher: it would error if any keyword arguments were passed.

In this change, I have moved `resolve_kwargs` into `api_utils` so that the same function can be used in both `custom_derivatives` and `custom_batching`. I've also updated the logic to only throw a `TypeError` if the caller passes a keyword only argument when calling a `custom_*`-decorated function. This changes the behavior of `custom_jvp` and `custom_vjp`, although users shouldn't see that effect, since previously having kwargs would have errored.

PiperOrigin-RevId: 662402158
2024-08-13 00:30:23 -07:00

90 lines
3.1 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax.api_util."""
import itertools as it
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import api_util
from jax import numpy as jnp
from jax._src import test_util as jtu
jax.config.parse_flags_with_absl()
class ApiUtilTest(jtu.JaxTestCase):
def test_donation_vector(self):
params = {"a": jnp.ones([]), "b": jnp.ones([])}
state = {"c": jnp.ones([]), "d": jnp.ones([])}
x = jnp.ones([])
args = params, state, x
for size in range(4):
for donate_argnums in it.permutations((0, 1, 2), size):
for kwargs in ({}, {"a": x}):
expected = ()
expected += (True, True) if 0 in donate_argnums else (False, False)
expected += (True, True) if 1 in donate_argnums else (False, False)
expected += (True,) if 2 in donate_argnums else (False,)
if kwargs:
expected += (False,)
self.assertEqual(
expected,
api_util.donation_vector(donate_argnums, (),
jax.tree.structure((args, kwargs))))
@parameterized.parameters(
((0,), (0,)),
((0, 1), (1, 2)),
((0, 1, 2), (0, 1, 2)),
)
def test_rebase_donate_argnums_rejects_overlapping(self, donate, static):
with self.assertRaisesRegex(ValueError, "cannot intersect"):
api_util.rebase_donate_argnums(donate, static)
@parameterized.parameters(
((), (), ()),
((), (1, 2, 3), ()),
((0,), (2, 3), (0,)),
((0, 1), (2, 3), (0, 1)),
((2, 3), (0, 1), (0, 1)),
((3, 2), (1, 0), (0, 1)),
((3,), (0, 1), (1,)),
((3, 3, 3,), (0, 1), (1,)),
)
def test_rebase_donate_argnums(self, donate, static, expected):
self.assertEqual(expected,
api_util.rebase_donate_argnums(donate, static))
def test_resolve_kwargs(self):
def fun(x, y, z=3):
return x, y, z
assert api_util.resolve_kwargs(fun, (1,), {"y": 2}) == (1, 2, 3)
assert api_util.resolve_kwargs(fun, (1, 2), {"z": 3}) == (1, 2, 3)
assert api_util.resolve_kwargs(
fun, (), {"x": 1, "y": 2, "z": 3}) == (1, 2, 3)
def test_resolve_kwargs_with_keyword(self):
def fun(x, y, z, *, kw=True):
del kw
return x, y, z
assert api_util.resolve_kwargs(fun, (1, 2), {"z": 3}) == (1, 2, 3)
with self.assertRaisesRegex(TypeError, "keyword arguments"):
api_util.resolve_kwargs(fun, (1, 2), {"z": 3, "kw": False})
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())