CI: fix mypy error

This commit is contained in:
Jake VanderPlas 2024-05-20 13:23:15 -07:00
parent 06d2e489eb
commit 329ab036ee

View File

@ -986,7 +986,7 @@ def dot_product_attention(query: Array,
has_bias = bias is not None
has_mask = mask is not None
has_dbias = has_bias and is_training and \
should_export_dbias(bias.shape, query.shape, layout)
should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr]
variadic_args = (has_bias, has_mask, has_dbias)
if bias is None:
bias = jnp.zeros(0, dtype=query.dtype)