1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 05:46:06 +00:00

Under pjit the with mesh: context will use use_mesh(mesh): jit instead of tracking separately using resource_env.

This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested.
This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally.

PiperOrigin-RevId: 735602187
This commit is contained in:
Yash Katariya 2025-03-10 20:20:20 -07:00 committed by jax authors
parent 02505fa757
commit 76dec38286
7 changed files with 65 additions and 108 deletions

@ -901,7 +901,7 @@ error_checks[lax.while_p] = while_loop_error_check
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_shardings, out_shardings,
in_layouts, out_layouts,
resource_env, donated_invars, name, inline, keep_unused,
donated_invars, ctx_mesh, name, inline, keep_unused,
compiler_options_kvs):
# jaxpr to checked_jaxpr
err_vals, err_tree = jtu.tree_flatten(error)
@ -928,8 +928,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
out_shardings=new_out_shardings,
in_layouts=new_in_layouts,
out_layouts=new_out_layouts,
resource_env=resource_env,
donated_invars=new_donated_invars,
ctx_mesh=ctx_mesh,
name=name,
inline=inline,
keep_unused=keep_unused,

@ -181,7 +181,8 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))(
*tiled_args
)
if closed_jaxpr.out_avals != tiled_results:
if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] !=
[(t.shape, t.dtype) for t in tiled_results]):
raise ValueError(
"Mismatch in result shapes. %s vs %s"
% (repr(closed_jaxpr.out_avals), repr(tiled_results))

@ -1663,7 +1663,7 @@ class MismatchType(enum.Enum):
elif self.name == 'OUT_SHARDING':
return 'explicit output sharding'
elif self.name == 'CONTEXT_DEVICES':
return 'devices'
return 'context mesh'
return f'{self.name}'
@ -3060,7 +3060,6 @@ class JitGlobalCppCacheKeys:
in_layouts_leaves: tuple[Any, ...] | None = None
out_layouts_treedef: PyTreeDef | None = None
out_layouts_leaves: tuple[Any, ...] | None = None
use_resource_env: bool = False
compiler_options_kvs: tuple[tuple[str, Any], ...] | None = None
@functools.cached_property

@ -357,7 +357,6 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
in_layouts_leaves=jit_info.in_layouts_leaves,
out_layouts_treedef=jit_info.out_layouts_treedef,
out_layouts_leaves=jit_info.out_layouts_leaves,
use_resource_env=jit_info.use_resource_env,
compiler_options_kvs=jit_info.compiler_options_kvs)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
@ -544,8 +543,7 @@ class PjitParams(NamedTuple):
def _infer_params_impl(
fun: Callable,
ji: PjitInfo,
pjit_mesh: mesh_lib.Mesh | None,
resource_env: mesh_lib.ResourceEnv | None,
ctx_mesh: mesh_lib.Mesh | None,
dbg: core.DebugInfo,
args: tuple[Any, ...],
kwargs: dict[str, Any],
@ -557,8 +555,8 @@ def _infer_params_impl(
raise ValueError(
"pjit does not support kwargs when in_shardings is specified.")
if pjit_mesh is not None:
if (ji.backend or ji.device) and not pjit_mesh.empty:
if ctx_mesh is not None:
if (ji.backend or ji.device) and not ctx_mesh.empty:
raise ValueError(
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit.")
@ -590,11 +588,11 @@ def _infer_params_impl(
in_shardings_treedef = out_shardings_treedef = treedef
else:
in_shardings_leaves = tuple(
_create_sharding_for_array(pjit_mesh, x, 'in_shardings', 'jit')
_create_sharding_for_array(ctx_mesh, x, 'in_shardings', 'jit')
for x in ji.in_shardings_leaves)
in_shardings_treedef = ji.in_shardings_treedef
out_shardings_leaves = tuple(
_create_sharding_for_array(pjit_mesh, x, 'out_shardings', 'jit')
_create_sharding_for_array(ctx_mesh, x, 'out_shardings', 'jit')
for x in ji.out_shardings_leaves)
out_shardings_treedef = ji.out_shardings_treedef
@ -652,8 +650,8 @@ def _infer_params_impl(
out_shardings=out_shardings_flat,
in_layouts=in_layouts_flat,
out_layouts=out_layouts_flat,
resource_env=resource_env,
donated_invars=donated_invars,
ctx_mesh=ctx_mesh,
name=fun_qual_name(flat_fun),
keep_unused=ji.keep_unused,
inline=ji.inline,
@ -683,38 +681,30 @@ def _infer_params_cached(
jit_info: PjitInfo,
signature: jax_jit.ArgumentSignature,
in_avals: tuple[core.AbstractValue, ...],
pjit_mesh: mesh_lib.Mesh | None,
resource_env: mesh_lib.ResourceEnv | None,
ctx_mesh: mesh_lib.Mesh | None,
) -> InferParamsCacheEntry:
return InferParamsCacheEntry()
def disallow_use_mesh_and_legacy_mesh_ctx_mgr_together():
if (not mesh_lib.thread_resources.env.physical_mesh.empty and
mesh_lib.get_concrete_mesh() is not None):
raise ValueError(
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
' together is not allowed.')
def _infer_params(
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[PjitParams, list[Any]]:
disallow_use_mesh_and_legacy_mesh_ctx_mgr_together()
) -> tuple[PjitParams, list[Any]]:
if ji.use_resource_env:
# We need to fetch the mesh from inside the wrapped function, because
# meshes are dynamically scoped (i.e., with a context manager).
resource_env = mesh_lib.thread_resources.env
pjit_mesh = resource_env.physical_mesh
else:
resource_env = None
pjit_mesh = None
with mesh_lib.use_mesh(mesh_lib.thread_resources.env.physical_mesh):
return _infer_params_internal(fun, ji, args, kwargs)
return _infer_params_internal(fun, ji, args, kwargs)
def _infer_params_internal(
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[PjitParams, list[Any]]:
ctx_mesh = mesh_lib.get_concrete_mesh()
dbg = debug_info(
'jit', fun, args, kwargs, static_argnums=ji.static_argnums,
static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo,
signature=ji.fun_signature)
if config.dynamic_shapes.value: # if dynamic shapes, don't use the cache
p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, dbg,
p, args_flat = _infer_params_impl(fun, ji, ctx_mesh, dbg,
args, kwargs, in_avals=None)
return p, p.consts + args_flat
@ -722,10 +712,11 @@ def _infer_params(
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
ji.static_argnames, tree_util.default_registry)
avals = _infer_input_type(fun, dbg, dynargs)
entry = _infer_params_cached(fun, ji, signature, avals, pjit_mesh, resource_env)
entry = _infer_params_cached(fun, ji, signature, avals, ctx_mesh)
if entry.pjit_params is None:
p, args_flat = _infer_params_impl(
fun, ji, pjit_mesh, resource_env, dbg, args, kwargs, in_avals=avals)
fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals)
if p.attrs_tracked: # if attrs, don't popoulate the cache
return p, p.consts + args_flat
entry.pjit_params = p
@ -1616,7 +1607,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
def _resolve_and_lower(
args, jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline,
lowering_platforms, lowering_parameters, pgle_profiler,
compiler_options_kvs):
in_shardings = _resolve_in_shardings(args, in_shardings)
@ -1624,8 +1615,8 @@ def _resolve_and_lower(
jaxpr.in_avals)
out_layouts = _resolve_out_layouts(out_layouts, out_shardings, jaxpr.out_avals)
return _pjit_lower(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, name, keep_unused, inline, compiler_options_kvs,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs,
lowering_platforms=lowering_platforms,
lowering_parameters=lowering_parameters,
pgle_profiler=pgle_profiler)
@ -1634,7 +1625,7 @@ _pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore
def _pjit_call_impl_python(
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
pgle_compile_options, pgle_profiler = {}, None
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
@ -1659,8 +1650,8 @@ def _pjit_call_impl_python(
compiled = _resolve_and_lower(
args, jaxpr=jaxpr, in_shardings=in_shardings,
out_shardings=out_shardings, in_layouts=in_layouts,
out_layouts=out_layouts, resource_env=resource_env,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
out_layouts=out_layouts, donated_invars=donated_invars,
ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused,
inline=inline, lowering_platforms=None,
lowering_parameters=mlir.LoweringParameters(),
pgle_profiler=pgle_profiler,
@ -1691,7 +1682,7 @@ def _pjit_call_impl_python(
@weakref_lru_cache
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, resource_env, donated_invars, name,
out_layouts, donated_invars, ctx_mesh, name,
keep_unused, inline, compiler_options_kvs):
# The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so
# returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to
@ -1705,14 +1696,14 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
def _pjit_call_impl(*args, jaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
def call_impl_cache_miss(*args_, **kwargs_):
out_flat, compiled, pgle_profiler = _pjit_call_impl_python(
*args, jaxpr=jaxpr, in_shardings=in_shardings,
out_shardings=out_shardings, in_layouts=in_layouts,
out_layouts=out_layouts, resource_env=resource_env,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
out_layouts=out_layouts, donated_invars=donated_invars,
ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused,
inline=inline, compiler_options_kvs=compiler_options_kvs)
fastpath_data = _get_fastpath_data(
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
@ -1721,7 +1712,7 @@ def _pjit_call_impl(*args, jaxpr,
f = _get_jaxpr_as_fun(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs)
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
cache_key = pxla.JitGlobalCppCacheKeys(
@ -1730,8 +1721,7 @@ def _pjit_call_impl(*args, jaxpr,
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
out_layouts_treedef=None, out_layouts_leaves=out_layouts,
use_resource_env=resource_env is not None)
out_layouts_treedef=None, out_layouts_leaves=out_layouts)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], cache_key,
tree_util.dispatch_registry, pxla.cc_shard_arg,
@ -1746,8 +1736,8 @@ def _pjit_lower(
out_shardings,
in_layouts: pxla.MaybeLayout,
out_layouts: pxla.MaybeLayout,
resource_env,
donated_invars,
ctx_mesh,
name: str,
keep_unused: bool,
inline: bool,
@ -1757,12 +1747,10 @@ def _pjit_lower(
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None):
util.test_event("pjit_lower")
mesh = (resource_env.physical_mesh if resource_env is not None else
mesh_lib.get_concrete_mesh())
return pxla.lower_sharding_computation(
jaxpr, 'jit', name, in_shardings, out_shardings,
in_layouts, out_layouts, tuple(donated_invars),
keep_unused=keep_unused, context_mesh=mesh,
keep_unused=keep_unused, context_mesh=ctx_mesh,
compiler_options_kvs=compiler_options_kvs,
lowering_platforms=lowering_platforms,
lowering_parameters=lowering_parameters,
@ -1914,8 +1902,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext,
def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
jaxpr: core.ClosedJaxpr, in_shardings,
out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, keep_unused, inline, compiler_options_kvs):
out_shardings, in_layouts, out_layouts, donated_invars,
ctx_mesh, keep_unused, inline, compiler_options_kvs):
effects = list(ctx.tokens_in.effects())
output_types = map(mlir.aval_to_ir_type, ctx.avals_out)
output_types = [mlir.token_type()] * len(effects) + output_types
@ -1945,23 +1933,20 @@ def _pjit_batcher(axis_data, vals_in,
dims_in: tuple[int, ...],
jaxpr: core.ClosedJaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in)
if resource_env is not None:
mesh = resource_env.physical_mesh
else:
mesh = None
# TODO(axch): prepend with Nones (?) to account for new segment_lens inputs
in_shardings = tuple(
_pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, mesh, aval.ndim)
_pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, ctx_mesh,
aval.ndim)
if axis_in is not None else i
for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals))
out_shardings = tuple(
_pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, mesh, aval.ndim)
_pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, ctx_mesh,
aval.ndim)
if axis_out is not None else o
for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals))
# TODO(yashkatariya): Figure out layouts should change under vmap.
@ -1977,8 +1962,8 @@ def _pjit_batcher(axis_data, vals_in,
out_shardings=out_shardings,
in_layouts=in_layouts,
out_layouts=out_layouts,
resource_env=resource_env,
donated_invars=donated_invars,
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
@ -2000,8 +1985,8 @@ def _insert_axis_partitions(spec, dim, val):
def _pjit_batcher_for_sharding(
s: Sharding | UnspecifiedValue,
dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, mesh,
ndim: int):
dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None,
mesh, ndim: int):
if isinstance(s, UnspecifiedValue):
return s
hlo_s = s._to_xla_hlo_sharding(ndim)
@ -2040,7 +2025,7 @@ def _pjit_batcher_for_sharding(
def _pjit_jvp(primals_in, tangents_in,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
@ -2057,8 +2042,8 @@ def _pjit_jvp(primals_in, tangents_in,
out_shardings=(*out_shardings, *_filter_zeros_out(out_shardings)),
in_layouts=(*in_layouts, *_filter_zeros_in(in_layouts)),
out_layouts=(*out_layouts, *_filter_zeros_out(out_layouts)),
resource_env=resource_env,
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
@ -2074,7 +2059,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
def _pjit_linearization(nzs, *primals_in, jaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs)
# constvars will become residuals. Move them to the end of the ordinary args.
@ -2090,8 +2075,8 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
out_shardings=_filter_zeros(nzs_out, out_shardings),
in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts,
out_layouts=_filter_zeros(nzs_out, out_layouts),
resource_env=resource_env,
donated_invars=_filter_zeros(nzs, donated_invars) + res_donated,
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
@ -2110,8 +2095,8 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
out_shardings=(*res_shardings, *out_shardings),
in_layouts=in_layouts,
out_layouts=(*res_layouts, *out_layouts),
resource_env=resource_env,
donated_invars=donated_invars,
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
@ -2126,7 +2111,7 @@ ad.primitive_linearizations[pjit_p] = _pjit_linearization
def _pjit_partial_eval(trace: pe.JaxprTrace,
*in_tracers,
jaxpr: core.ClosedJaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, resource_env, donated_invars,
in_layouts, out_layouts, donated_invars, ctx_mesh,
name, keep_unused, inline, compiler_options_kvs):
in_pvals = [t.pval for t in in_tracers]
@ -2193,8 +2178,9 @@ def _pjit_partial_eval(trace: pe.JaxprTrace,
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
out_shardings=known_out_shardings,
in_layouts=keep_where(in_layouts, known_ins),
out_layouts=known_out_layouts, resource_env=resource_env,
out_layouts=known_out_layouts,
donated_invars=keep_where(donated_invars, known_ins),
ctx_mesh=ctx_mesh,
name=name, keep_unused=keep_unused, inline=inline,
compiler_options_kvs=compiler_options_kvs)
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
@ -2225,9 +2211,9 @@ def _pjit_partial_eval(trace: pe.JaxprTrace,
out_shardings=keep_where(out_shardings, unknown_outs),
in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts),
out_layouts=keep_where(out_layouts, unknown_outs),
resource_env=resource_env,
donated_invars=(keep_where(donated_invars, unknown_ins) +
(False,) * num_residuals),
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
@ -2313,7 +2299,7 @@ def _pjit_transpose_trace(fun: lu.WrappedFun,
def _pjit_transpose(cts_in, *primals_in,
jaxpr: core.ClosedJaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
def prune_type(ty, xs, maybe_zeros):
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
@ -2362,8 +2348,8 @@ def _pjit_transpose(cts_in, *primals_in,
out_shardings=transpose_out_shardings,
in_layouts=transpose_in_layouts,
out_layouts=transpose_out_layouts,
resource_env=resource_env,
donated_invars=(False,) * len(primals_and_nz_cts_in),
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
@ -2447,9 +2433,8 @@ def _pjit_pp_rule(eqn: core.JaxprEqn,
del params['out_layouts']
if not params['keep_unused']:
del params['keep_unused']
if (params['resource_env'] is None or
params['resource_env'].physical_mesh.empty):
del params['resource_env']
if params['ctx_mesh'] is None or params['ctx_mesh'].empty:
del params['ctx_mesh']
if not params['compiler_options_kvs']:
del params['compiler_options_kvs']
@ -2549,8 +2534,6 @@ def with_sharding_constraint(x, shardings):
flatten_axes("with_sharding_constraint layouts", tree, layouts))
del layouts
disallow_use_mesh_and_legacy_mesh_ctx_mgr_together()
context_mesh = (
mesh_lib.get_abstract_mesh() if mesh_lib.get_concrete_mesh() is not None
else mesh_lib.thread_resources.env.physical_mesh)

@ -3573,8 +3573,8 @@ def _pjit(*args: TfVal,
in_shardings: Sequence[sharding.Sharding],
out_shardings: Sequence[sharding.Sharding],
in_layouts, out_layouts,
resource_env: mesh.ResourceEnv,
donated_invars,
ctx_mesh,
name: str,
keep_unused: bool,
inline: bool,

@ -775,7 +775,7 @@ sparse_rules_bcoo[lax.while_p] = _while_sparse
def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, resource_env, donated_invars, name,
in_layouts, out_layouts, donated_invars, ctx_mesh, name,
keep_unused, inline, compiler_options_kvs):
if any(donated_invars):
raise NotImplementedError("sparse xla_call with donated_invars")
@ -808,8 +808,8 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
out_shardings=out_shardings,
in_layouts=in_layouts,
out_layouts=out_layouts,
resource_env=resource_env,
donated_invars=donated_invars,
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,

@ -1205,8 +1205,7 @@ class PJitTest(jtu.BufferDonationTestCase):
with self.assertRaisesRegex(
ValueError,
r"One of with_sharding_constraint.*Sharding "
r"NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), "
r"spec=PartitionSpec\(None, 'mdl', None, None\).*\) is only "
r"NamedSharding.*PartitionSpec\(None, 'mdl', None, None\).*\) is only "
"valid for values of rank at least 4, but was applied to a value of rank 1"):
pjit_f(jnp.array([1, 2, 3]))
@ -6873,31 +6872,6 @@ class ShardingInTypesTest(jtu.JaxTestCase):
' axis_types are `Auto`'):
NamedSharding(mesh, P(P.UNCONSTRAINED))
def test_use_mesh_legacy_mesh_ctx_mgr_mix_error(self):
mesh = jtu.create_mesh((1, 1), ('x', 'y'))
with self.assertRaisesRegex(
ValueError,
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
' together is not allowed'):
with jax.sharding.use_mesh(mesh), mesh:
jax.jit(lambda x: x)(jnp.arange(8))
with self.assertRaisesRegex(
ValueError,
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
' together is not allowed'):
with jax.sharding.use_mesh(mesh), mesh:
jnp.zeros((8, 2), dtype=jnp.int32)
x = jnp.arange(8)
with self.assertRaisesRegex(
ValueError,
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
' together is not allowed'):
with jax.sharding.use_mesh(mesh), mesh:
jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
def test_pspec_einsum_no_context_mesh(self):
mesh = jtu.create_mesh((1, 1), ('x', 'y'),
axis_types={AxisTypes.Explicit: ('x', 'y')})