mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove the device assignment check in _resolve_in_shardings since that's historical and not needed anymore
PiperOrigin-RevId: 674091716
This commit is contained in:
parent
dffac29e63
commit
3d1d5e94ab
@ -1340,7 +1340,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
|
||||
|
||||
When ``append = jnp.array([[3],[1]])``, it is appended to ``a`` along ``axis``
|
||||
before computing the difference.
|
||||
|
||||
|
||||
>>> jnp.diff(a, append=jnp.array([[3],[1]]))
|
||||
Array([[ 4, -3, 7, -6],
|
||||
[ 5, -1, -3, -3]], dtype=int32)
|
||||
|
@ -19,7 +19,6 @@ from collections.abc import Callable, Sequence, Iterable
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import inspect
|
||||
import itertools as it
|
||||
import logging
|
||||
import operator as op
|
||||
import weakref
|
||||
@ -1494,11 +1493,8 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
|
||||
return tuple(resolved_in_layouts)
|
||||
|
||||
|
||||
def _resolve_in_shardings(
|
||||
args, pjit_in_shardings: Sequence[PjitSharding],
|
||||
out_shardings: Sequence[PjitSharding],
|
||||
pjit_mesh: pxla.Mesh | None,
|
||||
check_device_assignment: bool = True) -> Sequence[PjitSharding]:
|
||||
def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
|
||||
) -> Sequence[PjitSharding]:
|
||||
# If True, means that device or backend is set by the user on pjit and it
|
||||
# has the same semantics as device_put i.e. doesn't matter which device the
|
||||
# arg is on, reshard it to the device mentioned. So don't do any of the
|
||||
@ -1521,18 +1517,6 @@ def _resolve_in_shardings(
|
||||
if getattr(a, '_committed', True):
|
||||
committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None))
|
||||
|
||||
# Check if the device_assignment across inputs, outputs and arguments is the
|
||||
# same.
|
||||
if check_device_assignment:
|
||||
pxla._get_and_check_device_assignment(
|
||||
it.chain(
|
||||
util.stable_unique(committed_arg_shardings),
|
||||
((i, pxla.MismatchType.IN_SHARDING, None)
|
||||
for i in util.stable_unique(pjit_in_shardings)),
|
||||
((o, pxla.MismatchType.OUT_SHARDING, None)
|
||||
for o in util.stable_unique(out_shardings))),
|
||||
(None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat)))
|
||||
|
||||
resolved_in_shardings = []
|
||||
for arg, pjit_in_s in zip(args, pjit_in_shardings):
|
||||
# arg sharding can be None in case of ShapeDtypeStruct. jax.Array does
|
||||
@ -1602,9 +1586,7 @@ def _resolve_and_lower(
|
||||
args, jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
|
||||
lowering_platforms, lowering_parameters, pgle_profiler):
|
||||
in_shardings = _resolve_in_shardings(
|
||||
args, in_shardings, out_shardings,
|
||||
resource_env.physical_mesh if resource_env is not None else None)
|
||||
in_shardings = _resolve_in_shardings(args, in_shardings)
|
||||
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
|
||||
jaxpr.in_avals)
|
||||
lowered = _pjit_lower(
|
||||
|
Loading…
x
Reference in New Issue
Block a user