mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use approximate cost estimates for flash attention instead of reference XLA estimates.
PiperOrigin-RevId: 691209201
This commit is contained in:
parent
6d8950c04f
commit
249f0101b3
@ -574,28 +574,26 @@ def _fwd_cost_estimate(
|
||||
q: jax.Array,
|
||||
k: jax.Array,
|
||||
v: jax.Array,
|
||||
ab: jax.Array | None,
|
||||
segment_ids: SegmentIds | None,
|
||||
*,
|
||||
causal: bool,
|
||||
sm_scale: jax.Array | None,
|
||||
kernel_inputs_specs,
|
||||
kernel_outputs_specs,
|
||||
) -> pl.CostEstimate | None:
|
||||
full_cost = (
|
||||
mha_reference.lower(
|
||||
q, k, v, ab, segment_ids, causal=causal, sm_scale=sm_scale
|
||||
)
|
||||
.compile()
|
||||
.cost_analysis()
|
||||
)
|
||||
if not full_cost:
|
||||
return None
|
||||
b, h, tq, dqk = q.shape
|
||||
tk = k.shape[-2]
|
||||
dv = v.shape[-1]
|
||||
|
||||
# Simplify flop computation to include only matmul operations.
|
||||
qk_flops = 2 * tq * tk * dqk
|
||||
av_flops = 2 * tq * tk * dv
|
||||
per_head_flops = qk_flops + av_flops
|
||||
flops = b * h * per_head_flops
|
||||
|
||||
transcendentals = b * tq * tk * h
|
||||
input_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_inputs_specs))
|
||||
output_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_outputs_specs))
|
||||
return pl.CostEstimate(
|
||||
flops=full_cost[0]["flops"],
|
||||
transcendentals=full_cost[0]["transcendentals"],
|
||||
flops=flops,
|
||||
transcendentals=transcendentals,
|
||||
bytes_accessed=input_bytes + output_bytes,
|
||||
)
|
||||
|
||||
@ -792,10 +790,6 @@ def _flash_attention_impl(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
ab,
|
||||
segment_ids,
|
||||
causal=causal,
|
||||
sm_scale=sm_scale,
|
||||
kernel_inputs_specs=(q, k, v, ab, q_segment_ids, kv_segment_ids),
|
||||
kernel_outputs_specs=out_shape,
|
||||
),
|
||||
|
Loading…
x
Reference in New Issue
Block a user