From 0d07372995f124c5ba6d1d88e6c2544494a4dc9b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 10 Feb 2023 15:36:04 -0800 Subject: [PATCH] Point to the exact primitive name nested under jit/pjit instead of mentioning all possible ones. PiperOrigin-RevId: 508770290 --- jax/_src/dispatch.py | 20 +++++++++++++++----- jax/_src/interpreters/pxla.py | 19 +++++++++++-------- jax/_src/pjit.py | 4 ++-- tests/pjit_test.py | 22 +++++++++++----------- 4 files changed, 39 insertions(+), 26 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 6bb2d5515..d0844f6ae 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -22,7 +22,7 @@ import itertools import time from typing import ( Any, Callable, Dict, Iterable, Iterator, Optional, Protocol, - Sequence, Set, Tuple, List, Type, Union) + Sequence, Set, Tuple, List, Type, Union, NamedTuple) import logging import os import re @@ -555,18 +555,28 @@ def jaxpr_has_primitive(jaxpr, prim_name: str): return False -def jaxpr_shardings(jaxpr) -> Iterator[Tuple[jax.sharding.XLACompatibleSharding, str]]: +class SourceInfo(NamedTuple): + source_info: str + eqn_name: str + + +def jaxpr_shardings( + jaxpr) -> Iterator[Tuple[jax.sharding.XLACompatibleSharding, SourceInfo]]: from jax.experimental import pjit, shard_map for eqn in jaxpr.eqns: if eqn.primitive is pjit.sharding_constraint_p: - yield (eqn.params['sharding'], source_info_util.summarize(eqn.source_info)) + source_info = SourceInfo(source_info_util.summarize(eqn.source_info), + eqn.primitive.name) + yield (eqn.params['sharding'], source_info) elif eqn.primitive is pjit.pjit_p: - source_info = source_info_util.summarize(eqn.source_info) + source_info = SourceInfo(source_info_util.summarize(eqn.source_info), + eqn.primitive.name) yield from ((i, source_info) for i in eqn.params['in_shardings']) yield from ((o, source_info) for o in eqn.params['out_shardings']) elif eqn.primitive is shard_map.shard_map_p: - source_info = source_info_util.summarize(eqn.source_info) + source_info = SourceInfo(source_info_util.summarize(eqn.source_info), + eqn.primitive.name) def _names_to_pspec(names): ndmin = max(names) + 1 if names else 0 return PartitionSpec(*(names.get(i) for i in range(ndmin))) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 7e9a1ca13..3a1bb0d17 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2727,8 +2727,6 @@ class MismatchType(enum.Enum): return 'explicit input sharding' elif self.name == 'OUT_SHARDING': return 'explicit output sharding' - elif self.name == 'SHARDING_INSIDE_COMPUTATION': - return 'with_sharding_constraint or nested pjit or shard_map' elif self.name == 'CONTEXT_DEVICES': return 'devices' return f'{self.name}' @@ -2738,7 +2736,7 @@ class MismatchType(enum.Enum): class DeviceAssignmentMismatch: da: Sequence[xc.Device] m_type: MismatchType - source_info: Optional[str] + source_info: Optional[dispatch.SourceInfo] @property def device_ids(self) -> Sequence[int]: @@ -2753,14 +2751,18 @@ class DeviceAssignmentMismatch: @property def source_info_str(self): - return "" if self.source_info is None else f" at {self.source_info}" + return "" if self.source_info is None else f" at {self.source_info.source_info}" @property def _dev_ids_plat_str(self): return f"device ids {self.device_ids} on platform {self.platform}" + def m_type_str(self, api_name): + return (f'{self.source_info.eqn_name} inside {api_name}' + if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type) + def _str(self, api_name): - return (f"{self._maybe_api_name(api_name)} {self.m_type} with " + return (f"{self._maybe_api_name(api_name)} {self.m_type_str(api_name)} with " f"{self._dev_ids_plat_str}{self.source_info_str}") @@ -2768,9 +2770,10 @@ class DeviceAssignmentMismatchError(Exception): pass -ShardingInfo = Tuple[Union[sharding_internal.XLACompatibleSharding, - UnspecifiedValue, AUTOAxisResource], - MismatchType, Optional[str]] +ShardingInfo = Tuple[ + Union[sharding_internal.XLACompatibleSharding, UnspecifiedValue, + AUTOAxisResource], + MismatchType, Optional[Any]] # Any is dispatch.SourceInfo to avoid circular imports def _get_and_check_device_assignment( shardings: Iterable[ShardingInfo], diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 7cdfc99e5..92b9dffe7 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -128,7 +128,7 @@ def _find_arg_mismatch(arg_list, fails, fun_name): if first_err.m_type == pxla.MismatchType.ARG_SHARDING: if first_err.da == inp_da: mismatched_args_msg.append( - (f"argument {name} of {fun_name} with {aval.str_short()} and " + (f"argument {name} of {fun_name} with shape {aval.str_short()} and " f"{first_err._dev_ids_plat_str}")) break @@ -136,7 +136,7 @@ def _find_arg_mismatch(arg_list, fails, fun_name): if second_err.m_type == pxla.MismatchType.ARG_SHARDING: if second_err.da == inp_da: mismatched_args_msg.append( - (f"argument {name} of {fun_name} with {aval.str_short()} and " + (f"argument {name} of {fun_name} with shape {aval.str_short()} and " f"{second_err._dev_ids_plat_str}")) break return mismatched_args_msg diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 7eccb8636..732ce5208 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -2354,8 +2354,8 @@ class ArrayPjitTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, "Received incompatible devices for pjitted computation. Got argument " - r"x of.*\ with int.*\[3\] and device ids \[0\].*and argument " - r"y of.*\ with int.*\[3\] and device ids \[1\].*"): + r"x of.*\ with shape int.*\[3\] and device ids \[0\].*and " + r"argument y of.*\ with shape int.*\[3\] and device ids \[1\].*"): pjit(lambda x, y: (x, y))(a, b) def test_pjit_pytree_inp_device_assignment_mismatch(self): @@ -2366,9 +2366,9 @@ class ArrayPjitTest(jtu.JaxTestCase): NamedSharding(mesh, P('x', 'y'))) msg = ("Received incompatible devices for pjitted computation. Got " - r"argument {} of.* with int.*\[3\] and device ids \[0\].*and " - r"argument {} of.* with int.*\[8,2\] and device ids " - r"\[0, 1, 2, 3\].*") + r"argument {} of.* with shape int.*\[3\] and device ids " + r"\[0\].*and argument {} of.* with shape int.*\[8,2\] and " + r"device ids \[0, 1, 2, 3\].*") with self.assertRaisesRegex( ValueError, msg.format(r'tuple_inp\[0\]', r'tuple_inp\[1\]\[0\]')): @@ -2513,8 +2513,8 @@ class ArrayPjitTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, "Received incompatible devices for jitted computation. Got argument " - r"inp of.*sharded_inp with bfloat16\[8,2\] and device ids \[0\].*" - r"with_sharding_constraint.*with device ids \[0, 1, 2, 3\].*"): + r"inp of.*sharded_inp with shape bfloat16\[8,2\] and device ids \[0\].*" + r"sharding_constraint inside jit with device ids \[0, 1, 2, 3\].*"): sharded_inp(committed_inp) @pjit @@ -2527,8 +2527,8 @@ class ArrayPjitTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, "Received incompatible devices for pjitted computation. Got argument " - r"inp1 of.*my_nested_pjit with bfloat16\[8,2\] and device ids \[0\].*" - r"nested pjit.*with device ids \[0, 1, 2, 3\].*"): + r"inp1 of.*my_nested_pjit with shape bfloat16\[8,2\] and device ids \[0\].*" + r"pjit inside pjit with device ids \[0, 1, 2, 3\].*"): my_nested_pjit(committed_inp, committed_inp, committed_inp) @jax_array(True) @@ -2546,8 +2546,8 @@ class ArrayPjitTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, "Received incompatible devices for jitted computation. Got explicit " - r"output sharding with device ids \[0\].*with_sharding_constraint.*with " - r"device ids \[0, 1, 2, 3\].*"): + r"output sharding with device ids \[0\].*sharding_constraint inside " + r"jit with device ids \[0, 1, 2, 3\].*"): sharded_zeros((4096, 3072), P('x', 'y')) @jax_array(True)