Fix mypy error

This commit is contained in:
Jake VanderPlas 2024-07-07 19:45:50 -07:00
parent ceea8dc3fe
commit f851099649

View File

@ -980,7 +980,7 @@ def dot_product_attention(query: Array,
large_negative_number = get_large_negative_number(query.dtype)
mask = jnp.where(mask, jnp.asarray(0, query.dtype), large_negative_number)
# reshape mask to have 4D shape
mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape)
mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # type: ignore[union-attr]
# combine bias and mask
if bias is None: