mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

The functionality comes from the jax.experimental.export module, which will be deprecated. The following APIs are introduced: ``` from jax import export def f(...): ... ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs) blob: bytearray = ex.serialize() rehydrated: export.Export = export.deserialize(blob) def caller(...): ... rehydrated.call(*args, **kwargs) ``` Module documentation will follow shortly. There are no changes for now in the jax.experimental.export APIs. Most of the changes in this PR are in tests due to some differences in the new jax.export APIs compared to jax.experimental.export: * Instead of `jax.experimental.export.call(exp)` we now write `exp.call` * The `jax.experimental.export.export` allowed the function argument to be any Python callable and it would wrap it with a `jax.jit`. This is not supported anymore by export, and instead the user must use `jax.jit`.
87 lines
2.4 KiB
Python
87 lines
2.4 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Microbenchmarks for JAX shape polymorphism symbolic expressions."""
|
|
|
|
import google_benchmark as benchmark
|
|
|
|
import jax
|
|
from jax import core
|
|
from jax._src.numpy import lax_numpy
|
|
from jax import export
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
@benchmark.register
|
|
def parse(state):
|
|
while state:
|
|
export.symbolic_shape("a, b, max(a, b), min(max(a, b), b), "
|
|
"floordiv(a, 2), mod(b, floordiv(a, 2))")
|
|
|
|
@benchmark.register
|
|
def builder_arith(state):
|
|
a, b, c = export.symbolic_shape("a, b, c")
|
|
while state:
|
|
for _ in range(1000):
|
|
e1 = (a + b // a + c % a + 3)
|
|
e2 = (b // a - a - c % a + 4)
|
|
_ = e1 + e2 + (e1 * e2)
|
|
|
|
@benchmark.register
|
|
def builder_linear_arith(state):
|
|
a, b, c = export.symbolic_shape("a, b, c")
|
|
while state:
|
|
left = [a, 3*a, a + 2*b, a + 3*b + 4*c]
|
|
right = [b, -1*a, a - 2*b, a - 3*b - 4*c]
|
|
for l in left:
|
|
for r in right:
|
|
for l_k in [1, 2, -2]:
|
|
for r_k in [1, 2, -2]:
|
|
comb = l * l_k + r * r_k
|
|
if not isinstance(comb, int):
|
|
_ = comb.leading_term # Ensure we actually materialize
|
|
|
|
|
|
@benchmark.register
|
|
def builder_min_max(state):
|
|
a, b = export.symbolic_shape("a, b")
|
|
while state:
|
|
for _ in range(100):
|
|
a.scope._clear_caches()
|
|
_ = core.max_dim(a, b) + core.min_dim(a, a + b)
|
|
|
|
@benchmark.register
|
|
def load_constraints(state):
|
|
while state:
|
|
export.symbolic_shape(
|
|
"a, b, c",
|
|
constraints=["a >= c",
|
|
"max(max(a, b), 2) >= max(a, b)"])
|
|
|
|
@benchmark.register
|
|
def inequalities_slice(state):
|
|
|
|
a, b = export.symbolic_shape("a, b")
|
|
while state:
|
|
for _ in range(30):
|
|
a.scope._clear_caches()
|
|
start, _, slice_size = lax_numpy._preprocess_slice(slice(2, a, 4), b)
|
|
_ = 0 <= slice_size <= b
|
|
_ = start >= 0
|
|
_ = start + slice_size <= b
|
|
|
|
|
|
if __name__ == "__main__":
|
|
benchmark.main()
|