Set the mesh as manual during partial_eval_custom in shard_map so that _add_reshapes happens under the correct mesh.

PiperOrigin-RevId: 723268798
This commit is contained in:
Yash Katariya 2025-02-04 16:35:32 -08:00 committed by jax authors
parent 02f4531310
commit 307006e194
2 changed files with 5 additions and 3 deletions

View File

@ -530,8 +530,8 @@ class AbstractMesh:
@staticmethod
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
jax_config.abstract_mesh_context_manager.set_local(mesh)
return
prev = jax_config.abstract_mesh_context_manager.swap_local(mesh)
return prev
# Create this indirection because pytype fails to recognize a property if a

View File

@ -1719,7 +1719,9 @@ def _partial_eval_jaxpr_custom_rule(
idx_map = {id(v): i for i, v in enumerate(out_vars)}
out_fwd = [idx_map.get(id(v)) for v in res_vars]
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()):
mesh = eqn.params['mesh']
with (core.extend_axis_env_nd(mesh.shape.items()),
set_abstract_mesh(_as_manual_mesh(mesh))):
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)