benchmarks: add JIT versions of sparse.BCOO benchmarks

PiperOrigin-RevId: 405696495
This commit is contained in:
Jake VanderPlas 2021-10-26 11:38:24 -07:00 committed by jax authors
parent 47210f045b
commit c62452f2d2
2 changed files with 65 additions and 23 deletions

View File

@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Microbenchmarks for JAX `api` functions."""
import functools
import operator
import google_benchmark
import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax.experimental import sparse
import jax.numpy as jnp
import numpy as np
partial = functools.partial
@ -323,50 +324,93 @@ def sda_index_8(state):
_run_sda_index_bench(state, 8)
@google_benchmark.register
def sparse_bcoo_fromdense(state):
def _sparse_bcoo_fromdense(state, jit: bool):
shape = (2000, 2000)
nse = 10000
size = np.prod(shape)
rng = np.random.RandomState(1701)
data = rng.randn(nse)
indices = np.unravel_index(rng.choice(size, size=nse, replace=False), shape=shape)
indices = np.unravel_index(
rng.choice(size, size=nse, replace=False), shape=shape)
mat = jnp.zeros(shape).at[indices].set(data)
sparse.BCOO.fromdense(mat).block_until_ready() # warm-up
f = sparse.BCOO.fromdense
if jit:
# Note: nse must be specified for JIT.
f = jax.jit(partial(f, nse=nse))
f(mat).block_until_ready() # warm-up
while state:
sparse.BCOO.fromdense(mat).block_until_ready()
f(mat).block_until_ready()
@google_benchmark.register
def sparse_bcoo_fromdense(state):
return _sparse_bcoo_fromdense(state, jit=False)
@google_benchmark.register
def sparse_bcoo_fromdense_jit(state):
return _sparse_bcoo_fromdense(state, jit=True)
def _sparse_bcoo_todense(state, jit: bool):
shape = (2000, 2000)
nse = 10000
size = np.prod(shape)
rng = np.random.RandomState(1701)
data = rng.randn(nse)
indices = np.unravel_index(
rng.choice(size, size=nse, replace=False), shape=shape)
mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape)
f = lambda mat: mat.todense()
if jit:
f = jax.jit(f)
f(mat).block_until_ready() # warm-up
while state:
f(mat).block_until_ready()
@google_benchmark.register
def sparse_bcoo_todense(state):
return _sparse_bcoo_todense(state, jit=False)
@google_benchmark.register
def sparse_bcoo_todense_jit(state):
return _sparse_bcoo_todense(state, jit=True)
def _sparse_bcoo_matvec(state, jit: bool):
shape = (2000, 2000)
nse = 10000
size = np.prod(shape)
rng = np.random.RandomState(1701)
data = rng.randn(nse)
indices = np.unravel_index(rng.choice(size, size=nse, replace=False), shape=shape)
indices = np.unravel_index(
rng.choice(size, size=nse, replace=False), shape=shape)
mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape)
mat.todense().block_until_ready() # warm-up
vec = rng.randn(shape[1])
f = lambda mat, vec: mat @ vec
if jit:
f = jax.jit(f)
f(mat, vec).block_until_ready() # warm-up
while state:
mat.todense().block_until_ready()
f(mat, vec).block_until_ready()
@google_benchmark.register
def sparse_bcoo_matvec(state):
shape = 2000, 2000
nse = 10000
size = np.prod(shape)
rng = np.random.RandomState(1701)
data = rng.randn(nse)
indices = np.unravel_index(rng.choice(size, size=nse, replace=False), shape=shape)
mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape)
vec = rng.randn(shape[1])
(mat @ vec).block_until_ready() # warm-up
return _sparse_bcoo_matvec(state, jit=False)
while state:
(mat @ vec).block_until_ready()
@google_benchmark.register
def sparse_bcoo_matvec_jit(state):
return _sparse_bcoo_matvec(state, jit=True)
def swap(a, b):

View File

@ -19,7 +19,6 @@ from typing import Any, NamedTuple, Sequence, Tuple
import numpy as np
import jax
from jax import core
from jax import dtypes
from jax import lax
@ -983,7 +982,6 @@ class BCOO(ops.JAXSparse):
"""Return a de-duplicated representation of the BCOO matrix."""
return BCOO(_dedupe_bcoo(self.data, self.indices, self.shape), shape=self.shape)
@jax.jit
def todense(self):
"""Create a dense version of the array."""
return bcoo_todense(self.data, self.indices, shape=self.shape)