mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix error message in dot_product_attention
PiperOrigin-RevId: 660960409
This commit is contained in:
parent
d8eafc8ee3
commit
d28d14917e
@ -937,11 +937,11 @@ def dot_product_attention(
|
||||
_check_has_shape(value_arr, [B, S, K, H], 'value')
|
||||
_check_has_shape(query_arr, [B, -1, -1, H], 'query')
|
||||
if query_arr.shape[-2] % K != 0:
|
||||
raise ValueError(f"The number of query heads must to a multiple of "
|
||||
raise ValueError(f"The number of query heads must be a multiple of "
|
||||
f"key/value heads, but got {query_arr.shape[-2]} vs {K}")
|
||||
if not (query_arr.dtype == key_arr.dtype == value_arr.dtype):
|
||||
raise ValueError(f"query/key/value should have the same shape, but got "
|
||||
f"{query_arr.shape} vs {key_arr.shape} vs {value_arr.shape}.")
|
||||
raise ValueError(f"query/key/value should have the same dtype, but got "
|
||||
f"{query_arr.dtype} vs {key_arr.dtype} vs {value_arr.dtype}.")
|
||||
if mask is not None and mask.dtype != jnp.bool_ and mask.ndim != 4:
|
||||
raise ValueError(f"Mask must be a 4D boolean tensor, but got "
|
||||
f"rank={mask.ndim}, dtype={mask.dtype}.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user