Raise a better error message when an invalid input is passed to jit call.

Before:

```
TypeError: Argument 'ShapeDtypeStruct(shape=(4, 2), dtype=int32)' of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type.

```

After:

```
TypeError: Argument 'x['b']['c']' of shape int32[4,2] of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type.

```

The error is raised deep down the stack during `shard_arg`, so we raise an `InvalidInputException` and catch it in `_python_pjit_helper` where we have the `arg_names` information.

PiperOrigin-RevId: 618014044
This commit is contained in:
Yash Katariya 2024-03-21 17:45:44 -07:00 committed by jax authors
parent 7f7e0c00df
commit d57bb8c748
3 changed files with 35 additions and 4 deletions

View File

@ -110,8 +110,6 @@ def tuple_sharding_proto(elems):
return proto
### handlers
# JAX abstract values -> XLA shapes
@ -132,6 +130,10 @@ _xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
# IR constants
class InvalidInputException(Exception):
pass
# TODO(mattjj): try to remove this canonicalize_dtype stuff
def canonicalize_dtype(x):
typ = type(x)
@ -142,8 +144,8 @@ def canonicalize_dtype(x):
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return canonicalize_dtype(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid "
"JAX type.")
raise InvalidInputException(
f"Argument '{x}' of type {type(x)} is not a valid JAX type.")
def _canonicalize_masked_array_dtype(x):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "

View File

@ -167,9 +167,11 @@ def _python_pjit_helper(jit_info, *args, **kwargs):
_infer_params(jit_info, args, kwargs)
for arg in args_flat:
dispatch.check_arg(arg)
if attrs_tracked:
init_states = _get_states(attrs_tracked)
args_flat = [*init_states, *args_flat]
try:
out_flat = pjit_p.bind(*args_flat, **params)
except pxla.DeviceAssignmentMismatchError as e:
@ -180,12 +182,29 @@ def _python_pjit_helper(jit_info, *args, **kwargs):
msg = _device_assignment_mismatch_error(
fun_name, fails, args_flat, api_name, arg_names)
raise ValueError(msg) from None
except xla.InvalidInputException as e:
arg_names = [''] * len(args_flat) if arg_names is None else arg_names
# Run canonicalization again to figure out which arg failed.
if params['jaxpr'].consts:
raise TypeError(e.args[0]) from e
else:
for arg, name, aval in zip(args_flat, arg_names, params['jaxpr'].in_avals):
try:
xla.canonicalize_dtype(arg)
except xla.InvalidInputException as _:
# Reraise as TypeError with the new message.
raise TypeError(
f"Argument '{name}' of shape {aval.str_short()} of type"
f' {type(arg)} is not a valid JAX type.') from e
raise AssertionError("Unreachable") from e
if attrs_tracked:
final_states, out_flat = split_list(out_flat, [len(attrs_tracked)])
_set_states(attrs_tracked, final_states)
outs = tree_unflatten(out_tree, out_flat)
return outs, out_flat, out_tree, args_flat, params['jaxpr'], attrs_tracked
def _set_states(attrs_tracked, vals):
from jax.experimental.attrs import jax_setattr # type: ignore
for ((obj, attr), val) in zip(attrs_tracked, vals):

View File

@ -2927,6 +2927,16 @@ class ArrayPjitTest(jtu.JaxTestCase):
'out_shardings should not be specified.'):
pjit(lambda x: x, out_shardings=s, device=jax.devices()[0])
def test_check_arg_error(self):
sds = jax.ShapeDtypeStruct((4, 2), np.int32)
inp = np.arange(8).reshape(4, 2)
with self.assertRaisesRegex(
TypeError,
r"Argument 'x\['b'\]\['c'\]' of shape int32\[4,2\] of "
"type.*ShapeDtypeStruct.*is not a valid JAX type."):
jax.jit(lambda x: x)({'a': inp, 'b': {'c': sds}})
def test_pjit_device_backend_both_error(self):
with self.assertRaisesRegex(
ValueError, "can't specify both a device and a backend for jit"):