mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Merge pull request #15592 from mattjj:shmap-post-process-scalar-res
PiperOrigin-RevId: 524121962
This commit is contained in:
commit
468d7720bf
@ -947,6 +947,8 @@ def _shard_map_partial_eval_post_process(
|
||||
del check_rep
|
||||
unk_tracers = [t for t in tracers if not t.is_known()]
|
||||
jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers)
|
||||
jaxpr, res = _promote_scalar_residuals_jaxpr(jaxpr, res)
|
||||
|
||||
out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers])
|
||||
out = [*consts, *res]
|
||||
main = trace.main
|
||||
@ -985,19 +987,20 @@ pe.JaxprTrace.post_process_shard_map = _shard_map_partial_eval_post_process
|
||||
@lu.transformation
|
||||
def _promote_scalar_residuals(*args, **kwargs):
|
||||
jaxpr, (out_pvals, out_consts, env) = yield args, kwargs
|
||||
which_scalar = [isinstance(v.aval, core.ShapedArray) and not v.aval.shape
|
||||
for v in jaxpr.constvars]
|
||||
out_consts_ = [jax.lax.broadcast(x, (1,)) if scalar else x
|
||||
for x, scalar in zip(out_consts, which_scalar)]
|
||||
jaxpr, out_consts = _promote_scalar_residuals_jaxpr(jaxpr, out_consts)
|
||||
yield jaxpr, (out_pvals, out_consts, env)
|
||||
|
||||
def _promote_scalar_residuals_jaxpr(jaxpr, res):
|
||||
which = [isinstance(v.aval, core.ShapedArray) and not v.aval.shape
|
||||
for v in jaxpr.constvars]
|
||||
res_ = [jax.lax.broadcast(x, (1,)) if s else x for x, s in zip(res, which)]
|
||||
|
||||
@lu.wrap_init
|
||||
def fun(*args):
|
||||
out_consts = [x.reshape(*x.shape[1:]) if scalar else x
|
||||
for x, scalar in zip(out_consts_, which_scalar)]
|
||||
return core.eval_jaxpr(jaxpr, out_consts, *args)
|
||||
in_avals = [v.aval for v in jaxpr.invars]
|
||||
jaxpr, _, out_consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
yield jaxpr, (out_pvals, out_consts, env)
|
||||
res = [_rem_singleton(x) if s else x for x, s in zip(res_, which)]
|
||||
return core.eval_jaxpr(jaxpr, res, *args)
|
||||
jaxpr, _, res = pe.trace_to_jaxpr_dynamic(fun, [v.aval for v in jaxpr.invars])
|
||||
return jaxpr, res
|
||||
|
||||
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
check_rep):
|
||||
|
@ -812,6 +812,12 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
expected_num_eqns=1 + 1, # one outer eqn, two remain in body
|
||||
check_diff=False)
|
||||
|
||||
def test_post_process_partial_eval_with_scalar_res(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
|
||||
g = jax.grad(lambda x: shard_map(lambda: jnp.sin(x), mesh=mesh,
|
||||
in_specs=P(), out_specs=P())())(2.0)
|
||||
self.assertAllClose(g, jnp.cos(2.0), check_dtypes=False)
|
||||
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
|
Loading…
x
Reference in New Issue
Block a user