diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 453137590..240796088 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -21,7 +21,6 @@ import operator import google_benchmark import jax from jax import lax -from jax.experimental import sparse from jax._src.api_util import shaped_abstractify # technically not an api fn from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation from jax._src.lib import xla_client as xc @@ -419,119 +418,6 @@ def sda_index_8(state): _run_sda_index_bench(state, 8) -def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False): - shape = (2000, 2000) - nse = 10000 - size = math.prod(shape) - rng = np.random.RandomState(1701) - data = rng.randn(nse) - indices = np.unravel_index( - rng.choice(size, size=nse, replace=False), shape=shape) - mat = jnp.zeros(shape).at[indices].set(data) - - f = sparse.BCOO.fromdense - if compile or jit: - # Note: nse must be specified for JIT. - f = jax.jit(partial(f, nse=nse)) - - if compile: - while state: - f.lower(mat).compile() - else: - f(mat).block_until_ready() - while state: - f(mat).block_until_ready() - - -@google_benchmark.register -def sparse_bcoo_fromdense(state): - return _sparse_bcoo_fromdense(state) - - -@google_benchmark.register -def sparse_bcoo_fromdense_jit(state): - return _sparse_bcoo_fromdense(state, jit=True) - - -@google_benchmark.register -def sparse_bcoo_fromdense_compile(state): - return _sparse_bcoo_fromdense(state, compile=True) - - -def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): - shape = (2000, 2000) - nse = 10000 - size = math.prod(shape) - rng = np.random.RandomState(1701) - data = rng.randn(nse) - indices = np.unravel_index( - rng.choice(size, size=nse, replace=False), shape=shape) - mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape) - - f = lambda mat: mat.todense() - if jit or compile: - f = jax.jit(f) - - if compile: - while state: - f.lower(mat).compile() - else: - f(mat).block_until_ready() - while state: - f(mat).block_until_ready() - - -@google_benchmark.register -def sparse_bcoo_todense(state): - return _sparse_bcoo_todense(state) - - -@google_benchmark.register -def sparse_bcoo_todense_jit(state): - return _sparse_bcoo_todense(state, jit=True) - - -@google_benchmark.register -def sparse_bcoo_todense_compile(state): - return _sparse_bcoo_todense(state, compile=True) - - -def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): - shape = (2000, 2000) - nse = 10000 - key = jax.random.PRNGKey(1701) - mat = sparse.random_bcoo(key, nse=nse, shape=shape, dtype=jnp.float32, - indices_dtype=jnp.int32, sorted_indices=True) - vec = jax.random.uniform(key, shape=(shape[1],), dtype=jnp.float32) - - f = lambda mat, vec: mat @ vec - if jit or compile: - f = jax.jit(f) - - if compile: - while state: - f.lower(mat, vec).compile() - else: - f(mat, vec).block_until_ready() - while state: - f(mat, vec).block_until_ready() - - -@google_benchmark.register -def sparse_bcoo_matvec(state): - return _sparse_bcoo_matvec(state) - - -@google_benchmark.register -def sparse_bcoo_matvec_jit(state): - return _sparse_bcoo_matvec(state, jit=True) - - -@google_benchmark.register -def sparse_bcoo_matvec_compile(state): - return _sparse_bcoo_matvec(state, compile=True) - - @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def bench_shaped_abstractify(state): diff --git a/benchmarks/sparse_benchmark.py b/benchmarks/sparse_benchmark.py new file mode 100644 index 000000000..65550b9cf --- /dev/null +++ b/benchmarks/sparse_benchmark.py @@ -0,0 +1,155 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Microbenchmarks for sparse JAX.""" + +from functools import partial +import jax.numpy as jnp +import numpy as np +import math +import google_benchmark +import jax +from jax.experimental import sparse + +def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False): + shape = (2000, 2000) + nse = 10000 + size = math.prod(shape) + rng = np.random.RandomState(1701) + data = rng.randn(nse) + indices = np.unravel_index( + rng.choice(size, size=nse, replace=False), shape=shape + ) + mat = jnp.zeros(shape).at[indices].set(data) + + f = sparse.BCOO.fromdense + if compile or jit: + # Note: nse must be specified for JIT. + f = jax.jit(partial(f, nse=nse)) + + if compile: + while state: + state.pause_timing() + jax.clear_caches() + state.resume_timing() + f.lower(mat).compile() + else: + f(mat).block_until_ready() + while state: + f(mat).block_until_ready() + + +@google_benchmark.register +def sparse_bcoo_fromdense(state): + return _sparse_bcoo_fromdense(state) + + +@google_benchmark.register +def sparse_bcoo_fromdense_jit(state): + return _sparse_bcoo_fromdense(state, jit=True) + + +@google_benchmark.register +def sparse_bcoo_fromdense_compile(state): + return _sparse_bcoo_fromdense(state, compile=True) + + +def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): + shape = (2000, 2000) + nse = 10000 + size = math.prod(shape) + rng = np.random.RandomState(1701) + data = rng.randn(nse) + indices = np.unravel_index( + rng.choice(size, size=nse, replace=False), shape=shape + ) + mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape) + + f = lambda mat: mat.todense() + if jit or compile: + f = jax.jit(f) + + if compile: + while state: + state.pause_timing() + jax.clear_caches() + state.resume_timing() + f.lower(mat).compile() + else: + f(mat).block_until_ready() + while state: + f(mat).block_until_ready() + + +@google_benchmark.register +def sparse_bcoo_todense(state): + return _sparse_bcoo_todense(state) + + +@google_benchmark.register +def sparse_bcoo_todense_jit(state): + return _sparse_bcoo_todense(state, jit=True) + + +@google_benchmark.register +def sparse_bcoo_todense_compile(state): + return _sparse_bcoo_todense(state, compile=True) + + +def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): + shape = (2000, 2000) + nse = 10000 + key = jax.random.PRNGKey(1701) + mat = sparse.random_bcoo( + key, + nse=nse, + shape=shape, + dtype=jnp.float32, + indices_dtype=jnp.int32, + sorted_indices=True, + ) + vec = jax.random.uniform(key, shape=(shape[1],), dtype=jnp.float32) + + f = lambda mat, vec: mat @ vec + if jit or compile: + f = jax.jit(f) + + if compile: + while state: + state.pause_timing() + jax.clear_caches() + state.resume_timing() + f.lower(mat, vec).compile() + else: + f(mat, vec).block_until_ready() + while state: + f(mat, vec).block_until_ready() + + +@google_benchmark.register +def sparse_bcoo_matvec(state): + return _sparse_bcoo_matvec(state) + + +@google_benchmark.register +def sparse_bcoo_matvec_jit(state): + return _sparse_bcoo_matvec(state, jit=True) + + +@google_benchmark.register +def sparse_bcoo_matvec_compile(state): + return _sparse_bcoo_matvec(state, compile=True) + + +if __name__ == "__main__": + google_benchmark.main()