mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[shape_poly] Add another micro-benchmark
A very common operation is to compute linear combinations of expressions. PiperOrigin-RevId: 605644781
This commit is contained in:
parent
07d793d7b2
commit
1e7642279b
@ -38,6 +38,21 @@ def builder_arith(state):
|
||||
e2 = (b // a - a - c % a + 4)
|
||||
_ = e1 + e2 + (e1 * e2)
|
||||
|
||||
@benchmark.register
|
||||
def builder_linear_arith(state):
|
||||
a, b, c = export.symbolic_shape("a, b, c")
|
||||
while state:
|
||||
left = [a, 3*a, a + 2*b, a + 3*b + 4*c]
|
||||
right = [b, -1*a, a - 2*b, a - 3*b - 4*c]
|
||||
for l in left:
|
||||
for r in right:
|
||||
for l_k in [1, 2, -2]:
|
||||
for r_k in [1, 2, -2]:
|
||||
comb = l * l_k + r * r_k
|
||||
if not isinstance(comb, int):
|
||||
_ = comb.leading_term # Ensure we actually materialize
|
||||
|
||||
|
||||
@benchmark.register
|
||||
def builder_min_max(state):
|
||||
a, b = export.symbolic_shape("a, b")
|
||||
|
Loading…
x
Reference in New Issue
Block a user