mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Point to the exact primitive name nested under jit/pjit instead of mentioning all possible ones.
PiperOrigin-RevId: 508770290
This commit is contained in:
parent
568a93bcd1
commit
0d07372995
@ -22,7 +22,7 @@ import itertools
|
||||
import time
|
||||
from typing import (
|
||||
Any, Callable, Dict, Iterable, Iterator, Optional, Protocol,
|
||||
Sequence, Set, Tuple, List, Type, Union)
|
||||
Sequence, Set, Tuple, List, Type, Union, NamedTuple)
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@ -555,18 +555,28 @@ def jaxpr_has_primitive(jaxpr, prim_name: str):
|
||||
return False
|
||||
|
||||
|
||||
def jaxpr_shardings(jaxpr) -> Iterator[Tuple[jax.sharding.XLACompatibleSharding, str]]:
|
||||
class SourceInfo(NamedTuple):
|
||||
source_info: str
|
||||
eqn_name: str
|
||||
|
||||
|
||||
def jaxpr_shardings(
|
||||
jaxpr) -> Iterator[Tuple[jax.sharding.XLACompatibleSharding, SourceInfo]]:
|
||||
from jax.experimental import pjit, shard_map
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive is pjit.sharding_constraint_p:
|
||||
yield (eqn.params['sharding'], source_info_util.summarize(eqn.source_info))
|
||||
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
|
||||
eqn.primitive.name)
|
||||
yield (eqn.params['sharding'], source_info)
|
||||
elif eqn.primitive is pjit.pjit_p:
|
||||
source_info = source_info_util.summarize(eqn.source_info)
|
||||
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
|
||||
eqn.primitive.name)
|
||||
yield from ((i, source_info) for i in eqn.params['in_shardings'])
|
||||
yield from ((o, source_info) for o in eqn.params['out_shardings'])
|
||||
elif eqn.primitive is shard_map.shard_map_p:
|
||||
source_info = source_info_util.summarize(eqn.source_info)
|
||||
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
|
||||
eqn.primitive.name)
|
||||
def _names_to_pspec(names):
|
||||
ndmin = max(names) + 1 if names else 0
|
||||
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
|
||||
|
@ -2727,8 +2727,6 @@ class MismatchType(enum.Enum):
|
||||
return 'explicit input sharding'
|
||||
elif self.name == 'OUT_SHARDING':
|
||||
return 'explicit output sharding'
|
||||
elif self.name == 'SHARDING_INSIDE_COMPUTATION':
|
||||
return 'with_sharding_constraint or nested pjit or shard_map'
|
||||
elif self.name == 'CONTEXT_DEVICES':
|
||||
return 'devices'
|
||||
return f'{self.name}'
|
||||
@ -2738,7 +2736,7 @@ class MismatchType(enum.Enum):
|
||||
class DeviceAssignmentMismatch:
|
||||
da: Sequence[xc.Device]
|
||||
m_type: MismatchType
|
||||
source_info: Optional[str]
|
||||
source_info: Optional[dispatch.SourceInfo]
|
||||
|
||||
@property
|
||||
def device_ids(self) -> Sequence[int]:
|
||||
@ -2753,14 +2751,18 @@ class DeviceAssignmentMismatch:
|
||||
|
||||
@property
|
||||
def source_info_str(self):
|
||||
return "" if self.source_info is None else f" at {self.source_info}"
|
||||
return "" if self.source_info is None else f" at {self.source_info.source_info}"
|
||||
|
||||
@property
|
||||
def _dev_ids_plat_str(self):
|
||||
return f"device ids {self.device_ids} on platform {self.platform}"
|
||||
|
||||
def m_type_str(self, api_name):
|
||||
return (f'{self.source_info.eqn_name} inside {api_name}'
|
||||
if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type)
|
||||
|
||||
def _str(self, api_name):
|
||||
return (f"{self._maybe_api_name(api_name)} {self.m_type} with "
|
||||
return (f"{self._maybe_api_name(api_name)} {self.m_type_str(api_name)} with "
|
||||
f"{self._dev_ids_plat_str}{self.source_info_str}")
|
||||
|
||||
|
||||
@ -2768,9 +2770,10 @@ class DeviceAssignmentMismatchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
ShardingInfo = Tuple[Union[sharding_internal.XLACompatibleSharding,
|
||||
UnspecifiedValue, AUTOAxisResource],
|
||||
MismatchType, Optional[str]]
|
||||
ShardingInfo = Tuple[
|
||||
Union[sharding_internal.XLACompatibleSharding, UnspecifiedValue,
|
||||
AUTOAxisResource],
|
||||
MismatchType, Optional[Any]] # Any is dispatch.SourceInfo to avoid circular imports
|
||||
|
||||
def _get_and_check_device_assignment(
|
||||
shardings: Iterable[ShardingInfo],
|
||||
|
@ -128,7 +128,7 @@ def _find_arg_mismatch(arg_list, fails, fun_name):
|
||||
if first_err.m_type == pxla.MismatchType.ARG_SHARDING:
|
||||
if first_err.da == inp_da:
|
||||
mismatched_args_msg.append(
|
||||
(f"argument {name} of {fun_name} with {aval.str_short()} and "
|
||||
(f"argument {name} of {fun_name} with shape {aval.str_short()} and "
|
||||
f"{first_err._dev_ids_plat_str}"))
|
||||
break
|
||||
|
||||
@ -136,7 +136,7 @@ def _find_arg_mismatch(arg_list, fails, fun_name):
|
||||
if second_err.m_type == pxla.MismatchType.ARG_SHARDING:
|
||||
if second_err.da == inp_da:
|
||||
mismatched_args_msg.append(
|
||||
(f"argument {name} of {fun_name} with {aval.str_short()} and "
|
||||
(f"argument {name} of {fun_name} with shape {aval.str_short()} and "
|
||||
f"{second_err._dev_ids_plat_str}"))
|
||||
break
|
||||
return mismatched_args_msg
|
||||
|
@ -2354,8 +2354,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Received incompatible devices for pjitted computation. Got argument "
|
||||
r"x of.*\<lambda\> with int.*\[3\] and device ids \[0\].*and argument "
|
||||
r"y of.*\<lambda\> with int.*\[3\] and device ids \[1\].*"):
|
||||
r"x of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and "
|
||||
r"argument y of.*\<lambda\> with shape int.*\[3\] and device ids \[1\].*"):
|
||||
pjit(lambda x, y: (x, y))(a, b)
|
||||
|
||||
def test_pjit_pytree_inp_device_assignment_mismatch(self):
|
||||
@ -2366,9 +2366,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
msg = ("Received incompatible devices for pjitted computation. Got "
|
||||
r"argument {} of.*<lambda> with int.*\[3\] and device ids \[0\].*and "
|
||||
r"argument {} of.*<lambda> with int.*\[8,2\] and device ids "
|
||||
r"\[0, 1, 2, 3\].*")
|
||||
r"argument {} of.*<lambda> with shape int.*\[3\] and device ids "
|
||||
r"\[0\].*and argument {} of.*<lambda> with shape int.*\[8,2\] and "
|
||||
r"device ids \[0, 1, 2, 3\].*")
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, msg.format(r'tuple_inp\[0\]', r'tuple_inp\[1\]\[0\]')):
|
||||
@ -2513,8 +2513,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Received incompatible devices for jitted computation. Got argument "
|
||||
r"inp of.*sharded_inp with bfloat16\[8,2\] and device ids \[0\].*"
|
||||
r"with_sharding_constraint.*with device ids \[0, 1, 2, 3\].*"):
|
||||
r"inp of.*sharded_inp with shape bfloat16\[8,2\] and device ids \[0\].*"
|
||||
r"sharding_constraint inside jit with device ids \[0, 1, 2, 3\].*"):
|
||||
sharded_inp(committed_inp)
|
||||
|
||||
@pjit
|
||||
@ -2527,8 +2527,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Received incompatible devices for pjitted computation. Got argument "
|
||||
r"inp1 of.*my_nested_pjit with bfloat16\[8,2\] and device ids \[0\].*"
|
||||
r"nested pjit.*with device ids \[0, 1, 2, 3\].*"):
|
||||
r"inp1 of.*my_nested_pjit with shape bfloat16\[8,2\] and device ids \[0\].*"
|
||||
r"pjit inside pjit with device ids \[0, 1, 2, 3\].*"):
|
||||
my_nested_pjit(committed_inp, committed_inp, committed_inp)
|
||||
|
||||
@jax_array(True)
|
||||
@ -2546,8 +2546,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Received incompatible devices for jitted computation. Got explicit "
|
||||
r"output sharding with device ids \[0\].*with_sharding_constraint.*with "
|
||||
r"device ids \[0, 1, 2, 3\].*"):
|
||||
r"output sharding with device ids \[0\].*sharding_constraint inside "
|
||||
r"jit with device ids \[0, 1, 2, 3\].*"):
|
||||
sharded_zeros((4096, 3072), P('x', 'y'))
|
||||
|
||||
@jax_array(True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user