mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
custom vmap: support closure and staged constants
The `custom_vmap` primitive stages out its wrapped function at call time. It might extract closed-over or otherwise constant values ("consts") in doing so. To handle these, we can reduce back to the empty closure setting: convert the consts to formal arguments, both in the target function and in the custom vmap rule, and ignore them in the latter. We only need to play this trick once, on initial entry. After that, we can resume in assuming an empty closure.
This commit is contained in:
parent
7128bb4ac9
commit
ef9b2fe4a1
@ -19,6 +19,7 @@ from typing import Callable, Optional
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax import tree_util
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters.batching import not_mapped
|
||||
@ -66,11 +67,11 @@ class custom_vmap:
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
assert not len(consts)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
in_tree = treedef_tuple((tree_structure(consts), in_tree))
|
||||
out_flat = custom_vmap_p.bind(*consts, *args_flat,
|
||||
call=closed_call,
|
||||
rule=self.vmap_rule,
|
||||
rule=ClosedRule(self.vmap_rule),
|
||||
in_tree=in_tree,
|
||||
out_tree=out_tree())
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
@ -78,6 +79,21 @@ class custom_vmap:
|
||||
|
||||
### utils
|
||||
|
||||
# Define a class, instead of making a function closing over `rule`, so
|
||||
# that we can override __str__
|
||||
class ClosedRule:
|
||||
def __init__(self, rule):
|
||||
functools.update_wrapper(self, rule)
|
||||
self.rule = rule
|
||||
|
||||
def __call__(self, axis_size, all_in_batched, *all_args):
|
||||
_, args = all_args
|
||||
consts_batched, in_batched = all_in_batched
|
||||
assert not any(tree_util.tree_leaves(consts_batched)), consts_batched
|
||||
return call_rule(self.rule, axis_size, in_batched, args)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.rule)
|
||||
|
||||
def ensure_list(xs):
|
||||
return xs if type(xs) is list else list(xs)
|
||||
|
@ -8805,6 +8805,28 @@ class CustomVmapTest(jtu.JaxTestCase):
|
||||
ys = api.vmap(f)(xs)
|
||||
self.assertAllClose(ys, jnp.cos(xs))
|
||||
|
||||
def test_closure(self):
|
||||
z = jnp.array([2., 1., 3.])
|
||||
|
||||
@api.custom_vmap
|
||||
def f(x): return z + jnp.sin(x)
|
||||
|
||||
@f.def_vmap
|
||||
def rule(axis_size, in_batched, *args):
|
||||
self.assertEqual(len(in_batched), 1)
|
||||
self.assertEqual(len(args), 1)
|
||||
xs, = args
|
||||
xs_batched, = in_batched
|
||||
self.assertEqual(xs_batched, True)
|
||||
self.assertEqual(axis_size, xs.shape[0])
|
||||
return z + jnp.cos(xs), xs_batched
|
||||
|
||||
x, xs = jnp.array(1.), jnp.arange(3)
|
||||
y = f(x)
|
||||
self.assertAllClose(y, z + jnp.sin(x))
|
||||
ys = api.vmap(f)(xs)
|
||||
self.assertAllClose(ys, z + jnp.cos(xs))
|
||||
|
||||
def test_rule_multi_output(self):
|
||||
@api.custom_vmap
|
||||
def f(x): return jnp.sin(x), jnp.cos(x)
|
||||
@ -8962,6 +8984,36 @@ class CustomVmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ys, jnp.cos(xs))
|
||||
self.assertAllClose(tys, -jnp.sin(xs) * txs)
|
||||
|
||||
def test_jvp_closure(self):
|
||||
z = jnp.array([2., 1., 3.])
|
||||
def bcast(x): return z + x - z
|
||||
|
||||
@api.custom_vmap
|
||||
def f(x): return z + jnp.sin(x)
|
||||
|
||||
@f.def_vmap
|
||||
def rule(axis_size, in_batched, xs):
|
||||
self.assertEqual(axis_size, 3)
|
||||
self.assertEqual(in_batched, [True])
|
||||
return z + jnp.cos(xs), in_batched[0]
|
||||
|
||||
f_jvp = lambda x, tx: api.jvp(f, [x], [tx])
|
||||
|
||||
x, tx = jnp.array(1.), jnp.array(2.)
|
||||
xs, txs = jnp.arange(3.), jnp.arange(3.) * 2.
|
||||
|
||||
y, ty = f_jvp(x, tx)
|
||||
self.assertAllClose(y, z + jnp.sin(x))
|
||||
self.assertAllClose(ty, bcast(jnp.cos(x)) * tx)
|
||||
|
||||
ys, tys = api.vmap(f_jvp)(xs, txs)
|
||||
self.assertAllClose(ys, z + jnp.cos(xs))
|
||||
self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs)
|
||||
|
||||
ys, tys = api.jvp(api.vmap(f), [xs], [txs])
|
||||
self.assertAllClose(ys, z + jnp.cos(xs))
|
||||
self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs)
|
||||
|
||||
def test_jvp_nary(self):
|
||||
@api.custom_vmap
|
||||
def f(x, y): return jnp.sin(x) + y
|
||||
|
Loading…
x
Reference in New Issue
Block a user