Fix error message in dot_product_attention

PiperOrigin-RevId: 660960409
This commit is contained in:
Gleb Pobudzey 2024-08-08 13:29:38 -07:00 committed by jax authors
parent d8eafc8ee3
commit d28d14917e

View File

@ -937,11 +937,11 @@ def dot_product_attention(
_check_has_shape(value_arr, [B, S, K, H], 'value')
_check_has_shape(query_arr, [B, -1, -1, H], 'query')
if query_arr.shape[-2] % K != 0:
raise ValueError(f"The number of query heads must to a multiple of "
raise ValueError(f"The number of query heads must be a multiple of "
f"key/value heads, but got {query_arr.shape[-2]} vs {K}")
if not (query_arr.dtype == key_arr.dtype == value_arr.dtype):
raise ValueError(f"query/key/value should have the same shape, but got "
f"{query_arr.shape} vs {key_arr.shape} vs {value_arr.shape}.")
raise ValueError(f"query/key/value should have the same dtype, but got "
f"{query_arr.dtype} vs {key_arr.dtype} vs {value_arr.dtype}.")
if mask is not None and mask.dtype != jnp.bool_ and mask.ndim != 4:
raise ValueError(f"Mask must be a 4D boolean tensor, but got "
f"rank={mask.ndim}, dtype={mask.dtype}.")