mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add JVP of xmap
This commit is contained in:
parent
9d56552517
commit
c13efc1223
@ -35,6 +35,7 @@ from ..interpreters import partial_eval as pe
|
||||
from ..interpreters import pxla
|
||||
from ..interpreters import xla
|
||||
from ..interpreters import batching
|
||||
from ..interpreters import ad
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from .._src.util import safe_map, safe_zip, HashableFunction, as_hashable_function, unzip2
|
||||
@ -638,10 +639,7 @@ class EvaluationPlan(NamedTuple):
|
||||
# -------- xmap primitive and its transforms --------
|
||||
|
||||
# xmap has a different set of parameters than pmap, so we make it its own primitive type
|
||||
class XMapPrimitive(core.Primitive):
|
||||
multiple_results = True
|
||||
map_primitive = True # Not really, but it gives us a few good behaviors
|
||||
|
||||
class XMapPrimitive(core.MapPrimitive): # Not really a map, but it gives us a few good defaults
|
||||
def __init__(self):
|
||||
super().__init__('xmap')
|
||||
self.def_impl(xmap_impl)
|
||||
@ -670,6 +668,8 @@ def _xmap_axis_subst(params, subst):
|
||||
return dict(params, call_jaxpr=new_jaxpr)
|
||||
core.axis_substitution_rules[xmap_p] = _xmap_axis_subst
|
||||
|
||||
ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore
|
||||
ad.call_param_updaters[xmap_p] = ad.call_param_updaters[xla.xla_call_p]
|
||||
|
||||
# This is DynamicJaxprTrace.process_map with some very minor modifications
|
||||
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
|
@ -543,6 +543,14 @@ class XMapTest(XMapTestCase):
|
||||
y = rng.randn(*yshape)
|
||||
self.assertAllClose(fm(x, y), fref(x, y))
|
||||
|
||||
def testJVP(self):
|
||||
f = xmap(lambda x, y: jnp.cos(lax.dot(x, jnp.sin(y),
|
||||
precision=lax.Precision.HIGHEST)),
|
||||
in_axes=[['i', ...], {}], out_axes=['i', ...])
|
||||
x = jnp.arange(12, dtype=jnp.float32).reshape((3, 4)) / 100
|
||||
y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100
|
||||
jtu.check_grads(f, (x, y), order=2, modes=['fwd'])
|
||||
|
||||
|
||||
class XMapTestSPMD(SPMDTestMixin, XMapTest):
|
||||
"""Re-executes all basic tests with the SPMD partitioner enabled"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user