[shape_poly] Add another micro-benchmark

A very common operation is to compute linear combinations of expressions.

PiperOrigin-RevId: 605644781
This commit is contained in:
George Necula 2024-02-09 09:02:06 -08:00 committed by jax authors
parent 07d793d7b2
commit 1e7642279b

View File

@ -38,6 +38,21 @@ def builder_arith(state):
e2 = (b // a - a - c % a + 4)
_ = e1 + e2 + (e1 * e2)
@benchmark.register
def builder_linear_arith(state):
a, b, c = export.symbolic_shape("a, b, c")
while state:
left = [a, 3*a, a + 2*b, a + 3*b + 4*c]
right = [b, -1*a, a - 2*b, a - 3*b - 4*c]
for l in left:
for r in right:
for l_k in [1, 2, -2]:
for r_k in [1, 2, -2]:
comb = l * l_k + r * r_k
if not isinstance(comb, int):
_ = comb.leading_term # Ensure we actually materialize
@benchmark.register
def builder_min_max(state):
a, b = export.symbolic_shape("a, b")