mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Prefer raising of TypeError for invalid types instead of ValueError.
This commit is contained in:
parent
f2fc72d647
commit
5564521308
@ -2559,7 +2559,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
# TODO(jakevdp): provide a default for devices that considers both local
|
||||
# devices and pods
|
||||
if not isinstance(shards, Sequence):
|
||||
raise ValueError("device_put_sharded `shards` input must be a sequence; "
|
||||
raise TypeError("device_put_sharded `shards` input must be a sequence; "
|
||||
f"got {type(shards)}")
|
||||
if len(shards) != len(devices):
|
||||
raise ValueError(f"len(shards) = {len(shards)} must equal "
|
||||
@ -2911,7 +2911,7 @@ def named_scope(
|
||||
... return jax.nn.relu(logits)
|
||||
"""
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("named_scope name argument must be a string.")
|
||||
raise TypeError("named_scope name argument must be a string.")
|
||||
with source_info_util.extend_name_stack(name):
|
||||
yield
|
||||
|
||||
|
@ -1300,6 +1300,6 @@ def check_error(error: Error) -> None:
|
||||
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
|
||||
"""
|
||||
if not isinstance(error, Error):
|
||||
raise ValueError('check_error takes an Error as argument, '
|
||||
raise TypeError('check_error takes an Error as argument, '
|
||||
f'got type {type(error)} instead.')
|
||||
_check_error(error, debug=False)
|
||||
|
@ -628,7 +628,7 @@ def define_string_state(
|
||||
|
||||
def validator(new_val):
|
||||
if not isinstance(new_val, str):
|
||||
raise ValueError('new string config value must be of type str,'
|
||||
raise TypeError('new string config value must be of type str,'
|
||||
f' got {new_val} of type {type(new_val)}.')
|
||||
|
||||
return define_string_or_object_state(
|
||||
|
@ -72,7 +72,7 @@ class Layout:
|
||||
)
|
||||
if not isinstance(
|
||||
device_local_layout, (DeviceLocalLayout, type(None), AutoLayout)):
|
||||
raise ValueError(
|
||||
raise TypeError(
|
||||
'Invalid value received for the device_local_layout argument.'
|
||||
' Expected values are `None`, `DeviceLocalLayout.AUTO` or an'
|
||||
f' instance of `DeviceLocalLayout`. Got {device_local_layout} of'
|
||||
@ -80,7 +80,7 @@ class Layout:
|
||||
)
|
||||
if not isinstance(
|
||||
sharding, (Sharding, type(None), AutoSharding)):
|
||||
raise ValueError(
|
||||
raise TypeError(
|
||||
'Invalid value received for the sharding argument. Expected values'
|
||||
' are `None`, `pjit.AUTO` or an instance of `jax.Sharding`. Got'
|
||||
f' {sharding} of type {type(sharding)}')
|
||||
|
@ -385,7 +385,7 @@ def bcoo_extract(sparr: BCOO, arr: ArrayLike, *, assume_unique: bool | None = No
|
||||
extracted : a BCOO array with the same sparsity pattern as self.
|
||||
"""
|
||||
if not isinstance(sparr, BCOO):
|
||||
raise ValueError(f"First argument to bcoo_extract should be a BCOO array. Got {type(sparr)=}")
|
||||
raise TypeError(f"First argument to bcoo_extract should be a BCOO array. Got {type(sparr)=}")
|
||||
a = jnp.asarray(arr)
|
||||
if a.shape != sparr.shape:
|
||||
raise ValueError(f"shape mismatch: {sparr.shape=} {a.shape=}")
|
||||
@ -1951,7 +1951,7 @@ def bcoo_slice(mat: BCOO, *, start_indices: Sequence[int], limit_indices: Sequen
|
||||
out: BCOO array containing the slice.
|
||||
"""
|
||||
if not isinstance(mat, BCOO):
|
||||
raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
|
||||
raise TypeError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
|
||||
start_indices = [operator.index(i) for i in start_indices]
|
||||
limit_indices = [operator.index(i) for i in limit_indices]
|
||||
if strides is not None:
|
||||
@ -2030,7 +2030,7 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq
|
||||
jax.eval_shape(partial(lax.dynamic_slice, slice_sizes=slice_sizes),
|
||||
jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices)
|
||||
if not isinstance(mat, BCOO):
|
||||
raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
|
||||
raise TypeError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
|
||||
start_indices = tuple(jnp.asarray(i) for i in start_indices)
|
||||
assert all(jnp.issubdtype(i.dtype, np.integer) for i in start_indices)
|
||||
assert all(i.shape == () for i in start_indices)
|
||||
@ -2379,7 +2379,7 @@ def _convert_to_1d_for_conv(mat, index_dtype):
|
||||
# zero-out data at OOB indices, otherwise strange things happen.
|
||||
data = jnp.where(lax.squeeze(indices, (1,)) < mat.shape[-1], data, 0)
|
||||
else:
|
||||
raise ValueError(f"bcoo_conv_general_dilated: input of type {type(mat)} not recognized.")
|
||||
raise TypeError(f"bcoo_conv_general_dilated: input of type {type(mat)} not recognized.")
|
||||
return BCOO((data, indices), shape=mat.shape[2:])
|
||||
|
||||
def _bcoo_conv_1d(lhs: BCOO, rhs: BCOO, padding: Sequence[int]) -> BCOO:
|
||||
|
Loading…
x
Reference in New Issue
Block a user