mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix mypy error
This commit is contained in:
parent
ceea8dc3fe
commit
f851099649
@ -980,7 +980,7 @@ def dot_product_attention(query: Array,
|
||||
large_negative_number = get_large_negative_number(query.dtype)
|
||||
mask = jnp.where(mask, jnp.asarray(0, query.dtype), large_negative_number)
|
||||
# reshape mask to have 4D shape
|
||||
mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape)
|
||||
mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # type: ignore[union-attr]
|
||||
|
||||
# combine bias and mask
|
||||
if bias is None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user