mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Raise a better error message when an invalid input is passed to jit call.
Before: ``` TypeError: Argument 'ShapeDtypeStruct(shape=(4, 2), dtype=int32)' of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type. ``` After: ``` TypeError: Argument 'x['b']['c']' of shape int32[4,2] of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type. ``` The error is raised deep down the stack during `shard_arg`, so we raise an `InvalidInputException` and catch it in `_python_pjit_helper` where we have the `arg_names` information. PiperOrigin-RevId: 618014044
This commit is contained in:
parent
7f7e0c00df
commit
d57bb8c748
@ -110,8 +110,6 @@ def tuple_sharding_proto(elems):
|
||||
return proto
|
||||
|
||||
|
||||
|
||||
|
||||
### handlers
|
||||
|
||||
# JAX abstract values -> XLA shapes
|
||||
@ -132,6 +130,10 @@ _xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
|
||||
|
||||
# IR constants
|
||||
|
||||
class InvalidInputException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# TODO(mattjj): try to remove this canonicalize_dtype stuff
|
||||
def canonicalize_dtype(x):
|
||||
typ = type(x)
|
||||
@ -142,8 +144,8 @@ def canonicalize_dtype(x):
|
||||
if handler: return handler(x)
|
||||
if hasattr(x, '__jax_array__'):
|
||||
return canonicalize_dtype(x.__jax_array__())
|
||||
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid "
|
||||
"JAX type.")
|
||||
raise InvalidInputException(
|
||||
f"Argument '{x}' of type {type(x)} is not a valid JAX type.")
|
||||
|
||||
def _canonicalize_masked_array_dtype(x):
|
||||
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
|
||||
|
@ -167,9 +167,11 @@ def _python_pjit_helper(jit_info, *args, **kwargs):
|
||||
_infer_params(jit_info, args, kwargs)
|
||||
for arg in args_flat:
|
||||
dispatch.check_arg(arg)
|
||||
|
||||
if attrs_tracked:
|
||||
init_states = _get_states(attrs_tracked)
|
||||
args_flat = [*init_states, *args_flat]
|
||||
|
||||
try:
|
||||
out_flat = pjit_p.bind(*args_flat, **params)
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
@ -180,12 +182,29 @@ def _python_pjit_helper(jit_info, *args, **kwargs):
|
||||
msg = _device_assignment_mismatch_error(
|
||||
fun_name, fails, args_flat, api_name, arg_names)
|
||||
raise ValueError(msg) from None
|
||||
except xla.InvalidInputException as e:
|
||||
arg_names = [''] * len(args_flat) if arg_names is None else arg_names
|
||||
# Run canonicalization again to figure out which arg failed.
|
||||
if params['jaxpr'].consts:
|
||||
raise TypeError(e.args[0]) from e
|
||||
else:
|
||||
for arg, name, aval in zip(args_flat, arg_names, params['jaxpr'].in_avals):
|
||||
try:
|
||||
xla.canonicalize_dtype(arg)
|
||||
except xla.InvalidInputException as _:
|
||||
# Reraise as TypeError with the new message.
|
||||
raise TypeError(
|
||||
f"Argument '{name}' of shape {aval.str_short()} of type"
|
||||
f' {type(arg)} is not a valid JAX type.') from e
|
||||
raise AssertionError("Unreachable") from e
|
||||
|
||||
if attrs_tracked:
|
||||
final_states, out_flat = split_list(out_flat, [len(attrs_tracked)])
|
||||
_set_states(attrs_tracked, final_states)
|
||||
outs = tree_unflatten(out_tree, out_flat)
|
||||
return outs, out_flat, out_tree, args_flat, params['jaxpr'], attrs_tracked
|
||||
|
||||
|
||||
def _set_states(attrs_tracked, vals):
|
||||
from jax.experimental.attrs import jax_setattr # type: ignore
|
||||
for ((obj, attr), val) in zip(attrs_tracked, vals):
|
||||
|
@ -2927,6 +2927,16 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
'out_shardings should not be specified.'):
|
||||
pjit(lambda x: x, out_shardings=s, device=jax.devices()[0])
|
||||
|
||||
def test_check_arg_error(self):
|
||||
sds = jax.ShapeDtypeStruct((4, 2), np.int32)
|
||||
inp = np.arange(8).reshape(4, 2)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"Argument 'x\['b'\]\['c'\]' of shape int32\[4,2\] of "
|
||||
"type.*ShapeDtypeStruct.*is not a valid JAX type."):
|
||||
jax.jit(lambda x: x)({'a': inp, 'b': {'c': sds}})
|
||||
|
||||
def test_pjit_device_backend_both_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "can't specify both a device and a backend for jit"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user