mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove pytest_benchmark for test-requirements.txt and move the benchmark file which was using that package to use google_benchmark.
PiperOrigin-RevId: 483736267
This commit is contained in:
parent
548d7f4599
commit
cf6b5097d0
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""Microbenchmarks for JAX `api` functions."""
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import operator
|
||||
|
||||
@ -25,6 +26,7 @@ from jax.experimental import sparse
|
||||
from jax._src.api_util import shaped_abstractify # technically not an api fn
|
||||
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
@ -47,10 +49,15 @@ def required_devices(num_devices_required):
|
||||
return helper2
|
||||
return helper1
|
||||
|
||||
|
||||
def swap(a, b):
|
||||
return b, a
|
||||
|
||||
|
||||
class AnEnum(enum.IntEnum):
|
||||
A = 123
|
||||
B = 456
|
||||
|
||||
@google_benchmark.register
|
||||
def eager_unary_dispatch(state):
|
||||
a = jax.device_put(1)
|
||||
@ -555,6 +562,30 @@ def bench_shaped_abstractify(state):
|
||||
_ = [shaped_abstractify(x) for x in args]
|
||||
|
||||
|
||||
def _run_benchmark_for_xla_abstractify(arg, state):
|
||||
while state:
|
||||
xla.abstractify(arg)
|
||||
|
||||
def bench_xla_abstractify():
|
||||
_abstractify_args = [
|
||||
(3, 'scalar_int'),
|
||||
(3.5, 'scalar_float'),
|
||||
(np.int32(3), 'scalar_numpy_int32'),
|
||||
(np.uint32(7), 'scalar_numpy_uint32'),
|
||||
(np.random.randn(3, 4, 5, 6), 'numpy_random'),
|
||||
(np.arange(100, dtype=np.float32), 'numpy_arange_100_float32'),
|
||||
(AnEnum.B, 'enum'),
|
||||
]
|
||||
benchmarks = []
|
||||
for a, name in _abstractify_args:
|
||||
benchmarks.extend([
|
||||
google_benchmark.register(
|
||||
partial(_run_benchmark_for_xla_abstractify, a),
|
||||
name=f'bench_xla_abstractify_{name}'),
|
||||
])
|
||||
bench_xla_abstractify()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMicrosecond)
|
||||
def bench_are_op_shardings_equal(state):
|
||||
|
@ -3,7 +3,6 @@ cloudpickle
|
||||
colorama>=0.4.4
|
||||
matplotlib
|
||||
pillow>=9.1.0
|
||||
pytest-benchmark
|
||||
pytest-xdist
|
||||
wheel
|
||||
rich
|
||||
|
@ -1,43 +0,0 @@
|
||||
# Copyright 2019 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import enum
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from jax import numpy as jnp
|
||||
from jax.interpreters import xla
|
||||
|
||||
|
||||
class AnEnum(enum.IntEnum):
|
||||
A = 123
|
||||
B = 456
|
||||
|
||||
|
||||
_abstractify_args = [
|
||||
3,
|
||||
3.5,
|
||||
np.int32(3),
|
||||
np.uint32(7),
|
||||
np.random.randn(3, 4, 5, 6),
|
||||
np.arange(100, dtype=np.float32),
|
||||
jnp.int64(-3),
|
||||
jnp.array([1, 2, 3]),
|
||||
AnEnum.B,
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("arg", _abstractify_args)
|
||||
def test_abstractify(benchmark, arg):
|
||||
benchmark(xla.abstractify, arg)
|
Loading…
x
Reference in New Issue
Block a user