mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Make error_test
a jax_test so that we can test other configs and fix it with jit
/pjit
merge.
PiperOrigin-RevId: 502743523
This commit is contained in:
parent
b58dd3cbe1
commit
05e1ddd4ea
@ -544,8 +544,9 @@ def _shaped_abstractify_slow(x):
|
||||
if hasattr(x, 'dtype'):
|
||||
dtype = dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True)
|
||||
else:
|
||||
raise ValueError(f"Cannot interpret value of type {type(x)} as an abstract array; "
|
||||
"it does not have a dtype attribute")
|
||||
raise TypeError(
|
||||
f"Cannot interpret value of type {type(x)} as an abstract array; it "
|
||||
"does not have a dtype attribute")
|
||||
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
|
||||
named_shape=named_shape)
|
||||
|
||||
|
@ -34,6 +34,8 @@ from jax._src import array
|
||||
from jax._src.config import config
|
||||
from jax._src import dispatch
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.api_util import (argnums_partial_except, flatten_axes,
|
||||
flatten_fun, flatten_fun_nokwargs,
|
||||
donation_vector, shaped_abstractify,
|
||||
@ -61,6 +63,8 @@ from jax._src.util import (
|
||||
distributed_debug_log, split_list, tuple_insert, weakref_lru_cache,
|
||||
merge_lists)
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
class _FromGdaSingleton:
|
||||
pass
|
||||
FROM_GDA = _FromGdaSingleton()
|
||||
@ -116,6 +120,7 @@ def _python_pjit_helper(infer_params_fn, *args, **kwargs):
|
||||
def _python_pjit(fun: Callable, infer_params_fn):
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def wrapped(*args, **kwargs):
|
||||
return _python_pjit_helper(infer_params_fn, *args, **kwargs)[0]
|
||||
|
||||
@ -136,6 +141,7 @@ def _read_most_recent_pjit_call_executable():
|
||||
|
||||
def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames):
|
||||
|
||||
@api_boundary
|
||||
def cache_miss(*args, **kwargs):
|
||||
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
|
||||
infer_params_fn, *args, **kwargs)
|
||||
@ -238,6 +244,7 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
else:
|
||||
wrapped = _python_pjit(fun, infer_params_fn)
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs):
|
||||
(args_flat, flat_local_in_avals, params, in_tree, out_tree,
|
||||
donate_argnums) = infer_params_fn(*args, **kwargs)
|
||||
|
@ -111,12 +111,13 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
jax_test(
|
||||
name = "errors_test",
|
||||
srcs = ["errors_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
# No need to test all other configs.
|
||||
enable_configs = [
|
||||
"cpu",
|
||||
"cpu_jit_pjit_api_merge",
|
||||
],
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user