pjit: add test for basic static_argnums

This commit is contained in:
Matthew Johnson 2022-01-19 18:44:31 -08:00
parent 04e6786277
commit bbaee9c54e
2 changed files with 14 additions and 4 deletions

View File

@ -40,7 +40,7 @@ from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters.sharded_jit import PartitionSpec
from jax._src.lib import xla_client as xc
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves
from jax.tree_util import tree_map, tree_flatten, tree_unflatten
from jax._src.util import (extend_name_stack, HashableFunction, safe_zip,
wrap_name, wraps, distributed_debug_log,
split_list, cache, tuple_insert)
@ -220,8 +220,9 @@ def pjit(fun: Callable,
f, static_argnums, args, allow_invalid=False)
else:
dyn_args = args
del args
args_flat, in_tree = tree_flatten(args)
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
if donate_argnums:
donated_invars = donation_vector(donate_argnums, dyn_args, ())
@ -259,9 +260,9 @@ def pjit(fun: Callable,
@wraps(fun)
def wrapped(*args, **kwargs):
for arg in tree_leaves(args):
_check_arg(arg)
args_flat, params, _, out_tree, _ = infer_params(*args, **kwargs)
for arg in args_flat:
_check_arg(arg)
out = pjit_p.bind(*args_flat, **params)
return tree_unflatten(out_tree, out)

View File

@ -622,6 +622,15 @@ class PJitTest(jtu.BufferDonationTestCase):
"called with:\n.*int32.*",
lambda: exe(x_i32, x_i32))
@jtu.with_mesh([('x', 2)])
def test_static_argnums(self):
@partial(pjit, in_axis_resources=None, out_axis_resources=None,
static_argnums=(1,))
def f(x, y):
return x + (3 if y == 'hi' else 4)
self.assertEqual(f(1, 'hi' ), 4)
self.assertEqual(f(1, 'bye'), 5)
class GDAPjitTest(jtu.JaxTestCase):