custom vmap: support closure and staged constants

The `custom_vmap` primitive stages out its wrapped function at call
time. It might extract closed-over or otherwise constant values
("consts") in doing so. To handle these, we can reduce back to the
empty closure setting: convert the consts to formal arguments, both in
the target function and in the custom vmap rule, and ignore them in
the latter.

We only need to play this trick once, on initial entry. After that, we
can resume in assuming an empty closure.
This commit is contained in:
Roy Frostig 2022-11-22 17:12:20 -08:00
parent 7128bb4ac9
commit ef9b2fe4a1
2 changed files with 70 additions and 2 deletions

View File

@ -19,6 +19,7 @@ from typing import Callable, Optional
import jax
from jax import core
from jax import linear_util as lu
from jax import tree_util
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters.batching import not_mapped
@ -66,11 +67,11 @@ class custom_vmap:
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
assert not len(consts)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
in_tree = treedef_tuple((tree_structure(consts), in_tree))
out_flat = custom_vmap_p.bind(*consts, *args_flat,
call=closed_call,
rule=self.vmap_rule,
rule=ClosedRule(self.vmap_rule),
in_tree=in_tree,
out_tree=out_tree())
return tree_unflatten(out_tree(), out_flat)
@ -78,6 +79,21 @@ class custom_vmap:
### utils
# Define a class, instead of making a function closing over `rule`, so
# that we can override __str__
class ClosedRule:
def __init__(self, rule):
functools.update_wrapper(self, rule)
self.rule = rule
def __call__(self, axis_size, all_in_batched, *all_args):
_, args = all_args
consts_batched, in_batched = all_in_batched
assert not any(tree_util.tree_leaves(consts_batched)), consts_batched
return call_rule(self.rule, axis_size, in_batched, args)
def __str__(self):
return str(self.rule)
def ensure_list(xs):
return xs if type(xs) is list else list(xs)

View File

@ -8805,6 +8805,28 @@ class CustomVmapTest(jtu.JaxTestCase):
ys = api.vmap(f)(xs)
self.assertAllClose(ys, jnp.cos(xs))
def test_closure(self):
z = jnp.array([2., 1., 3.])
@api.custom_vmap
def f(x): return z + jnp.sin(x)
@f.def_vmap
def rule(axis_size, in_batched, *args):
self.assertEqual(len(in_batched), 1)
self.assertEqual(len(args), 1)
xs, = args
xs_batched, = in_batched
self.assertEqual(xs_batched, True)
self.assertEqual(axis_size, xs.shape[0])
return z + jnp.cos(xs), xs_batched
x, xs = jnp.array(1.), jnp.arange(3)
y = f(x)
self.assertAllClose(y, z + jnp.sin(x))
ys = api.vmap(f)(xs)
self.assertAllClose(ys, z + jnp.cos(xs))
def test_rule_multi_output(self):
@api.custom_vmap
def f(x): return jnp.sin(x), jnp.cos(x)
@ -8962,6 +8984,36 @@ class CustomVmapTest(jtu.JaxTestCase):
self.assertAllClose(ys, jnp.cos(xs))
self.assertAllClose(tys, -jnp.sin(xs) * txs)
def test_jvp_closure(self):
z = jnp.array([2., 1., 3.])
def bcast(x): return z + x - z
@api.custom_vmap
def f(x): return z + jnp.sin(x)
@f.def_vmap
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [True])
return z + jnp.cos(xs), in_batched[0]
f_jvp = lambda x, tx: api.jvp(f, [x], [tx])
x, tx = jnp.array(1.), jnp.array(2.)
xs, txs = jnp.arange(3.), jnp.arange(3.) * 2.
y, ty = f_jvp(x, tx)
self.assertAllClose(y, z + jnp.sin(x))
self.assertAllClose(ty, bcast(jnp.cos(x)) * tx)
ys, tys = api.vmap(f_jvp)(xs, txs)
self.assertAllClose(ys, z + jnp.cos(xs))
self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs)
ys, tys = api.jvp(api.vmap(f), [xs], [txs])
self.assertAllClose(ys, z + jnp.cos(xs))
self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs)
def test_jvp_nary(self):
@api.custom_vmap
def f(x, y): return jnp.sin(x) + y