mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Move DeviceAssignmentMismatchError
exception catching code to def lower
method of Traced
so that all libraries calling traced.lower()
see a better error message
PiperOrigin-RevId: 674095608
This commit is contained in:
parent
3d1d5e94ab
commit
634fbb5bec
@ -474,16 +474,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs):
|
||||
traced = trace(*args, **kwargs)
|
||||
try:
|
||||
return traced.lower()
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
fun_name = getattr(fun, '__qualname__',
|
||||
getattr(fun, '__name__', str(fun)))
|
||||
msg = _device_assignment_mismatch_error(
|
||||
fun_name, fails, traced._args_flat, 'jit', traced._arg_names)
|
||||
raise ValueError(msg) from None
|
||||
return trace(*args, **kwargs).lower()
|
||||
|
||||
@api_boundary
|
||||
def eval_shape(*args, **kwargs):
|
||||
@ -503,7 +494,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
|
||||
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
|
||||
pgle_profiler=None)
|
||||
return stages.Traced(
|
||||
p.params['jaxpr'], args_info, p.params["name"],p.out_tree,
|
||||
p.params['jaxpr'], args_info, p.params["name"], p.out_tree,
|
||||
lower_callable, args_flat, p.arg_names, p.num_consts)
|
||||
|
||||
wrapped = _cpp_pjit(fun, jit_info)
|
||||
|
@ -734,12 +734,22 @@ class Traced(Stage):
|
||||
|
||||
def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
|
||||
_private_parameters: mlir.LoweringParameters | None = None):
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src import pjit
|
||||
|
||||
if _private_parameters is None:
|
||||
_private_parameters = mlir.LoweringParameters()
|
||||
new_callable = functools.partial(
|
||||
self._lower_callable, lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=_private_parameters)
|
||||
return Lowered(new_callable(), self.args_info, self._out_tree)
|
||||
try:
|
||||
lowering = new_callable()
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
msg = pjit._device_assignment_mismatch_error(
|
||||
self.fun_name, fails, self._args_flat, 'jit', self._arg_names)
|
||||
raise ValueError(msg) from None
|
||||
return Lowered(lowering, self.args_info, self._out_tree)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
Loading…
x
Reference in New Issue
Block a user