diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index 74edcb9cd..1edd1e064 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -177,6 +177,23 @@ def _dot_general_roofline( unfused_hbm_bytes=hbm_bytes, ) +@roofline.register_roofline(convolution.conv_general_dilated_p) +def _conv_general_dilated_roofline( + ctx: roofline.RooflineRuleContext, + *args, + **kw, +) -> roofline.RooflineResult: + lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + # TODO(b/394648206): support computing unfused_flops for conv. + return roofline.RooflineResult( + unfused_hbm_bytes=( + lhs.dtype.itemsize * lhs.size + + rhs.dtype.itemsize * rhs.size + + out.dtype.itemsize * out.size + ), + ) + def _return_zeros_if_one_sized_axis( ctx: roofline.RooflineRuleContext, axes: tuple[str, ...] diff --git a/tests/roofline_test.py b/tests/roofline_test.py index a8c003321..564b4a9a1 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -14,6 +14,7 @@ from __future__ import annotations from functools import partial +from typing import Sequence from absl.testing import absltest import jax @@ -571,6 +572,130 @@ class RooflineTest(jtu.JaxTestCase): result.unfused_hbm_bytes, self._bytes_per_word * (3 * 7 + 7 * 5 + 3 * 5) ) + def get_conv_output_dim(self, i, k, pad_low, pad_high, stride): + return jnp.floor((i - k + pad_low + pad_high) / stride) + 1 + + @jtu.parameterized.named_parameters( + dict( + testcase_name="simple", + window_strides=(1, 1), + padding=((0, 0), (0, 0)), + ), + dict( + testcase_name="padding", + window_strides=(1, 1), + padding=((1, 2), (3, 4)), + ), + dict( + testcase_name="window_strides", + window_strides=(2, 2), + padding=((0, 0), (0, 0)), + ), + dict( + testcase_name="window_strides_and_padding", + window_strides=(3, 3), + padding=((1, 2), (3, 4)), + ), + ) + def test_conv_general_dilated_unfused_hbm_bytes( + self, window_strides: Sequence[int, int], padding: Sequence[int, int] + ): + iw, ih = 100, 200 + kw, kh = 7, 7 + input_data = jnp.zeros((1, 1, iw, ih), dtype=int) + kernel_data = jnp.ones((1, 1, kw, kh), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, rhs=b, window_strides=window_strides, padding=padding + ) + + _, result = roofline.roofline( + conv, + mesh=mesh.AbstractMesh((), ()), + in_specs=(P(), P()), + out_specs=P(), + )(input_data, kernel_data) + + expected_input_size = 1 * 1 * iw * ih + expected_kernel_size = 1 * 1 * kw * kh + + ow = self.get_conv_output_dim( + iw, kw, padding[0][0], padding[0][1], window_strides[0] + ) + oh = self.get_conv_output_dim( + ih, kh, padding[1][0], padding[1][1], window_strides[1] + ) + expected_output_size = 1 * 1 * ow * oh + # Bytes accessed is sum of inputs and output. + expected_unfused_hbm_bytes = self._bytes_per_word * ( + expected_input_size + expected_kernel_size + expected_output_size + ) + # TODO(b/394648206): add subtest for unfused_flops once they are supported. + self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) + + @jtu.parameterized.named_parameters( + dict( + testcase_name="same", + padding="SAME", + ), + dict( + testcase_name="same_lower", + padding="SAME_LOWER", + ), + ) + def test_conv_general_dilated_padding_string_unfused_hbm_bytes(self, padding: str): + input_data = jnp.zeros((1, 1, 10, 20), dtype=int) + kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, rhs=b, window_strides=(1, 1), padding=padding + ) + + _, result = roofline.roofline( + conv, + mesh=mesh.AbstractMesh((), ()), + in_specs=(P(), P()), + out_specs=P(), + )(input_data, kernel_data) + + expected_input_size = 1 * 1 * 10 * 20 + expected_kernel_size = 1 * 1 * 3 * 3 + # Because of same{_lower} padding, output shape should equal to input shape. + # This may not be true for other `{feature, batch}`_group_count`s.c + expected_output_size = expected_input_size + # Bytes accessed is sum of inputs and output. + expected_unfused_hbm_bytes = self._bytes_per_word * ( + expected_input_size + expected_kernel_size + expected_output_size + ) + self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) + + def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): + input_data = jnp.zeros((1, 1, 10, 20), dtype=int) + kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, rhs=b, window_strides=(1, 1), padding="VALID" + ) + + _, result = roofline.roofline( + conv, + mesh=mesh.AbstractMesh((), ()), + in_specs=(P(), P()), + out_specs=P(), + )(input_data, kernel_data) + + expected_input_size = 1 * 1 * 10 * 20 + expected_kernel_size = 1 * 1 * 3 * 3 + # Valid padding is same as 0 padding. + expected_output_size = ( + 1 + * 1 + * self.get_conv_output_dim(10, 3, 0, 0, 1) + * self.get_conv_output_dim(20, 3, 0, 0, 1) + ) + # Bytes accessed is sum of inputs and output. + expected_unfused_hbm_bytes = self._bytes_per_word * ( + expected_input_size + expected_kernel_size + expected_output_size + ) + self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) + def test_reduce_sum_no_axis(self): _, result = roofline.roofline( lambda x: jnp.sum(x),