Dedupe shardings before passing them to _get_and_check_device_assignment

In practice, the number of different shardings is usually much smaller then
the number of inputs/output.

PiperOrigin-RevId: 600558309
This commit is contained in:
Sergei Lebedev 2024-01-22 13:44:34 -08:00 committed by jax authors
parent 8226ff3880
commit 46f796b38d
3 changed files with 35 additions and 26 deletions

View File

@ -1985,10 +1985,11 @@ def lower_sharding_computation(
# should be the same.
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
backend, device_assignment = _get_and_check_device_assignment(
it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings],
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings],
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
for js, source_info in jaxpr_sharding]),
it.chain(
((i, MismatchType.ARG_SHARDING, None) for i in util.stable_unique(in_shardings)),
((o, MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings)),
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
for js, source_info in util.stable_unique(jaxpr_sharding))),
devices_from_context)
transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))

View File

@ -27,18 +27,19 @@ import warnings
import numpy as np
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import stages
from jax._src import dispatch
from jax._src import mesh as mesh_lib
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import tree_util
from jax._src import stages
from jax._src import traceback_util
from jax._src import api
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
@ -1182,30 +1183,29 @@ def _resolve_in_shardings(
committed_arg_shardings = []
for a in args:
if hasattr(a, 'sharding'):
arg_s = a.sharding
# arg sharding can be None in case of ShapeDtypeStruct. jax.Array does
# not allow None as the sharding.
if arg_s is None:
continue
if not isinstance(arg_s, XLACompatibleSharding):
raise ValueError(f'One of the argument to pjit got sharding {arg_s} '
'which is not a subclass of XLACompatibleSharding.')
# Don't consider PmapSharding inputs as committed. They will get resharded
# unconditionally.
if isinstance(arg_s, PmapSharding):
continue
if getattr(a, '_committed', True):
committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None))
arg_s = getattr(a, 'sharding', None)
# arg sharding can be None in case of ShapeDtypeStruct. jax.Array does
# not allow None as the sharding.
if arg_s is None:
continue
if not isinstance(arg_s, XLACompatibleSharding):
raise ValueError(f'One of the argument to pjit got sharding {arg_s} '
'which is not a subclass of XLACompatibleSharding.')
# Don't consider PmapSharding inputs as committed. They will get resharded
# unconditionally.
if isinstance(arg_s, PmapSharding):
continue
if getattr(a, '_committed', True):
committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None))
# Check if the device_assignment across inputs, outputs and arguments is the
# same.
if check_device_assignment:
pxla._get_and_check_device_assignment(
it.chain(
committed_arg_shardings,
[(i, pxla.MismatchType.IN_SHARDING, None) for i in pjit_in_shardings],
[(o, pxla.MismatchType.OUT_SHARDING, None) for o in out_shardings]),
util.stable_unique(committed_arg_shardings),
((i, pxla.MismatchType.IN_SHARDING, None) for i in util.stable_unique(pjit_in_shardings)),
((o, pxla.MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings))),
(None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat)))
resolved_in_shardings = []

View File

@ -510,6 +510,14 @@ def distributed_debug_log(*pairs):
logger.warning("\n".join(lines))
def stable_unique(it: Iterable[T]) -> Iterable[T]:
"""Returns unique elements from `it` in the order of occurrence.
The elements must be hashable.
"""
return dict.fromkeys(it).keys()
class OrderedSet(Generic[T]):
elts_set: set[T]
elts_list: list[T]