Point to the exact primitive name nested under jit/pjit instead of mentioning all possible ones.

PiperOrigin-RevId: 508770290
This commit is contained in:
Yash Katariya 2023-02-10 15:36:04 -08:00 committed by jax authors
parent 568a93bcd1
commit 0d07372995
4 changed files with 39 additions and 26 deletions

View File

@ -22,7 +22,7 @@ import itertools
import time
from typing import (
Any, Callable, Dict, Iterable, Iterator, Optional, Protocol,
Sequence, Set, Tuple, List, Type, Union)
Sequence, Set, Tuple, List, Type, Union, NamedTuple)
import logging
import os
import re
@ -555,18 +555,28 @@ def jaxpr_has_primitive(jaxpr, prim_name: str):
return False
def jaxpr_shardings(jaxpr) -> Iterator[Tuple[jax.sharding.XLACompatibleSharding, str]]:
class SourceInfo(NamedTuple):
source_info: str
eqn_name: str
def jaxpr_shardings(
jaxpr) -> Iterator[Tuple[jax.sharding.XLACompatibleSharding, SourceInfo]]:
from jax.experimental import pjit, shard_map
for eqn in jaxpr.eqns:
if eqn.primitive is pjit.sharding_constraint_p:
yield (eqn.params['sharding'], source_info_util.summarize(eqn.source_info))
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
eqn.primitive.name)
yield (eqn.params['sharding'], source_info)
elif eqn.primitive is pjit.pjit_p:
source_info = source_info_util.summarize(eqn.source_info)
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
eqn.primitive.name)
yield from ((i, source_info) for i in eqn.params['in_shardings'])
yield from ((o, source_info) for o in eqn.params['out_shardings'])
elif eqn.primitive is shard_map.shard_map_p:
source_info = source_info_util.summarize(eqn.source_info)
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
eqn.primitive.name)
def _names_to_pspec(names):
ndmin = max(names) + 1 if names else 0
return PartitionSpec(*(names.get(i) for i in range(ndmin)))

View File

@ -2727,8 +2727,6 @@ class MismatchType(enum.Enum):
return 'explicit input sharding'
elif self.name == 'OUT_SHARDING':
return 'explicit output sharding'
elif self.name == 'SHARDING_INSIDE_COMPUTATION':
return 'with_sharding_constraint or nested pjit or shard_map'
elif self.name == 'CONTEXT_DEVICES':
return 'devices'
return f'{self.name}'
@ -2738,7 +2736,7 @@ class MismatchType(enum.Enum):
class DeviceAssignmentMismatch:
da: Sequence[xc.Device]
m_type: MismatchType
source_info: Optional[str]
source_info: Optional[dispatch.SourceInfo]
@property
def device_ids(self) -> Sequence[int]:
@ -2753,14 +2751,18 @@ class DeviceAssignmentMismatch:
@property
def source_info_str(self):
return "" if self.source_info is None else f" at {self.source_info}"
return "" if self.source_info is None else f" at {self.source_info.source_info}"
@property
def _dev_ids_plat_str(self):
return f"device ids {self.device_ids} on platform {self.platform}"
def m_type_str(self, api_name):
return (f'{self.source_info.eqn_name} inside {api_name}'
if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type)
def _str(self, api_name):
return (f"{self._maybe_api_name(api_name)} {self.m_type} with "
return (f"{self._maybe_api_name(api_name)} {self.m_type_str(api_name)} with "
f"{self._dev_ids_plat_str}{self.source_info_str}")
@ -2768,9 +2770,10 @@ class DeviceAssignmentMismatchError(Exception):
pass
ShardingInfo = Tuple[Union[sharding_internal.XLACompatibleSharding,
UnspecifiedValue, AUTOAxisResource],
MismatchType, Optional[str]]
ShardingInfo = Tuple[
Union[sharding_internal.XLACompatibleSharding, UnspecifiedValue,
AUTOAxisResource],
MismatchType, Optional[Any]] # Any is dispatch.SourceInfo to avoid circular imports
def _get_and_check_device_assignment(
shardings: Iterable[ShardingInfo],

View File

@ -128,7 +128,7 @@ def _find_arg_mismatch(arg_list, fails, fun_name):
if first_err.m_type == pxla.MismatchType.ARG_SHARDING:
if first_err.da == inp_da:
mismatched_args_msg.append(
(f"argument {name} of {fun_name} with {aval.str_short()} and "
(f"argument {name} of {fun_name} with shape {aval.str_short()} and "
f"{first_err._dev_ids_plat_str}"))
break
@ -136,7 +136,7 @@ def _find_arg_mismatch(arg_list, fails, fun_name):
if second_err.m_type == pxla.MismatchType.ARG_SHARDING:
if second_err.da == inp_da:
mismatched_args_msg.append(
(f"argument {name} of {fun_name} with {aval.str_short()} and "
(f"argument {name} of {fun_name} with shape {aval.str_short()} and "
f"{second_err._dev_ids_plat_str}"))
break
return mismatched_args_msg

View File

@ -2354,8 +2354,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for pjitted computation. Got argument "
r"x of.*\<lambda\> with int.*\[3\] and device ids \[0\].*and argument "
r"y of.*\<lambda\> with int.*\[3\] and device ids \[1\].*"):
r"x of.*\<lambda\> with shape int.*\[3\] and device ids \[0\].*and "
r"argument y of.*\<lambda\> with shape int.*\[3\] and device ids \[1\].*"):
pjit(lambda x, y: (x, y))(a, b)
def test_pjit_pytree_inp_device_assignment_mismatch(self):
@ -2366,9 +2366,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
NamedSharding(mesh, P('x', 'y')))
msg = ("Received incompatible devices for pjitted computation. Got "
r"argument {} of.*<lambda> with int.*\[3\] and device ids \[0\].*and "
r"argument {} of.*<lambda> with int.*\[8,2\] and device ids "
r"\[0, 1, 2, 3\].*")
r"argument {} of.*<lambda> with shape int.*\[3\] and device ids "
r"\[0\].*and argument {} of.*<lambda> with shape int.*\[8,2\] and "
r"device ids \[0, 1, 2, 3\].*")
with self.assertRaisesRegex(
ValueError, msg.format(r'tuple_inp\[0\]', r'tuple_inp\[1\]\[0\]')):
@ -2513,8 +2513,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for jitted computation. Got argument "
r"inp of.*sharded_inp with bfloat16\[8,2\] and device ids \[0\].*"
r"with_sharding_constraint.*with device ids \[0, 1, 2, 3\].*"):
r"inp of.*sharded_inp with shape bfloat16\[8,2\] and device ids \[0\].*"
r"sharding_constraint inside jit with device ids \[0, 1, 2, 3\].*"):
sharded_inp(committed_inp)
@pjit
@ -2527,8 +2527,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for pjitted computation. Got argument "
r"inp1 of.*my_nested_pjit with bfloat16\[8,2\] and device ids \[0\].*"
r"nested pjit.*with device ids \[0, 1, 2, 3\].*"):
r"inp1 of.*my_nested_pjit with shape bfloat16\[8,2\] and device ids \[0\].*"
r"pjit inside pjit with device ids \[0, 1, 2, 3\].*"):
my_nested_pjit(committed_inp, committed_inp, committed_inp)
@jax_array(True)
@ -2546,8 +2546,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for jitted computation. Got explicit "
r"output sharding with device ids \[0\].*with_sharding_constraint.*with "
r"device ids \[0, 1, 2, 3\].*"):
r"output sharding with device ids \[0\].*sharding_constraint inside "
r"jit with device ids \[0, 1, 2, 3\].*"):
sharded_zeros((4096, 3072), P('x', 'y'))
@jax_array(True)