mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Support convolution in roofline.
So far we support only `unfused_hmb_bytes` and don't account for `{feature, batch}_group_count`s due to complexity. PiperOrigin-RevId: 736948528
This commit is contained in:
parent
88d4bc3d45
commit
0c8e601f90
@ -177,6 +177,23 @@ def _dot_general_roofline(
|
||||
unfused_hbm_bytes=hbm_bytes,
|
||||
)
|
||||
|
||||
@roofline.register_roofline(convolution.conv_general_dilated_p)
|
||||
def _conv_general_dilated_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
||||
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
||||
# TODO(b/394648206): support computing unfused_flops for conv.
|
||||
return roofline.RooflineResult(
|
||||
unfused_hbm_bytes=(
|
||||
lhs.dtype.itemsize * lhs.size
|
||||
+ rhs.dtype.itemsize * rhs.size
|
||||
+ out.dtype.itemsize * out.size
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _return_zeros_if_one_sized_axis(
|
||||
ctx: roofline.RooflineRuleContext, axes: tuple[str, ...]
|
||||
|
@ -14,6 +14,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
@ -571,6 +572,130 @@ class RooflineTest(jtu.JaxTestCase):
|
||||
result.unfused_hbm_bytes, self._bytes_per_word * (3 * 7 + 7 * 5 + 3 * 5)
|
||||
)
|
||||
|
||||
def get_conv_output_dim(self, i, k, pad_low, pad_high, stride):
|
||||
return jnp.floor((i - k + pad_low + pad_high) / stride) + 1
|
||||
|
||||
@jtu.parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name="simple",
|
||||
window_strides=(1, 1),
|
||||
padding=((0, 0), (0, 0)),
|
||||
),
|
||||
dict(
|
||||
testcase_name="padding",
|
||||
window_strides=(1, 1),
|
||||
padding=((1, 2), (3, 4)),
|
||||
),
|
||||
dict(
|
||||
testcase_name="window_strides",
|
||||
window_strides=(2, 2),
|
||||
padding=((0, 0), (0, 0)),
|
||||
),
|
||||
dict(
|
||||
testcase_name="window_strides_and_padding",
|
||||
window_strides=(3, 3),
|
||||
padding=((1, 2), (3, 4)),
|
||||
),
|
||||
)
|
||||
def test_conv_general_dilated_unfused_hbm_bytes(
|
||||
self, window_strides: Sequence[int, int], padding: Sequence[int, int]
|
||||
):
|
||||
iw, ih = 100, 200
|
||||
kw, kh = 7, 7
|
||||
input_data = jnp.zeros((1, 1, iw, ih), dtype=int)
|
||||
kernel_data = jnp.ones((1, 1, kw, kh), dtype=int)
|
||||
conv = lambda a, b: lax.conv_general_dilated(
|
||||
lhs=a, rhs=b, window_strides=window_strides, padding=padding
|
||||
)
|
||||
|
||||
_, result = roofline.roofline(
|
||||
conv,
|
||||
mesh=mesh.AbstractMesh((), ()),
|
||||
in_specs=(P(), P()),
|
||||
out_specs=P(),
|
||||
)(input_data, kernel_data)
|
||||
|
||||
expected_input_size = 1 * 1 * iw * ih
|
||||
expected_kernel_size = 1 * 1 * kw * kh
|
||||
|
||||
ow = self.get_conv_output_dim(
|
||||
iw, kw, padding[0][0], padding[0][1], window_strides[0]
|
||||
)
|
||||
oh = self.get_conv_output_dim(
|
||||
ih, kh, padding[1][0], padding[1][1], window_strides[1]
|
||||
)
|
||||
expected_output_size = 1 * 1 * ow * oh
|
||||
# Bytes accessed is sum of inputs and output.
|
||||
expected_unfused_hbm_bytes = self._bytes_per_word * (
|
||||
expected_input_size + expected_kernel_size + expected_output_size
|
||||
)
|
||||
# TODO(b/394648206): add subtest for unfused_flops once they are supported.
|
||||
self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes)
|
||||
|
||||
@jtu.parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name="same",
|
||||
padding="SAME",
|
||||
),
|
||||
dict(
|
||||
testcase_name="same_lower",
|
||||
padding="SAME_LOWER",
|
||||
),
|
||||
)
|
||||
def test_conv_general_dilated_padding_string_unfused_hbm_bytes(self, padding: str):
|
||||
input_data = jnp.zeros((1, 1, 10, 20), dtype=int)
|
||||
kernel_data = jnp.ones((1, 1, 3, 3), dtype=int)
|
||||
conv = lambda a, b: lax.conv_general_dilated(
|
||||
lhs=a, rhs=b, window_strides=(1, 1), padding=padding
|
||||
)
|
||||
|
||||
_, result = roofline.roofline(
|
||||
conv,
|
||||
mesh=mesh.AbstractMesh((), ()),
|
||||
in_specs=(P(), P()),
|
||||
out_specs=P(),
|
||||
)(input_data, kernel_data)
|
||||
|
||||
expected_input_size = 1 * 1 * 10 * 20
|
||||
expected_kernel_size = 1 * 1 * 3 * 3
|
||||
# Because of same{_lower} padding, output shape should equal to input shape.
|
||||
# This may not be true for other `{feature, batch}`_group_count`s.c
|
||||
expected_output_size = expected_input_size
|
||||
# Bytes accessed is sum of inputs and output.
|
||||
expected_unfused_hbm_bytes = self._bytes_per_word * (
|
||||
expected_input_size + expected_kernel_size + expected_output_size
|
||||
)
|
||||
self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes)
|
||||
|
||||
def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self):
|
||||
input_data = jnp.zeros((1, 1, 10, 20), dtype=int)
|
||||
kernel_data = jnp.ones((1, 1, 3, 3), dtype=int)
|
||||
conv = lambda a, b: lax.conv_general_dilated(
|
||||
lhs=a, rhs=b, window_strides=(1, 1), padding="VALID"
|
||||
)
|
||||
|
||||
_, result = roofline.roofline(
|
||||
conv,
|
||||
mesh=mesh.AbstractMesh((), ()),
|
||||
in_specs=(P(), P()),
|
||||
out_specs=P(),
|
||||
)(input_data, kernel_data)
|
||||
|
||||
expected_input_size = 1 * 1 * 10 * 20
|
||||
expected_kernel_size = 1 * 1 * 3 * 3
|
||||
# Valid padding is same as 0 padding.
|
||||
expected_output_size = (
|
||||
1
|
||||
* 1
|
||||
* self.get_conv_output_dim(10, 3, 0, 0, 1)
|
||||
* self.get_conv_output_dim(20, 3, 0, 0, 1)
|
||||
)
|
||||
# Bytes accessed is sum of inputs and output.
|
||||
expected_unfused_hbm_bytes = self._bytes_per_word * (
|
||||
expected_input_size + expected_kernel_size + expected_output_size
|
||||
)
|
||||
self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes)
|
||||
|
||||
def test_reduce_sum_no_axis(self):
|
||||
_, result = roofline.roofline(
|
||||
lambda x: jnp.sum(x),
|
||||
|
Loading…
x
Reference in New Issue
Block a user