Merge pull request #13366 from froystig:custom-vmap-consts

PiperOrigin-RevId: 490412416
This commit is contained in:
jax authors 2022-11-22 22:09:57 -08:00
commit d045fe2f95
2 changed files with 70 additions and 2 deletions

View File

@ -19,6 +19,7 @@ from typing import Callable, Optional
import jax
from jax import core
from jax import linear_util as lu
from jax import tree_util
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters.batching import not_mapped
@ -66,11 +67,11 @@ class custom_vmap:
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
assert not len(consts)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
in_tree = treedef_tuple((tree_structure(consts), in_tree))
out_flat = custom_vmap_p.bind(*consts, *args_flat,
call=closed_call,
rule=self.vmap_rule,
rule=ClosedRule(self.vmap_rule),
in_tree=in_tree,
out_tree=out_tree())
return tree_unflatten(out_tree(), out_flat)
@ -78,6 +79,21 @@ class custom_vmap:
### utils
# Define a class, instead of making a function closing over `rule`, so
# that we can override __str__
class ClosedRule:
def __init__(self, rule):
functools.update_wrapper(self, rule)
self.rule = rule
def __call__(self, axis_size, all_in_batched, *all_args):
_, args = all_args
consts_batched, in_batched = all_in_batched
assert not any(tree_util.tree_leaves(consts_batched)), consts_batched
return call_rule(self.rule, axis_size, in_batched, args)
def __str__(self):
return str(self.rule)
def ensure_list(xs):
return xs if type(xs) is list else list(xs)

View File

@ -8805,6 +8805,28 @@ class CustomVmapTest(jtu.JaxTestCase):
ys = api.vmap(f)(xs)
self.assertAllClose(ys, jnp.cos(xs))
def test_closure(self):
z = jnp.array([2., 1., 3.])
@api.custom_vmap
def f(x): return z + jnp.sin(x)
@f.def_vmap
def rule(axis_size, in_batched, *args):
self.assertEqual(len(in_batched), 1)
self.assertEqual(len(args), 1)
xs, = args
xs_batched, = in_batched
self.assertEqual(xs_batched, True)
self.assertEqual(axis_size, xs.shape[0])
return z + jnp.cos(xs), xs_batched
x, xs = jnp.array(1.), jnp.arange(3)
y = f(x)
self.assertAllClose(y, z + jnp.sin(x))
ys = api.vmap(f)(xs)
self.assertAllClose(ys, z + jnp.cos(xs))
def test_rule_multi_output(self):
@api.custom_vmap
def f(x): return jnp.sin(x), jnp.cos(x)
@ -8962,6 +8984,36 @@ class CustomVmapTest(jtu.JaxTestCase):
self.assertAllClose(ys, jnp.cos(xs))
self.assertAllClose(tys, -jnp.sin(xs) * txs)
def test_jvp_closure(self):
z = jnp.array([2., 1., 3.])
def bcast(x): return z + x - z
@api.custom_vmap
def f(x): return z + jnp.sin(x)
@f.def_vmap
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [True])
return z + jnp.cos(xs), in_batched[0]
f_jvp = lambda x, tx: api.jvp(f, [x], [tx])
x, tx = jnp.array(1.), jnp.array(2.)
xs, txs = jnp.arange(3.), jnp.arange(3.) * 2.
y, ty = f_jvp(x, tx)
self.assertAllClose(y, z + jnp.sin(x))
self.assertAllClose(ty, bcast(jnp.cos(x)) * tx)
ys, tys = api.vmap(f_jvp)(xs, txs)
self.assertAllClose(ys, z + jnp.cos(xs))
self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs)
ys, tys = api.jvp(api.vmap(f), [xs], [txs])
self.assertAllClose(ys, z + jnp.cos(xs))
self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs)
def test_jvp_nary(self):
@api.custom_vmap
def f(x, y): return jnp.sin(x) + y