mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
86 lines
2.4 KiB
Python
86 lines
2.4 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Microbenchmarks for JAX shape polymorphism symbolic expressions."""
|
|
|
|
import google_benchmark as benchmark
|
|
|
|
import jax
|
|
from jax import core
|
|
from jax import export
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
@benchmark.register
|
|
def parse(state):
|
|
while state:
|
|
export.symbolic_shape("a, b, max(a, b), min(max(a, b), b), "
|
|
"floordiv(a, 2), mod(b, floordiv(a, 2))")
|
|
|
|
@benchmark.register
|
|
def builder_arith(state):
|
|
a, b, c = export.symbolic_shape("a, b, c")
|
|
while state:
|
|
for _ in range(1000):
|
|
e1 = (a + b // a + c % a + 3)
|
|
e2 = (b // a - a - c % a + 4)
|
|
_ = e1 + e2 + (e1 * e2)
|
|
|
|
@benchmark.register
|
|
def builder_linear_arith(state):
|
|
a, b, c = export.symbolic_shape("a, b, c")
|
|
while state:
|
|
left = [a, 3*a, a + 2*b, a + 3*b + 4*c]
|
|
right = [b, -1*a, a - 2*b, a - 3*b - 4*c]
|
|
for l in left:
|
|
for r in right:
|
|
for l_k in [1, 2, -2]:
|
|
for r_k in [1, 2, -2]:
|
|
comb = l * l_k + r * r_k
|
|
if not isinstance(comb, int):
|
|
_ = comb.leading_term # Ensure we actually materialize
|
|
|
|
|
|
@benchmark.register
|
|
def builder_min_max(state):
|
|
a, b = export.symbolic_shape("a, b")
|
|
while state:
|
|
for _ in range(100):
|
|
a.scope._clear_caches()
|
|
_ = core.max_dim(a, b) + core.min_dim(a, a + b)
|
|
|
|
@benchmark.register
|
|
def load_constraints(state):
|
|
while state:
|
|
export.symbolic_shape(
|
|
"a, b, c",
|
|
constraints=["a >= c",
|
|
"max(max(a, b), 2) >= max(a, b)"])
|
|
|
|
@benchmark.register
|
|
def inequalities_slice(state):
|
|
|
|
a, b = export.symbolic_shape("a, b")
|
|
while state:
|
|
for _ in range(30):
|
|
a.scope._clear_caches()
|
|
start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b)
|
|
_ = 0 <= slice_size <= b
|
|
_ = start >= 0
|
|
_ = start + slice_size <= b
|
|
|
|
|
|
if __name__ == "__main__":
|
|
benchmark.main()
|