mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Dedupe shardings before passing them to _get_and_check_device_assignment
In practice, the number of different shardings is usually much smaller then the number of inputs/output. PiperOrigin-RevId: 600558309
This commit is contained in:
parent
8226ff3880
commit
46f796b38d
@ -1985,10 +1985,11 @@ def lower_sharding_computation(
|
||||
# should be the same.
|
||||
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
|
||||
backend, device_assignment = _get_and_check_device_assignment(
|
||||
it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings],
|
||||
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings],
|
||||
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
|
||||
for js, source_info in jaxpr_sharding]),
|
||||
it.chain(
|
||||
((i, MismatchType.ARG_SHARDING, None) for i in util.stable_unique(in_shardings)),
|
||||
((o, MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings)),
|
||||
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
|
||||
for js, source_info in util.stable_unique(jaxpr_sharding))),
|
||||
devices_from_context)
|
||||
|
||||
transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
|
||||
|
@ -27,18 +27,19 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import stages
|
||||
from jax._src import dispatch
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import stages
|
||||
from jax._src import traceback_util
|
||||
from jax._src import api
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.api_util import (
|
||||
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
||||
@ -1182,30 +1183,29 @@ def _resolve_in_shardings(
|
||||
|
||||
committed_arg_shardings = []
|
||||
for a in args:
|
||||
if hasattr(a, 'sharding'):
|
||||
arg_s = a.sharding
|
||||
# arg sharding can be None in case of ShapeDtypeStruct. jax.Array does
|
||||
# not allow None as the sharding.
|
||||
if arg_s is None:
|
||||
continue
|
||||
if not isinstance(arg_s, XLACompatibleSharding):
|
||||
raise ValueError(f'One of the argument to pjit got sharding {arg_s} '
|
||||
'which is not a subclass of XLACompatibleSharding.')
|
||||
# Don't consider PmapSharding inputs as committed. They will get resharded
|
||||
# unconditionally.
|
||||
if isinstance(arg_s, PmapSharding):
|
||||
continue
|
||||
if getattr(a, '_committed', True):
|
||||
committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None))
|
||||
arg_s = getattr(a, 'sharding', None)
|
||||
# arg sharding can be None in case of ShapeDtypeStruct. jax.Array does
|
||||
# not allow None as the sharding.
|
||||
if arg_s is None:
|
||||
continue
|
||||
if not isinstance(arg_s, XLACompatibleSharding):
|
||||
raise ValueError(f'One of the argument to pjit got sharding {arg_s} '
|
||||
'which is not a subclass of XLACompatibleSharding.')
|
||||
# Don't consider PmapSharding inputs as committed. They will get resharded
|
||||
# unconditionally.
|
||||
if isinstance(arg_s, PmapSharding):
|
||||
continue
|
||||
if getattr(a, '_committed', True):
|
||||
committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None))
|
||||
|
||||
# Check if the device_assignment across inputs, outputs and arguments is the
|
||||
# same.
|
||||
if check_device_assignment:
|
||||
pxla._get_and_check_device_assignment(
|
||||
it.chain(
|
||||
committed_arg_shardings,
|
||||
[(i, pxla.MismatchType.IN_SHARDING, None) for i in pjit_in_shardings],
|
||||
[(o, pxla.MismatchType.OUT_SHARDING, None) for o in out_shardings]),
|
||||
util.stable_unique(committed_arg_shardings),
|
||||
((i, pxla.MismatchType.IN_SHARDING, None) for i in util.stable_unique(pjit_in_shardings)),
|
||||
((o, pxla.MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings))),
|
||||
(None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat)))
|
||||
|
||||
resolved_in_shardings = []
|
||||
|
@ -510,6 +510,14 @@ def distributed_debug_log(*pairs):
|
||||
logger.warning("\n".join(lines))
|
||||
|
||||
|
||||
def stable_unique(it: Iterable[T]) -> Iterable[T]:
|
||||
"""Returns unique elements from `it` in the order of occurrence.
|
||||
|
||||
The elements must be hashable.
|
||||
"""
|
||||
return dict.fromkeys(it).keys()
|
||||
|
||||
|
||||
class OrderedSet(Generic[T]):
|
||||
elts_set: set[T]
|
||||
elts_list: list[T]
|
||||
|
Loading…
x
Reference in New Issue
Block a user