Merge pull request #15592 from mattjj:shmap-post-process-scalar-res

PiperOrigin-RevId: 524121962
This commit is contained in:
jax authors 2023-04-13 15:40:07 -07:00
commit 468d7720bf
2 changed files with 19 additions and 10 deletions

View File

@ -947,6 +947,8 @@ def _shard_map_partial_eval_post_process(
del check_rep
unk_tracers = [t for t in tracers if not t.is_known()]
jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers)
jaxpr, res = _promote_scalar_residuals_jaxpr(jaxpr, res)
out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers])
out = [*consts, *res]
main = trace.main
@ -985,19 +987,20 @@ pe.JaxprTrace.post_process_shard_map = _shard_map_partial_eval_post_process
@lu.transformation
def _promote_scalar_residuals(*args, **kwargs):
jaxpr, (out_pvals, out_consts, env) = yield args, kwargs
which_scalar = [isinstance(v.aval, core.ShapedArray) and not v.aval.shape
for v in jaxpr.constvars]
out_consts_ = [jax.lax.broadcast(x, (1,)) if scalar else x
for x, scalar in zip(out_consts, which_scalar)]
jaxpr, out_consts = _promote_scalar_residuals_jaxpr(jaxpr, out_consts)
yield jaxpr, (out_pvals, out_consts, env)
def _promote_scalar_residuals_jaxpr(jaxpr, res):
which = [isinstance(v.aval, core.ShapedArray) and not v.aval.shape
for v in jaxpr.constvars]
res_ = [jax.lax.broadcast(x, (1,)) if s else x for x, s in zip(res, which)]
@lu.wrap_init
def fun(*args):
out_consts = [x.reshape(*x.shape[1:]) if scalar else x
for x, scalar in zip(out_consts_, which_scalar)]
return core.eval_jaxpr(jaxpr, out_consts, *args)
in_avals = [v.aval for v in jaxpr.invars]
jaxpr, _, out_consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
yield jaxpr, (out_pvals, out_consts, env)
res = [_rem_singleton(x) if s else x for x, s in zip(res_, which)]
return core.eval_jaxpr(jaxpr, res, *args)
jaxpr, _, res = pe.trace_to_jaxpr_dynamic(fun, [v.aval for v in jaxpr.invars])
return jaxpr, res
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
check_rep):

View File

@ -812,6 +812,12 @@ class ShardMapTest(jtu.JaxTestCase):
expected_num_eqns=1 + 1, # one outer eqn, two remain in body
check_diff=False)
def test_post_process_partial_eval_with_scalar_res(self):
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
g = jax.grad(lambda x: shard_map(lambda: jnp.sin(x), mesh=mesh,
in_specs=P(), out_specs=P())())(2.0)
self.assertAllClose(g, jnp.cos(2.0), check_dtypes=False)
class FunSpec(NamedTuple):
name: str