Support convolution in roofline.

So far we support only `unfused_hmb_bytes` and don't account for `{feature, batch}_group_count`s due to complexity.

PiperOrigin-RevId: 736948528
This commit is contained in:
Zac Mustin 2025-03-14 12:25:35 -07:00 committed by jax authors
parent 88d4bc3d45
commit 0c8e601f90
2 changed files with 142 additions and 0 deletions

View File

@ -177,6 +177,23 @@ def _dot_general_roofline(
unfused_hbm_bytes=hbm_bytes,
)
@roofline.register_roofline(convolution.conv_general_dilated_p)
def _conv_general_dilated_roofline(
ctx: roofline.RooflineRuleContext,
*args,
**kw,
) -> roofline.RooflineResult:
lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
# TODO(b/394648206): support computing unfused_flops for conv.
return roofline.RooflineResult(
unfused_hbm_bytes=(
lhs.dtype.itemsize * lhs.size
+ rhs.dtype.itemsize * rhs.size
+ out.dtype.itemsize * out.size
),
)
def _return_zeros_if_one_sized_axis(
ctx: roofline.RooflineRuleContext, axes: tuple[str, ...]

View File

@ -14,6 +14,7 @@
from __future__ import annotations
from functools import partial
from typing import Sequence
from absl.testing import absltest
import jax
@ -571,6 +572,130 @@ class RooflineTest(jtu.JaxTestCase):
result.unfused_hbm_bytes, self._bytes_per_word * (3 * 7 + 7 * 5 + 3 * 5)
)
def get_conv_output_dim(self, i, k, pad_low, pad_high, stride):
return jnp.floor((i - k + pad_low + pad_high) / stride) + 1
@jtu.parameterized.named_parameters(
dict(
testcase_name="simple",
window_strides=(1, 1),
padding=((0, 0), (0, 0)),
),
dict(
testcase_name="padding",
window_strides=(1, 1),
padding=((1, 2), (3, 4)),
),
dict(
testcase_name="window_strides",
window_strides=(2, 2),
padding=((0, 0), (0, 0)),
),
dict(
testcase_name="window_strides_and_padding",
window_strides=(3, 3),
padding=((1, 2), (3, 4)),
),
)
def test_conv_general_dilated_unfused_hbm_bytes(
self, window_strides: Sequence[int, int], padding: Sequence[int, int]
):
iw, ih = 100, 200
kw, kh = 7, 7
input_data = jnp.zeros((1, 1, iw, ih), dtype=int)
kernel_data = jnp.ones((1, 1, kw, kh), dtype=int)
conv = lambda a, b: lax.conv_general_dilated(
lhs=a, rhs=b, window_strides=window_strides, padding=padding
)
_, result = roofline.roofline(
conv,
mesh=mesh.AbstractMesh((), ()),
in_specs=(P(), P()),
out_specs=P(),
)(input_data, kernel_data)
expected_input_size = 1 * 1 * iw * ih
expected_kernel_size = 1 * 1 * kw * kh
ow = self.get_conv_output_dim(
iw, kw, padding[0][0], padding[0][1], window_strides[0]
)
oh = self.get_conv_output_dim(
ih, kh, padding[1][0], padding[1][1], window_strides[1]
)
expected_output_size = 1 * 1 * ow * oh
# Bytes accessed is sum of inputs and output.
expected_unfused_hbm_bytes = self._bytes_per_word * (
expected_input_size + expected_kernel_size + expected_output_size
)
# TODO(b/394648206): add subtest for unfused_flops once they are supported.
self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes)
@jtu.parameterized.named_parameters(
dict(
testcase_name="same",
padding="SAME",
),
dict(
testcase_name="same_lower",
padding="SAME_LOWER",
),
)
def test_conv_general_dilated_padding_string_unfused_hbm_bytes(self, padding: str):
input_data = jnp.zeros((1, 1, 10, 20), dtype=int)
kernel_data = jnp.ones((1, 1, 3, 3), dtype=int)
conv = lambda a, b: lax.conv_general_dilated(
lhs=a, rhs=b, window_strides=(1, 1), padding=padding
)
_, result = roofline.roofline(
conv,
mesh=mesh.AbstractMesh((), ()),
in_specs=(P(), P()),
out_specs=P(),
)(input_data, kernel_data)
expected_input_size = 1 * 1 * 10 * 20
expected_kernel_size = 1 * 1 * 3 * 3
# Because of same{_lower} padding, output shape should equal to input shape.
# This may not be true for other `{feature, batch}`_group_count`s.c
expected_output_size = expected_input_size
# Bytes accessed is sum of inputs and output.
expected_unfused_hbm_bytes = self._bytes_per_word * (
expected_input_size + expected_kernel_size + expected_output_size
)
self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes)
def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self):
input_data = jnp.zeros((1, 1, 10, 20), dtype=int)
kernel_data = jnp.ones((1, 1, 3, 3), dtype=int)
conv = lambda a, b: lax.conv_general_dilated(
lhs=a, rhs=b, window_strides=(1, 1), padding="VALID"
)
_, result = roofline.roofline(
conv,
mesh=mesh.AbstractMesh((), ()),
in_specs=(P(), P()),
out_specs=P(),
)(input_data, kernel_data)
expected_input_size = 1 * 1 * 10 * 20
expected_kernel_size = 1 * 1 * 3 * 3
# Valid padding is same as 0 padding.
expected_output_size = (
1
* 1
* self.get_conv_output_dim(10, 3, 0, 0, 1)
* self.get_conv_output_dim(20, 3, 0, 0, 1)
)
# Bytes accessed is sum of inputs and output.
expected_unfused_hbm_bytes = self._bytes_per_word * (
expected_input_size + expected_kernel_size + expected_output_size
)
self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes)
def test_reduce_sum_no_axis(self):
_, result = roofline.roofline(
lambda x: jnp.sum(x),