Add JVP of xmap

This commit is contained in:
Adam Paszke 2021-04-14 14:01:28 +00:00
parent 9d56552517
commit c13efc1223
2 changed files with 12 additions and 4 deletions

View File

@ -35,6 +35,7 @@ from ..interpreters import partial_eval as pe
from ..interpreters import pxla
from ..interpreters import xla
from ..interpreters import batching
from ..interpreters import ad
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from .._src.util import safe_map, safe_zip, HashableFunction, as_hashable_function, unzip2
@ -638,10 +639,7 @@ class EvaluationPlan(NamedTuple):
# -------- xmap primitive and its transforms --------
# xmap has a different set of parameters than pmap, so we make it its own primitive type
class XMapPrimitive(core.Primitive):
multiple_results = True
map_primitive = True # Not really, but it gives us a few good behaviors
class XMapPrimitive(core.MapPrimitive): # Not really a map, but it gives us a few good defaults
def __init__(self):
super().__init__('xmap')
self.def_impl(xmap_impl)
@ -670,6 +668,8 @@ def _xmap_axis_subst(params, subst):
return dict(params, call_jaxpr=new_jaxpr)
core.axis_substitution_rules[xmap_p] = _xmap_axis_subst
ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore
ad.call_param_updaters[xmap_p] = ad.call_param_updaters[xla.xla_call_p]
# This is DynamicJaxprTrace.process_map with some very minor modifications
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):

View File

@ -543,6 +543,14 @@ class XMapTest(XMapTestCase):
y = rng.randn(*yshape)
self.assertAllClose(fm(x, y), fref(x, y))
def testJVP(self):
f = xmap(lambda x, y: jnp.cos(lax.dot(x, jnp.sin(y),
precision=lax.Precision.HIGHEST)),
in_axes=[['i', ...], {}], out_axes=['i', ...])
x = jnp.arange(12, dtype=jnp.float32).reshape((3, 4)) / 100
y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100
jtu.check_grads(f, (x, y), order=2, modes=['fwd'])
class XMapTestSPMD(SPMDTestMixin, XMapTest):
"""Re-executes all basic tests with the SPMD partitioner enabled"""