mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and set_concrete_mesh -> use_concrete_mesh)
PiperOrigin-RevId: 736382641
This commit is contained in:
parent
a4ca0dbc6c
commit
2d01226b3b
@ -349,7 +349,7 @@ devices in the mesh are ignored for tracing and lowering:
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>>
|
||||
>>> # Use an AbstractMesh for exporting
|
||||
>>> export_mesh = AbstractMesh((("a", 4),))
|
||||
>>> export_mesh = AbstractMesh((4,), ("a",))
|
||||
|
||||
>>> def f(x):
|
||||
... return x.T
|
||||
|
@ -33,7 +33,7 @@ from jax._src import errors
|
||||
from jax._src import profiler
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.mesh import set_concrete_mesh
|
||||
from jax._src.mesh import use_concrete_mesh
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
@ -1112,7 +1112,7 @@ def shard_device_array(x, devices, indices, sharding):
|
||||
else:
|
||||
# TODO(yashkatariya): Maybe this should be set when we call the handler in
|
||||
# InputsHandler.__call__?
|
||||
with set_concrete_mesh(None):
|
||||
with use_concrete_mesh(None):
|
||||
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
|
||||
aval = core.shaped_abstractify(x)
|
||||
return pxla.batched_device_put(aval, sharding, shards, devices)
|
||||
|
@ -972,7 +972,7 @@ def shard_map_error_check(
|
||||
in_avals[i] = sharder(mesh, auto, new_in_names[i], v)
|
||||
|
||||
with (shard_map._extend_axis_env(mesh, auto),
|
||||
mesh_lib.set_abstract_mesh(shard_map._as_manual_mesh(mesh, auto))):
|
||||
mesh_lib.use_abstract_mesh(shard_map._as_manual_mesh(mesh, auto))):
|
||||
# jaxpr to checked_jaxpr
|
||||
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
|
||||
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals
|
||||
|
@ -1008,7 +1008,7 @@ def instantiate_zeros(tangent):
|
||||
if hasattr(tangent.aval, 'sharding'):
|
||||
# TODO(dougalm, yashkatariya): Delete this context manager once we figure
|
||||
# out how to ensure jaxpr arguments always have the context mesh.
|
||||
with mesh_lib.set_abstract_mesh(tangent.aval.sharding.mesh): # type: ignore
|
||||
with mesh_lib.use_abstract_mesh(tangent.aval.sharding.mesh): # type: ignore
|
||||
return zeros_like_aval(tangent.aval)
|
||||
return zeros_like_aval(tangent.aval)
|
||||
return tangent
|
||||
|
@ -1106,7 +1106,7 @@ def broadcast(x, sz, axis, mesh_axis=None):
|
||||
sharding = x_aval.sharding.with_spec(new_spec)
|
||||
# TODO(dougalm, yashkatariya): Delete this context manager once we figure
|
||||
# out how to ensure jaxpr arguments always have the context mesh.
|
||||
with mesh_lib.set_abstract_mesh(sharding.mesh):
|
||||
with mesh_lib.use_abstract_mesh(sharding.mesh):
|
||||
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims,
|
||||
out_sharding=sharding)
|
||||
|
||||
|
@ -535,7 +535,9 @@ class AbstractMesh(_BaseMesh):
|
||||
def _raise_value_error(name):
|
||||
raise ValueError(f"AbstractMesh does not implement {name}")
|
||||
|
||||
class SetAbstractMeshContextManager:
|
||||
empty_abstract_mesh = AbstractMesh((), ())
|
||||
|
||||
class UseAbstractMeshContextManager:
|
||||
__slots__ = ['mesh', 'prev']
|
||||
|
||||
def __init__(self, mesh: AbstractMesh):
|
||||
@ -547,18 +549,14 @@ class SetAbstractMeshContextManager:
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
jax_config.abstract_mesh_context_manager.set_local(self.prev)
|
||||
|
||||
set_abstract_mesh = SetAbstractMeshContextManager
|
||||
|
||||
|
||||
empty_abstract_mesh = AbstractMesh((), ())
|
||||
use_abstract_mesh = UseAbstractMeshContextManager
|
||||
|
||||
def get_abstract_mesh():
|
||||
val = jax_config.abstract_mesh_context_manager.value
|
||||
return empty_abstract_mesh if val is None else val
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_concrete_mesh(mesh: Mesh | None):
|
||||
def use_concrete_mesh(mesh: Mesh | None):
|
||||
prev_val = jax_config.device_context.swap_local(mesh)
|
||||
try:
|
||||
yield
|
||||
|
@ -2850,7 +2850,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
def decorator(*args, **kwargs):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'auto_axes',
|
||||
error_on_manual_to_auto_explict=True)
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
with mesh_lib.use_abstract_mesh(new_mesh):
|
||||
in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual(
|
||||
core.get_aval(a).sharding.spec, new_mesh), args)
|
||||
args = mesh_cast(args, in_specs)
|
||||
@ -2861,7 +2861,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
@contextlib.contextmanager
|
||||
def use_auto_axes(*axes):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'use_auto_axes')
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
with mesh_lib.use_abstract_mesh(new_mesh):
|
||||
yield
|
||||
|
||||
|
||||
@ -2870,7 +2870,7 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
def decorator(*args, **kwargs):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit, 'explicit_axes',
|
||||
error_on_manual_to_auto_explict=True)
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
with mesh_lib.use_abstract_mesh(new_mesh):
|
||||
args = mesh_cast(args, in_shardings)
|
||||
out = fun(*args, **kwargs)
|
||||
out_specs = tree_map(lambda o: core.modify_spec_for_auto_manual(
|
||||
@ -2882,7 +2882,7 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
def use_explicit_axes(*axes):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit,
|
||||
'use_explicit_axes')
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
with mesh_lib.use_abstract_mesh(new_mesh):
|
||||
yield
|
||||
|
||||
# -------------------- helpers --------------------
|
||||
|
@ -1389,8 +1389,8 @@ def use_mesh(mesh: mesh_lib.Mesh):
|
||||
# if not core.trace_state_clean():
|
||||
# raise ValueError('`use_mesh` can only be used outside of `jax.jit`')
|
||||
|
||||
with (mesh_lib.set_abstract_mesh(mesh.abstract_mesh),
|
||||
mesh_lib.set_concrete_mesh(mesh)):
|
||||
with (mesh_lib.use_abstract_mesh(mesh.abstract_mesh),
|
||||
mesh_lib.use_concrete_mesh(mesh)):
|
||||
yield
|
||||
|
||||
def set_mesh(mesh: mesh_lib.Mesh) -> None:
|
||||
|
@ -47,7 +47,7 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.core import Tracer
|
||||
from jax._src.mesh import (AbstractMesh, Mesh, AxisTypes, set_abstract_mesh,
|
||||
from jax._src.mesh import (AbstractMesh, Mesh, AxisTypes, use_abstract_mesh,
|
||||
get_abstract_mesh)
|
||||
from jax._src.api import _shared_code_pmap, _prepare_pmap
|
||||
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
|
||||
@ -523,7 +523,7 @@ def _shard_map_staging(
|
||||
in_avals = [t.aval for t in in_tracers]
|
||||
in_avals_ = map(partial(_shard_aval, mesh, auto), in_names, in_avals)
|
||||
manual_mesh = _as_manual_mesh(mesh, auto)
|
||||
with _extend_axis_env(mesh, auto), set_abstract_mesh(manual_mesh):
|
||||
with _extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh):
|
||||
jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
|
||||
_check_names(out_names_thunk(), out_avals_)
|
||||
if check_rep:
|
||||
@ -539,7 +539,7 @@ def _shard_map_staging(
|
||||
constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts))
|
||||
outvars = map(trace.makevar, out_tracers)
|
||||
in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore
|
||||
with _extend_axis_env(mesh, auto), set_abstract_mesh(manual_mesh):
|
||||
with _extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh):
|
||||
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
||||
params = dict(mesh=mesh, in_names=in_names_staged,
|
||||
out_names=tuple(out_names_thunk()), jaxpr=jaxpr,
|
||||
@ -857,7 +857,7 @@ def _run_shmap(f, mesh, auto, args, reps, check_rep, context_mesh):
|
||||
in_tracers = map(partial(ShardMapTracer, trace), reps, args)
|
||||
manual_mesh = _as_manual_mesh(mesh, auto)
|
||||
with (core.set_current_trace(trace), _extend_axis_env(mesh, auto),
|
||||
set_abstract_mesh(manual_mesh)):
|
||||
use_abstract_mesh(manual_mesh)):
|
||||
ans = f.call_wrapped(*in_tracers)
|
||||
outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans))
|
||||
return outs, out_rep
|
||||
@ -869,7 +869,7 @@ def _names_to_pspec(names: AxisNames) -> PartitionSpec:
|
||||
|
||||
def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType, context_mesh) -> JaxType:
|
||||
with (core.eval_context(), jax.disable_jit(False),
|
||||
set_abstract_mesh(context_mesh)):
|
||||
use_abstract_mesh(context_mesh)):
|
||||
return jax.jit(HashablePartial(_unmatch, mesh, tuple(src.items())))(x)
|
||||
|
||||
def _unmatch(mesh, src_tup, x):
|
||||
@ -948,7 +948,7 @@ class ShardMapTrace(core.Trace):
|
||||
else:
|
||||
f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh)
|
||||
with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False),
|
||||
jax.debug_infs(False), set_abstract_mesh(self.context_mesh)):
|
||||
jax.debug_infs(False), use_abstract_mesh(self.context_mesh)):
|
||||
out_vals = jax.jit(f)(*in_vals)
|
||||
_maybe_check_special(out_vals)
|
||||
rep_rule = _check_rules.get(prim, partial(_rule_missing, prim))
|
||||
@ -1018,13 +1018,13 @@ class ShardMapTracer(core.Tracer):
|
||||
|
||||
def to_concrete_value(self):
|
||||
if self.rep == set(self._trace.mesh.axis_names):
|
||||
with core.eval_context(), set_abstract_mesh(self._trace.context_mesh):
|
||||
with core.eval_context(), use_abstract_mesh(self._trace.context_mesh):
|
||||
return core.to_concrete_value(self.val[0])
|
||||
else:
|
||||
return None
|
||||
|
||||
def __str__(self) -> str:
|
||||
with core.eval_context(), set_abstract_mesh(self._trace.context_mesh):
|
||||
with core.eval_context(), use_abstract_mesh(self._trace.context_mesh):
|
||||
blocks = list(self.val)
|
||||
mesh = self._trace.mesh
|
||||
axis_names = f"({', '.join(map(str, mesh.axis_names))},)"
|
||||
@ -1801,7 +1801,7 @@ def _partial_eval_jaxpr_custom_rule(
|
||||
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
mesh = eqn.params['mesh']
|
||||
with (_extend_axis_env(mesh, auto),
|
||||
set_abstract_mesh(_as_manual_mesh(mesh, auto))):
|
||||
use_abstract_mesh(_as_manual_mesh(mesh, auto))):
|
||||
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
|
||||
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
|
||||
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)
|
||||
|
Loading…
x
Reference in New Issue
Block a user