Remove in_positional_semantics and out_positional_semantics from xmap

PiperOrigin-RevId: 517477866
This commit is contained in:
Yash Katariya 2023-03-17 12:23:37 -07:00 committed by jax authors
parent d02f28199b
commit 7c7c60eabf

View File

@ -520,11 +520,8 @@ def xmap(fun: Callable,
lambda: tuple(_flatten_axes("xmap out_axes", out_tree(), out_axes, tupled_args=False)),
closure=(out_axes_entries, out_axes_treedef))
in_positional_semantics = (_PositionalSemantics.GLOBAL,) * len(args_flat)
out_positional_semantics = _PositionalSemantics.GLOBAL
axis_resource_count = _get_axis_resource_count(
frozen_axis_resources, resource_env, in_positional_semantics)
frozen_axis_resources, resource_env)
for axis, size in axis_sizes.items():
resources = axis_resource_count[axis]
@ -534,8 +531,7 @@ def xmap(fun: Callable,
f"by the total number of resources assigned to this axis "
f"({frozen_axis_resources[axis]}, {resources.nglobal} in total)")
frozen_global_axis_sizes = _get_axis_sizes(
args_flat, in_axes_flat, axis_sizes, axis_resource_count,
in_positional_semantics)
args_flat, in_axes_flat, axis_sizes, axis_resource_count)
missing_sizes = defined_names - set(frozen_global_axis_sizes.keys())
if missing_sizes:
@ -551,9 +547,9 @@ def xmap(fun: Callable,
f"which asserts that it should be of rank {spec.expected_rank}, "
f"but the argument has rank {arg.ndim} (and shape {arg.shape})")
_check_gda_or_array_xmap_partitioning(frozen_axis_resources, resource_env,
frozen_global_axis_sizes, in_axes_flat,
in_positional_semantics, args_flat)
_check_gda_or_array_xmap_partitioning(
frozen_axis_resources, resource_env, frozen_global_axis_sizes,
in_axes_flat, args_flat)
params = dict(
name=getattr(fun, '__name__', '<unnamed function>'),
@ -565,9 +561,7 @@ def xmap(fun: Callable,
resource_env=resource_env,
backend=backend,
spmd_in_axes=None,
spmd_out_axes_thunk=None,
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics)
spmd_out_axes_thunk=None)
return fun_flat, args_flat, params, in_tree, out_tree
def verify_outputs(out_flat, out_tree, params):
@ -600,8 +594,7 @@ def xmap(fun: Callable,
fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'],
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
params['resource_env'], params['backend'], params['spmd_in_axes'],
params['spmd_out_axes_thunk'], params['in_positional_semantics'],
params['out_positional_semantics'],
params['spmd_out_axes_thunk'],
_experimental_lowering_platform, *avals_flat)
in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
@ -615,13 +608,12 @@ def xmap(fun: Callable,
def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_invars,
global_axis_sizes, axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics,
out_positional_semantics):
spmd_in_axes, spmd_out_axes_thunk):
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
xmap_callable = make_xmap_callable(
fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics, out_positional_semantics,
spmd_in_axes, spmd_out_axes_thunk,
None, *in_avals).compile().unsafe_call
distributed_debug_log(("Running xmapped function", name),
("python function", fun.f),
@ -634,12 +626,11 @@ def make_xmap_callable(fun: lu.WrappedFun,
name,
in_axes, out_axes_thunk, donated_invars,
global_axis_sizes, axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics,
out_positional_semantics,
spmd_in_axes, spmd_out_axes_thunk,
lowering_platform: Optional[str],
*in_avals):
plan = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
axis_resources, resource_env, global_axis_sizes)
# TODO: Making axis substitution final style would allow us to avoid
# tracing to jaxpr here
@ -675,10 +666,6 @@ def make_xmap_callable(fun: lu.WrappedFun,
assert spmd_in_axes is None and spmd_out_axes_thunk is None # No outer xmaps, so should be None
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
mesh = resource_env.physical_mesh
global_in_avals = [
av if ips == _PositionalSemantics.GLOBAL else pxla.mesh_local_to_global(mesh, ax, av)
for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics)
]
tiling_method: pxla.TilingMethod
if config.experimental_xmap_spmd_lowering_manual:
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
@ -692,7 +679,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
return pxla.lower_mesh_computation(
f, 'xmap', name, mesh,
in_shardings, out_shardings, donated_invars,
use_spmd_lowering, global_in_avals,
use_spmd_lowering, in_avals,
tiling_method=tiling_method,
lowering_platform=lowering_platform)
else:
@ -727,12 +714,11 @@ class EvaluationPlan(NamedTuple):
def from_axis_resources(cls,
axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]],
resource_env: ResourceEnv,
global_axis_sizes: Dict[AxisName, int],
in_positional_semantics: Sequence[bool]):
global_axis_sizes: Dict[AxisName, int]):
physical_axis_resources, loop_axis_resources = _unzip_axis_resources(
axis_resources, resource_env)
axis_resource_count = _get_axis_resource_count(
axis_resources, resource_env, in_positional_semantics)
axis_resources, resource_env)
axis_subst_dict = dict(axis_resources)
axis_vmap_size: Dict[AxisName, Optional[int]] = {}
for naxis, raxes in sorted(axis_resources.items(), key=lambda x: str(x[0])):
@ -893,10 +879,9 @@ def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes)
arg_cts = tree_unflatten(out_tree(), out_flat)
axis_resource_count = _get_axis_resource_count(
params['axis_resources'], params['resource_env'],
params['in_positional_semantics'])
params['axis_resources'], params['resource_env'])
local_axis_sizes = {
axis: axis_resource_count[axis].to_local(params['out_positional_semantics'], global_size)
axis: axis_resource_count[axis].to_local(global_size)
for axis, global_size in params['global_axis_sizes'].items()
}
def unmap_zero(zero, axes):
@ -909,13 +894,12 @@ ad.primitive_transposes[xmap_p] = _xmap_transpose
def _typecheck_xmap(
*in_atoms, call_jaxpr, name, in_axes, out_axes, donated_invars,
global_axis_sizes, axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes, in_positional_semantics,
out_positional_semantics):
spmd_in_axes, spmd_out_axes):
in_avals = [x.aval for x in in_atoms]
axis_resource_count = _get_axis_resource_count(
axis_resources, resource_env, in_positional_semantics)
axis_resources, resource_env)
local_axis_sizes = {
axis: axis_resource_count[axis].to_local(out_positional_semantics, global_size)
axis: axis_resource_count[axis].to_local(global_size)
for axis, global_size in global_axis_sizes.items()
}
binder_in_avals = [_insert_aval_axes(v.aval, a_in_axes, local_axis_sizes)
@ -1002,10 +986,9 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
else:
spmd_out_axes = None
axis_resource_count = _get_axis_resource_count(
params['axis_resources'], params['resource_env'],
params['in_positional_semantics'])
params['axis_resources'], params['resource_env'])
local_axis_sizes = {
axis: axis_resource_count[axis].to_local(params['out_positional_semantics'], global_size)
axis: axis_resource_count[axis].to_local(global_size)
for axis, global_size in global_axis_sizes.items()
}
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
@ -1028,9 +1011,6 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
donated_invars=new_donated_invars,
spmd_in_axes=new_spmd_in_axes,
spmd_out_axes=spmd_out_axes,
in_positional_semantics=(
*( _PositionalSemantics.GLOBAL,) * len(constvars),
*params['in_positional_semantics']),
call_jaxpr=call_jaxpr)
del new_params['out_axes_thunk']
del new_params['spmd_out_axes_thunk']
@ -1163,11 +1143,9 @@ def _jaxpr_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params
unknown_arg_tracers = [t for t in tracers if not t.pval.is_known()]
# Create output tracers for unknown part, adjusting avals.
axis_resource_count = _get_axis_resource_count(
params['axis_resources'], params['resource_env'],
params['in_positional_semantics'])
params['axis_resources'], params['resource_env'])
local_axis_sizes = {
ax: axis_resource_count[ax].to_local(
params['out_positional_semantics'], global_size)
ax: axis_resource_count[ax].to_local(global_size)
for ax, global_size in global_axis_sizes.items()}
out_pvals = [pe.PartialVal.unknown(_insert_aval_axes(a, ax, local_axis_sizes))
for a, ax in zip(out_avals, out_axes_unknown)]
@ -1310,17 +1288,16 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
in_axes, out_axes, donated_invars,
global_axis_sizes,
spmd_in_axes, spmd_out_axes,
in_positional_semantics, out_positional_semantics,
axis_resources, resource_env, backend):
xla.check_backend_matches(backend, ctx.module_context.platform)
# The only way for any of those two assertions to be violated is when xmap
# is using the SPMD lowering, but then this rule shouldn't even trigger.
assert spmd_in_axes is None and spmd_out_axes is None
plan = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
axis_resources, resource_env, global_axis_sizes)
axis_resource_count = _get_axis_resource_count(
axis_resources, resource_env, in_positional_semantics)
axis_resources, resource_env)
if any(resource_count.distributed for resource_count in axis_resource_count.values()):
raise NotImplementedError
@ -1387,12 +1364,11 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
call_jaxpr, name, in_axes, out_axes,
donated_invars, global_axis_sizes, spmd_in_axes,
spmd_out_axes, in_positional_semantics,
out_positional_semantics, axis_resources,
spmd_out_axes, axis_resources,
resource_env, backend):
xla.check_backend_matches(backend, ctx.module_context.platform)
plan = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
axis_resources, resource_env, global_axis_sizes)
resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr)
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ())))
@ -1450,14 +1426,13 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
call_jaxpr, name, in_axes, out_axes,
donated_invars, global_axis_sizes, spmd_in_axes,
spmd_out_axes, in_positional_semantics,
out_positional_semantics, axis_resources,
spmd_out_axes, axis_resources,
resource_env, backend):
assert spmd_in_axes is None and spmd_out_axes is None
# This first part (up to vtile_manual) is shared with non-MANUAL SPMD rule.
xla.check_backend_matches(backend, ctx.module_context.platform)
plan = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
axis_resources, resource_env, global_axis_sizes)
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr)
@ -1574,35 +1549,14 @@ class ResourceCount(NamedTuple):
nlocal: Optional[int]
distributed: bool
def to_local(self, semantics, global_size):
if semantics == _PositionalSemantics.GLOBAL:
def to_local(self, global_size):
return global_size
elif semantics == _PositionalSemantics.LOCAL:
assert self.nlocal is not None
assert global_size % self.nglobal == 0, "Please report this issue!"
return (global_size // self.nglobal) * self.nlocal
else:
raise AssertionError(f"Unhandled case {_positional_semantics}")
def to_global(self, semantics, local_size):
if semantics == _PositionalSemantics.GLOBAL:
return local_size
elif semantics == _PositionalSemantics.LOCAL:
assert self.nlocal is not None
assert local_size % self.nlocal == 0, "Please report this issue!"
return (local_size // self.nlocal) * self.nglobal
else:
raise AssertionError(f"Unhandled case {_positional_semantics}")
def _get_axis_resource_count(
axis_resources, resource_env,
in_positional_semantics) -> Dict[ResourceAxisName, ResourceCount]:
axis_resources, resource_env) -> Dict[ResourceAxisName, ResourceCount]:
global_res_shape = resource_env.shape
if all(ips == _PositionalSemantics.GLOBAL for ips in in_positional_semantics):
local_res_shape = None
else:
local_res_shape = resource_env.local_shape
distributed = (False if resource_env.physical_mesh.empty else
resource_env.physical_mesh.size != len(resource_env.physical_mesh.local_devices))
@ -1621,13 +1575,10 @@ def _get_axis_resource_count(
def _get_axis_sizes(args_flat: Iterable[Any],
in_axes_flat: Iterable[AxisNamePos],
global_axis_sizes: Dict[AxisName, int],
axis_resource_count: Dict[AxisName, ResourceCount],
in_positional_semantics: Sequence[_PositionalSemantics]):
axis_resource_count: Dict[AxisName, ResourceCount]):
global_axis_sizes = dict(global_axis_sizes)
for arg, in_axes, ips in zip(args_flat, in_axes_flat, in_positional_semantics):
for arg, in_axes in zip(args_flat, in_axes_flat):
for name, dim in in_axes.items():
resources = axis_resource_count[name]
local_ = "local " if resources.distributed else ""
try:
dim_size = arg.shape[dim]
except IndexError:
@ -1636,22 +1587,6 @@ def _get_axis_sizes(args_flat: Iterable[Any],
f"{in_axes.user_repr}, which implies that it has at least "
f"{max(in_axes.values()) + 1} dimensions, but the argument "
f"has rank {arg.ndim}")
if ips == _PositionalSemantics.LOCAL:
local_dim_size = dim_size
if local_dim_size % resources.nlocal != 0:
raise ValueError(f"One of xmap arguments has an in_axes specification of "
f"{in_axes.user_repr}, which implies that its size in dimension "
f"{dim} ({local_dim_size}) should be divisible by the number of "
f"{local_}resources assigned to axis {name} ({resources.nlocal})")
global_dim_size = resources.to_global(ips, local_dim_size)
if name in global_axis_sizes:
expected_local_dim_size = resources.to_local(ips, global_axis_sizes[name])
if local_dim_size != expected_local_dim_size:
raise ValueError(f"The {local_}size of axis {name} was previously inferred to be "
f"{expected_local_dim_size}, but found an argument of shape {arg.shape} "
f"with in_axes specification {in_axes.user_repr}. Shape mismatch "
f"occurs in dimension {dim}: {local_dim_size} != {expected_local_dim_size}")
elif ips == _PositionalSemantics.GLOBAL:
global_dim_size = dim_size
if name in global_axis_sizes:
expected_global_dim_size = global_axis_sizes[name]
@ -1792,7 +1727,7 @@ def _check_out_avals_vs_out_axes(out_avals: Sequence[core.AbstractValue],
def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
global_axis_sizes, in_axes_flat,
in_positional_semantics, args_flat):
args_flat):
@lru_cache()
def _check_sharding(in_sharding, xmap_sharding, ndim, arr_flavor):
if not pxla.are_op_shardings_equal(
@ -1805,8 +1740,7 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
f"xmap spec: {xmap_sharding.spec}")
mesh_in_axes = EvaluationPlan.from_axis_resources( # pytype: disable=wrong-arg-types # always-use-return-annotations
axis_resources, resource_env, global_axis_sizes,
in_positional_semantics).to_mesh_axes(in_axes_flat)
axis_resources, resource_env, global_axis_sizes).to_mesh_axes(in_axes_flat)
for arg, xmap_array_mapping in safe_zip(args_flat, mesh_in_axes):
if isinstance(arg, ArrayImpl):
if not isinstance(arg.sharding, NamedSharding):