Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and set_concrete_mesh -> use_concrete_mesh)

PiperOrigin-RevId: 736382641
This commit is contained in:
Yash Katariya 2025-03-12 22:29:08 -07:00 committed by jax authors
parent a4ca0dbc6c
commit 2d01226b3b
9 changed files with 26 additions and 28 deletions

View File

@ -349,7 +349,7 @@ devices in the mesh are ignored for tracing and lowering:
>>> from jax.sharding import PartitionSpec as P
>>>
>>> # Use an AbstractMesh for exporting
>>> export_mesh = AbstractMesh((("a", 4),))
>>> export_mesh = AbstractMesh((4,), ("a",))
>>> def f(x):
... return x.T

View File

@ -33,7 +33,7 @@ from jax._src import errors
from jax._src import profiler
from jax._src import util
from jax._src import xla_bridge
from jax._src.mesh import set_concrete_mesh
from jax._src.mesh import use_concrete_mesh
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
@ -1112,7 +1112,7 @@ def shard_device_array(x, devices, indices, sharding):
else:
# TODO(yashkatariya): Maybe this should be set when we call the handler in
# InputsHandler.__call__?
with set_concrete_mesh(None):
with use_concrete_mesh(None):
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
aval = core.shaped_abstractify(x)
return pxla.batched_device_put(aval, sharding, shards, devices)

View File

@ -972,7 +972,7 @@ def shard_map_error_check(
in_avals[i] = sharder(mesh, auto, new_in_names[i], v)
with (shard_map._extend_axis_env(mesh, auto),
mesh_lib.set_abstract_mesh(shard_map._as_manual_mesh(mesh, auto))):
mesh_lib.use_abstract_mesh(shard_map._as_manual_mesh(mesh, auto))):
# jaxpr to checked_jaxpr
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals

View File

@ -1008,7 +1008,7 @@ def instantiate_zeros(tangent):
if hasattr(tangent.aval, 'sharding'):
# TODO(dougalm, yashkatariya): Delete this context manager once we figure
# out how to ensure jaxpr arguments always have the context mesh.
with mesh_lib.set_abstract_mesh(tangent.aval.sharding.mesh): # type: ignore
with mesh_lib.use_abstract_mesh(tangent.aval.sharding.mesh): # type: ignore
return zeros_like_aval(tangent.aval)
return zeros_like_aval(tangent.aval)
return tangent

View File

@ -1106,7 +1106,7 @@ def broadcast(x, sz, axis, mesh_axis=None):
sharding = x_aval.sharding.with_spec(new_spec)
# TODO(dougalm, yashkatariya): Delete this context manager once we figure
# out how to ensure jaxpr arguments always have the context mesh.
with mesh_lib.set_abstract_mesh(sharding.mesh):
with mesh_lib.use_abstract_mesh(sharding.mesh):
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims,
out_sharding=sharding)

View File

@ -535,7 +535,9 @@ class AbstractMesh(_BaseMesh):
def _raise_value_error(name):
raise ValueError(f"AbstractMesh does not implement {name}")
class SetAbstractMeshContextManager:
empty_abstract_mesh = AbstractMesh((), ())
class UseAbstractMeshContextManager:
__slots__ = ['mesh', 'prev']
def __init__(self, mesh: AbstractMesh):
@ -547,18 +549,14 @@ class SetAbstractMeshContextManager:
def __exit__(self, exc_type, exc_value, traceback):
jax_config.abstract_mesh_context_manager.set_local(self.prev)
set_abstract_mesh = SetAbstractMeshContextManager
empty_abstract_mesh = AbstractMesh((), ())
use_abstract_mesh = UseAbstractMeshContextManager
def get_abstract_mesh():
val = jax_config.abstract_mesh_context_manager.value
return empty_abstract_mesh if val is None else val
@contextlib.contextmanager
def set_concrete_mesh(mesh: Mesh | None):
def use_concrete_mesh(mesh: Mesh | None):
prev_val = jax_config.device_context.swap_local(mesh)
try:
yield

View File

@ -2850,7 +2850,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None,
def decorator(*args, **kwargs):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'auto_axes',
error_on_manual_to_auto_explict=True)
with mesh_lib.set_abstract_mesh(new_mesh):
with mesh_lib.use_abstract_mesh(new_mesh):
in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual(
core.get_aval(a).sharding.spec, new_mesh), args)
args = mesh_cast(args, in_specs)
@ -2861,7 +2861,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None,
@contextlib.contextmanager
def use_auto_axes(*axes):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'use_auto_axes')
with mesh_lib.set_abstract_mesh(new_mesh):
with mesh_lib.use_abstract_mesh(new_mesh):
yield
@ -2870,7 +2870,7 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None,
def decorator(*args, **kwargs):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit, 'explicit_axes',
error_on_manual_to_auto_explict=True)
with mesh_lib.set_abstract_mesh(new_mesh):
with mesh_lib.use_abstract_mesh(new_mesh):
args = mesh_cast(args, in_shardings)
out = fun(*args, **kwargs)
out_specs = tree_map(lambda o: core.modify_spec_for_auto_manual(
@ -2882,7 +2882,7 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None,
def use_explicit_axes(*axes):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit,
'use_explicit_axes')
with mesh_lib.set_abstract_mesh(new_mesh):
with mesh_lib.use_abstract_mesh(new_mesh):
yield
# -------------------- helpers --------------------

View File

@ -1389,8 +1389,8 @@ def use_mesh(mesh: mesh_lib.Mesh):
# if not core.trace_state_clean():
# raise ValueError('`use_mesh` can only be used outside of `jax.jit`')
with (mesh_lib.set_abstract_mesh(mesh.abstract_mesh),
mesh_lib.set_concrete_mesh(mesh)):
with (mesh_lib.use_abstract_mesh(mesh.abstract_mesh),
mesh_lib.use_concrete_mesh(mesh)):
yield
def set_mesh(mesh: mesh_lib.Mesh) -> None:

View File

@ -47,7 +47,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.core import Tracer
from jax._src.mesh import (AbstractMesh, Mesh, AxisTypes, set_abstract_mesh,
from jax._src.mesh import (AbstractMesh, Mesh, AxisTypes, use_abstract_mesh,
get_abstract_mesh)
from jax._src.api import _shared_code_pmap, _prepare_pmap
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
@ -523,7 +523,7 @@ def _shard_map_staging(
in_avals = [t.aval for t in in_tracers]
in_avals_ = map(partial(_shard_aval, mesh, auto), in_names, in_avals)
manual_mesh = _as_manual_mesh(mesh, auto)
with _extend_axis_env(mesh, auto), set_abstract_mesh(manual_mesh):
with _extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh):
jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
_check_names(out_names_thunk(), out_avals_)
if check_rep:
@ -539,7 +539,7 @@ def _shard_map_staging(
constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts))
outvars = map(trace.makevar, out_tracers)
in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore
with _extend_axis_env(mesh, auto), set_abstract_mesh(manual_mesh):
with _extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh):
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
params = dict(mesh=mesh, in_names=in_names_staged,
out_names=tuple(out_names_thunk()), jaxpr=jaxpr,
@ -857,7 +857,7 @@ def _run_shmap(f, mesh, auto, args, reps, check_rep, context_mesh):
in_tracers = map(partial(ShardMapTracer, trace), reps, args)
manual_mesh = _as_manual_mesh(mesh, auto)
with (core.set_current_trace(trace), _extend_axis_env(mesh, auto),
set_abstract_mesh(manual_mesh)):
use_abstract_mesh(manual_mesh)):
ans = f.call_wrapped(*in_tracers)
outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans))
return outs, out_rep
@ -869,7 +869,7 @@ def _names_to_pspec(names: AxisNames) -> PartitionSpec:
def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType, context_mesh) -> JaxType:
with (core.eval_context(), jax.disable_jit(False),
set_abstract_mesh(context_mesh)):
use_abstract_mesh(context_mesh)):
return jax.jit(HashablePartial(_unmatch, mesh, tuple(src.items())))(x)
def _unmatch(mesh, src_tup, x):
@ -948,7 +948,7 @@ class ShardMapTrace(core.Trace):
else:
f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh)
with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False),
jax.debug_infs(False), set_abstract_mesh(self.context_mesh)):
jax.debug_infs(False), use_abstract_mesh(self.context_mesh)):
out_vals = jax.jit(f)(*in_vals)
_maybe_check_special(out_vals)
rep_rule = _check_rules.get(prim, partial(_rule_missing, prim))
@ -1018,13 +1018,13 @@ class ShardMapTracer(core.Tracer):
def to_concrete_value(self):
if self.rep == set(self._trace.mesh.axis_names):
with core.eval_context(), set_abstract_mesh(self._trace.context_mesh):
with core.eval_context(), use_abstract_mesh(self._trace.context_mesh):
return core.to_concrete_value(self.val[0])
else:
return None
def __str__(self) -> str:
with core.eval_context(), set_abstract_mesh(self._trace.context_mesh):
with core.eval_context(), use_abstract_mesh(self._trace.context_mesh):
blocks = list(self.val)
mesh = self._trace.mesh
axis_names = f"({', '.join(map(str, mesh.axis_names))},)"
@ -1801,7 +1801,7 @@ def _partial_eval_jaxpr_custom_rule(
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
mesh = eqn.params['mesh']
with (_extend_axis_env(mesh, auto),
set_abstract_mesh(_as_manual_mesh(mesh, auto))):
use_abstract_mesh(_as_manual_mesh(mesh, auto))):
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)