mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #13366 from froystig:custom-vmap-consts
PiperOrigin-RevId: 490412416
This commit is contained in:
commit
d045fe2f95
@ -19,6 +19,7 @@ from typing import Callable, Optional
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax import tree_util
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters.batching import not_mapped
|
||||
@ -66,11 +67,11 @@ class custom_vmap:
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
assert not len(consts)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
in_tree = treedef_tuple((tree_structure(consts), in_tree))
|
||||
out_flat = custom_vmap_p.bind(*consts, *args_flat,
|
||||
call=closed_call,
|
||||
rule=self.vmap_rule,
|
||||
rule=ClosedRule(self.vmap_rule),
|
||||
in_tree=in_tree,
|
||||
out_tree=out_tree())
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
@ -78,6 +79,21 @@ class custom_vmap:
|
||||
|
||||
### utils
|
||||
|
||||
# Define a class, instead of making a function closing over `rule`, so
|
||||
# that we can override __str__
|
||||
class ClosedRule:
|
||||
def __init__(self, rule):
|
||||
functools.update_wrapper(self, rule)
|
||||
self.rule = rule
|
||||
|
||||
def __call__(self, axis_size, all_in_batched, *all_args):
|
||||
_, args = all_args
|
||||
consts_batched, in_batched = all_in_batched
|
||||
assert not any(tree_util.tree_leaves(consts_batched)), consts_batched
|
||||
return call_rule(self.rule, axis_size, in_batched, args)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.rule)
|
||||
|
||||
def ensure_list(xs):
|
||||
return xs if type(xs) is list else list(xs)
|
||||
|
@ -8805,6 +8805,28 @@ class CustomVmapTest(jtu.JaxTestCase):
|
||||
ys = api.vmap(f)(xs)
|
||||
self.assertAllClose(ys, jnp.cos(xs))
|
||||
|
||||
def test_closure(self):
|
||||
z = jnp.array([2., 1., 3.])
|
||||
|
||||
@api.custom_vmap
|
||||
def f(x): return z + jnp.sin(x)
|
||||
|
||||
@f.def_vmap
|
||||
def rule(axis_size, in_batched, *args):
|
||||
self.assertEqual(len(in_batched), 1)
|
||||
self.assertEqual(len(args), 1)
|
||||
xs, = args
|
||||
xs_batched, = in_batched
|
||||
self.assertEqual(xs_batched, True)
|
||||
self.assertEqual(axis_size, xs.shape[0])
|
||||
return z + jnp.cos(xs), xs_batched
|
||||
|
||||
x, xs = jnp.array(1.), jnp.arange(3)
|
||||
y = f(x)
|
||||
self.assertAllClose(y, z + jnp.sin(x))
|
||||
ys = api.vmap(f)(xs)
|
||||
self.assertAllClose(ys, z + jnp.cos(xs))
|
||||
|
||||
def test_rule_multi_output(self):
|
||||
@api.custom_vmap
|
||||
def f(x): return jnp.sin(x), jnp.cos(x)
|
||||
@ -8962,6 +8984,36 @@ class CustomVmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ys, jnp.cos(xs))
|
||||
self.assertAllClose(tys, -jnp.sin(xs) * txs)
|
||||
|
||||
def test_jvp_closure(self):
|
||||
z = jnp.array([2., 1., 3.])
|
||||
def bcast(x): return z + x - z
|
||||
|
||||
@api.custom_vmap
|
||||
def f(x): return z + jnp.sin(x)
|
||||
|
||||
@f.def_vmap
|
||||
def rule(axis_size, in_batched, xs):
|
||||
self.assertEqual(axis_size, 3)
|
||||
self.assertEqual(in_batched, [True])
|
||||
return z + jnp.cos(xs), in_batched[0]
|
||||
|
||||
f_jvp = lambda x, tx: api.jvp(f, [x], [tx])
|
||||
|
||||
x, tx = jnp.array(1.), jnp.array(2.)
|
||||
xs, txs = jnp.arange(3.), jnp.arange(3.) * 2.
|
||||
|
||||
y, ty = f_jvp(x, tx)
|
||||
self.assertAllClose(y, z + jnp.sin(x))
|
||||
self.assertAllClose(ty, bcast(jnp.cos(x)) * tx)
|
||||
|
||||
ys, tys = api.vmap(f_jvp)(xs, txs)
|
||||
self.assertAllClose(ys, z + jnp.cos(xs))
|
||||
self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs)
|
||||
|
||||
ys, tys = api.jvp(api.vmap(f), [xs], [txs])
|
||||
self.assertAllClose(ys, z + jnp.cos(xs))
|
||||
self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs)
|
||||
|
||||
def test_jvp_nary(self):
|
||||
@api.custom_vmap
|
||||
def f(x, y): return jnp.sin(x) + y
|
||||
|
Loading…
x
Reference in New Issue
Block a user