From 7c7c60eabf2bcbed285c7e1032fd9ac3686f16e2 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 17 Mar 2023 12:23:37 -0700 Subject: [PATCH] Remove in_positional_semantics and out_positional_semantics from xmap PiperOrigin-RevId: 517477866 --- jax/_src/maps.py | 156 ++++++++++++++--------------------------------- 1 file changed, 45 insertions(+), 111 deletions(-) diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 93dd396c8..19e4307bf 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -520,11 +520,8 @@ def xmap(fun: Callable, lambda: tuple(_flatten_axes("xmap out_axes", out_tree(), out_axes, tupled_args=False)), closure=(out_axes_entries, out_axes_treedef)) - in_positional_semantics = (_PositionalSemantics.GLOBAL,) * len(args_flat) - out_positional_semantics = _PositionalSemantics.GLOBAL - axis_resource_count = _get_axis_resource_count( - frozen_axis_resources, resource_env, in_positional_semantics) + frozen_axis_resources, resource_env) for axis, size in axis_sizes.items(): resources = axis_resource_count[axis] @@ -534,8 +531,7 @@ def xmap(fun: Callable, f"by the total number of resources assigned to this axis " f"({frozen_axis_resources[axis]}, {resources.nglobal} in total)") frozen_global_axis_sizes = _get_axis_sizes( - args_flat, in_axes_flat, axis_sizes, axis_resource_count, - in_positional_semantics) + args_flat, in_axes_flat, axis_sizes, axis_resource_count) missing_sizes = defined_names - set(frozen_global_axis_sizes.keys()) if missing_sizes: @@ -551,9 +547,9 @@ def xmap(fun: Callable, f"which asserts that it should be of rank {spec.expected_rank}, " f"but the argument has rank {arg.ndim} (and shape {arg.shape})") - _check_gda_or_array_xmap_partitioning(frozen_axis_resources, resource_env, - frozen_global_axis_sizes, in_axes_flat, - in_positional_semantics, args_flat) + _check_gda_or_array_xmap_partitioning( + frozen_axis_resources, resource_env, frozen_global_axis_sizes, + in_axes_flat, args_flat) params = dict( name=getattr(fun, '__name__', ''), @@ -565,9 +561,7 @@ def xmap(fun: Callable, resource_env=resource_env, backend=backend, spmd_in_axes=None, - spmd_out_axes_thunk=None, - in_positional_semantics=in_positional_semantics, - out_positional_semantics=out_positional_semantics) + spmd_out_axes_thunk=None) return fun_flat, args_flat, params, in_tree, out_tree def verify_outputs(out_flat, out_tree, params): @@ -600,8 +594,7 @@ def xmap(fun: Callable, fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'], params['donated_invars'], params['global_axis_sizes'], params['axis_resources'], params['resource_env'], params['backend'], params['spmd_in_axes'], - params['spmd_out_axes_thunk'], params['in_positional_semantics'], - params['out_positional_semantics'], + params['spmd_out_axes_thunk'], _experimental_lowering_platform, *avals_flat) in_tree = treedef_tuple([in_tree, tree_flatten({})[1]]) @@ -615,13 +608,12 @@ def xmap(fun: Callable, def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, - spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics, - out_positional_semantics): + spmd_in_axes, spmd_out_axes_thunk): in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args] xmap_callable = make_xmap_callable( fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, - spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics, out_positional_semantics, + spmd_in_axes, spmd_out_axes_thunk, None, *in_avals).compile().unsafe_call distributed_debug_log(("Running xmapped function", name), ("python function", fun.f), @@ -634,12 +626,11 @@ def make_xmap_callable(fun: lu.WrappedFun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, - spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics, - out_positional_semantics, + spmd_in_axes, spmd_out_axes_thunk, lowering_platform: Optional[str], *in_avals): plan = EvaluationPlan.from_axis_resources( - axis_resources, resource_env, global_axis_sizes, in_positional_semantics) + axis_resources, resource_env, global_axis_sizes) # TODO: Making axis substitution final style would allow us to avoid # tracing to jaxpr here @@ -675,10 +666,6 @@ def make_xmap_callable(fun: lu.WrappedFun, assert spmd_in_axes is None and spmd_out_axes_thunk is None # No outer xmaps, so should be None mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes) mesh = resource_env.physical_mesh - global_in_avals = [ - av if ips == _PositionalSemantics.GLOBAL else pxla.mesh_local_to_global(mesh, ax, av) - for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics) - ] tiling_method: pxla.TilingMethod if config.experimental_xmap_spmd_lowering_manual: manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values())) @@ -692,7 +679,7 @@ def make_xmap_callable(fun: lu.WrappedFun, return pxla.lower_mesh_computation( f, 'xmap', name, mesh, in_shardings, out_shardings, donated_invars, - use_spmd_lowering, global_in_avals, + use_spmd_lowering, in_avals, tiling_method=tiling_method, lowering_platform=lowering_platform) else: @@ -727,12 +714,11 @@ class EvaluationPlan(NamedTuple): def from_axis_resources(cls, axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]], resource_env: ResourceEnv, - global_axis_sizes: Dict[AxisName, int], - in_positional_semantics: Sequence[bool]): + global_axis_sizes: Dict[AxisName, int]): physical_axis_resources, loop_axis_resources = _unzip_axis_resources( axis_resources, resource_env) axis_resource_count = _get_axis_resource_count( - axis_resources, resource_env, in_positional_semantics) + axis_resources, resource_env) axis_subst_dict = dict(axis_resources) axis_vmap_size: Dict[AxisName, Optional[int]] = {} for naxis, raxes in sorted(axis_resources.items(), key=lambda x: str(x[0])): @@ -893,10 +879,9 @@ def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes) arg_cts = tree_unflatten(out_tree(), out_flat) axis_resource_count = _get_axis_resource_count( - params['axis_resources'], params['resource_env'], - params['in_positional_semantics']) + params['axis_resources'], params['resource_env']) local_axis_sizes = { - axis: axis_resource_count[axis].to_local(params['out_positional_semantics'], global_size) + axis: axis_resource_count[axis].to_local(global_size) for axis, global_size in params['global_axis_sizes'].items() } def unmap_zero(zero, axes): @@ -909,13 +894,12 @@ ad.primitive_transposes[xmap_p] = _xmap_transpose def _typecheck_xmap( *in_atoms, call_jaxpr, name, in_axes, out_axes, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, - spmd_in_axes, spmd_out_axes, in_positional_semantics, - out_positional_semantics): + spmd_in_axes, spmd_out_axes): in_avals = [x.aval for x in in_atoms] axis_resource_count = _get_axis_resource_count( - axis_resources, resource_env, in_positional_semantics) + axis_resources, resource_env) local_axis_sizes = { - axis: axis_resource_count[axis].to_local(out_positional_semantics, global_size) + axis: axis_resource_count[axis].to_local(global_size) for axis, global_size in global_axis_sizes.items() } binder_in_avals = [_insert_aval_axes(v.aval, a_in_axes, local_axis_sizes) @@ -1002,10 +986,9 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): else: spmd_out_axes = None axis_resource_count = _get_axis_resource_count( - params['axis_resources'], params['resource_env'], - params['in_positional_semantics']) + params['axis_resources'], params['resource_env']) local_axis_sizes = { - axis: axis_resource_count[axis].to_local(params['out_positional_semantics'], global_size) + axis: axis_resource_count[axis].to_local(global_size) for axis, global_size in global_axis_sizes.items() } out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes) @@ -1028,9 +1011,6 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): donated_invars=new_donated_invars, spmd_in_axes=new_spmd_in_axes, spmd_out_axes=spmd_out_axes, - in_positional_semantics=( - *( _PositionalSemantics.GLOBAL,) * len(constvars), - *params['in_positional_semantics']), call_jaxpr=call_jaxpr) del new_params['out_axes_thunk'] del new_params['spmd_out_axes_thunk'] @@ -1163,11 +1143,9 @@ def _jaxpr_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params unknown_arg_tracers = [t for t in tracers if not t.pval.is_known()] # Create output tracers for unknown part, adjusting avals. axis_resource_count = _get_axis_resource_count( - params['axis_resources'], params['resource_env'], - params['in_positional_semantics']) + params['axis_resources'], params['resource_env']) local_axis_sizes = { - ax: axis_resource_count[ax].to_local( - params['out_positional_semantics'], global_size) + ax: axis_resource_count[ax].to_local(global_size) for ax, global_size in global_axis_sizes.items()} out_pvals = [pe.PartialVal.unknown(_insert_aval_axes(a, ax, local_axis_sizes)) for a, ax in zip(out_avals, out_axes_unknown)] @@ -1310,17 +1288,16 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes, in_axes, out_axes, donated_invars, global_axis_sizes, spmd_in_axes, spmd_out_axes, - in_positional_semantics, out_positional_semantics, axis_resources, resource_env, backend): xla.check_backend_matches(backend, ctx.module_context.platform) # The only way for any of those two assertions to be violated is when xmap # is using the SPMD lowering, but then this rule shouldn't even trigger. assert spmd_in_axes is None and spmd_out_axes is None plan = EvaluationPlan.from_axis_resources( - axis_resources, resource_env, global_axis_sizes, in_positional_semantics) + axis_resources, resource_env, global_axis_sizes) axis_resource_count = _get_axis_resource_count( - axis_resources, resource_env, in_positional_semantics) + axis_resources, resource_env) if any(resource_count.distributed for resource_count in axis_resource_count.values()): raise NotImplementedError @@ -1387,12 +1364,11 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes, def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, call_jaxpr, name, in_axes, out_axes, donated_invars, global_axis_sizes, spmd_in_axes, - spmd_out_axes, in_positional_semantics, - out_positional_semantics, axis_resources, + spmd_out_axes, axis_resources, resource_env, backend): xla.check_backend_matches(backend, ctx.module_context.platform) plan = EvaluationPlan.from_axis_resources( - axis_resources, resource_env, global_axis_sizes, in_positional_semantics) + axis_resources, resource_env, global_axis_sizes) resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr) f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ()))) @@ -1450,14 +1426,13 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes, call_jaxpr, name, in_axes, out_axes, donated_invars, global_axis_sizes, spmd_in_axes, - spmd_out_axes, in_positional_semantics, - out_positional_semantics, axis_resources, + spmd_out_axes, axis_resources, resource_env, backend): assert spmd_in_axes is None and spmd_out_axes is None # This first part (up to vtile_manual) is shared with non-MANUAL SPMD rule. xla.check_backend_matches(backend, ctx.module_context.platform) plan = EvaluationPlan.from_axis_resources( - axis_resources, resource_env, global_axis_sizes, in_positional_semantics) + axis_resources, resource_env, global_axis_sizes) manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values())) resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr) @@ -1574,35 +1549,14 @@ class ResourceCount(NamedTuple): nlocal: Optional[int] distributed: bool - def to_local(self, semantics, global_size): - if semantics == _PositionalSemantics.GLOBAL: - return global_size - elif semantics == _PositionalSemantics.LOCAL: - assert self.nlocal is not None - assert global_size % self.nglobal == 0, "Please report this issue!" - return (global_size // self.nglobal) * self.nlocal - else: - raise AssertionError(f"Unhandled case {_positional_semantics}") - - def to_global(self, semantics, local_size): - if semantics == _PositionalSemantics.GLOBAL: - return local_size - elif semantics == _PositionalSemantics.LOCAL: - assert self.nlocal is not None - assert local_size % self.nlocal == 0, "Please report this issue!" - return (local_size // self.nlocal) * self.nglobal - else: - raise AssertionError(f"Unhandled case {_positional_semantics}") + def to_local(self, global_size): + return global_size def _get_axis_resource_count( - axis_resources, resource_env, - in_positional_semantics) -> Dict[ResourceAxisName, ResourceCount]: + axis_resources, resource_env) -> Dict[ResourceAxisName, ResourceCount]: global_res_shape = resource_env.shape - if all(ips == _PositionalSemantics.GLOBAL for ips in in_positional_semantics): - local_res_shape = None - else: - local_res_shape = resource_env.local_shape + local_res_shape = None distributed = (False if resource_env.physical_mesh.empty else resource_env.physical_mesh.size != len(resource_env.physical_mesh.local_devices)) @@ -1621,13 +1575,10 @@ def _get_axis_resource_count( def _get_axis_sizes(args_flat: Iterable[Any], in_axes_flat: Iterable[AxisNamePos], global_axis_sizes: Dict[AxisName, int], - axis_resource_count: Dict[AxisName, ResourceCount], - in_positional_semantics: Sequence[_PositionalSemantics]): + axis_resource_count: Dict[AxisName, ResourceCount]): global_axis_sizes = dict(global_axis_sizes) - for arg, in_axes, ips in zip(args_flat, in_axes_flat, in_positional_semantics): + for arg, in_axes in zip(args_flat, in_axes_flat): for name, dim in in_axes.items(): - resources = axis_resource_count[name] - local_ = "local " if resources.distributed else "" try: dim_size = arg.shape[dim] except IndexError: @@ -1636,30 +1587,14 @@ def _get_axis_sizes(args_flat: Iterable[Any], f"{in_axes.user_repr}, which implies that it has at least " f"{max(in_axes.values()) + 1} dimensions, but the argument " f"has rank {arg.ndim}") - if ips == _PositionalSemantics.LOCAL: - local_dim_size = dim_size - if local_dim_size % resources.nlocal != 0: - raise ValueError(f"One of xmap arguments has an in_axes specification of " - f"{in_axes.user_repr}, which implies that its size in dimension " - f"{dim} ({local_dim_size}) should be divisible by the number of " - f"{local_}resources assigned to axis {name} ({resources.nlocal})") - global_dim_size = resources.to_global(ips, local_dim_size) - if name in global_axis_sizes: - expected_local_dim_size = resources.to_local(ips, global_axis_sizes[name]) - if local_dim_size != expected_local_dim_size: - raise ValueError(f"The {local_}size of axis {name} was previously inferred to be " - f"{expected_local_dim_size}, but found an argument of shape {arg.shape} " - f"with in_axes specification {in_axes.user_repr}. Shape mismatch " - f"occurs in dimension {dim}: {local_dim_size} != {expected_local_dim_size}") - elif ips == _PositionalSemantics.GLOBAL: - global_dim_size = dim_size - if name in global_axis_sizes: - expected_global_dim_size = global_axis_sizes[name] - if global_dim_size != expected_global_dim_size: - raise ValueError(f"The size of axis {name} was previously inferred to be " - f"{expected_global_dim_size}, but found an argument of shape {arg.shape} " - f"with in_axes specification {in_axes.user_repr}. Shape mismatch " - f"occurs in dimension {dim}: {global_dim_size} != {expected_global_dim_size}") + global_dim_size = dim_size + if name in global_axis_sizes: + expected_global_dim_size = global_axis_sizes[name] + if global_dim_size != expected_global_dim_size: + raise ValueError(f"The size of axis {name} was previously inferred to be " + f"{expected_global_dim_size}, but found an argument of shape {arg.shape} " + f"with in_axes specification {in_axes.user_repr}. Shape mismatch " + f"occurs in dimension {dim}: {global_dim_size} != {expected_global_dim_size}") global_axis_sizes[name] = global_dim_size return FrozenDict(global_axis_sizes) @@ -1792,7 +1727,7 @@ def _check_out_avals_vs_out_axes(out_avals: Sequence[core.AbstractValue], def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env, global_axis_sizes, in_axes_flat, - in_positional_semantics, args_flat): + args_flat): @lru_cache() def _check_sharding(in_sharding, xmap_sharding, ndim, arr_flavor): if not pxla.are_op_shardings_equal( @@ -1805,8 +1740,7 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env, f"xmap spec: {xmap_sharding.spec}") mesh_in_axes = EvaluationPlan.from_axis_resources( # pytype: disable=wrong-arg-types # always-use-return-annotations - axis_resources, resource_env, global_axis_sizes, - in_positional_semantics).to_mesh_axes(in_axes_flat) + axis_resources, resource_env, global_axis_sizes).to_mesh_axes(in_axes_flat) for arg, xmap_array_mapping in safe_zip(args_flat, mesh_in_axes): if isinstance(arg, ArrayImpl): if not isinstance(arg.sharding, NamedSharding):