From 2d01226b3bf400b7db86802a5114b4a4877acc13 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 12 Mar 2025 22:29:08 -0700 Subject: [PATCH] Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and set_concrete_mesh -> use_concrete_mesh) PiperOrigin-RevId: 736382641 --- docs/export/export.md | 2 +- jax/_src/array.py | 4 ++-- jax/_src/checkify.py | 2 +- jax/_src/interpreters/ad.py | 2 +- jax/_src/interpreters/batching.py | 2 +- jax/_src/mesh.py | 12 +++++------- jax/_src/pjit.py | 8 ++++---- jax/_src/sharding_impls.py | 4 ++-- jax/experimental/shard_map.py | 18 +++++++++--------- 9 files changed, 26 insertions(+), 28 deletions(-) diff --git a/docs/export/export.md b/docs/export/export.md index f1542d80d..18cdcc6c5 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -349,7 +349,7 @@ devices in the mesh are ignored for tracing and lowering: >>> from jax.sharding import PartitionSpec as P >>> >>> # Use an AbstractMesh for exporting ->>> export_mesh = AbstractMesh((("a", 4),)) +>>> export_mesh = AbstractMesh((4,), ("a",)) >>> def f(x): ... return x.T diff --git a/jax/_src/array.py b/jax/_src/array.py index e472dfcf3..a8e3ec3d0 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -33,7 +33,7 @@ from jax._src import errors from jax._src import profiler from jax._src import util from jax._src import xla_bridge -from jax._src.mesh import set_concrete_mesh +from jax._src.mesh import use_concrete_mesh from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.interpreters import xla @@ -1112,7 +1112,7 @@ def shard_device_array(x, devices, indices, sharding): else: # TODO(yashkatariya): Maybe this should be set when we call the handler in # InputsHandler.__call__? - with set_concrete_mesh(None): + with use_concrete_mesh(None): shards = x._multi_slice(start_indices, limit_indices, removed_dims) aval = core.shaped_abstractify(x) return pxla.batched_device_put(aval, sharding, shards, devices) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 74ea53714..8a797e1f3 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -972,7 +972,7 @@ def shard_map_error_check( in_avals[i] = sharder(mesh, auto, new_in_names[i], v) with (shard_map._extend_axis_env(mesh, auto), - mesh_lib.set_abstract_mesh(shard_map._as_manual_mesh(mesh, auto))): + mesh_lib.use_abstract_mesh(shard_map._as_manual_mesh(mesh, auto))): # jaxpr to checked_jaxpr checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 2f835ab83..c20108b08 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -1008,7 +1008,7 @@ def instantiate_zeros(tangent): if hasattr(tangent.aval, 'sharding'): # TODO(dougalm, yashkatariya): Delete this context manager once we figure # out how to ensure jaxpr arguments always have the context mesh. - with mesh_lib.set_abstract_mesh(tangent.aval.sharding.mesh): # type: ignore + with mesh_lib.use_abstract_mesh(tangent.aval.sharding.mesh): # type: ignore return zeros_like_aval(tangent.aval) return zeros_like_aval(tangent.aval) return tangent diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index edeef78c4..c7495d06e 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -1106,7 +1106,7 @@ def broadcast(x, sz, axis, mesh_axis=None): sharding = x_aval.sharding.with_spec(new_spec) # TODO(dougalm, yashkatariya): Delete this context manager once we figure # out how to ensure jaxpr arguments always have the context mesh. - with mesh_lib.set_abstract_mesh(sharding.mesh): + with mesh_lib.use_abstract_mesh(sharding.mesh): return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 4c456e969..e4ac74650 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -535,7 +535,9 @@ class AbstractMesh(_BaseMesh): def _raise_value_error(name): raise ValueError(f"AbstractMesh does not implement {name}") -class SetAbstractMeshContextManager: +empty_abstract_mesh = AbstractMesh((), ()) + +class UseAbstractMeshContextManager: __slots__ = ['mesh', 'prev'] def __init__(self, mesh: AbstractMesh): @@ -547,18 +549,14 @@ class SetAbstractMeshContextManager: def __exit__(self, exc_type, exc_value, traceback): jax_config.abstract_mesh_context_manager.set_local(self.prev) -set_abstract_mesh = SetAbstractMeshContextManager - - -empty_abstract_mesh = AbstractMesh((), ()) +use_abstract_mesh = UseAbstractMeshContextManager def get_abstract_mesh(): val = jax_config.abstract_mesh_context_manager.value return empty_abstract_mesh if val is None else val - @contextlib.contextmanager -def set_concrete_mesh(mesh: Mesh | None): +def use_concrete_mesh(mesh: Mesh | None): prev_val = jax_config.device_context.swap_local(mesh) try: yield diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index caa9a0d58..fbe8b1b27 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2850,7 +2850,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None, def decorator(*args, **kwargs): new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'auto_axes', error_on_manual_to_auto_explict=True) - with mesh_lib.set_abstract_mesh(new_mesh): + with mesh_lib.use_abstract_mesh(new_mesh): in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual( core.get_aval(a).sharding.spec, new_mesh), args) args = mesh_cast(args, in_specs) @@ -2861,7 +2861,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None, @contextlib.contextmanager def use_auto_axes(*axes): new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'use_auto_axes') - with mesh_lib.set_abstract_mesh(new_mesh): + with mesh_lib.use_abstract_mesh(new_mesh): yield @@ -2870,7 +2870,7 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None, def decorator(*args, **kwargs): new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit, 'explicit_axes', error_on_manual_to_auto_explict=True) - with mesh_lib.set_abstract_mesh(new_mesh): + with mesh_lib.use_abstract_mesh(new_mesh): args = mesh_cast(args, in_shardings) out = fun(*args, **kwargs) out_specs = tree_map(lambda o: core.modify_spec_for_auto_manual( @@ -2882,7 +2882,7 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None, def use_explicit_axes(*axes): new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit, 'use_explicit_axes') - with mesh_lib.set_abstract_mesh(new_mesh): + with mesh_lib.use_abstract_mesh(new_mesh): yield # -------------------- helpers -------------------- diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 025145ac9..25a002efa 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1389,8 +1389,8 @@ def use_mesh(mesh: mesh_lib.Mesh): # if not core.trace_state_clean(): # raise ValueError('`use_mesh` can only be used outside of `jax.jit`') - with (mesh_lib.set_abstract_mesh(mesh.abstract_mesh), - mesh_lib.set_concrete_mesh(mesh)): + with (mesh_lib.use_abstract_mesh(mesh.abstract_mesh), + mesh_lib.use_concrete_mesh(mesh)): yield def set_mesh(mesh: mesh_lib.Mesh) -> None: diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index f43062326..bb6790b72 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -47,7 +47,7 @@ from jax._src import source_info_util from jax._src import traceback_util from jax._src import util from jax._src.core import Tracer -from jax._src.mesh import (AbstractMesh, Mesh, AxisTypes, set_abstract_mesh, +from jax._src.mesh import (AbstractMesh, Mesh, AxisTypes, use_abstract_mesh, get_abstract_mesh) from jax._src.api import _shared_code_pmap, _prepare_pmap from jax._src.lax import (lax, parallel as lax_parallel, slicing, @@ -523,7 +523,7 @@ def _shard_map_staging( in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(_shard_aval, mesh, auto), in_names, in_avals) manual_mesh = _as_manual_mesh(mesh, auto) - with _extend_axis_env(mesh, auto), set_abstract_mesh(manual_mesh): + with _extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh): jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) _check_names(out_names_thunk(), out_avals_) if check_rep: @@ -539,7 +539,7 @@ def _shard_map_staging( constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts)) outvars = map(trace.makevar, out_tracers) in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore - with _extend_axis_env(mesh, auto), set_abstract_mesh(manual_mesh): + with _extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh): jaxpr = pe.convert_constvars_jaxpr(jaxpr) params = dict(mesh=mesh, in_names=in_names_staged, out_names=tuple(out_names_thunk()), jaxpr=jaxpr, @@ -857,7 +857,7 @@ def _run_shmap(f, mesh, auto, args, reps, check_rep, context_mesh): in_tracers = map(partial(ShardMapTracer, trace), reps, args) manual_mesh = _as_manual_mesh(mesh, auto) with (core.set_current_trace(trace), _extend_axis_env(mesh, auto), - set_abstract_mesh(manual_mesh)): + use_abstract_mesh(manual_mesh)): ans = f.call_wrapped(*in_tracers) outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans)) return outs, out_rep @@ -869,7 +869,7 @@ def _names_to_pspec(names: AxisNames) -> PartitionSpec: def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType, context_mesh) -> JaxType: with (core.eval_context(), jax.disable_jit(False), - set_abstract_mesh(context_mesh)): + use_abstract_mesh(context_mesh)): return jax.jit(HashablePartial(_unmatch, mesh, tuple(src.items())))(x) def _unmatch(mesh, src_tup, x): @@ -948,7 +948,7 @@ class ShardMapTrace(core.Trace): else: f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh) with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False), - jax.debug_infs(False), set_abstract_mesh(self.context_mesh)): + jax.debug_infs(False), use_abstract_mesh(self.context_mesh)): out_vals = jax.jit(f)(*in_vals) _maybe_check_special(out_vals) rep_rule = _check_rules.get(prim, partial(_rule_missing, prim)) @@ -1018,13 +1018,13 @@ class ShardMapTracer(core.Tracer): def to_concrete_value(self): if self.rep == set(self._trace.mesh.axis_names): - with core.eval_context(), set_abstract_mesh(self._trace.context_mesh): + with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): return core.to_concrete_value(self.val[0]) else: return None def __str__(self) -> str: - with core.eval_context(), set_abstract_mesh(self._trace.context_mesh): + with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): blocks = list(self.val) mesh = self._trace.mesh axis_names = f"({', '.join(map(str, mesh.axis_names))},)" @@ -1801,7 +1801,7 @@ def _partial_eval_jaxpr_custom_rule( which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] mesh = eqn.params['mesh'] with (_extend_axis_env(mesh, auto), - set_abstract_mesh(_as_manual_mesh(mesh, auto))): + use_abstract_mesh(_as_manual_mesh(mesh, auto))): jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)