mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
pjit: add test for basic static_argnums
This commit is contained in:
parent
04e6786277
commit
bbaee9c54e
@ -40,7 +40,7 @@ from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters.sharded_jit import PartitionSpec
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves
|
||||
from jax.tree_util import tree_map, tree_flatten, tree_unflatten
|
||||
from jax._src.util import (extend_name_stack, HashableFunction, safe_zip,
|
||||
wrap_name, wraps, distributed_debug_log,
|
||||
split_list, cache, tuple_insert)
|
||||
@ -220,8 +220,9 @@ def pjit(fun: Callable,
|
||||
f, static_argnums, args, allow_invalid=False)
|
||||
else:
|
||||
dyn_args = args
|
||||
del args
|
||||
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
|
||||
if donate_argnums:
|
||||
donated_invars = donation_vector(donate_argnums, dyn_args, ())
|
||||
@ -259,9 +260,9 @@ def pjit(fun: Callable,
|
||||
|
||||
@wraps(fun)
|
||||
def wrapped(*args, **kwargs):
|
||||
for arg in tree_leaves(args):
|
||||
_check_arg(arg)
|
||||
args_flat, params, _, out_tree, _ = infer_params(*args, **kwargs)
|
||||
for arg in args_flat:
|
||||
_check_arg(arg)
|
||||
out = pjit_p.bind(*args_flat, **params)
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
||||
|
@ -622,6 +622,15 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
"called with:\n.*int32.*",
|
||||
lambda: exe(x_i32, x_i32))
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def test_static_argnums(self):
|
||||
@partial(pjit, in_axis_resources=None, out_axis_resources=None,
|
||||
static_argnums=(1,))
|
||||
def f(x, y):
|
||||
return x + (3 if y == 'hi' else 4)
|
||||
|
||||
self.assertEqual(f(1, 'hi' ), 4)
|
||||
self.assertEqual(f(1, 'bye'), 5)
|
||||
|
||||
class GDAPjitTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user