rocm_jax/benchmarks/shape_poly_benchmark.py

86 lines
2.4 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Microbenchmarks for JAX shape polymorphism symbolic expressions."""
import google_benchmark as benchmark
import jax
from jax import core
from jax import export
jax.config.parse_flags_with_absl()
@benchmark.register
def parse(state):
while state:
export.symbolic_shape("a, b, max(a, b), min(max(a, b), b), "
"floordiv(a, 2), mod(b, floordiv(a, 2))")
@benchmark.register
def builder_arith(state):
a, b, c = export.symbolic_shape("a, b, c")
while state:
for _ in range(1000):
e1 = (a + b // a + c % a + 3)
e2 = (b // a - a - c % a + 4)
_ = e1 + e2 + (e1 * e2)
@benchmark.register
def builder_linear_arith(state):
a, b, c = export.symbolic_shape("a, b, c")
while state:
left = [a, 3*a, a + 2*b, a + 3*b + 4*c]
right = [b, -1*a, a - 2*b, a - 3*b - 4*c]
for l in left:
for r in right:
for l_k in [1, 2, -2]:
for r_k in [1, 2, -2]:
comb = l * l_k + r * r_k
if not isinstance(comb, int):
_ = comb.leading_term # Ensure we actually materialize
@benchmark.register
def builder_min_max(state):
a, b = export.symbolic_shape("a, b")
while state:
for _ in range(100):
a.scope._clear_caches()
_ = core.max_dim(a, b) + core.min_dim(a, a + b)
@benchmark.register
def load_constraints(state):
while state:
export.symbolic_shape(
"a, b, c",
constraints=["a >= c",
"max(max(a, b), 2) >= max(a, b)"])
@benchmark.register
def inequalities_slice(state):
a, b = export.symbolic_shape("a, b")
while state:
for _ in range(30):
a.scope._clear_caches()
start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b)
_ = 0 <= slice_size <= b
_ = start >= 0
_ = start + slice_size <= b
if __name__ == "__main__":
benchmark.main()