Be stricter when it comes to handling dtypes in splash_attention mask function

We previously took a logical_and of a mix of boolean and integer inputs, which isn't allowed
under some of the strict dtype modes. This has been causing some JAX tests to fail.

PiperOrigin-RevId: 647669850
This commit is contained in:
Adam Paszke 2024-06-28 07:22:25 -07:00 committed by jax authors
parent 44071f8595
commit 648b9519cf

View File

@ -604,7 +604,7 @@ def _apply_mask_and_soft_cap(
mask = pl.load(mask_ref, (k_slice, slice(None)))
snm = jnp.where(should_not_mask, 1, 0)
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)))
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
if mask_function is not None:
# Compute the mask using the given q_sequence indices.
@ -634,7 +634,13 @@ def _apply_mask_and_soft_cap(
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))
assert q_sequence.shape == k_sequence.shape
masks.append(mask_function(q_sequence, k_sequence)) # pytype: disable=wrong-arg-count
computed_mask = mask_function(q_sequence, k_sequence) # pytype: disable=wrong-arg-count
if computed_mask.dtype != jnp.dtype(jnp.bool_):
raise ValueError(
"Mask function must return a boolean-valued array, but got:"
f" {computed_mask.dtype}"
)
masks.append(computed_mask)
if q_segment_ids_ref is not None:
if k_in_lanes: