mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Set the mesh as manual during partial_eval_custom in shard_map so that _add_reshapes
happens under the correct mesh.
PiperOrigin-RevId: 723268798
This commit is contained in:
parent
02f4531310
commit
307006e194
@ -530,8 +530,8 @@ class AbstractMesh:
|
||||
|
||||
@staticmethod
|
||||
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
|
||||
jax_config.abstract_mesh_context_manager.set_local(mesh)
|
||||
return
|
||||
prev = jax_config.abstract_mesh_context_manager.swap_local(mesh)
|
||||
return prev
|
||||
|
||||
|
||||
# Create this indirection because pytype fails to recognize a property if a
|
||||
|
@ -1719,7 +1719,9 @@ def _partial_eval_jaxpr_custom_rule(
|
||||
idx_map = {id(v): i for i, v in enumerate(out_vars)}
|
||||
out_fwd = [idx_map.get(id(v)) for v in res_vars]
|
||||
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()):
|
||||
mesh = eqn.params['mesh']
|
||||
with (core.extend_axis_env_nd(mesh.shape.items()),
|
||||
set_abstract_mesh(_as_manual_mesh(mesh))):
|
||||
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
|
||||
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
|
||||
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)
|
||||
|
Loading…
x
Reference in New Issue
Block a user