mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove in_positional_semantics and out_positional_semantics from xmap
PiperOrigin-RevId: 517477866
This commit is contained in:
parent
d02f28199b
commit
7c7c60eabf
136
jax/_src/maps.py
136
jax/_src/maps.py
@ -520,11 +520,8 @@ def xmap(fun: Callable,
|
||||
lambda: tuple(_flatten_axes("xmap out_axes", out_tree(), out_axes, tupled_args=False)),
|
||||
closure=(out_axes_entries, out_axes_treedef))
|
||||
|
||||
in_positional_semantics = (_PositionalSemantics.GLOBAL,) * len(args_flat)
|
||||
out_positional_semantics = _PositionalSemantics.GLOBAL
|
||||
|
||||
axis_resource_count = _get_axis_resource_count(
|
||||
frozen_axis_resources, resource_env, in_positional_semantics)
|
||||
frozen_axis_resources, resource_env)
|
||||
|
||||
for axis, size in axis_sizes.items():
|
||||
resources = axis_resource_count[axis]
|
||||
@ -534,8 +531,7 @@ def xmap(fun: Callable,
|
||||
f"by the total number of resources assigned to this axis "
|
||||
f"({frozen_axis_resources[axis]}, {resources.nglobal} in total)")
|
||||
frozen_global_axis_sizes = _get_axis_sizes(
|
||||
args_flat, in_axes_flat, axis_sizes, axis_resource_count,
|
||||
in_positional_semantics)
|
||||
args_flat, in_axes_flat, axis_sizes, axis_resource_count)
|
||||
|
||||
missing_sizes = defined_names - set(frozen_global_axis_sizes.keys())
|
||||
if missing_sizes:
|
||||
@ -551,9 +547,9 @@ def xmap(fun: Callable,
|
||||
f"which asserts that it should be of rank {spec.expected_rank}, "
|
||||
f"but the argument has rank {arg.ndim} (and shape {arg.shape})")
|
||||
|
||||
_check_gda_or_array_xmap_partitioning(frozen_axis_resources, resource_env,
|
||||
frozen_global_axis_sizes, in_axes_flat,
|
||||
in_positional_semantics, args_flat)
|
||||
_check_gda_or_array_xmap_partitioning(
|
||||
frozen_axis_resources, resource_env, frozen_global_axis_sizes,
|
||||
in_axes_flat, args_flat)
|
||||
|
||||
params = dict(
|
||||
name=getattr(fun, '__name__', '<unnamed function>'),
|
||||
@ -565,9 +561,7 @@ def xmap(fun: Callable,
|
||||
resource_env=resource_env,
|
||||
backend=backend,
|
||||
spmd_in_axes=None,
|
||||
spmd_out_axes_thunk=None,
|
||||
in_positional_semantics=in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
spmd_out_axes_thunk=None)
|
||||
return fun_flat, args_flat, params, in_tree, out_tree
|
||||
|
||||
def verify_outputs(out_flat, out_tree, params):
|
||||
@ -600,8 +594,7 @@ def xmap(fun: Callable,
|
||||
fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'],
|
||||
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
|
||||
params['resource_env'], params['backend'], params['spmd_in_axes'],
|
||||
params['spmd_out_axes_thunk'], params['in_positional_semantics'],
|
||||
params['out_positional_semantics'],
|
||||
params['spmd_out_axes_thunk'],
|
||||
_experimental_lowering_platform, *avals_flat)
|
||||
|
||||
in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
|
||||
@ -615,13 +608,12 @@ def xmap(fun: Callable,
|
||||
|
||||
def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_invars,
|
||||
global_axis_sizes, axis_resources, resource_env, backend,
|
||||
spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics,
|
||||
out_positional_semantics):
|
||||
spmd_in_axes, spmd_out_axes_thunk):
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
|
||||
xmap_callable = make_xmap_callable(
|
||||
fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
|
||||
axis_resources, resource_env, backend,
|
||||
spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics, out_positional_semantics,
|
||||
spmd_in_axes, spmd_out_axes_thunk,
|
||||
None, *in_avals).compile().unsafe_call
|
||||
distributed_debug_log(("Running xmapped function", name),
|
||||
("python function", fun.f),
|
||||
@ -634,12 +626,11 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
name,
|
||||
in_axes, out_axes_thunk, donated_invars,
|
||||
global_axis_sizes, axis_resources, resource_env, backend,
|
||||
spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics,
|
||||
out_positional_semantics,
|
||||
spmd_in_axes, spmd_out_axes_thunk,
|
||||
lowering_platform: Optional[str],
|
||||
*in_avals):
|
||||
plan = EvaluationPlan.from_axis_resources(
|
||||
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
|
||||
axis_resources, resource_env, global_axis_sizes)
|
||||
|
||||
# TODO: Making axis substitution final style would allow us to avoid
|
||||
# tracing to jaxpr here
|
||||
@ -675,10 +666,6 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
assert spmd_in_axes is None and spmd_out_axes_thunk is None # No outer xmaps, so should be None
|
||||
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
|
||||
mesh = resource_env.physical_mesh
|
||||
global_in_avals = [
|
||||
av if ips == _PositionalSemantics.GLOBAL else pxla.mesh_local_to_global(mesh, ax, av)
|
||||
for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics)
|
||||
]
|
||||
tiling_method: pxla.TilingMethod
|
||||
if config.experimental_xmap_spmd_lowering_manual:
|
||||
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
|
||||
@ -692,7 +679,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
return pxla.lower_mesh_computation(
|
||||
f, 'xmap', name, mesh,
|
||||
in_shardings, out_shardings, donated_invars,
|
||||
use_spmd_lowering, global_in_avals,
|
||||
use_spmd_lowering, in_avals,
|
||||
tiling_method=tiling_method,
|
||||
lowering_platform=lowering_platform)
|
||||
else:
|
||||
@ -727,12 +714,11 @@ class EvaluationPlan(NamedTuple):
|
||||
def from_axis_resources(cls,
|
||||
axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]],
|
||||
resource_env: ResourceEnv,
|
||||
global_axis_sizes: Dict[AxisName, int],
|
||||
in_positional_semantics: Sequence[bool]):
|
||||
global_axis_sizes: Dict[AxisName, int]):
|
||||
physical_axis_resources, loop_axis_resources = _unzip_axis_resources(
|
||||
axis_resources, resource_env)
|
||||
axis_resource_count = _get_axis_resource_count(
|
||||
axis_resources, resource_env, in_positional_semantics)
|
||||
axis_resources, resource_env)
|
||||
axis_subst_dict = dict(axis_resources)
|
||||
axis_vmap_size: Dict[AxisName, Optional[int]] = {}
|
||||
for naxis, raxes in sorted(axis_resources.items(), key=lambda x: str(x[0])):
|
||||
@ -893,10 +879,9 @@ def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes)
|
||||
arg_cts = tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
axis_resource_count = _get_axis_resource_count(
|
||||
params['axis_resources'], params['resource_env'],
|
||||
params['in_positional_semantics'])
|
||||
params['axis_resources'], params['resource_env'])
|
||||
local_axis_sizes = {
|
||||
axis: axis_resource_count[axis].to_local(params['out_positional_semantics'], global_size)
|
||||
axis: axis_resource_count[axis].to_local(global_size)
|
||||
for axis, global_size in params['global_axis_sizes'].items()
|
||||
}
|
||||
def unmap_zero(zero, axes):
|
||||
@ -909,13 +894,12 @@ ad.primitive_transposes[xmap_p] = _xmap_transpose
|
||||
def _typecheck_xmap(
|
||||
*in_atoms, call_jaxpr, name, in_axes, out_axes, donated_invars,
|
||||
global_axis_sizes, axis_resources, resource_env, backend,
|
||||
spmd_in_axes, spmd_out_axes, in_positional_semantics,
|
||||
out_positional_semantics):
|
||||
spmd_in_axes, spmd_out_axes):
|
||||
in_avals = [x.aval for x in in_atoms]
|
||||
axis_resource_count = _get_axis_resource_count(
|
||||
axis_resources, resource_env, in_positional_semantics)
|
||||
axis_resources, resource_env)
|
||||
local_axis_sizes = {
|
||||
axis: axis_resource_count[axis].to_local(out_positional_semantics, global_size)
|
||||
axis: axis_resource_count[axis].to_local(global_size)
|
||||
for axis, global_size in global_axis_sizes.items()
|
||||
}
|
||||
binder_in_avals = [_insert_aval_axes(v.aval, a_in_axes, local_axis_sizes)
|
||||
@ -1002,10 +986,9 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
else:
|
||||
spmd_out_axes = None
|
||||
axis_resource_count = _get_axis_resource_count(
|
||||
params['axis_resources'], params['resource_env'],
|
||||
params['in_positional_semantics'])
|
||||
params['axis_resources'], params['resource_env'])
|
||||
local_axis_sizes = {
|
||||
axis: axis_resource_count[axis].to_local(params['out_positional_semantics'], global_size)
|
||||
axis: axis_resource_count[axis].to_local(global_size)
|
||||
for axis, global_size in global_axis_sizes.items()
|
||||
}
|
||||
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
|
||||
@ -1028,9 +1011,6 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
donated_invars=new_donated_invars,
|
||||
spmd_in_axes=new_spmd_in_axes,
|
||||
spmd_out_axes=spmd_out_axes,
|
||||
in_positional_semantics=(
|
||||
*( _PositionalSemantics.GLOBAL,) * len(constvars),
|
||||
*params['in_positional_semantics']),
|
||||
call_jaxpr=call_jaxpr)
|
||||
del new_params['out_axes_thunk']
|
||||
del new_params['spmd_out_axes_thunk']
|
||||
@ -1163,11 +1143,9 @@ def _jaxpr_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params
|
||||
unknown_arg_tracers = [t for t in tracers if not t.pval.is_known()]
|
||||
# Create output tracers for unknown part, adjusting avals.
|
||||
axis_resource_count = _get_axis_resource_count(
|
||||
params['axis_resources'], params['resource_env'],
|
||||
params['in_positional_semantics'])
|
||||
params['axis_resources'], params['resource_env'])
|
||||
local_axis_sizes = {
|
||||
ax: axis_resource_count[ax].to_local(
|
||||
params['out_positional_semantics'], global_size)
|
||||
ax: axis_resource_count[ax].to_local(global_size)
|
||||
for ax, global_size in global_axis_sizes.items()}
|
||||
out_pvals = [pe.PartialVal.unknown(_insert_aval_axes(a, ax, local_axis_sizes))
|
||||
for a, ax in zip(out_avals, out_axes_unknown)]
|
||||
@ -1310,17 +1288,16 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
||||
in_axes, out_axes, donated_invars,
|
||||
global_axis_sizes,
|
||||
spmd_in_axes, spmd_out_axes,
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
axis_resources, resource_env, backend):
|
||||
xla.check_backend_matches(backend, ctx.module_context.platform)
|
||||
# The only way for any of those two assertions to be violated is when xmap
|
||||
# is using the SPMD lowering, but then this rule shouldn't even trigger.
|
||||
assert spmd_in_axes is None and spmd_out_axes is None
|
||||
plan = EvaluationPlan.from_axis_resources(
|
||||
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
|
||||
axis_resources, resource_env, global_axis_sizes)
|
||||
|
||||
axis_resource_count = _get_axis_resource_count(
|
||||
axis_resources, resource_env, in_positional_semantics)
|
||||
axis_resources, resource_env)
|
||||
if any(resource_count.distributed for resource_count in axis_resource_count.values()):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1387,12 +1364,11 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
||||
def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
call_jaxpr, name, in_axes, out_axes,
|
||||
donated_invars, global_axis_sizes, spmd_in_axes,
|
||||
spmd_out_axes, in_positional_semantics,
|
||||
out_positional_semantics, axis_resources,
|
||||
spmd_out_axes, axis_resources,
|
||||
resource_env, backend):
|
||||
xla.check_backend_matches(backend, ctx.module_context.platform)
|
||||
plan = EvaluationPlan.from_axis_resources(
|
||||
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
|
||||
axis_resources, resource_env, global_axis_sizes)
|
||||
|
||||
resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr)
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ())))
|
||||
@ -1450,14 +1426,13 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
call_jaxpr, name, in_axes, out_axes,
|
||||
donated_invars, global_axis_sizes, spmd_in_axes,
|
||||
spmd_out_axes, in_positional_semantics,
|
||||
out_positional_semantics, axis_resources,
|
||||
spmd_out_axes, axis_resources,
|
||||
resource_env, backend):
|
||||
assert spmd_in_axes is None and spmd_out_axes is None
|
||||
# This first part (up to vtile_manual) is shared with non-MANUAL SPMD rule.
|
||||
xla.check_backend_matches(backend, ctx.module_context.platform)
|
||||
plan = EvaluationPlan.from_axis_resources(
|
||||
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
|
||||
axis_resources, resource_env, global_axis_sizes)
|
||||
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
|
||||
|
||||
resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr)
|
||||
@ -1574,35 +1549,14 @@ class ResourceCount(NamedTuple):
|
||||
nlocal: Optional[int]
|
||||
distributed: bool
|
||||
|
||||
def to_local(self, semantics, global_size):
|
||||
if semantics == _PositionalSemantics.GLOBAL:
|
||||
def to_local(self, global_size):
|
||||
return global_size
|
||||
elif semantics == _PositionalSemantics.LOCAL:
|
||||
assert self.nlocal is not None
|
||||
assert global_size % self.nglobal == 0, "Please report this issue!"
|
||||
return (global_size // self.nglobal) * self.nlocal
|
||||
else:
|
||||
raise AssertionError(f"Unhandled case {_positional_semantics}")
|
||||
|
||||
def to_global(self, semantics, local_size):
|
||||
if semantics == _PositionalSemantics.GLOBAL:
|
||||
return local_size
|
||||
elif semantics == _PositionalSemantics.LOCAL:
|
||||
assert self.nlocal is not None
|
||||
assert local_size % self.nlocal == 0, "Please report this issue!"
|
||||
return (local_size // self.nlocal) * self.nglobal
|
||||
else:
|
||||
raise AssertionError(f"Unhandled case {_positional_semantics}")
|
||||
|
||||
|
||||
def _get_axis_resource_count(
|
||||
axis_resources, resource_env,
|
||||
in_positional_semantics) -> Dict[ResourceAxisName, ResourceCount]:
|
||||
axis_resources, resource_env) -> Dict[ResourceAxisName, ResourceCount]:
|
||||
global_res_shape = resource_env.shape
|
||||
if all(ips == _PositionalSemantics.GLOBAL for ips in in_positional_semantics):
|
||||
local_res_shape = None
|
||||
else:
|
||||
local_res_shape = resource_env.local_shape
|
||||
|
||||
distributed = (False if resource_env.physical_mesh.empty else
|
||||
resource_env.physical_mesh.size != len(resource_env.physical_mesh.local_devices))
|
||||
@ -1621,13 +1575,10 @@ def _get_axis_resource_count(
|
||||
def _get_axis_sizes(args_flat: Iterable[Any],
|
||||
in_axes_flat: Iterable[AxisNamePos],
|
||||
global_axis_sizes: Dict[AxisName, int],
|
||||
axis_resource_count: Dict[AxisName, ResourceCount],
|
||||
in_positional_semantics: Sequence[_PositionalSemantics]):
|
||||
axis_resource_count: Dict[AxisName, ResourceCount]):
|
||||
global_axis_sizes = dict(global_axis_sizes)
|
||||
for arg, in_axes, ips in zip(args_flat, in_axes_flat, in_positional_semantics):
|
||||
for arg, in_axes in zip(args_flat, in_axes_flat):
|
||||
for name, dim in in_axes.items():
|
||||
resources = axis_resource_count[name]
|
||||
local_ = "local " if resources.distributed else ""
|
||||
try:
|
||||
dim_size = arg.shape[dim]
|
||||
except IndexError:
|
||||
@ -1636,22 +1587,6 @@ def _get_axis_sizes(args_flat: Iterable[Any],
|
||||
f"{in_axes.user_repr}, which implies that it has at least "
|
||||
f"{max(in_axes.values()) + 1} dimensions, but the argument "
|
||||
f"has rank {arg.ndim}")
|
||||
if ips == _PositionalSemantics.LOCAL:
|
||||
local_dim_size = dim_size
|
||||
if local_dim_size % resources.nlocal != 0:
|
||||
raise ValueError(f"One of xmap arguments has an in_axes specification of "
|
||||
f"{in_axes.user_repr}, which implies that its size in dimension "
|
||||
f"{dim} ({local_dim_size}) should be divisible by the number of "
|
||||
f"{local_}resources assigned to axis {name} ({resources.nlocal})")
|
||||
global_dim_size = resources.to_global(ips, local_dim_size)
|
||||
if name in global_axis_sizes:
|
||||
expected_local_dim_size = resources.to_local(ips, global_axis_sizes[name])
|
||||
if local_dim_size != expected_local_dim_size:
|
||||
raise ValueError(f"The {local_}size of axis {name} was previously inferred to be "
|
||||
f"{expected_local_dim_size}, but found an argument of shape {arg.shape} "
|
||||
f"with in_axes specification {in_axes.user_repr}. Shape mismatch "
|
||||
f"occurs in dimension {dim}: {local_dim_size} != {expected_local_dim_size}")
|
||||
elif ips == _PositionalSemantics.GLOBAL:
|
||||
global_dim_size = dim_size
|
||||
if name in global_axis_sizes:
|
||||
expected_global_dim_size = global_axis_sizes[name]
|
||||
@ -1792,7 +1727,7 @@ def _check_out_avals_vs_out_axes(out_avals: Sequence[core.AbstractValue],
|
||||
|
||||
def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
|
||||
global_axis_sizes, in_axes_flat,
|
||||
in_positional_semantics, args_flat):
|
||||
args_flat):
|
||||
@lru_cache()
|
||||
def _check_sharding(in_sharding, xmap_sharding, ndim, arr_flavor):
|
||||
if not pxla.are_op_shardings_equal(
|
||||
@ -1805,8 +1740,7 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
|
||||
f"xmap spec: {xmap_sharding.spec}")
|
||||
|
||||
mesh_in_axes = EvaluationPlan.from_axis_resources( # pytype: disable=wrong-arg-types # always-use-return-annotations
|
||||
axis_resources, resource_env, global_axis_sizes,
|
||||
in_positional_semantics).to_mesh_axes(in_axes_flat)
|
||||
axis_resources, resource_env, global_axis_sizes).to_mesh_axes(in_axes_flat)
|
||||
for arg, xmap_array_mapping in safe_zip(args_flat, mesh_in_axes):
|
||||
if isinstance(arg, ArrayImpl):
|
||||
if not isinstance(arg.sharding, NamedSharding):
|
||||
|
Loading…
x
Reference in New Issue
Block a user