mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add tests for varying {batch, feature}_group_count
s for roofline conv
.
We'll need to use batch/feature when calculating flops, so it'll help reduce the size of the "calculating-flops" change if we can include them in our tests now. PiperOrigin-RevId: 739081930
This commit is contained in:
parent
c7d6b653ce
commit
7953e6d0f8
@ -572,40 +572,63 @@ class RooflineTest(jtu.JaxTestCase):
|
||||
result.unfused_hbm_bytes, self._bytes_per_word * (3 * 7 + 7 * 5 + 3 * 5)
|
||||
)
|
||||
|
||||
def get_conv_output_dim(self, i, k, pad_low, pad_high, stride):
|
||||
def get_conv_output_dim(self, i, k, pad_low, pad_high, stride) -> int:
|
||||
return jnp.floor((i - k + pad_low + pad_high) / stride) + 1
|
||||
|
||||
@jtu.parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name="simple",
|
||||
window_strides=(1, 1),
|
||||
padding=((0, 0), (0, 0)),
|
||||
),
|
||||
dict(
|
||||
testcase_name="padding",
|
||||
window_strides=(1, 1),
|
||||
padding=((1, 2), (3, 4)),
|
||||
),
|
||||
dict(
|
||||
testcase_name="window_strides",
|
||||
window_strides=(2, 2),
|
||||
padding=((0, 0), (0, 0)),
|
||||
),
|
||||
dict(
|
||||
testcase_name="window_strides_and_padding",
|
||||
window_strides=(3, 3),
|
||||
padding=((1, 2), (3, 4)),
|
||||
),
|
||||
def get_conv_num_output_channels(
|
||||
self, batch_group_count: int, feature_group_count: int
|
||||
) -> int:
|
||||
if batch_group_count > 1:
|
||||
return batch_group_count
|
||||
elif feature_group_count > 1:
|
||||
return feature_group_count
|
||||
else:
|
||||
return 1
|
||||
|
||||
@jtu.parameterized.product(
|
||||
window_strides=[(1, 1), (2, 2)],
|
||||
padding=[((0, 0), (0, 0)), ((1, 2), (3, 4))],
|
||||
# batch must be divisible by batch_group_count, so we only include factors
|
||||
# of batch_group_count.
|
||||
batch=[6, 12],
|
||||
batch_group_count=[1, 3],
|
||||
# num_input_channels must be divisible by feature_group_count, so we only
|
||||
# include factors of feature_group_count.
|
||||
num_input_channels=[6, 12],
|
||||
feature_group_count=[1, 3],
|
||||
)
|
||||
def test_conv_general_dilated_unfused_hbm_bytes(
|
||||
self, window_strides: Sequence[int, int], padding: Sequence[int, int]
|
||||
self,
|
||||
window_strides: Sequence[int, int],
|
||||
padding: Sequence[int, int],
|
||||
batch: int,
|
||||
batch_group_count: int,
|
||||
num_input_channels: int,
|
||||
feature_group_count: int,
|
||||
):
|
||||
if batch_group_count > 1 and feature_group_count > 1:
|
||||
self.skipTest(
|
||||
"batch_group_count and feature_group_count cannot both be > 1"
|
||||
)
|
||||
|
||||
num_output_channels = self.get_conv_num_output_channels(
|
||||
batch_group_count, feature_group_count
|
||||
)
|
||||
|
||||
num_input_features = int(num_input_channels / feature_group_count)
|
||||
iw, ih = 100, 200
|
||||
kw, kh = 7, 7
|
||||
input_data = jnp.zeros((1, 1, iw, ih), dtype=int)
|
||||
kernel_data = jnp.ones((1, 1, kw, kh), dtype=int)
|
||||
input_data = jnp.zeros((batch, num_input_channels, iw, ih), dtype=int)
|
||||
kernel_data = jnp.ones(
|
||||
(num_output_channels, num_input_features, kw, kh), dtype=int
|
||||
)
|
||||
conv = lambda a, b: lax.conv_general_dilated(
|
||||
lhs=a, rhs=b, window_strides=window_strides, padding=padding
|
||||
lhs=a,
|
||||
rhs=b,
|
||||
window_strides=window_strides,
|
||||
padding=padding,
|
||||
batch_group_count=batch_group_count,
|
||||
feature_group_count=feature_group_count,
|
||||
)
|
||||
|
||||
_, result = roofline.roofline(
|
||||
@ -615,8 +638,8 @@ class RooflineTest(jtu.JaxTestCase):
|
||||
out_specs=P(),
|
||||
)(input_data, kernel_data)
|
||||
|
||||
expected_input_size = 1 * 1 * iw * ih
|
||||
expected_kernel_size = 1 * 1 * kw * kh
|
||||
expected_input_size = batch * num_input_channels * iw * ih
|
||||
expected_kernel_size = num_output_channels * num_input_features * kw * kh
|
||||
|
||||
ow = self.get_conv_output_dim(
|
||||
iw, kw, padding[0][0], padding[0][1], window_strides[0]
|
||||
@ -624,7 +647,10 @@ class RooflineTest(jtu.JaxTestCase):
|
||||
oh = self.get_conv_output_dim(
|
||||
ih, kh, padding[1][0], padding[1][1], window_strides[1]
|
||||
)
|
||||
expected_output_size = 1 * 1 * ow * oh
|
||||
expected_output_shape = jnp.array(
|
||||
(batch / batch_group_count, num_output_channels, ow, oh)
|
||||
)
|
||||
expected_output_size = jnp.prod((expected_output_shape))
|
||||
# Bytes accessed is sum of inputs and output.
|
||||
expected_unfused_hbm_bytes = self._bytes_per_word * (
|
||||
expected_input_size + expected_kernel_size + expected_output_size
|
||||
@ -642,7 +668,9 @@ class RooflineTest(jtu.JaxTestCase):
|
||||
padding="SAME_LOWER",
|
||||
),
|
||||
)
|
||||
def test_conv_general_dilated_padding_string_unfused_hbm_bytes(self, padding: str):
|
||||
def test_conv_general_dilated_padding_string_unfused_hbm_bytes(
|
||||
self, padding: str
|
||||
):
|
||||
input_data = jnp.zeros((1, 1, 10, 20), dtype=int)
|
||||
kernel_data = jnp.ones((1, 1, 3, 3), dtype=int)
|
||||
conv = lambda a, b: lax.conv_general_dilated(
|
||||
|
Loading…
x
Reference in New Issue
Block a user