mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Be stricter when it comes to handling dtypes in splash_attention mask function
We previously took a logical_and of a mix of boolean and integer inputs, which isn't allowed under some of the strict dtype modes. This has been causing some JAX tests to fail. PiperOrigin-RevId: 647669850
This commit is contained in:
parent
44071f8595
commit
648b9519cf
@ -604,7 +604,7 @@ def _apply_mask_and_soft_cap(
|
||||
mask = pl.load(mask_ref, (k_slice, slice(None)))
|
||||
|
||||
snm = jnp.where(should_not_mask, 1, 0)
|
||||
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)))
|
||||
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
|
||||
|
||||
if mask_function is not None:
|
||||
# Compute the mask using the given q_sequence indices.
|
||||
@ -634,7 +634,13 @@ def _apply_mask_and_soft_cap(
|
||||
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))
|
||||
|
||||
assert q_sequence.shape == k_sequence.shape
|
||||
masks.append(mask_function(q_sequence, k_sequence)) # pytype: disable=wrong-arg-count
|
||||
computed_mask = mask_function(q_sequence, k_sequence) # pytype: disable=wrong-arg-count
|
||||
if computed_mask.dtype != jnp.dtype(jnp.bool_):
|
||||
raise ValueError(
|
||||
"Mask function must return a boolean-valued array, but got:"
|
||||
f" {computed_mask.dtype}"
|
||||
)
|
||||
masks.append(computed_mask)
|
||||
|
||||
if q_segment_ids_ref is not None:
|
||||
if k_in_lanes:
|
||||
|
Loading…
x
Reference in New Issue
Block a user