Remove the device assignment check in _resolve_in_shardings since that's historical and not needed anymore

PiperOrigin-RevId: 674091716
This commit is contained in:
Yash Katariya 2024-09-12 18:47:25 -07:00 committed by jax authors
parent dffac29e63
commit 3d1d5e94ab
2 changed files with 4 additions and 22 deletions

View File

@ -1340,7 +1340,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
When ``append = jnp.array([[3],[1]])``, it is appended to ``a`` along ``axis``
before computing the difference.
>>> jnp.diff(a, append=jnp.array([[3],[1]]))
Array([[ 4, -3, 7, -6],
[ 5, -1, -3, -3]], dtype=int32)

View File

@ -19,7 +19,6 @@ from collections.abc import Callable, Sequence, Iterable
import dataclasses
from functools import partial
import inspect
import itertools as it
import logging
import operator as op
import weakref
@ -1494,11 +1493,8 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
return tuple(resolved_in_layouts)
def _resolve_in_shardings(
args, pjit_in_shardings: Sequence[PjitSharding],
out_shardings: Sequence[PjitSharding],
pjit_mesh: pxla.Mesh | None,
check_device_assignment: bool = True) -> Sequence[PjitSharding]:
def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
) -> Sequence[PjitSharding]:
# If True, means that device or backend is set by the user on pjit and it
# has the same semantics as device_put i.e. doesn't matter which device the
# arg is on, reshard it to the device mentioned. So don't do any of the
@ -1521,18 +1517,6 @@ def _resolve_in_shardings(
if getattr(a, '_committed', True):
committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None))
# Check if the device_assignment across inputs, outputs and arguments is the
# same.
if check_device_assignment:
pxla._get_and_check_device_assignment(
it.chain(
util.stable_unique(committed_arg_shardings),
((i, pxla.MismatchType.IN_SHARDING, None)
for i in util.stable_unique(pjit_in_shardings)),
((o, pxla.MismatchType.OUT_SHARDING, None)
for o in util.stable_unique(out_shardings))),
(None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat)))
resolved_in_shardings = []
for arg, pjit_in_s in zip(args, pjit_in_shardings):
# arg sharding can be None in case of ShapeDtypeStruct. jax.Array does
@ -1602,9 +1586,7 @@ def _resolve_and_lower(
args, jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
lowering_platforms, lowering_parameters, pgle_profiler):
in_shardings = _resolve_in_shardings(
args, in_shardings, out_shardings,
resource_env.physical_mesh if resource_env is not None else None)
in_shardings = _resolve_in_shardings(args, in_shardings)
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
jaxpr.in_avals)
lowered = _pjit_lower(