Use approximate cost estimates for flash attention instead of reference XLA estimates.

PiperOrigin-RevId: 691209201
This commit is contained in:
jax authors 2024-10-29 16:52:25 -07:00
parent 6d8950c04f
commit 249f0101b3

View File

@ -574,28 +574,26 @@ def _fwd_cost_estimate(
q: jax.Array,
k: jax.Array,
v: jax.Array,
ab: jax.Array | None,
segment_ids: SegmentIds | None,
*,
causal: bool,
sm_scale: jax.Array | None,
kernel_inputs_specs,
kernel_outputs_specs,
) -> pl.CostEstimate | None:
full_cost = (
mha_reference.lower(
q, k, v, ab, segment_ids, causal=causal, sm_scale=sm_scale
)
.compile()
.cost_analysis()
)
if not full_cost:
return None
b, h, tq, dqk = q.shape
tk = k.shape[-2]
dv = v.shape[-1]
# Simplify flop computation to include only matmul operations.
qk_flops = 2 * tq * tk * dqk
av_flops = 2 * tq * tk * dv
per_head_flops = qk_flops + av_flops
flops = b * h * per_head_flops
transcendentals = b * tq * tk * h
input_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_inputs_specs))
output_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_outputs_specs))
return pl.CostEstimate(
flops=full_cost[0]["flops"],
transcendentals=full_cost[0]["transcendentals"],
flops=flops,
transcendentals=transcendentals,
bytes_accessed=input_bytes + output_bytes,
)
@ -792,10 +790,6 @@ def _flash_attention_impl(
q,
k,
v,
ab,
segment_ids,
causal=causal,
sm_scale=sm_scale,
kernel_inputs_specs=(q, k, v, ab, q_segment_ids, kv_segment_ids),
kernel_outputs_specs=out_shape,
),