mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
benchmarks: add JIT versions of sparse.BCOO benchmarks
PiperOrigin-RevId: 405696495
This commit is contained in:
parent
47210f045b
commit
c62452f2d2
@ -12,15 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Microbenchmarks for JAX `api` functions."""
|
||||
|
||||
import functools
|
||||
import operator
|
||||
|
||||
import google_benchmark
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax import lax
|
||||
from jax.experimental import sparse
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
partial = functools.partial
|
||||
@ -323,50 +324,93 @@ def sda_index_8(state):
|
||||
_run_sda_index_bench(state, 8)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def sparse_bcoo_fromdense(state):
|
||||
def _sparse_bcoo_fromdense(state, jit: bool):
|
||||
shape = (2000, 2000)
|
||||
nse = 10000
|
||||
size = np.prod(shape)
|
||||
rng = np.random.RandomState(1701)
|
||||
data = rng.randn(nse)
|
||||
indices = np.unravel_index(rng.choice(size, size=nse, replace=False), shape=shape)
|
||||
indices = np.unravel_index(
|
||||
rng.choice(size, size=nse, replace=False), shape=shape)
|
||||
mat = jnp.zeros(shape).at[indices].set(data)
|
||||
sparse.BCOO.fromdense(mat).block_until_ready() # warm-up
|
||||
|
||||
f = sparse.BCOO.fromdense
|
||||
if jit:
|
||||
# Note: nse must be specified for JIT.
|
||||
f = jax.jit(partial(f, nse=nse))
|
||||
f(mat).block_until_ready() # warm-up
|
||||
|
||||
while state:
|
||||
sparse.BCOO.fromdense(mat).block_until_ready()
|
||||
f(mat).block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def sparse_bcoo_fromdense(state):
|
||||
return _sparse_bcoo_fromdense(state, jit=False)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def sparse_bcoo_fromdense_jit(state):
|
||||
return _sparse_bcoo_fromdense(state, jit=True)
|
||||
|
||||
|
||||
def _sparse_bcoo_todense(state, jit: bool):
|
||||
shape = (2000, 2000)
|
||||
nse = 10000
|
||||
size = np.prod(shape)
|
||||
rng = np.random.RandomState(1701)
|
||||
data = rng.randn(nse)
|
||||
indices = np.unravel_index(
|
||||
rng.choice(size, size=nse, replace=False), shape=shape)
|
||||
mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape)
|
||||
|
||||
f = lambda mat: mat.todense()
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
f(mat).block_until_ready() # warm-up
|
||||
|
||||
while state:
|
||||
f(mat).block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def sparse_bcoo_todense(state):
|
||||
return _sparse_bcoo_todense(state, jit=False)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def sparse_bcoo_todense_jit(state):
|
||||
return _sparse_bcoo_todense(state, jit=True)
|
||||
|
||||
|
||||
def _sparse_bcoo_matvec(state, jit: bool):
|
||||
shape = (2000, 2000)
|
||||
nse = 10000
|
||||
size = np.prod(shape)
|
||||
rng = np.random.RandomState(1701)
|
||||
data = rng.randn(nse)
|
||||
indices = np.unravel_index(rng.choice(size, size=nse, replace=False), shape=shape)
|
||||
indices = np.unravel_index(
|
||||
rng.choice(size, size=nse, replace=False), shape=shape)
|
||||
mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape)
|
||||
mat.todense().block_until_ready() # warm-up
|
||||
vec = rng.randn(shape[1])
|
||||
|
||||
f = lambda mat, vec: mat @ vec
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
f(mat, vec).block_until_ready() # warm-up
|
||||
|
||||
while state:
|
||||
mat.todense().block_until_ready()
|
||||
f(mat, vec).block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def sparse_bcoo_matvec(state):
|
||||
shape = 2000, 2000
|
||||
nse = 10000
|
||||
size = np.prod(shape)
|
||||
rng = np.random.RandomState(1701)
|
||||
data = rng.randn(nse)
|
||||
indices = np.unravel_index(rng.choice(size, size=nse, replace=False), shape=shape)
|
||||
mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape)
|
||||
vec = rng.randn(shape[1])
|
||||
(mat @ vec).block_until_ready() # warm-up
|
||||
return _sparse_bcoo_matvec(state, jit=False)
|
||||
|
||||
while state:
|
||||
(mat @ vec).block_until_ready()
|
||||
|
||||
@google_benchmark.register
|
||||
def sparse_bcoo_matvec_jit(state):
|
||||
return _sparse_bcoo_matvec(state, jit=True)
|
||||
|
||||
|
||||
def swap(a, b):
|
||||
|
@ -19,7 +19,6 @@ from typing import Any, NamedTuple, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
@ -983,7 +982,6 @@ class BCOO(ops.JAXSparse):
|
||||
"""Return a de-duplicated representation of the BCOO matrix."""
|
||||
return BCOO(_dedupe_bcoo(self.data, self.indices, self.shape), shape=self.shape)
|
||||
|
||||
@jax.jit
|
||||
def todense(self):
|
||||
"""Create a dense version of the array."""
|
||||
return bcoo_todense(self.data, self.indices, shape=self.shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user