Prefer raising of TypeError for invalid types instead of ValueError.

This commit is contained in:
Sai-Suraj-27 2024-04-08 13:08:24 +05:30
parent f2fc72d647
commit 5564521308
5 changed files with 10 additions and 10 deletions

View File

@ -2559,7 +2559,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
# TODO(jakevdp): provide a default for devices that considers both local
# devices and pods
if not isinstance(shards, Sequence):
raise ValueError("device_put_sharded `shards` input must be a sequence; "
raise TypeError("device_put_sharded `shards` input must be a sequence; "
f"got {type(shards)}")
if len(shards) != len(devices):
raise ValueError(f"len(shards) = {len(shards)} must equal "
@ -2911,7 +2911,7 @@ def named_scope(
... return jax.nn.relu(logits)
"""
if not isinstance(name, str):
raise ValueError("named_scope name argument must be a string.")
raise TypeError("named_scope name argument must be a string.")
with source_info_util.extend_name_stack(name):
yield

View File

@ -1300,6 +1300,6 @@ def check_error(error: Error) -> None:
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
"""
if not isinstance(error, Error):
raise ValueError('check_error takes an Error as argument, '
raise TypeError('check_error takes an Error as argument, '
f'got type {type(error)} instead.')
_check_error(error, debug=False)

View File

@ -628,7 +628,7 @@ def define_string_state(
def validator(new_val):
if not isinstance(new_val, str):
raise ValueError('new string config value must be of type str,'
raise TypeError('new string config value must be of type str,'
f' got {new_val} of type {type(new_val)}.')
return define_string_or_object_state(

View File

@ -72,7 +72,7 @@ class Layout:
)
if not isinstance(
device_local_layout, (DeviceLocalLayout, type(None), AutoLayout)):
raise ValueError(
raise TypeError(
'Invalid value received for the device_local_layout argument.'
' Expected values are `None`, `DeviceLocalLayout.AUTO` or an'
f' instance of `DeviceLocalLayout`. Got {device_local_layout} of'
@ -80,7 +80,7 @@ class Layout:
)
if not isinstance(
sharding, (Sharding, type(None), AutoSharding)):
raise ValueError(
raise TypeError(
'Invalid value received for the sharding argument. Expected values'
' are `None`, `pjit.AUTO` or an instance of `jax.Sharding`. Got'
f' {sharding} of type {type(sharding)}')

View File

@ -385,7 +385,7 @@ def bcoo_extract(sparr: BCOO, arr: ArrayLike, *, assume_unique: bool | None = No
extracted : a BCOO array with the same sparsity pattern as self.
"""
if not isinstance(sparr, BCOO):
raise ValueError(f"First argument to bcoo_extract should be a BCOO array. Got {type(sparr)=}")
raise TypeError(f"First argument to bcoo_extract should be a BCOO array. Got {type(sparr)=}")
a = jnp.asarray(arr)
if a.shape != sparr.shape:
raise ValueError(f"shape mismatch: {sparr.shape=} {a.shape=}")
@ -1951,7 +1951,7 @@ def bcoo_slice(mat: BCOO, *, start_indices: Sequence[int], limit_indices: Sequen
out: BCOO array containing the slice.
"""
if not isinstance(mat, BCOO):
raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
raise TypeError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
start_indices = [operator.index(i) for i in start_indices]
limit_indices = [operator.index(i) for i in limit_indices]
if strides is not None:
@ -2030,7 +2030,7 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq
jax.eval_shape(partial(lax.dynamic_slice, slice_sizes=slice_sizes),
jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices)
if not isinstance(mat, BCOO):
raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
raise TypeError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
start_indices = tuple(jnp.asarray(i) for i in start_indices)
assert all(jnp.issubdtype(i.dtype, np.integer) for i in start_indices)
assert all(i.shape == () for i in start_indices)
@ -2379,7 +2379,7 @@ def _convert_to_1d_for_conv(mat, index_dtype):
# zero-out data at OOB indices, otherwise strange things happen.
data = jnp.where(lax.squeeze(indices, (1,)) < mat.shape[-1], data, 0)
else:
raise ValueError(f"bcoo_conv_general_dilated: input of type {type(mat)} not recognized.")
raise TypeError(f"bcoo_conv_general_dilated: input of type {type(mat)} not recognized.")
return BCOO((data, indices), shape=mat.shape[2:])
def _bcoo_conv_1d(lhs: BCOO, rhs: BCOO, padding: Sequence[int]) -> BCOO: