Add tests for varying {batch, feature}_group_counts for roofline conv.

We'll need to use batch/feature when calculating flops, so it'll help reduce the size of the "calculating-flops" change if we can include them in our tests now.

PiperOrigin-RevId: 739081930
This commit is contained in:
Zac Mustin 2025-03-21 00:35:54 -07:00 committed by jax authors
parent c7d6b653ce
commit 7953e6d0f8

View File

@ -572,40 +572,63 @@ class RooflineTest(jtu.JaxTestCase):
result.unfused_hbm_bytes, self._bytes_per_word * (3 * 7 + 7 * 5 + 3 * 5)
)
def get_conv_output_dim(self, i, k, pad_low, pad_high, stride):
def get_conv_output_dim(self, i, k, pad_low, pad_high, stride) -> int:
return jnp.floor((i - k + pad_low + pad_high) / stride) + 1
@jtu.parameterized.named_parameters(
dict(
testcase_name="simple",
window_strides=(1, 1),
padding=((0, 0), (0, 0)),
),
dict(
testcase_name="padding",
window_strides=(1, 1),
padding=((1, 2), (3, 4)),
),
dict(
testcase_name="window_strides",
window_strides=(2, 2),
padding=((0, 0), (0, 0)),
),
dict(
testcase_name="window_strides_and_padding",
window_strides=(3, 3),
padding=((1, 2), (3, 4)),
),
def get_conv_num_output_channels(
self, batch_group_count: int, feature_group_count: int
) -> int:
if batch_group_count > 1:
return batch_group_count
elif feature_group_count > 1:
return feature_group_count
else:
return 1
@jtu.parameterized.product(
window_strides=[(1, 1), (2, 2)],
padding=[((0, 0), (0, 0)), ((1, 2), (3, 4))],
# batch must be divisible by batch_group_count, so we only include factors
# of batch_group_count.
batch=[6, 12],
batch_group_count=[1, 3],
# num_input_channels must be divisible by feature_group_count, so we only
# include factors of feature_group_count.
num_input_channels=[6, 12],
feature_group_count=[1, 3],
)
def test_conv_general_dilated_unfused_hbm_bytes(
self, window_strides: Sequence[int, int], padding: Sequence[int, int]
self,
window_strides: Sequence[int, int],
padding: Sequence[int, int],
batch: int,
batch_group_count: int,
num_input_channels: int,
feature_group_count: int,
):
if batch_group_count > 1 and feature_group_count > 1:
self.skipTest(
"batch_group_count and feature_group_count cannot both be > 1"
)
num_output_channels = self.get_conv_num_output_channels(
batch_group_count, feature_group_count
)
num_input_features = int(num_input_channels / feature_group_count)
iw, ih = 100, 200
kw, kh = 7, 7
input_data = jnp.zeros((1, 1, iw, ih), dtype=int)
kernel_data = jnp.ones((1, 1, kw, kh), dtype=int)
input_data = jnp.zeros((batch, num_input_channels, iw, ih), dtype=int)
kernel_data = jnp.ones(
(num_output_channels, num_input_features, kw, kh), dtype=int
)
conv = lambda a, b: lax.conv_general_dilated(
lhs=a, rhs=b, window_strides=window_strides, padding=padding
lhs=a,
rhs=b,
window_strides=window_strides,
padding=padding,
batch_group_count=batch_group_count,
feature_group_count=feature_group_count,
)
_, result = roofline.roofline(
@ -615,8 +638,8 @@ class RooflineTest(jtu.JaxTestCase):
out_specs=P(),
)(input_data, kernel_data)
expected_input_size = 1 * 1 * iw * ih
expected_kernel_size = 1 * 1 * kw * kh
expected_input_size = batch * num_input_channels * iw * ih
expected_kernel_size = num_output_channels * num_input_features * kw * kh
ow = self.get_conv_output_dim(
iw, kw, padding[0][0], padding[0][1], window_strides[0]
@ -624,7 +647,10 @@ class RooflineTest(jtu.JaxTestCase):
oh = self.get_conv_output_dim(
ih, kh, padding[1][0], padding[1][1], window_strides[1]
)
expected_output_size = 1 * 1 * ow * oh
expected_output_shape = jnp.array(
(batch / batch_group_count, num_output_channels, ow, oh)
)
expected_output_size = jnp.prod((expected_output_shape))
# Bytes accessed is sum of inputs and output.
expected_unfused_hbm_bytes = self._bytes_per_word * (
expected_input_size + expected_kernel_size + expected_output_size
@ -642,7 +668,9 @@ class RooflineTest(jtu.JaxTestCase):
padding="SAME_LOWER",
),
)
def test_conv_general_dilated_padding_string_unfused_hbm_bytes(self, padding: str):
def test_conv_general_dilated_padding_string_unfused_hbm_bytes(
self, padding: str
):
input_data = jnp.zeros((1, 1, 10, 20), dtype=int)
kernel_data = jnp.ones((1, 1, 3, 3), dtype=int)
conv = lambda a, b: lax.conv_general_dilated(