Move DeviceAssignmentMismatchError exception catching code to def lower method of Traced so that all libraries calling traced.lower() see a better error message

PiperOrigin-RevId: 674095608
This commit is contained in:
Yash Katariya 2024-09-12 19:02:57 -07:00 committed by jax authors
parent 3d1d5e94ab
commit 634fbb5bec
2 changed files with 13 additions and 12 deletions

View File

@ -474,16 +474,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
@api_boundary
def lower(*args, **kwargs):
traced = trace(*args, **kwargs)
try:
return traced.lower()
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
fun_name = getattr(fun, '__qualname__',
getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
fun_name, fails, traced._args_flat, 'jit', traced._arg_names)
raise ValueError(msg) from None
return trace(*args, **kwargs).lower()
@api_boundary
def eval_shape(*args, **kwargs):
@ -503,7 +494,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
pgle_profiler=None)
return stages.Traced(
p.params['jaxpr'], args_info, p.params["name"],p.out_tree,
p.params['jaxpr'], args_info, p.params["name"], p.out_tree,
lower_callable, args_flat, p.arg_names, p.num_consts)
wrapped = _cpp_pjit(fun, jit_info)

View File

@ -734,12 +734,22 @@ class Traced(Stage):
def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
_private_parameters: mlir.LoweringParameters | None = None):
from jax._src.interpreters import pxla
from jax._src import pjit
if _private_parameters is None:
_private_parameters = mlir.LoweringParameters()
new_callable = functools.partial(
self._lower_callable, lowering_platforms=lowering_platforms,
lowering_parameters=_private_parameters)
return Lowered(new_callable(), self.args_info, self._out_tree)
try:
lowering = new_callable()
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
msg = pjit._device_assignment_mismatch_error(
self.fun_name, fails, self._args_flat, 'jit', self._arg_names)
raise ValueError(msg) from None
return Lowered(lowering, self.args_info, self._out_tree)
@runtime_checkable