# Copyright 2023 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Microbenchmarks for sparse JAX.""" from functools import partial import jax.numpy as jnp import numpy as np import math import google_benchmark import jax from jax.experimental import sparse def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False): shape = (2000, 2000) nse = 10000 size = math.prod(shape) rng = np.random.RandomState(1701) data = rng.randn(nse) indices = np.unravel_index( rng.choice(size, size=nse, replace=False), shape=shape ) mat = jnp.zeros(shape).at[indices].set(data) f = sparse.BCOO.fromdense if compile or jit: # Note: nse must be specified for JIT. f = jax.jit(partial(f, nse=nse)) if compile: while state: state.pause_timing() jax.clear_caches() state.resume_timing() f.lower(mat).compile() else: f(mat).block_until_ready() while state: f(mat).block_until_ready() @google_benchmark.register def sparse_bcoo_fromdense(state): return _sparse_bcoo_fromdense(state) @google_benchmark.register def sparse_bcoo_fromdense_jit(state): return _sparse_bcoo_fromdense(state, jit=True) @google_benchmark.register def sparse_bcoo_fromdense_compile(state): return _sparse_bcoo_fromdense(state, compile=True) def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): shape = (2000, 2000) nse = 10000 size = math.prod(shape) rng = np.random.RandomState(1701) data = rng.randn(nse) indices = np.unravel_index( rng.choice(size, size=nse, replace=False), shape=shape ) mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape) f = lambda mat: mat.todense() if jit or compile: f = jax.jit(f) if compile: while state: state.pause_timing() jax.clear_caches() state.resume_timing() f.lower(mat).compile() else: f(mat).block_until_ready() while state: f(mat).block_until_ready() @google_benchmark.register def sparse_bcoo_todense(state): return _sparse_bcoo_todense(state) @google_benchmark.register def sparse_bcoo_todense_jit(state): return _sparse_bcoo_todense(state, jit=True) @google_benchmark.register def sparse_bcoo_todense_compile(state): return _sparse_bcoo_todense(state, compile=True) def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): shape = (2000, 2000) nse = 10000 key = jax.random.key(1701) mat = sparse.random_bcoo( key, nse=nse, shape=shape, dtype=jnp.float32, indices_dtype=jnp.int32, sorted_indices=True, ) vec = jax.random.uniform(key, shape=(shape[1],), dtype=jnp.float32) f = lambda mat, vec: mat @ vec if jit or compile: f = jax.jit(f) if compile: while state: state.pause_timing() jax.clear_caches() state.resume_timing() f.lower(mat, vec).compile() else: f(mat, vec).block_until_ready() while state: f(mat, vec).block_until_ready() @google_benchmark.register def sparse_bcoo_matvec(state): return _sparse_bcoo_matvec(state) @google_benchmark.register def sparse_bcoo_matvec_jit(state): return _sparse_bcoo_matvec(state, jit=True) @google_benchmark.register def sparse_bcoo_matvec_compile(state): return _sparse_bcoo_matvec(state, compile=True) if __name__ == "__main__": google_benchmark.main()