mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
CI: fix mypy error
This commit is contained in:
parent
06d2e489eb
commit
329ab036ee
@ -986,7 +986,7 @@ def dot_product_attention(query: Array,
|
||||
has_bias = bias is not None
|
||||
has_mask = mask is not None
|
||||
has_dbias = has_bias and is_training and \
|
||||
should_export_dbias(bias.shape, query.shape, layout)
|
||||
should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr]
|
||||
variadic_args = (has_bias, has_mask, has_dbias)
|
||||
if bias is None:
|
||||
bias = jnp.zeros(0, dtype=query.dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user