mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 05:46:06 +00:00
Under pjit the with mesh:
context will use use_mesh(mesh): jit
instead of tracking separately using resource_env
.
This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested. This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally. PiperOrigin-RevId: 735602187
This commit is contained in:
parent
02505fa757
commit
76dec38286
jax
_src
experimental
tests
@ -901,7 +901,7 @@ error_checks[lax.while_p] = while_loop_error_check
|
||||
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
in_shardings, out_shardings,
|
||||
in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, inline, keep_unused,
|
||||
donated_invars, ctx_mesh, name, inline, keep_unused,
|
||||
compiler_options_kvs):
|
||||
# jaxpr to checked_jaxpr
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
@ -928,8 +928,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
out_shardings=new_out_shardings,
|
||||
in_layouts=new_in_layouts,
|
||||
out_layouts=new_out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=new_donated_invars,
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
inline=inline,
|
||||
keep_unused=keep_unused,
|
||||
|
@ -181,7 +181,8 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
||||
closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))(
|
||||
*tiled_args
|
||||
)
|
||||
if closed_jaxpr.out_avals != tiled_results:
|
||||
if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] !=
|
||||
[(t.shape, t.dtype) for t in tiled_results]):
|
||||
raise ValueError(
|
||||
"Mismatch in result shapes. %s vs %s"
|
||||
% (repr(closed_jaxpr.out_avals), repr(tiled_results))
|
||||
|
@ -1663,7 +1663,7 @@ class MismatchType(enum.Enum):
|
||||
elif self.name == 'OUT_SHARDING':
|
||||
return 'explicit output sharding'
|
||||
elif self.name == 'CONTEXT_DEVICES':
|
||||
return 'devices'
|
||||
return 'context mesh'
|
||||
return f'{self.name}'
|
||||
|
||||
|
||||
@ -3060,7 +3060,6 @@ class JitGlobalCppCacheKeys:
|
||||
in_layouts_leaves: tuple[Any, ...] | None = None
|
||||
out_layouts_treedef: PyTreeDef | None = None
|
||||
out_layouts_leaves: tuple[Any, ...] | None = None
|
||||
use_resource_env: bool = False
|
||||
compiler_options_kvs: tuple[tuple[str, Any], ...] | None = None
|
||||
|
||||
@functools.cached_property
|
||||
|
129
jax/_src/pjit.py
129
jax/_src/pjit.py
@ -357,7 +357,6 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
|
||||
in_layouts_leaves=jit_info.in_layouts_leaves,
|
||||
out_layouts_treedef=jit_info.out_layouts_treedef,
|
||||
out_layouts_leaves=jit_info.out_layouts_leaves,
|
||||
use_resource_env=jit_info.use_resource_env,
|
||||
compiler_options_kvs=jit_info.compiler_options_kvs)
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
|
||||
@ -544,8 +543,7 @@ class PjitParams(NamedTuple):
|
||||
def _infer_params_impl(
|
||||
fun: Callable,
|
||||
ji: PjitInfo,
|
||||
pjit_mesh: mesh_lib.Mesh | None,
|
||||
resource_env: mesh_lib.ResourceEnv | None,
|
||||
ctx_mesh: mesh_lib.Mesh | None,
|
||||
dbg: core.DebugInfo,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
@ -557,8 +555,8 @@ def _infer_params_impl(
|
||||
raise ValueError(
|
||||
"pjit does not support kwargs when in_shardings is specified.")
|
||||
|
||||
if pjit_mesh is not None:
|
||||
if (ji.backend or ji.device) and not pjit_mesh.empty:
|
||||
if ctx_mesh is not None:
|
||||
if (ji.backend or ji.device) and not ctx_mesh.empty:
|
||||
raise ValueError(
|
||||
"Mesh context manager should not be used with jit when backend or "
|
||||
"device is also specified as an argument to jit.")
|
||||
@ -590,11 +588,11 @@ def _infer_params_impl(
|
||||
in_shardings_treedef = out_shardings_treedef = treedef
|
||||
else:
|
||||
in_shardings_leaves = tuple(
|
||||
_create_sharding_for_array(pjit_mesh, x, 'in_shardings', 'jit')
|
||||
_create_sharding_for_array(ctx_mesh, x, 'in_shardings', 'jit')
|
||||
for x in ji.in_shardings_leaves)
|
||||
in_shardings_treedef = ji.in_shardings_treedef
|
||||
out_shardings_leaves = tuple(
|
||||
_create_sharding_for_array(pjit_mesh, x, 'out_shardings', 'jit')
|
||||
_create_sharding_for_array(ctx_mesh, x, 'out_shardings', 'jit')
|
||||
for x in ji.out_shardings_leaves)
|
||||
out_shardings_treedef = ji.out_shardings_treedef
|
||||
|
||||
@ -652,8 +650,8 @@ def _infer_params_impl(
|
||||
out_shardings=out_shardings_flat,
|
||||
in_layouts=in_layouts_flat,
|
||||
out_layouts=out_layouts_flat,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=fun_qual_name(flat_fun),
|
||||
keep_unused=ji.keep_unused,
|
||||
inline=ji.inline,
|
||||
@ -683,38 +681,30 @@ def _infer_params_cached(
|
||||
jit_info: PjitInfo,
|
||||
signature: jax_jit.ArgumentSignature,
|
||||
in_avals: tuple[core.AbstractValue, ...],
|
||||
pjit_mesh: mesh_lib.Mesh | None,
|
||||
resource_env: mesh_lib.ResourceEnv | None,
|
||||
ctx_mesh: mesh_lib.Mesh | None,
|
||||
) -> InferParamsCacheEntry:
|
||||
return InferParamsCacheEntry()
|
||||
|
||||
def disallow_use_mesh_and_legacy_mesh_ctx_mgr_together():
|
||||
if (not mesh_lib.thread_resources.env.physical_mesh.empty and
|
||||
mesh_lib.get_concrete_mesh() is not None):
|
||||
raise ValueError(
|
||||
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
|
||||
' together is not allowed.')
|
||||
|
||||
def _infer_params(
|
||||
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[PjitParams, list[Any]]:
|
||||
disallow_use_mesh_and_legacy_mesh_ctx_mgr_together()
|
||||
) -> tuple[PjitParams, list[Any]]:
|
||||
if ji.use_resource_env:
|
||||
# We need to fetch the mesh from inside the wrapped function, because
|
||||
# meshes are dynamically scoped (i.e., with a context manager).
|
||||
resource_env = mesh_lib.thread_resources.env
|
||||
pjit_mesh = resource_env.physical_mesh
|
||||
else:
|
||||
resource_env = None
|
||||
pjit_mesh = None
|
||||
with mesh_lib.use_mesh(mesh_lib.thread_resources.env.physical_mesh):
|
||||
return _infer_params_internal(fun, ji, args, kwargs)
|
||||
return _infer_params_internal(fun, ji, args, kwargs)
|
||||
|
||||
def _infer_params_internal(
|
||||
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[PjitParams, list[Any]]:
|
||||
ctx_mesh = mesh_lib.get_concrete_mesh()
|
||||
dbg = debug_info(
|
||||
'jit', fun, args, kwargs, static_argnums=ji.static_argnums,
|
||||
static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo,
|
||||
signature=ji.fun_signature)
|
||||
|
||||
if config.dynamic_shapes.value: # if dynamic shapes, don't use the cache
|
||||
p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, dbg,
|
||||
p, args_flat = _infer_params_impl(fun, ji, ctx_mesh, dbg,
|
||||
args, kwargs, in_avals=None)
|
||||
return p, p.consts + args_flat
|
||||
|
||||
@ -722,10 +712,11 @@ def _infer_params(
|
||||
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
|
||||
ji.static_argnames, tree_util.default_registry)
|
||||
avals = _infer_input_type(fun, dbg, dynargs)
|
||||
entry = _infer_params_cached(fun, ji, signature, avals, pjit_mesh, resource_env)
|
||||
entry = _infer_params_cached(fun, ji, signature, avals, ctx_mesh)
|
||||
|
||||
if entry.pjit_params is None:
|
||||
p, args_flat = _infer_params_impl(
|
||||
fun, ji, pjit_mesh, resource_env, dbg, args, kwargs, in_avals=avals)
|
||||
fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals)
|
||||
if p.attrs_tracked: # if attrs, don't popoulate the cache
|
||||
return p, p.consts + args_flat
|
||||
entry.pjit_params = p
|
||||
@ -1616,7 +1607,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
|
||||
|
||||
def _resolve_and_lower(
|
||||
args, jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
|
||||
out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
lowering_platforms, lowering_parameters, pgle_profiler,
|
||||
compiler_options_kvs):
|
||||
in_shardings = _resolve_in_shardings(args, in_shardings)
|
||||
@ -1624,8 +1615,8 @@ def _resolve_and_lower(
|
||||
jaxpr.in_avals)
|
||||
out_layouts = _resolve_out_layouts(out_layouts, out_shardings, jaxpr.out_avals)
|
||||
return _pjit_lower(
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, name, keep_unused, inline, compiler_options_kvs,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs,
|
||||
lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=lowering_parameters,
|
||||
pgle_profiler=pgle_profiler)
|
||||
@ -1634,7 +1625,7 @@ _pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore
|
||||
|
||||
def _pjit_call_impl_python(
|
||||
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
pgle_compile_options, pgle_profiler = {}, None
|
||||
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
|
||||
@ -1659,8 +1650,8 @@ def _pjit_call_impl_python(
|
||||
compiled = _resolve_and_lower(
|
||||
args, jaxpr=jaxpr, in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts, resource_env=resource_env,
|
||||
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
||||
out_layouts=out_layouts, donated_invars=donated_invars,
|
||||
ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused,
|
||||
inline=inline, lowering_platforms=None,
|
||||
lowering_parameters=mlir.LoweringParameters(),
|
||||
pgle_profiler=pgle_profiler,
|
||||
@ -1691,7 +1682,7 @@ def _pjit_call_impl_python(
|
||||
|
||||
@weakref_lru_cache
|
||||
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
out_layouts, resource_env, donated_invars, name,
|
||||
out_layouts, donated_invars, ctx_mesh, name,
|
||||
keep_unused, inline, compiler_options_kvs):
|
||||
# The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so
|
||||
# returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to
|
||||
@ -1705,14 +1696,14 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
|
||||
def _pjit_call_impl(*args, jaxpr,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
def call_impl_cache_miss(*args_, **kwargs_):
|
||||
out_flat, compiled, pgle_profiler = _pjit_call_impl_python(
|
||||
*args, jaxpr=jaxpr, in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts, resource_env=resource_env,
|
||||
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
||||
out_layouts=out_layouts, donated_invars=donated_invars,
|
||||
ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused,
|
||||
inline=inline, compiler_options_kvs=compiler_options_kvs)
|
||||
fastpath_data = _get_fastpath_data(
|
||||
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
|
||||
@ -1721,7 +1712,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
|
||||
f = _get_jaxpr_as_fun(
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs)
|
||||
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
|
||||
cache_key = pxla.JitGlobalCppCacheKeys(
|
||||
@ -1730,8 +1721,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
|
||||
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
|
||||
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
|
||||
out_layouts_treedef=None, out_layouts_leaves=out_layouts,
|
||||
use_resource_env=resource_env is not None)
|
||||
out_layouts_treedef=None, out_layouts_leaves=out_layouts)
|
||||
return xc._xla.pjit(
|
||||
name, f, call_impl_cache_miss, [], [], cache_key,
|
||||
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
||||
@ -1746,8 +1736,8 @@ def _pjit_lower(
|
||||
out_shardings,
|
||||
in_layouts: pxla.MaybeLayout,
|
||||
out_layouts: pxla.MaybeLayout,
|
||||
resource_env,
|
||||
donated_invars,
|
||||
ctx_mesh,
|
||||
name: str,
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
@ -1757,12 +1747,10 @@ def _pjit_lower(
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
pgle_profiler: profiler.PGLEProfiler | None):
|
||||
util.test_event("pjit_lower")
|
||||
mesh = (resource_env.physical_mesh if resource_env is not None else
|
||||
mesh_lib.get_concrete_mesh())
|
||||
return pxla.lower_sharding_computation(
|
||||
jaxpr, 'jit', name, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, tuple(donated_invars),
|
||||
keep_unused=keep_unused, context_mesh=mesh,
|
||||
keep_unused=keep_unused, context_mesh=ctx_mesh,
|
||||
compiler_options_kvs=compiler_options_kvs,
|
||||
lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=lowering_parameters,
|
||||
@ -1914,8 +1902,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext,
|
||||
|
||||
def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
|
||||
jaxpr: core.ClosedJaxpr, in_shardings,
|
||||
out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, keep_unused, inline, compiler_options_kvs):
|
||||
out_shardings, in_layouts, out_layouts, donated_invars,
|
||||
ctx_mesh, keep_unused, inline, compiler_options_kvs):
|
||||
effects = list(ctx.tokens_in.effects())
|
||||
output_types = map(mlir.aval_to_ir_type, ctx.avals_out)
|
||||
output_types = [mlir.token_type()] * len(effects) + output_types
|
||||
@ -1945,23 +1933,20 @@ def _pjit_batcher(axis_data, vals_in,
|
||||
dims_in: tuple[int, ...],
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
|
||||
new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in)
|
||||
|
||||
if resource_env is not None:
|
||||
mesh = resource_env.physical_mesh
|
||||
else:
|
||||
mesh = None
|
||||
|
||||
# TODO(axch): prepend with Nones (?) to account for new segment_lens inputs
|
||||
in_shardings = tuple(
|
||||
_pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, mesh, aval.ndim)
|
||||
_pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, ctx_mesh,
|
||||
aval.ndim)
|
||||
if axis_in is not None else i
|
||||
for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals))
|
||||
out_shardings = tuple(
|
||||
_pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, mesh, aval.ndim)
|
||||
_pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, ctx_mesh,
|
||||
aval.ndim)
|
||||
if axis_out is not None else o
|
||||
for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals))
|
||||
# TODO(yashkatariya): Figure out layouts should change under vmap.
|
||||
@ -1977,8 +1962,8 @@ def _pjit_batcher(axis_data, vals_in,
|
||||
out_shardings=out_shardings,
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
@ -2000,8 +1985,8 @@ def _insert_axis_partitions(spec, dim, val):
|
||||
|
||||
def _pjit_batcher_for_sharding(
|
||||
s: Sharding | UnspecifiedValue,
|
||||
dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, mesh,
|
||||
ndim: int):
|
||||
dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None,
|
||||
mesh, ndim: int):
|
||||
if isinstance(s, UnspecifiedValue):
|
||||
return s
|
||||
hlo_s = s._to_xla_hlo_sharding(ndim)
|
||||
@ -2040,7 +2025,7 @@ def _pjit_batcher_for_sharding(
|
||||
|
||||
def _pjit_jvp(primals_in, tangents_in,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
|
||||
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
|
||||
@ -2057,8 +2042,8 @@ def _pjit_jvp(primals_in, tangents_in,
|
||||
out_shardings=(*out_shardings, *_filter_zeros_out(out_shardings)),
|
||||
in_layouts=(*in_layouts, *_filter_zeros_in(in_layouts)),
|
||||
out_layouts=(*out_layouts, *_filter_zeros_out(out_layouts)),
|
||||
resource_env=resource_env,
|
||||
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
@ -2074,7 +2059,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
|
||||
|
||||
def _pjit_linearization(nzs, *primals_in, jaxpr,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs)
|
||||
# constvars will become residuals. Move them to the end of the ordinary args.
|
||||
@ -2090,8 +2075,8 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
|
||||
out_shardings=_filter_zeros(nzs_out, out_shardings),
|
||||
in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts,
|
||||
out_layouts=_filter_zeros(nzs_out, out_layouts),
|
||||
resource_env=resource_env,
|
||||
donated_invars=_filter_zeros(nzs, donated_invars) + res_donated,
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
@ -2110,8 +2095,8 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
|
||||
out_shardings=(*res_shardings, *out_shardings),
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=(*res_layouts, *out_layouts),
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
@ -2126,7 +2111,7 @@ ad.primitive_linearizations[pjit_p] = _pjit_linearization
|
||||
def _pjit_partial_eval(trace: pe.JaxprTrace,
|
||||
*in_tracers,
|
||||
jaxpr: core.ClosedJaxpr, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, resource_env, donated_invars,
|
||||
in_layouts, out_layouts, donated_invars, ctx_mesh,
|
||||
name, keep_unused, inline, compiler_options_kvs):
|
||||
in_pvals = [t.pval for t in in_tracers]
|
||||
|
||||
@ -2193,8 +2178,9 @@ def _pjit_partial_eval(trace: pe.JaxprTrace,
|
||||
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
|
||||
out_shardings=known_out_shardings,
|
||||
in_layouts=keep_where(in_layouts, known_ins),
|
||||
out_layouts=known_out_layouts, resource_env=resource_env,
|
||||
out_layouts=known_out_layouts,
|
||||
donated_invars=keep_where(donated_invars, known_ins),
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name, keep_unused=keep_unused, inline=inline,
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
|
||||
@ -2225,9 +2211,9 @@ def _pjit_partial_eval(trace: pe.JaxprTrace,
|
||||
out_shardings=keep_where(out_shardings, unknown_outs),
|
||||
in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts),
|
||||
out_layouts=keep_where(out_layouts, unknown_outs),
|
||||
resource_env=resource_env,
|
||||
donated_invars=(keep_where(donated_invars, unknown_ins) +
|
||||
(False,) * num_residuals),
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
@ -2313,7 +2299,7 @@ def _pjit_transpose_trace(fun: lu.WrappedFun,
|
||||
def _pjit_transpose(cts_in, *primals_in,
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
donated_invars, ctx_mesh, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
def prune_type(ty, xs, maybe_zeros):
|
||||
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
|
||||
@ -2362,8 +2348,8 @@ def _pjit_transpose(cts_in, *primals_in,
|
||||
out_shardings=transpose_out_shardings,
|
||||
in_layouts=transpose_in_layouts,
|
||||
out_layouts=transpose_out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=(False,) * len(primals_and_nz_cts_in),
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
@ -2447,9 +2433,8 @@ def _pjit_pp_rule(eqn: core.JaxprEqn,
|
||||
del params['out_layouts']
|
||||
if not params['keep_unused']:
|
||||
del params['keep_unused']
|
||||
if (params['resource_env'] is None or
|
||||
params['resource_env'].physical_mesh.empty):
|
||||
del params['resource_env']
|
||||
if params['ctx_mesh'] is None or params['ctx_mesh'].empty:
|
||||
del params['ctx_mesh']
|
||||
if not params['compiler_options_kvs']:
|
||||
del params['compiler_options_kvs']
|
||||
|
||||
@ -2549,8 +2534,6 @@ def with_sharding_constraint(x, shardings):
|
||||
flatten_axes("with_sharding_constraint layouts", tree, layouts))
|
||||
del layouts
|
||||
|
||||
disallow_use_mesh_and_legacy_mesh_ctx_mgr_together()
|
||||
|
||||
context_mesh = (
|
||||
mesh_lib.get_abstract_mesh() if mesh_lib.get_concrete_mesh() is not None
|
||||
else mesh_lib.thread_resources.env.physical_mesh)
|
||||
|
@ -3573,8 +3573,8 @@ def _pjit(*args: TfVal,
|
||||
in_shardings: Sequence[sharding.Sharding],
|
||||
out_shardings: Sequence[sharding.Sharding],
|
||||
in_layouts, out_layouts,
|
||||
resource_env: mesh.ResourceEnv,
|
||||
donated_invars,
|
||||
ctx_mesh,
|
||||
name: str,
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
|
@ -775,7 +775,7 @@ sparse_rules_bcoo[lax.while_p] = _while_sparse
|
||||
|
||||
|
||||
def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, resource_env, donated_invars, name,
|
||||
in_layouts, out_layouts, donated_invars, ctx_mesh, name,
|
||||
keep_unused, inline, compiler_options_kvs):
|
||||
if any(donated_invars):
|
||||
raise NotImplementedError("sparse xla_call with donated_invars")
|
||||
@ -808,8 +808,8 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
|
||||
out_shardings=out_shardings,
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
ctx_mesh=ctx_mesh,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
|
@ -1205,8 +1205,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"One of with_sharding_constraint.*Sharding "
|
||||
r"NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), "
|
||||
r"spec=PartitionSpec\(None, 'mdl', None, None\).*\) is only "
|
||||
r"NamedSharding.*PartitionSpec\(None, 'mdl', None, None\).*\) is only "
|
||||
"valid for values of rank at least 4, but was applied to a value of rank 1"):
|
||||
pjit_f(jnp.array([1, 2, 3]))
|
||||
|
||||
@ -6873,31 +6872,6 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
' axis_types are `Auto`'):
|
||||
NamedSharding(mesh, P(P.UNCONSTRAINED))
|
||||
|
||||
def test_use_mesh_legacy_mesh_ctx_mgr_mix_error(self):
|
||||
mesh = jtu.create_mesh((1, 1), ('x', 'y'))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
|
||||
' together is not allowed'):
|
||||
with jax.sharding.use_mesh(mesh), mesh:
|
||||
jax.jit(lambda x: x)(jnp.arange(8))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
|
||||
' together is not allowed'):
|
||||
with jax.sharding.use_mesh(mesh), mesh:
|
||||
jnp.zeros((8, 2), dtype=jnp.int32)
|
||||
|
||||
x = jnp.arange(8)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
|
||||
' together is not allowed'):
|
||||
with jax.sharding.use_mesh(mesh), mesh:
|
||||
jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
|
||||
|
||||
def test_pspec_einsum_no_context_mesh(self):
|
||||
mesh = jtu.create_mesh((1, 1), ('x', 'y'),
|
||||
axis_types={AxisTypes.Explicit: ('x', 'y')})
|
||||
|
Loading…
x
Reference in New Issue
Block a user