Make error_test a jax_test so that we can test other configs and fix it with jit/pjit merge.

PiperOrigin-RevId: 502743523
This commit is contained in:
Yash Katariya 2023-01-17 18:42:21 -08:00 committed by jax authors
parent b58dd3cbe1
commit 05e1ddd4ea
3 changed files with 15 additions and 6 deletions

View File

@ -544,8 +544,9 @@ def _shaped_abstractify_slow(x):
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True)
else:
raise ValueError(f"Cannot interpret value of type {type(x)} as an abstract array; "
"it does not have a dtype attribute")
raise TypeError(
f"Cannot interpret value of type {type(x)} as an abstract array; it "
"does not have a dtype attribute")
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
named_shape=named_shape)

View File

@ -34,6 +34,8 @@ from jax._src import array
from jax._src.config import config
from jax._src import dispatch
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.traceback_util import api_boundary
from jax._src.api_util import (argnums_partial_except, flatten_axes,
flatten_fun, flatten_fun_nokwargs,
donation_vector, shaped_abstractify,
@ -61,6 +63,8 @@ from jax._src.util import (
distributed_debug_log, split_list, tuple_insert, weakref_lru_cache,
merge_lists)
traceback_util.register_exclusion(__file__)
class _FromGdaSingleton:
pass
FROM_GDA = _FromGdaSingleton()
@ -116,6 +120,7 @@ def _python_pjit_helper(infer_params_fn, *args, **kwargs):
def _python_pjit(fun: Callable, infer_params_fn):
@wraps(fun)
@api_boundary
def wrapped(*args, **kwargs):
return _python_pjit_helper(infer_params_fn, *args, **kwargs)[0]
@ -136,6 +141,7 @@ def _read_most_recent_pjit_call_executable():
def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames):
@api_boundary
def cache_miss(*args, **kwargs):
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
infer_params_fn, *args, **kwargs)
@ -238,6 +244,7 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
else:
wrapped = _python_pjit(fun, infer_params_fn)
@api_boundary
def lower(*args, **kwargs):
(args_flat, flat_local_in_avals, params, in_tree, out_tree,
donate_argnums) = infer_params_fn(*args, **kwargs)

View File

@ -111,12 +111,13 @@ jax_test(
],
)
py_test(
jax_test(
name = "errors_test",
srcs = ["errors_test.py"],
deps = [
"//jax",
"//jax:test_util",
# No need to test all other configs.
enable_configs = [
"cpu",
"cpu_jit_pjit_api_merge",
],
)