From 5564521308bf240ce9835037389662bc62cc7dcd Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Mon, 8 Apr 2024 13:08:24 +0530 Subject: [PATCH] Prefer raising of TypeError for invalid types instead of ValueError. --- jax/_src/api.py | 4 ++-- jax/_src/checkify.py | 2 +- jax/_src/config.py | 2 +- jax/_src/layout.py | 4 ++-- jax/experimental/sparse/bcoo.py | 8 ++++---- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 084367f25..12610b011 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2559,7 +2559,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # # TODO(jakevdp): provide a default for devices that considers both local # devices and pods if not isinstance(shards, Sequence): - raise ValueError("device_put_sharded `shards` input must be a sequence; " + raise TypeError("device_put_sharded `shards` input must be a sequence; " f"got {type(shards)}") if len(shards) != len(devices): raise ValueError(f"len(shards) = {len(shards)} must equal " @@ -2911,7 +2911,7 @@ def named_scope( ... return jax.nn.relu(logits) """ if not isinstance(name, str): - raise ValueError("named_scope name argument must be a string.") + raise TypeError("named_scope name argument must be a string.") with source_info_util.extend_name_stack(name): yield diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 69e3dd158..774a1591b 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -1300,6 +1300,6 @@ def check_error(error: Error) -> None: >>> error, _ = checkify.checkify(with_inner_jit)(-1) """ if not isinstance(error, Error): - raise ValueError('check_error takes an Error as argument, ' + raise TypeError('check_error takes an Error as argument, ' f'got type {type(error)} instead.') _check_error(error, debug=False) diff --git a/jax/_src/config.py b/jax/_src/config.py index fa5305262..03df5f231 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -628,7 +628,7 @@ def define_string_state( def validator(new_val): if not isinstance(new_val, str): - raise ValueError('new string config value must be of type str,' + raise TypeError('new string config value must be of type str,' f' got {new_val} of type {type(new_val)}.') return define_string_or_object_state( diff --git a/jax/_src/layout.py b/jax/_src/layout.py index bca436bc7..2071794a0 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -72,7 +72,7 @@ class Layout: ) if not isinstance( device_local_layout, (DeviceLocalLayout, type(None), AutoLayout)): - raise ValueError( + raise TypeError( 'Invalid value received for the device_local_layout argument.' ' Expected values are `None`, `DeviceLocalLayout.AUTO` or an' f' instance of `DeviceLocalLayout`. Got {device_local_layout} of' @@ -80,7 +80,7 @@ class Layout: ) if not isinstance( sharding, (Sharding, type(None), AutoSharding)): - raise ValueError( + raise TypeError( 'Invalid value received for the sharding argument. Expected values' ' are `None`, `pjit.AUTO` or an instance of `jax.Sharding`. Got' f' {sharding} of type {type(sharding)}') diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index b6572b01c..2bc977e7a 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -385,7 +385,7 @@ def bcoo_extract(sparr: BCOO, arr: ArrayLike, *, assume_unique: bool | None = No extracted : a BCOO array with the same sparsity pattern as self. """ if not isinstance(sparr, BCOO): - raise ValueError(f"First argument to bcoo_extract should be a BCOO array. Got {type(sparr)=}") + raise TypeError(f"First argument to bcoo_extract should be a BCOO array. Got {type(sparr)=}") a = jnp.asarray(arr) if a.shape != sparr.shape: raise ValueError(f"shape mismatch: {sparr.shape=} {a.shape=}") @@ -1951,7 +1951,7 @@ def bcoo_slice(mat: BCOO, *, start_indices: Sequence[int], limit_indices: Sequen out: BCOO array containing the slice. """ if not isinstance(mat, BCOO): - raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}") + raise TypeError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}") start_indices = [operator.index(i) for i in start_indices] limit_indices = [operator.index(i) for i in limit_indices] if strides is not None: @@ -2030,7 +2030,7 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq jax.eval_shape(partial(lax.dynamic_slice, slice_sizes=slice_sizes), jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices) if not isinstance(mat, BCOO): - raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}") + raise TypeError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}") start_indices = tuple(jnp.asarray(i) for i in start_indices) assert all(jnp.issubdtype(i.dtype, np.integer) for i in start_indices) assert all(i.shape == () for i in start_indices) @@ -2379,7 +2379,7 @@ def _convert_to_1d_for_conv(mat, index_dtype): # zero-out data at OOB indices, otherwise strange things happen. data = jnp.where(lax.squeeze(indices, (1,)) < mat.shape[-1], data, 0) else: - raise ValueError(f"bcoo_conv_general_dilated: input of type {type(mat)} not recognized.") + raise TypeError(f"bcoo_conv_general_dilated: input of type {type(mat)} not recognized.") return BCOO((data, indices), shape=mat.shape[2:]) def _bcoo_conv_1d(lhs: BCOO, rhs: BCOO, padding: Sequence[int]) -> BCOO: