mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
1501 lines
59 KiB
Python
1501 lines
59 KiB
Python
# Copyright 2021 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""JAX primitives related to sparse operations.
|
|
|
|
This is experimental work to explore sparse support in JAX.
|
|
|
|
The primitives defined here are deliberately low-level: each primitive implements
|
|
a common sparse operation (sparse to dense, dense to sparse, sparse matrix/vector
|
|
product, sparse matrix/matrix product) for two common sparse representations
|
|
(CSR and COO).
|
|
|
|
These routines have reference implementations defined via XLA scatter/gather
|
|
operations that will work on any backend, although they are not particularly
|
|
performant. On GPU runtimes built against CUDA 11.0 or newer, each operation is
|
|
computed efficiently via cusparse.
|
|
|
|
Further down are some examples of potential high-level wrappers for sparse objects.
|
|
(API should be considered unstable and subject to change).
|
|
"""
|
|
import functools
|
|
import operator
|
|
|
|
from typing import Any, Sequence, Tuple
|
|
|
|
from jax import api
|
|
from jax import core
|
|
from jax import jit
|
|
from jax import lax
|
|
from jax import tree_util
|
|
from jax import vmap
|
|
from jax.interpreters import batching
|
|
from jax.interpreters import partial_eval as pe
|
|
from jax.interpreters import xla
|
|
from jax.lib import cusparse
|
|
from jax.lib import xla_bridge
|
|
from jax.lib import xla_client
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from jax.interpreters import ad
|
|
from jax.util import safe_zip
|
|
from jax._src.lax.lax import ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_computation
|
|
|
|
xb = xla_bridge
|
|
xops = xla_client.ops
|
|
Dtype = Any
|
|
|
|
#--------------------------------------------------------------------
|
|
# utilities
|
|
# TODO: possibly make these utilities into primitives, targeting
|
|
# csr2coo/coo2csr/SPDDMM
|
|
@functools.partial(jit, static_argnums=1)
|
|
def _csr_to_coo(indptr, nse):
|
|
return jnp.cumsum(jnp.zeros_like(indptr, shape=nse).at[indptr].add(1)) - 1
|
|
|
|
@functools.partial(jit, static_argnums=1)
|
|
def _coo_to_csr(row, nrows):
|
|
indptr = jnp.zeros(nrows + 1, row.dtype)
|
|
return indptr.at[1:].set(jnp.cumsum(jnp.bincount(row, length=nrows)))
|
|
|
|
@jit
|
|
def _csr_extract(indices, indptr, mat):
|
|
"""Extract values of dense matrix mat at given CSR indices."""
|
|
return _coo_extract(_csr_to_coo(indptr, len(indices)), indices, mat)
|
|
|
|
@jit
|
|
def _coo_extract(row, col, mat):
|
|
"""Extract values of dense matrix mat at given COO indices."""
|
|
return mat[row, col]
|
|
|
|
#--------------------------------------------------------------------
|
|
# csr_todense
|
|
|
|
csr_todense_p = core.Primitive('csr_todense')
|
|
|
|
def csr_todense(data, indices, indptr, *, shape):
|
|
"""Convert CSR-format sparse matrix to a dense matrix.
|
|
|
|
Args:
|
|
data : array of shape ``(nse,)``.
|
|
indices : array of shape ``(nse,)``
|
|
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
|
|
shape : length-2 tuple representing the matrix shape
|
|
|
|
Returns:
|
|
mat : array with specified shape and dtype matching ``data``
|
|
"""
|
|
return csr_todense_p.bind(data, indices, indptr, shape=shape)
|
|
|
|
@csr_todense_p.def_impl
|
|
def _csr_todense_impl(data, indices, indptr, *, shape):
|
|
return _coo_todense_impl(data, _csr_to_coo(indptr, len(indices)), indices, shape=shape)
|
|
|
|
@csr_todense_p.def_abstract_eval
|
|
def _csr_todense_abstract_eval(data, indices, indptr, *, shape):
|
|
assert data.ndim == indices.ndim == indptr.ndim == 1
|
|
assert indices.dtype == indptr.dtype
|
|
assert data.shape == indices.shape
|
|
assert indptr.shape[0] == shape[0] + 1
|
|
return core.ShapedArray(shape, data.dtype)
|
|
|
|
def _csr_todense_gpu_translation_rule(c, data, indices, indptr, *, shape):
|
|
return cusparse.csr_todense(c, data, indices, indptr, shape=shape)
|
|
|
|
xla.translations[csr_todense_p] = xla.lower_fun(
|
|
_csr_todense_impl, multiple_results=False)
|
|
if cusparse and cusparse.is_supported:
|
|
xla.backend_specific_translations['gpu'][
|
|
csr_todense_p] = _csr_todense_gpu_translation_rule
|
|
|
|
#--------------------------------------------------------------------
|
|
# csr_fromdense
|
|
|
|
csr_fromdense_p = core.Primitive('csr_fromdense')
|
|
csr_fromdense_p.multiple_results = True
|
|
|
|
def csr_fromdense(mat, *, nse, index_dtype=np.int32):
|
|
"""Create CSR-format sparse matrix from a dense matrix.
|
|
|
|
Args:
|
|
mat : array to be converted to CSR.
|
|
nse : number of specified entries in ``mat``
|
|
index_dtype : dtype of sparse indices
|
|
|
|
Returns:
|
|
data : array of shape ``(nse,)`` and dtype ``mat.dtype``.
|
|
indices : array of shape ``(nse,)`` and dtype ``index_dtype``
|
|
indptr : array of shape ``(mat.shape[0] + 1,)`` and dtype ``index_dtype``
|
|
"""
|
|
mat = jnp.asarray(mat)
|
|
nse = core.concrete_or_error(operator.index, nse, "nse argument of csr_fromdense()")
|
|
return csr_fromdense_p.bind(mat, nse=nse, index_dtype=np.dtype(index_dtype))
|
|
|
|
@csr_fromdense_p.def_impl
|
|
def _csr_fromdense_impl(mat, *, nse, index_dtype):
|
|
mat = jnp.asarray(mat)
|
|
assert mat.ndim == 2
|
|
m = mat.shape[0]
|
|
|
|
row, col = jnp.nonzero(mat, size=nse)
|
|
data = mat[row, col]
|
|
|
|
true_nonzeros = jnp.arange(nse) < (mat != 0).sum()
|
|
data = jnp.where(true_nonzeros, data, 0)
|
|
row = jnp.where(true_nonzeros, row, m)
|
|
indices = col.astype(index_dtype)
|
|
indptr = jnp.zeros(m + 1, dtype=index_dtype).at[1:].set(
|
|
jnp.cumsum(jnp.bincount(row, length=m)))
|
|
return data, indices, indptr
|
|
|
|
@csr_fromdense_p.def_abstract_eval
|
|
def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype):
|
|
data = core.ShapedArray((nse,), mat.dtype)
|
|
indices = core.ShapedArray((nse,), index_dtype)
|
|
indptr = core.ShapedArray((mat.shape[0] + 1,), index_dtype)
|
|
return data, indices, indptr
|
|
|
|
def _csr_fromdense_gpu_translation_rule(c, mat, *, nse, index_dtype):
|
|
data, indices, indptr = cusparse.csr_fromdense(
|
|
c, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
|
|
return xops.Tuple(c, [data, indices, indptr])
|
|
|
|
xla.translations[csr_fromdense_p] = xla.lower_fun(
|
|
_csr_fromdense_impl, multiple_results=True)
|
|
if cusparse and cusparse.is_supported:
|
|
xla.backend_specific_translations['gpu'][
|
|
csr_fromdense_p] = _csr_fromdense_gpu_translation_rule
|
|
|
|
#--------------------------------------------------------------------
|
|
# csr_matvec
|
|
|
|
csr_matvec_p = core.Primitive('csr_matvec')
|
|
|
|
def csr_matvec(data, indices, indptr, v, *, shape, transpose=False):
|
|
"""Product of CSR sparse matrix and a dense vector.
|
|
|
|
Args:
|
|
data : array of shape ``(nse,)``.
|
|
indices : array of shape ``(nse,)``
|
|
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
|
|
v : array of shape ``(shape[0] if transpose else shape[1],)``
|
|
and dtype ``data.dtype``
|
|
shape : length-2 tuple representing the matrix shape
|
|
transpose : boolean specifying whether to transpose the sparse matrix
|
|
before computing.
|
|
|
|
Returns:
|
|
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
|
|
the matrix vector product.
|
|
"""
|
|
return csr_matvec_p.bind(data, indices, indptr, v, shape=shape, transpose=transpose)
|
|
|
|
@csr_matvec_p.def_impl
|
|
def _csr_matvec_impl(data, indices, indptr, v, *, shape, transpose):
|
|
row = _csr_to_coo(indptr, len(indices))
|
|
return _coo_matvec_impl(data, row, indices, v, shape=shape, transpose=transpose)
|
|
|
|
@csr_matvec_p.def_abstract_eval
|
|
def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose):
|
|
assert len(shape) == 2
|
|
assert v.ndim == data.ndim == indices.ndim == indptr.ndim == 1
|
|
assert data.shape == indices.shape
|
|
assert data.dtype == v.dtype
|
|
assert indices.dtype == indptr.dtype
|
|
assert len(indptr) == shape[0] + 1
|
|
out_shape = shape[1] if transpose else shape[0]
|
|
assert v.shape == (shape[0],) if transpose else (shape[1],)
|
|
return core.ShapedArray((out_shape,), data.dtype)
|
|
|
|
def _csr_matvec_gpu_translation_rule(c, data, indices, indptr, v, *, shape, transpose):
|
|
return cusparse.csr_matvec(c, data, indices, indptr, v, shape=shape, transpose=transpose)
|
|
|
|
xla.translations[csr_matvec_p] = xla.lower_fun(
|
|
_csr_matvec_impl, multiple_results=False)
|
|
if cusparse and cusparse.is_supported:
|
|
xla.backend_specific_translations['gpu'][
|
|
csr_matvec_p] = _csr_matvec_gpu_translation_rule
|
|
|
|
|
|
#--------------------------------------------------------------------
|
|
# csr_matmat
|
|
|
|
csr_matmat_p = core.Primitive('csr_matmat')
|
|
|
|
def csr_matmat(data, indices, indptr, B, *, shape, transpose=False):
|
|
"""Product of CSR sparse matrix and a dense matrix.
|
|
|
|
Args:
|
|
data : array of shape ``(nse,)``.
|
|
indices : array of shape ``(nse,)``
|
|
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
|
|
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
|
|
dtype ``data.dtype``
|
|
shape : length-2 tuple representing the matrix shape
|
|
transpose : boolean specifying whether to transpose the sparse matrix
|
|
before computing.
|
|
|
|
Returns:
|
|
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
|
|
representing the matrix-matrix product product.
|
|
"""
|
|
return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose)
|
|
|
|
@csr_matmat_p.def_impl
|
|
def _csr_matmat_impl(data, indices, indptr, B, *, shape, transpose):
|
|
row = _csr_to_coo(indptr, len(indices))
|
|
return _coo_matmat_impl(data, row, indices, B, shape=shape, transpose=transpose)
|
|
|
|
@csr_matmat_p.def_abstract_eval
|
|
def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
|
|
assert data.ndim == indices.ndim == indptr.ndim == 1
|
|
assert B.ndim == 2
|
|
assert data.shape == indices.shape
|
|
assert data.dtype == B.dtype
|
|
assert indices.dtype == indptr.dtype
|
|
assert len(indptr) == shape[0] + 1
|
|
out_shape = shape[1] if transpose else shape[0]
|
|
assert B.shape[0] == shape[0] if transpose else shape[1]
|
|
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
|
|
|
|
def _csr_matmat_gpu_translation_rule(c, data, indices, indptr, B, *, shape, transpose):
|
|
return cusparse.csr_matmat(c, data, indices, indptr, B, shape=shape, transpose=transpose)
|
|
|
|
xla.translations[csr_matmat_p] = xla.lower_fun(
|
|
_csr_matmat_impl, multiple_results=False)
|
|
if cusparse and cusparse.is_supported:
|
|
xla.backend_specific_translations['gpu'][
|
|
csr_matmat_p] = _csr_matmat_gpu_translation_rule
|
|
|
|
|
|
#--------------------------------------------------------------------
|
|
# coo_todense
|
|
|
|
coo_todense_p = core.Primitive('coo_todense')
|
|
|
|
def coo_todense(data, row, col, *, shape):
|
|
"""Convert CSR-format sparse matrix to a dense matrix.
|
|
|
|
Args:
|
|
data : array of shape ``(nse,)``.
|
|
row : array of shape ``(nse,)``
|
|
col : array of shape ``(nse,)`` and dtype ``row.dtype``
|
|
shape : length-2 tuple representing the matrix shape
|
|
|
|
Returns:
|
|
mat : array with specified shape and dtype matching ``data``
|
|
"""
|
|
return coo_todense_p.bind(data, row, col, shape=shape)
|
|
|
|
@coo_todense_p.def_impl
|
|
def _coo_todense_impl(data, row, col, *, shape):
|
|
return jnp.zeros(shape, data.dtype).at[row, col].add(data)
|
|
|
|
@coo_todense_p.def_abstract_eval
|
|
def _coo_todense_abstract_eval(data, row, col, *, shape):
|
|
return core.ShapedArray(shape, data.dtype)
|
|
|
|
def _coo_todense_gpu_translation_rule(c, data, row, col, *, shape):
|
|
return cusparse.coo_todense(c, data, row, col, shape=shape)
|
|
|
|
def _coo_todense_jvp(data_dot, data, row, col, *, shape):
|
|
return coo_todense(data_dot, row, col, shape=shape)
|
|
|
|
def _coo_todense_transpose(ct, data, row, col, *, shape):
|
|
# Note: we assume that transpose has the same sparsity pattern.
|
|
# Can we check this?
|
|
assert ad.is_undefined_primal(data)
|
|
if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
|
|
raise ValueError("Cannot transpose with respect to sparse indices")
|
|
assert ct.shape == shape
|
|
assert row.aval.dtype == col.aval.dtype
|
|
assert ct.dtype == data.aval.dtype
|
|
return _coo_extract(row, col, ct), row, col
|
|
|
|
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
|
|
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
|
|
xla.translations[coo_todense_p] = xla.lower_fun(
|
|
_coo_todense_impl, multiple_results=False)
|
|
if cusparse and cusparse.is_supported:
|
|
xla.backend_specific_translations['gpu'][
|
|
coo_todense_p] = _coo_todense_gpu_translation_rule
|
|
|
|
#--------------------------------------------------------------------
|
|
# coo_fromdense
|
|
|
|
coo_fromdense_p = core.Primitive('coo_fromdense')
|
|
coo_fromdense_p.multiple_results = True
|
|
|
|
def coo_fromdense(mat, *, nse, index_dtype=jnp.int32):
|
|
"""Create COO-format sparse matrix from a dense matrix.
|
|
|
|
Args:
|
|
mat : array to be converted to COO.
|
|
nse : number of specified entries in ``mat``
|
|
index_dtype : dtype of sparse indices
|
|
|
|
Returns:
|
|
data : array of shape ``(nse,)`` and dtype ``mat.dtype``
|
|
row : array of shape ``(nse,)`` and dtype ``index_dtype``
|
|
col : array of shape ``(nse,)`` and dtype ``index_dtype``
|
|
"""
|
|
mat = jnp.asarray(mat)
|
|
nse = core.concrete_or_error(operator.index, nse, "nse argument of coo_fromdense()")
|
|
return coo_fromdense_p.bind(mat, nse=nse, index_dtype=index_dtype)
|
|
|
|
@coo_fromdense_p.def_impl
|
|
def _coo_fromdense_impl(mat, *, nse, index_dtype):
|
|
mat = jnp.asarray(mat)
|
|
assert mat.ndim == 2
|
|
|
|
row, col = jnp.nonzero(mat, size=nse)
|
|
data = mat[row, col]
|
|
|
|
true_nonzeros = jnp.arange(nse) < (mat != 0).sum()
|
|
data = jnp.where(true_nonzeros, data, 0)
|
|
|
|
return data, row.astype(index_dtype), col.astype(index_dtype)
|
|
|
|
@coo_fromdense_p.def_abstract_eval
|
|
def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype):
|
|
data = core.ShapedArray((nse,), mat.dtype)
|
|
row = col = core.ShapedArray((nse,), index_dtype)
|
|
return data, row, col
|
|
|
|
def _coo_fromdense_gpu_translation_rule(c, mat, *, nse, index_dtype):
|
|
data, row, col = cusparse.coo_fromdense(
|
|
c, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
|
|
return xops.Tuple(c, [data, row, col])
|
|
|
|
def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype):
|
|
M, = primals
|
|
Mdot, = tangents
|
|
|
|
primals_out = coo_fromdense(M, nse=nse, index_dtype=index_dtype)
|
|
data, row, col = primals_out
|
|
|
|
if type(Mdot) is ad.Zero:
|
|
data_dot = ad.Zero.from_value(data)
|
|
else:
|
|
data_dot = _coo_extract(row, col, Mdot)
|
|
|
|
tangents_out = (data_dot, ad.Zero.from_value(row), ad.Zero.from_value(col))
|
|
|
|
return primals_out, tangents_out
|
|
|
|
def _coo_fromdense_transpose(ct, M, *, nse, index_dtype):
|
|
data, row, col = ct
|
|
assert len(data) == nse
|
|
assert row.dtype == col.dtype == index_dtype
|
|
if isinstance(row, ad.Zero) or isinstance(col, ad.Zero):
|
|
raise ValueError("Cannot transpose with respect to sparse indices")
|
|
assert ad.is_undefined_primal(M)
|
|
return coo_todense(data, row, col, shape=M.aval.shape)
|
|
|
|
ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp
|
|
ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
|
|
|
|
xla.translations[coo_fromdense_p] = xla.lower_fun(
|
|
_coo_fromdense_impl, multiple_results=True)
|
|
if cusparse and cusparse.is_supported:
|
|
xla.backend_specific_translations['gpu'][
|
|
coo_fromdense_p] = _coo_fromdense_gpu_translation_rule
|
|
|
|
#--------------------------------------------------------------------
|
|
# coo_matvec
|
|
|
|
coo_matvec_p = core.Primitive('coo_matvec')
|
|
|
|
def coo_matvec(data, row, col, v, *, shape, transpose=False):
|
|
"""Product of COO sparse matrix and a dense vector.
|
|
|
|
Args:
|
|
data : array of shape ``(nse,)``.
|
|
row : array of shape ``(nse,)``
|
|
col : array of shape ``(nse,)`` and dtype ``row.dtype``
|
|
v : array of shape ``(shape[0] if transpose else shape[1],)`` and
|
|
dtype ``data.dtype``
|
|
shape : length-2 tuple representing the matrix shape
|
|
transpose : boolean specifying whether to transpose the sparse matrix
|
|
before computing.
|
|
|
|
Returns:
|
|
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
|
|
the matrix vector product.
|
|
"""
|
|
return coo_matvec_p.bind(data, row, col, v, shape=shape, transpose=transpose)
|
|
|
|
@coo_matvec_p.def_impl
|
|
def _coo_matvec_impl(data, row, col, v, *, shape, transpose):
|
|
v = jnp.asarray(v)
|
|
if transpose:
|
|
row, col = col, row
|
|
out_shape = shape[1] if transpose else shape[0]
|
|
dv = data * v[col]
|
|
return jnp.zeros(out_shape, dv.dtype).at[row].add(dv)
|
|
|
|
@coo_matvec_p.def_abstract_eval
|
|
def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
|
|
assert data.shape == row.shape == col.shape
|
|
assert data.dtype == v.dtype
|
|
assert row.dtype == col.dtype
|
|
assert len(shape) == 2
|
|
assert v.shape == (shape[0],) if transpose else (shape[1],)
|
|
out_shape = shape[1] if transpose else shape[0]
|
|
return core.ShapedArray((out_shape,), data.dtype)
|
|
|
|
def _coo_matvec_gpu_translation_rule(c, data, row, col, v, *, shape, transpose):
|
|
return cusparse.coo_matvec(c, data, row, col, v, shape=shape, transpose=transpose)
|
|
|
|
def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, shape, transpose):
|
|
return coo_matvec(data_dot, row, col, v, shape=shape, transpose=transpose)
|
|
|
|
def _coo_matvec_jvp_vec(v_dot, data, row, col, v, *, shape, transpose):
|
|
return coo_matvec(data, row, col, v_dot, shape=shape, transpose=transpose)
|
|
|
|
def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
|
|
assert not ad.is_undefined_primal(row)
|
|
assert not ad.is_undefined_primal(col)
|
|
|
|
if ad.is_undefined_primal(v):
|
|
return data, row, col, coo_matvec(data, row, col, ct, shape=shape, transpose=not transpose)
|
|
else:
|
|
v = jnp.asarray(v)
|
|
# return _coo_extract(row, col, jnp.outer(ct, v)), row, col, v
|
|
return ct[row] * v[col], row, col, v
|
|
|
|
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
|
|
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
|
|
xla.translations[coo_matvec_p] = xla.lower_fun(
|
|
_coo_matvec_impl, multiple_results=False)
|
|
if cusparse and cusparse.is_supported:
|
|
xla.backend_specific_translations['gpu'][
|
|
coo_matvec_p] = _coo_matvec_gpu_translation_rule
|
|
|
|
#--------------------------------------------------------------------
|
|
# coo_matmat
|
|
|
|
coo_matmat_p = core.Primitive('coo_matmat')
|
|
|
|
def coo_matmat(data, row, col, B, *, shape, transpose=False):
|
|
"""Product of COO sparse matrix and a dense matrix.
|
|
|
|
Args:
|
|
data : array of shape ``(nse,)``.
|
|
row : array of shape ``(nse,)``
|
|
col : array of shape ``(nse,)`` and dtype ``row.dtype``
|
|
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
|
|
dtype ``data.dtype``
|
|
shape : length-2 tuple representing the matrix shape
|
|
transpose : boolean specifying whether to transpose the sparse matrix
|
|
before computing.
|
|
|
|
Returns:
|
|
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
|
|
representing the matrix vector product.
|
|
"""
|
|
return coo_matmat_p.bind(data, row, col, B, shape=shape, transpose=transpose)
|
|
|
|
@coo_matmat_p.def_impl
|
|
def _coo_matmat_impl(data, row, col, B, *, shape, transpose):
|
|
B = jnp.asarray(B)
|
|
if transpose:
|
|
row, col = col, row
|
|
out_shape = shape[1] if transpose else shape[0]
|
|
dB = data[:, None] * B[col]
|
|
return jnp.zeros((out_shape, B.shape[1]), dB.dtype).at[row].add(dB)
|
|
|
|
@coo_matmat_p.def_abstract_eval
|
|
def _coo_matmat_abstract_eval(data, row, col, B, *, shape, transpose):
|
|
assert data.shape == row.shape == col.shape
|
|
assert data.dtype == B.dtype
|
|
assert len(shape) == 2
|
|
assert B.shape[0] == shape[0] if transpose else shape[1]
|
|
out_shape = shape[1] if transpose else shape[0]
|
|
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
|
|
|
|
def _coo_matmat_gpu_translation_rule(c, data, row, col, B, *, shape, transpose):
|
|
return cusparse.coo_matmat(c, data, row, col, B, shape=shape, transpose=transpose)
|
|
|
|
xla.translations[coo_matmat_p] = xla.lower_fun(
|
|
_coo_matmat_impl, multiple_results=False)
|
|
if cusparse and cusparse.is_supported:
|
|
xla.backend_specific_translations['gpu'][
|
|
coo_matmat_p] = _coo_matmat_gpu_translation_rule
|
|
|
|
def _coo_matmat_jvp_rule(primals_in, tangents_in, **params):
|
|
vals, rows, cols, mat = primals_in
|
|
sparse_mat_dot, rows_dot, cols_dot, mat_dot = tangents_in
|
|
assert type(rows_dot) is ad.Zero
|
|
assert type(cols_dot) is ad.Zero
|
|
|
|
primals_out = coo_matmat(vals, rows, cols, mat, **params)
|
|
_zero = lambda p, t: lax.zeros_like_array(p) if isinstance(t, ad.Zero) else t
|
|
_sparse_mat_dot = _zero(vals, sparse_mat_dot)
|
|
_mat_dot = _zero(mat, mat_dot)
|
|
|
|
tangents_out = coo_matmat(_sparse_mat_dot, rows, cols, mat, **params) + coo_matmat(vals, rows, cols, _mat_dot, **params)
|
|
return primals_out, tangents_out
|
|
ad.primitive_jvps[coo_matmat_p] = _coo_matmat_jvp_rule
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
# BCOO primitives: batched extension of COO.
|
|
|
|
def _bcoo_nse(mat, n_batch=0, n_dense=0):
|
|
mat = jnp.asarray(mat)
|
|
mask = (mat != 0)
|
|
if n_dense > 0:
|
|
mask = mask.any([-(i + 1) for i in range(n_dense)])
|
|
mask = mask.sum(list(range(n_batch, mask.ndim)))
|
|
return mask.max()
|
|
|
|
def _dedupe_bcoo(data, indices):
|
|
f = _dedupe_bcoo_one
|
|
n_batch = indices.ndim - 2
|
|
for s1, s2 in safe_zip(indices.shape[:n_batch], data.shape[:n_batch]):
|
|
if s1 != s2:
|
|
# TODO: handle broadcasted dimensions.
|
|
raise NotImplementedError("dedupe_bcoo for broadcasted dimensions.")
|
|
f = vmap(f)
|
|
return f(data, indices)
|
|
|
|
def _dedupe_bcoo_one(data, indices):
|
|
assert indices.ndim == 2
|
|
assert data.shape[:1] == indices.shape[1:]
|
|
|
|
if indices.shape[0] == 0:
|
|
return data, indices
|
|
|
|
# This is a fixed-size version of jnp.unique() with return_indices=True
|
|
# unique values are zero-filled at the end.
|
|
perm = jnp.lexsort(indices[::-1])
|
|
aux = indices[:, perm]
|
|
mask = jnp.ones(indices.shape[1], dtype=bool)
|
|
mask = mask.at[1:].set(jnp.any(aux[:, 1:] != aux[:, :-1], 0))
|
|
imask = jnp.cumsum(mask) - 1
|
|
indices_unique = jnp.where(mask, aux, 0)[:, jnp.argsort(~mask)]
|
|
inv_idx = jnp.zeros_like(imask).at[perm].set(imask)
|
|
|
|
# With the above, de-duping is easy.
|
|
data_unique = jnp.zeros_like(data).at[inv_idx].add(data)
|
|
return data_unique, indices_unique
|
|
|
|
|
|
def _validate_bcoo(data, indices, shape):
|
|
assert jnp.issubdtype(indices.dtype, jnp.integer)
|
|
|
|
n_sparse, nse = indices.shape[-2:]
|
|
n_batch = indices.ndim - 2
|
|
n_dense = len(shape) - n_batch - n_sparse
|
|
assert n_dense >= 0
|
|
|
|
def _compatible(shape1, shape2):
|
|
return all(s1 in (1, s2) for s1, s2 in safe_zip(shape1, shape2))
|
|
|
|
if not _compatible(data.shape[:n_batch], shape[:n_batch]):
|
|
raise ValueError("data batch dimensions not compatible for "
|
|
f"data.shape={data.shape}, shape={shape}")
|
|
if data.shape[-(n_dense + 1):] != (nse,) + shape[n_batch + n_sparse:]:
|
|
raise ValueError(f"Invalid data.shape={data.shape} for "
|
|
f"nse={nse}, n_batch={n_batch}, n_dense={n_dense}")
|
|
if not _compatible(indices.shape[:n_batch], shape[:n_batch]):
|
|
raise ValueError("indices batch dimensions not compatible for "
|
|
f"indices.shape={indices.shape}, shape={shape}")
|
|
if indices.shape[n_batch:] != (n_sparse, nse):
|
|
raise ValueError(f"Invalid indices.shape={indices.shape} for "
|
|
f"nse={nse}, n_batch={n_batch}, n_dense={n_dense}")
|
|
|
|
return n_batch, n_sparse, n_dense
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
# bcoo_todense
|
|
|
|
bcoo_todense_p = core.Primitive('bcoo_todense_p')
|
|
|
|
def bcoo_todense(data, indices, *, shape):
|
|
"""Convert batched sparse matrix to a dense matrix.
|
|
|
|
Args:
|
|
data : array of shape ``batch_dims + (nse,) + block_dims``.
|
|
indices : array of shape ``batch_dims + (n_sparse, nse)``
|
|
shape : tuple; the shape of the (batched) matrix. Equal to
|
|
``batch_dims + sparse_dims + block_dims``
|
|
where ``len(sparse_dims) == n_sparse``
|
|
|
|
Returns:
|
|
mat : array with specified shape and dtype matching ``data``
|
|
"""
|
|
return bcoo_todense_p.bind(jnp.asarray(data), jnp.asarray(indices), shape=tuple(shape))
|
|
|
|
@bcoo_todense_p.def_impl
|
|
def _bcoo_todense_impl(data, indices, *, shape):
|
|
n_batch, n_sparse, _ = _validate_bcoo(data, indices, shape)
|
|
batch_slices = tuple(slice(s) for s in shape[:n_batch])
|
|
sparse_ind = tuple(indices[tuple(np.mgrid[batch_slices]) + (i,)] for i in range(n_sparse))
|
|
batch_ind = tuple(np.mgrid[batch_slices + (slice(1),)])[:-1]
|
|
if not sparse_ind:
|
|
data = data.sum(n_batch, keepdims=bool(batch_ind))
|
|
return jnp.zeros(shape, data.dtype).at[batch_ind + sparse_ind].add(data)
|
|
|
|
@bcoo_todense_p.def_abstract_eval
|
|
def _bcoo_todense_abstract_eval(data, indices, *, shape):
|
|
_validate_bcoo(data, indices, shape)
|
|
return core.ShapedArray(shape, data.dtype)
|
|
|
|
def _bcoo_todense_jvp(data_dot, data, indices, *, shape):
|
|
return bcoo_todense(data_dot, indices, shape=shape)
|
|
|
|
def _bcoo_todense_transpose(ct, data, indices, *, shape):
|
|
assert ad.is_undefined_primal(data)
|
|
if ad.is_undefined_primal(indices):
|
|
raise ValueError("Cannot transpose with respect to sparse indices")
|
|
assert ct.shape == shape
|
|
assert ct.dtype == data.aval.dtype
|
|
return bcoo_extract(indices, ct), indices
|
|
|
|
def _bcoo_todense_batching_rule(batched_args, batch_dims, *, shape):
|
|
data, indices = batched_args
|
|
if any(b not in [0, None] for b in batch_dims):
|
|
raise NotImplementedError(f"batch_dims={batch_dims}. Only 0 and None are supported.")
|
|
if batch_dims[0] is None:
|
|
data = data[None, ...]
|
|
if batch_dims[1] is None:
|
|
indices = indices[None, ...]
|
|
return bcoo_todense(data, indices, shape=(max(data.shape[0], indices.shape[0]), *shape)), 0
|
|
|
|
ad.defjvp(bcoo_todense_p, _bcoo_todense_jvp, None)
|
|
ad.primitive_transposes[bcoo_todense_p] = _bcoo_todense_transpose
|
|
batching.primitive_batchers[bcoo_todense_p] = _bcoo_todense_batching_rule
|
|
xla.translations[bcoo_todense_p] = xla.lower_fun(
|
|
_bcoo_todense_impl, multiple_results=False)
|
|
|
|
#--------------------------------------------------------------------
|
|
# bcoo_fromdense
|
|
|
|
bcoo_fromdense_p = core.Primitive('bcoo_fromdense')
|
|
bcoo_fromdense_p.multiple_results = True
|
|
|
|
def bcoo_fromdense(mat, *, nse=None, n_batch=0, n_dense=0, index_dtype=jnp.int32):
|
|
"""Create COO-format sparse matrix from a dense matrix.
|
|
|
|
Args:
|
|
mat : array to be converted to COO, with ``ndim = n_batch + n_sparse + n_dense``.
|
|
nse : number of specified elements in each batch
|
|
n_batch : number of batch dimensions (default: 0)
|
|
n_dense : number of block_dimensions (default: 0)
|
|
index_dtype : dtype of sparse indices (default: int32)
|
|
|
|
Returns:
|
|
data : array of shape ``mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]``
|
|
and dtype ``mat.dtype``
|
|
indices : array of shape ``mat.shape[:n_batch] + (n_sparse, nse)``
|
|
"""
|
|
mat = jnp.asarray(mat)
|
|
if nse is None:
|
|
nse = _bcoo_nse(mat, n_batch, n_dense)
|
|
nse = core.concrete_or_error(operator.index, nse, "nse argument of bcoo_fromdense")
|
|
return bcoo_fromdense_p.bind(mat, nse=nse, n_batch=n_batch, n_dense=n_dense,
|
|
index_dtype=index_dtype)
|
|
|
|
@bcoo_fromdense_p.def_impl
|
|
def _bcoo_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
|
|
mat = jnp.asarray(mat)
|
|
mask = (mat != 0)
|
|
if n_dense > 0:
|
|
mask = mask.any([-(i + 1) for i in range(n_dense)])
|
|
nonzero = lambda a: jnp.nonzero(a, size=nse) if a.ndim else ()
|
|
for _ in range(n_batch):
|
|
nonzero = vmap(nonzero, 0)
|
|
indices = nonzero(mask)
|
|
if not indices:
|
|
indices = jnp.zeros(mask.shape[:n_batch] + (0, nse), index_dtype)
|
|
else:
|
|
indices = jnp.moveaxis(jnp.array(indices, index_dtype), 0, n_batch)
|
|
data = bcoo_extract(indices, mat)
|
|
|
|
true_nonzeros = jnp.arange(nse) < mask.sum(list(range(n_batch, mask.ndim)))[..., None]
|
|
true_nonzeros = true_nonzeros[(n_batch + 1) * (slice(None),) + n_dense * (None,)]
|
|
data = jnp.where(true_nonzeros, data, 0)
|
|
|
|
return data, indices
|
|
|
|
@bcoo_fromdense_p.def_abstract_eval
|
|
def _bcoo_fromdense_abstract_eval(mat, *, nse, n_batch, n_dense, index_dtype):
|
|
n_sparse = mat.ndim - n_batch - n_dense
|
|
data_shape = mat.shape[:n_batch] + (nse,) + mat.shape[n_batch + n_sparse:]
|
|
index_shape = mat.shape[:n_batch] + (n_sparse, nse)
|
|
return core.ShapedArray(data_shape, mat.dtype), core.ShapedArray(index_shape, index_dtype)
|
|
|
|
def _bcoo_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype):
|
|
M, = primals
|
|
Mdot, = tangents
|
|
|
|
primals_out = bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense, index_dtype=index_dtype)
|
|
data, indices = primals_out
|
|
|
|
if type(Mdot) is ad.Zero:
|
|
data_dot = ad.Zero.from_value(data)
|
|
else:
|
|
data_dot = bcoo_extract(indices, Mdot)
|
|
|
|
tangents_out = (data_dot, ad.Zero.from_value(indices))
|
|
|
|
return primals_out, tangents_out
|
|
|
|
def _bcoo_fromdense_transpose(ct, M, *, nse, n_batch, n_dense, index_dtype):
|
|
data, indices = ct
|
|
n_sparse = M.ndim = n_batch - n_dense
|
|
assert data.shape == M.shape[:n_batch] + (nse,) + M.shape[n_batch + n_sparse:]
|
|
assert indices.shape == M.shape[:n_batch] + (n_sparse, nse)
|
|
assert indices.dtype == index_dtype
|
|
if isinstance(indices, ad.Zero):
|
|
raise ValueError("Cannot transpose with respect to sparse indices")
|
|
assert ad.is_undefined_primal(M)
|
|
return bcoo_todense(data, indices, shape=M.aval.shape)
|
|
|
|
def _bcoo_fromdense_batching_rule(batched_args, batch_dims, *, nse, n_batch, n_dense, index_dtype):
|
|
M, = batched_args
|
|
if batch_dims != (0,):
|
|
raise NotImplementedError(f"batch_dims={batch_dims}")
|
|
return bcoo_fromdense(M, nse=nse, n_batch=n_batch + 1, n_dense=n_dense, index_dtype=index_dtype), (0, 0)
|
|
|
|
ad.primitive_jvps[bcoo_fromdense_p] = _bcoo_fromdense_jvp
|
|
ad.primitive_transposes[bcoo_fromdense_p] = _bcoo_fromdense_transpose
|
|
batching.primitive_batchers[bcoo_fromdense_p] = _bcoo_fromdense_batching_rule
|
|
xla.translations[bcoo_fromdense_p] = xla.lower_fun(
|
|
_bcoo_fromdense_impl, multiple_results=True)
|
|
|
|
#----------------------------------------------------------------------
|
|
# bcoo_extract
|
|
|
|
bcoo_extract_p = core.Primitive('bcoo_extract')
|
|
|
|
def bcoo_extract(indices, mat):
|
|
"""Extract BCOO values from dense matrix `mat` at given BCOO indices."""
|
|
return bcoo_extract_p.bind(indices, mat)
|
|
|
|
@bcoo_extract_p.def_impl
|
|
def _bcoo_extract_impl(indices, mat):
|
|
n_sparse, _ = indices.shape[-2:]
|
|
n_batch = indices.ndim - 2
|
|
batch_slices = tuple(slice(s) for s in mat.shape[:n_batch])
|
|
sparse_ind = tuple(indices[tuple(np.mgrid[batch_slices]) + (i,)] for i in range(n_sparse))
|
|
batch_ind = tuple(np.mgrid[batch_slices + (slice(1),)])[:-1]
|
|
if not sparse_ind + batch_ind:
|
|
return mat[None]
|
|
return mat[batch_ind + sparse_ind]
|
|
|
|
@bcoo_extract_p.def_abstract_eval
|
|
def _bcoo_extract_abstract_eval(indices, mat):
|
|
n_sparse, nse = indices.shape[-2:]
|
|
n_batch = indices.ndim - 2
|
|
n_dense = mat.ndim - n_sparse - n_batch
|
|
assert mat.shape[:n_batch] == indices.shape[:n_batch]
|
|
out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]
|
|
return core.ShapedArray(out_shape, mat.dtype)
|
|
|
|
def _bcoo_extract_jvp(mat_dot, indices, mat):
|
|
assert mat_dot.shape == mat.shape
|
|
return bcoo_extract(indices, mat_dot)
|
|
|
|
def _bcoo_extract_transpose(ct, indices, mat):
|
|
assert ad.is_undefined_primal(mat)
|
|
if ad.is_undefined_primal(indices):
|
|
raise ValueError("Cannot transpose with respect to sparse indices")
|
|
assert ct.dtype == mat.aval.dtype
|
|
return indices, bcoo_todense(ct, indices, shape=mat.aval.shape)
|
|
|
|
def _bcoo_extract_batching_rule(batched_args, batch_dims):
|
|
indices, mat = batched_args
|
|
assert any(b is not None for b in batch_dims)
|
|
if batch_dims[0] is None:
|
|
bdim = batch_dims[1]
|
|
indices = lax.expand_dims(indices, (bdim,))
|
|
elif batch_dims[1] is None:
|
|
bdim = batch_dims[0]
|
|
mat = lax.expand_dims(mat, (bdim,))
|
|
else:
|
|
assert batch_dims[0] == batch_dims[1]
|
|
bdim = batch_dims[0]
|
|
n_batch = indices.ndim - 2
|
|
if bdim >= n_batch:
|
|
raise ValueError(f"batch_dims={batch_dims} out of range for indices with n_batch={n_batch}")
|
|
return bcoo_extract(indices, mat), bdim
|
|
|
|
ad.defjvp(bcoo_extract_p, None, _bcoo_extract_jvp)
|
|
ad.primitive_transposes[bcoo_extract_p] = _bcoo_extract_transpose
|
|
batching.primitive_batchers[bcoo_extract_p] = _bcoo_extract_batching_rule
|
|
xla.translations[bcoo_extract_p] = xla.lower_fun(
|
|
_bcoo_extract_impl, multiple_results=False)
|
|
|
|
#----------------------------------------------------------------------
|
|
# bcoo_transpose
|
|
# transpose of a BCOO array
|
|
|
|
bcoo_transpose_p = core.Primitive('bcoo_transpose')
|
|
bcoo_transpose_p.multiple_results = True
|
|
|
|
def bcoo_transpose(data, indices, *, permutation, shape):
|
|
if tuple(permutation) == tuple(range(len(shape))):
|
|
return data, indices
|
|
else:
|
|
return bcoo_transpose_p.bind(data, indices, permutation=permutation, shape=shape)
|
|
|
|
def _validate_permutation(data, indices, permutation, shape):
|
|
if not isinstance(permutation, (tuple, list, np.ndarray)):
|
|
raise TypeError(f"transpose permutation must be a tuple/list/ndarray, got {type(permutation)}.")
|
|
if tuple(sorted(permutation)) != tuple(range(len(shape))):
|
|
raise TypeError("transpose permutation isn't a permutation of operand dimensions, "
|
|
f"got permutation {permutation} for shape {shape}.")
|
|
n_batch, n_sparse, n_dense = _validate_bcoo(data, indices, shape)
|
|
batch_perm = permutation[:n_batch]
|
|
sparse_perm = [p - n_batch for p in permutation[n_batch: n_batch + n_sparse]]
|
|
dense_perm = [p - n_sparse - n_batch for p in permutation[n_batch + n_sparse:]]
|
|
if n_batch and tuple(sorted(batch_perm)) != tuple(range(n_batch)):
|
|
raise NotImplementedError("transpose permutation cannot permute batch axes with non-batch axes; "
|
|
f"got permutation {permutation}, with n_batch={n_batch}.")
|
|
if n_dense and tuple(sorted(dense_perm)) != tuple(range(n_dense)):
|
|
raise NotImplementedError("transpose permutation cannot permute dense axes with non-dense axes; "
|
|
f"got permutation {permutation}, with n_dense={n_dense}.")
|
|
return batch_perm, sparse_perm, dense_perm
|
|
|
|
@bcoo_transpose_p.def_impl
|
|
def _bcoo_transpose_impl(data, indices, *, permutation: Sequence[int], shape: Tuple[int]):
|
|
batch_perm, sparse_perm, dense_perm = _validate_permutation(data, indices, permutation, shape)
|
|
n_batch = len(batch_perm)
|
|
indices = indices[..., sparse_perm, :].transpose(*batch_perm, n_batch, n_batch + 1)
|
|
data = data.transpose(*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm))
|
|
return data, indices
|
|
|
|
@bcoo_transpose_p.def_abstract_eval
|
|
def _bcoo_transpose_abstract_eval(data, indices, *, permutation: Sequence[int], shape: Tuple[int]):
|
|
batch_perm, _, dense_perm = _validate_permutation(data, indices, permutation, shape)
|
|
n_batch = len(batch_perm)
|
|
indices_shape = np.array(indices.shape)[[*batch_perm, n_batch, n_batch + 1]]
|
|
data_shape = np.array(data.shape)[[*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm)]]
|
|
return core.ShapedArray(data_shape, data.dtype), core.ShapedArray(indices_shape, indices.dtype)
|
|
|
|
def _bcoo_transpose_jvp(primals, tangents, *, permutation, shape):
|
|
data, indices = primals
|
|
data_dot, _ = tangents
|
|
primals_out = bcoo_transpose(data, indices, permutation=permutation, shape=shape)
|
|
data_dot_out, _ = bcoo_transpose(data_dot, indices, permutation=permutation, shape=shape)
|
|
return primals_out, (data_dot_out, ad.Zero.from_value(indices))
|
|
|
|
def _bcoo_transpose_transpose(ct, data, indices, *, permutation, shape):
|
|
data_ct, indices_ct = ct
|
|
assert isinstance(indices_ct, ad.Zero)
|
|
if ad.is_undefined_primal(indices):
|
|
raise ValueError("Cannot transpose with respect to sparse indices")
|
|
assert data_ct.dtype == data.aval.dtype
|
|
ct_shape = tuple(shape[p] for p in permutation)
|
|
rev_permutation = np.argsort(permutation)
|
|
# TODO(jakevdp) avoid dummy indices?
|
|
dummy_indices = jnp.zeros([1 for i in range(indices.ndim - 2)] + list(indices.shape[-2:]), dtype=int)
|
|
data_trans, _ = bcoo_transpose(data_ct, dummy_indices, permutation=rev_permutation, shape=ct_shape)
|
|
return data_trans, indices_ct
|
|
|
|
def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation, shape):
|
|
data, indices = batched_args
|
|
batch_dims = list(batch_dims)
|
|
batch_size = max(0 if dim is None else arg.shape[dim]
|
|
for arg, dim in zip(batched_args, batch_dims))
|
|
if batch_dims[0] is None:
|
|
data = data[None]
|
|
else:
|
|
assert batch_dims[0] == 0
|
|
if batch_dims[1] is None:
|
|
indices = indices[None]
|
|
else:
|
|
assert batch_dims[1] == 0
|
|
batched_shape = (batch_size, *shape)
|
|
batched_permutation = (0, *(p + 1 for p in permutation))
|
|
data, indices = bcoo_transpose(data, indices, permutation=batched_permutation, shape=batched_shape)
|
|
if batch_dims[0] is None:
|
|
data = data[0]
|
|
if batch_dims[1] is None:
|
|
indices = indices[0]
|
|
return (data, indices), batch_dims
|
|
|
|
ad.primitive_jvps[bcoo_transpose_p] = _bcoo_transpose_jvp
|
|
ad.primitive_transposes[bcoo_transpose_p] = _bcoo_transpose_transpose
|
|
batching.primitive_batchers[bcoo_transpose_p] = _bcoo_transpose_batch_rule
|
|
xla.translations[bcoo_transpose_p] = xla.lower_fun(
|
|
_bcoo_transpose_impl, multiple_results=True)
|
|
|
|
#----------------------------------------------------------------------
|
|
# bcoo_dot_general
|
|
# (batched) general dot product of a BCOO sparse ND array and a dense ND array,
|
|
# returning a dense ND array.
|
|
|
|
bcoo_dot_general_p = core.Primitive('bcoo_dot_general')
|
|
|
|
def bcoo_dot_general(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
|
|
return bcoo_dot_general_p.bind(jnp.asarray(lhs_data), jnp.asarray(lhs_indices), jnp.asarray(rhs),
|
|
dimension_numbers=dimension_numbers, lhs_shape=tuple(lhs_shape))
|
|
|
|
def bcoo_rdot_general(lhs, rhs_data, rhs_indices, *, dimension_numbers, rhs_shape):
|
|
# TODO(jakevdp): perhaps this should be part of the bcoo_dot_general primitive?
|
|
result = bcoo_dot_general(rhs_data, rhs_indices, lhs, lhs_shape=rhs_shape,
|
|
dimension_numbers=[d[::-1] for d in dimension_numbers])
|
|
n_contract, n_batch = (len(d[0]) for d in dimension_numbers)
|
|
n_swap = len(rhs_shape) - n_contract
|
|
permutation = tuple([*range(n_batch), *range(n_swap, result.ndim), *range(n_batch, n_swap)])
|
|
return lax.transpose(result, permutation)
|
|
|
|
@bcoo_dot_general_p.def_impl
|
|
def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
|
|
lhs_data = jnp.asarray(lhs_data)
|
|
lhs_indices = jnp.asarray(lhs_indices)
|
|
rhs = jnp.asarray(rhs)
|
|
# Validate all inputs via abstract_eval
|
|
out_aval = _bcoo_dot_general_abstract_eval(lhs_data.aval, lhs_indices.aval, rhs.aval,
|
|
dimension_numbers=dimension_numbers,
|
|
lhs_shape=lhs_shape)
|
|
|
|
(lhs_contracting, rhs_contracting) , (lhs_batch, rhs_batch) = dimension_numbers
|
|
n_sparse = lhs_indices.shape[-2]
|
|
n_batch = lhs_indices.ndim - 2
|
|
|
|
# Move lhs batch dimensions to the front
|
|
if lhs_batch:
|
|
perm = list(lhs_batch) + remaining(range(n_batch), lhs_batch)
|
|
lhs_data = lhs_data.transpose(perm + list(range(n_batch, lhs_data.ndim)))
|
|
lhs_indices = lhs_indices.transpose(perm + list(range(n_batch, lhs_indices.ndim)))
|
|
|
|
# Move lhs contracting dimensions to the front of sparse dims, in order
|
|
n_contracting = len(lhs_contracting)
|
|
lhs_contracting = [d - n_batch for d in lhs_contracting]
|
|
perm = list(lhs_contracting) + remaining(range(n_sparse), lhs_contracting)
|
|
lhs_indices = lhs_indices[..., jnp.array(perm), :]
|
|
|
|
# Move rhs batch dimensions then contracting dimensions to the front, in order
|
|
perm = (list(rhs_batch) + list(rhs_contracting) +
|
|
remaining(range(rhs.ndim), rhs_batch, rhs_contracting))
|
|
rhs = rhs.transpose(perm)
|
|
|
|
out_array = jnp.zeros(out_aval.shape, out_aval.dtype)
|
|
def result(out_array, lhs_data, lhs_indices, rhs):
|
|
idx = tuple(lhs_indices)
|
|
idx_right, idx_out = idx[:n_contracting], idx[n_contracting:]
|
|
ctc = [0] if n_contracting else []
|
|
prod = lax.dot_general(lhs_data, rhs[idx_right], (([], []), (ctc, ctc)))
|
|
return out_array.at[idx_out].add(prod) if idx_out else prod.sum(0, dtype=out_array.dtype)
|
|
for i in range(n_batch)[::-1]:
|
|
axes_in = [0, 0, 0, 0]
|
|
if lhs_data.shape[i] == 1:
|
|
lhs_data = lax.squeeze(lhs_data, (i,))
|
|
axes_in[1] = None
|
|
if lhs_indices.shape[i] == 1:
|
|
lhs_indices = lax.squeeze(lhs_indices, (i,))
|
|
axes_in[2] = None
|
|
if i >= len(lhs_batch):
|
|
axes_in[3] = None
|
|
result = vmap(result, tuple(axes_in))
|
|
return result(out_array, lhs_data, lhs_indices, rhs)
|
|
|
|
@bcoo_dot_general_p.def_abstract_eval
|
|
def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
|
n_batch, n_sparse, _ = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
|
|
|
|
# Check for proper dimension_numbers
|
|
for dims in [lhs_contracting, rhs_contracting, lhs_batch, rhs_batch]:
|
|
assert len(dims) == len(set(dims))
|
|
assert not set(lhs_contracting).intersection(lhs_batch)
|
|
assert not set(rhs_contracting).intersection(rhs_batch)
|
|
assert [lhs_shape[d] for d in lhs_contracting] == [rhs.shape[d] for d in rhs_contracting]
|
|
assert [lhs_shape[d] for d in lhs_batch] == [rhs.shape[d] for d in rhs_batch]
|
|
|
|
if lhs_batch and max(lhs_batch) >= n_batch:
|
|
raise NotImplementedError(
|
|
"bcoo_dot_general batch dimensions must be among the batch dimensions in the sparse representtaion.\n"
|
|
f"got lhs_batch={lhs_batch}, n_batch={n_batch}")
|
|
|
|
# TODO: support constraction of batch dimensions?
|
|
if any(d < n_batch for d in lhs_contracting):
|
|
raise NotImplementedError("bcoo_dot_general: contracting over batch dimensions.")
|
|
|
|
# TODO: support contraction of dense dimensions?
|
|
if any(d >= n_batch + n_sparse for d in lhs_contracting):
|
|
raise NotImplementedError("bcoo_dot_general: contracting over dense dimensions.")
|
|
|
|
out_dtype = jnp.promote_types(lhs_data.dtype, rhs.dtype)
|
|
out_shape = (tuple(lhs_shape[i] for i in lhs_batch) +
|
|
tuple(s for i, s in enumerate(lhs_shape) if i not in {*lhs_contracting, *lhs_batch}) +
|
|
tuple(s for i, s in enumerate(rhs.shape) if i not in {*rhs_contracting, *rhs_batch}))
|
|
return core.ShapedArray(out_shape, out_dtype)
|
|
|
|
def _bcoo_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
|
|
return bcoo_dot_general(lhs_data_dot, lhs_indices, rhs, dimension_numbers=dimension_numbers, lhs_shape=lhs_shape)
|
|
|
|
def _bcoo_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
|
|
return bcoo_dot_general(lhs_data, lhs_indices, rhs_dot, dimension_numbers=dimension_numbers, lhs_shape=lhs_shape)
|
|
|
|
def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
|
|
assert not ad.is_undefined_primal(lhs_indices)
|
|
if type(ct) is ad.Zero:
|
|
return ad.Zero
|
|
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
|
lhs_ndim = len(lhs_shape)
|
|
rhs_ndim = rhs.aval.ndim if ad.is_undefined_primal(rhs) else rhs.ndim
|
|
lhs_kept = remaining(range(lhs_ndim), lhs_contract, lhs_batch)
|
|
rhs_kept = remaining(range(rhs_ndim), rhs_contract, rhs_batch)
|
|
ans_batch, ans_lhs, ans_rhs = map(list, ranges_like(lhs_batch, lhs_kept, rhs_kept))
|
|
if ad.is_undefined_primal(lhs_data):
|
|
dims = ((ans_rhs, rhs_kept), (ans_batch, rhs_batch))
|
|
lhs_contract_sorted_by_rhs = list(np.take(lhs_contract, np.argsort(rhs_contract)))
|
|
permutation = list(lhs_batch) + lhs_kept + lhs_contract_sorted_by_rhs
|
|
out_axes = np.argsort(permutation)
|
|
|
|
# What follows is essentially this, but computed in terms of dot_general_sampled:
|
|
# out_dense_T = lax.dot_general(ct, rhs, dimension_numbers=dims)
|
|
# out_dense = lax.transpose(out_dense_T, out_axes)
|
|
# result = bcoo_extract(lhs_indices, out_dense)
|
|
|
|
# Instead we (1) un-transpose indices, (2) compute SDDMM, (3) re-transpose result
|
|
dummy_data = jnp.ones([1 for i in range(lhs_indices.ndim - 2)] + [lhs_indices.shape[-1]])
|
|
dummy_shape = tuple(lhs_indices.shape[:-2]) + tuple(1 for i in range(lhs_indices.shape[-2]))
|
|
_, lhs_indices_T = bcoo_transpose(dummy_data, lhs_indices, permutation=permutation, shape=dummy_shape)
|
|
result_T = bcoo_dot_general_sampled(ct, rhs, lhs_indices_T, dimension_numbers=dims)
|
|
result, _ = bcoo_transpose(result_T, lhs_indices_T, permutation=out_axes, shape=dummy_shape)
|
|
|
|
return result, lhs_indices, rhs
|
|
else:
|
|
dims = ((lhs_kept, ans_lhs), (lhs_batch, ans_batch))
|
|
rhs_contract_sorted_by_lhs = list(np.take(rhs_contract, np.argsort(lhs_contract)))
|
|
out_axes = np.argsort(list(rhs_batch) + rhs_contract_sorted_by_lhs + rhs_kept)
|
|
result = bcoo_dot_general(lhs_data, lhs_indices, ct, lhs_shape=lhs_shape, dimension_numbers=dims)
|
|
return lhs_data, lhs_indices, lax.transpose(result, out_axes)
|
|
|
|
def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, lhs_shape):
|
|
lhs_data, lhs_indices, rhs = batched_args
|
|
batch_dims = list(batch_dims)
|
|
batch_size = max(0 if dim is None else arg.shape[dim]
|
|
for arg, dim in zip(batched_args, batch_dims))
|
|
if batch_dims[0] is None:
|
|
lhs_data = lhs_data[None]
|
|
batch_dims[0] = 0
|
|
if batch_dims[1] is None:
|
|
lhs_indices = lhs_indices[None]
|
|
batch_dims[1] = 0
|
|
# TODO: handle different batchings between lhs_data and lhs_indices?
|
|
assert batch_dims[0] == batch_dims[1] == 0
|
|
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
|
|
(len(lhs_shape), rhs.ndim), (batch_dims[0], batch_dims[2]), dimension_numbers)
|
|
new_shape = (batch_size, *lhs_shape)
|
|
batched_out = bcoo_dot_general(lhs_data, lhs_indices, rhs, lhs_shape=new_shape,
|
|
dimension_numbers=new_dimension_numbers)
|
|
return batched_out, result_batch_dim
|
|
|
|
ad.defjvp(bcoo_dot_general_p, _bcoo_dot_general_jvp_lhs, None, _bcoo_dot_general_jvp_rhs)
|
|
ad.primitive_transposes[bcoo_dot_general_p] = _bcoo_dot_general_transpose
|
|
batching.primitive_batchers[bcoo_dot_general_p] = _bcoo_dot_general_batch_rule
|
|
xla.translations[bcoo_dot_general_p] = xla.lower_fun(
|
|
_bcoo_dot_general_impl, multiple_results=False)
|
|
|
|
#----------------------------------------------------------------------
|
|
# bcoo_dot_general_sampled
|
|
# (batched) general sampled dot product of two dense ND arrays, with
|
|
# output computed only at a given set of sparse indices.
|
|
|
|
bcoo_dot_general_sampled_p = core.Primitive("bcoo_dot_general_sampled")
|
|
|
|
def bcoo_dot_general_sampled(A, B, indices, *, dimension_numbers):
|
|
return bcoo_dot_general_sampled_p.bind(A, B, indices, dimension_numbers=dimension_numbers)
|
|
|
|
@bcoo_dot_general_sampled_p.def_impl
|
|
def _bcoo_dot_general_sampled_impl(A, B, indices, *, dimension_numbers):
|
|
# TODO(jakevdp): use a more efficient implementation that avoids the full dot product.
|
|
dense_result = lax.dot_general(A, B, dimension_numbers=dimension_numbers)
|
|
return bcoo_extract(indices, dense_result)
|
|
|
|
@bcoo_dot_general_sampled_p.def_abstract_eval
|
|
def _bcoo_dot_general_sampled_abstract_eval(A, B, indices, *, dimension_numbers):
|
|
dense_result, = pe.abstract_eval_fun(lambda *args: [lax.dot_general(*args, dimension_numbers=dimension_numbers)], A, B)
|
|
sparse_result, = pe.abstract_eval_fun(lambda *args: [bcoo_extract(*args)], indices, dense_result)
|
|
return sparse_result
|
|
|
|
def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers):
|
|
A_shape = A.aval.shape if hasattr(A, 'aval') else A.shape
|
|
B_shape = B.aval.shape if hasattr(B, 'aval') else B.shape
|
|
mat_shape = _dot_general_shape_computation(
|
|
A_shape, B_shape, dimension_numbers=dimension_numbers)
|
|
mat = ad.UndefinedPrimal(core.ShapedArray(mat_shape, ct.dtype))
|
|
indices, ct = _bcoo_extract_transpose(ct, indices, mat)
|
|
kwds = {'dimension_numbers': dimension_numbers,
|
|
'precision': None,
|
|
'preferred_element_type': None}
|
|
A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds)
|
|
return A, B, indices
|
|
|
|
def _bcoo_dot_general_sampled_jvp_A(A_dot, A, B, indices, *, dimension_numbers):
|
|
return bcoo_dot_general_sampled(A_dot, B, indices, dimension_numbers=dimension_numbers)
|
|
|
|
def _bcoo_dot_general_sampled_jvp_B(B_dot, A, B, indices, *, dimension_numbers):
|
|
return bcoo_dot_general_sampled(A, B_dot, indices, dimension_numbers=dimension_numbers)
|
|
|
|
def _bcoo_dot_general_sampled_batch_rule(batched_args, batch_dims, *, dimension_numbers):
|
|
def impl(A, B, indices):
|
|
return _bcoo_dot_general_sampled_impl(A, B, indices, dimension_numbers=dimension_numbers)
|
|
return vmap(impl, in_axes=batch_dims, out_axes=0)(*batched_args), 0
|
|
|
|
ad.defjvp(bcoo_dot_general_sampled_p, _bcoo_dot_general_sampled_jvp_A,
|
|
_bcoo_dot_general_sampled_jvp_B, None)
|
|
ad.primitive_transposes[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_transpose
|
|
batching.primitive_batchers[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_batch_rule
|
|
xla.translations[bcoo_dot_general_sampled_p] = xla.lower_fun(
|
|
_bcoo_dot_general_sampled_impl, multiple_results=False)
|
|
|
|
#----------------------------------------------------------------------
|
|
# BCOO functions that maybe should be primitives?
|
|
|
|
def _tuple_replace(tup, ind, val):
|
|
return tuple(val if i == ind else t for i, t in enumerate(tup))
|
|
|
|
def bcoo_reduce_sum(data, indices, *, shape, axes):
|
|
assert all(0 <= a < len(shape) for a in axes)
|
|
axes = sorted(set(axes))
|
|
n_sparse, nse = indices.shape[-2:]
|
|
n_batch = indices.ndim - 2
|
|
|
|
# Sum over dense dimensions -> sum over data
|
|
dense_axes = tuple(ax - n_sparse + 1 for ax in axes if ax >= n_batch + n_sparse)
|
|
data = data.sum(dense_axes)
|
|
|
|
# Sum over sparse dimensions -> drop index; sum is implicit
|
|
sparse_idx = [i for i in range(n_sparse) if i + n_batch not in axes]
|
|
if not sparse_idx:
|
|
indices = jnp.zeros(_tuple_replace(indices.shape, n_batch, 0), indices.dtype)
|
|
else:
|
|
indices = indices[..., np.array(sparse_idx), :]
|
|
|
|
# Sum over batch dimensions -> reshape into nse
|
|
batch_axes = {ax for ax in axes if ax < n_batch}
|
|
|
|
# First handle broadcasted batch dimensions
|
|
for ax in batch_axes:
|
|
if data.shape[ax] == 1:
|
|
if indices.shape[ax] == 1:
|
|
data = data * shape[ax]
|
|
else:
|
|
data = lax.broadcast_in_dim(data, _tuple_replace(data.shape, ax, shape[ax]), tuple(range(data.ndim)))
|
|
else:
|
|
if indices.shape[ax] == 1:
|
|
data = data.sum(ax)
|
|
assert data.shape[ax] == indices.shape[ax]
|
|
|
|
new_batch_dims = tuple(sorted(set(range(n_batch)) - batch_axes))
|
|
new_batch_shape = tuple(data.shape[i] for i in new_batch_dims)
|
|
new_nse = int(nse * np.prod([data.shape[i] for i in batch_axes]))
|
|
|
|
data = lax.reshape(data,
|
|
new_batch_shape + (new_nse,) + data.shape[n_batch + 1:],
|
|
new_batch_dims + tuple(batch_axes) + tuple(range(n_batch, data.ndim)))
|
|
indices = lax.reshape(indices,
|
|
new_batch_shape + (indices.shape[n_batch], new_nse),
|
|
new_batch_dims + (n_batch,) + tuple(batch_axes) + tuple(range(n_batch + 1, indices.ndim)))
|
|
|
|
out_shape = tuple(shape[i] for i in range(len(shape)) if i not in axes)
|
|
return data, indices, out_shape
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
# Sparse objects (APIs subject to change)
|
|
class JAXSparse:
|
|
"""Base class for high-level JAX sparse objects."""
|
|
data: jnp.ndarray
|
|
shape: Tuple[int, int]
|
|
nse: property
|
|
dtype: property
|
|
|
|
@property
|
|
def ndim(self):
|
|
return len(self.shape)
|
|
|
|
def __init__(self, args, *, shape):
|
|
self.shape = shape
|
|
|
|
def __repr__(self):
|
|
repr_ = f"{self.__class__.__name__}({self.dtype}{list(self.shape)}, nse={self.nse})"
|
|
if isinstance(self.data, core.Tracer):
|
|
repr_ = f"{type(self.data).__name__}[{repr_}]"
|
|
return repr_
|
|
|
|
def tree_flatten(self):
|
|
raise NotImplementedError("tree_flatten")
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, aux_data, children):
|
|
return cls(children, **aux_data)
|
|
|
|
def matvec(self, v):
|
|
raise NotImplementedError("matvec")
|
|
|
|
def matmat(self, B):
|
|
raise NotImplementedError("matmat")
|
|
|
|
def transpose(self, axes=None):
|
|
raise NotImplementedError()
|
|
|
|
@property
|
|
def T(self):
|
|
return self.transpose()
|
|
|
|
def __matmul__(self, other):
|
|
if isinstance(other, JAXSparse):
|
|
raise NotImplementedError("matmul between two sparse objects.")
|
|
other = jnp.asarray(other)
|
|
if other.ndim == 1:
|
|
return self.matvec(other)
|
|
elif other.ndim == 2:
|
|
return self.matmat(other)
|
|
else:
|
|
raise NotImplementedError(f"matmul with object of shape {other.shape}")
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
class CSR(JAXSparse):
|
|
"""Experimental CSR matrix implemented in JAX; API subject to change."""
|
|
data: jnp.ndarray
|
|
indices: jnp.ndarray
|
|
indptr: jnp.ndarray
|
|
nse = property(lambda self: self.data.size)
|
|
dtype = property(lambda self: self.data.dtype)
|
|
|
|
def __init__(self, args, *, shape):
|
|
self.data, self.indices, self.indptr = map(jnp.asarray, args)
|
|
super().__init__(args, shape=shape)
|
|
|
|
@classmethod
|
|
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
|
|
if nse is None:
|
|
nse = (mat != 0).sum()
|
|
return cls(csr_fromdense(mat, nse=nse, index_dtype=index_dtype), shape=mat.shape)
|
|
|
|
@api.jit
|
|
def todense(self):
|
|
return csr_todense(self.data, self.indices, self.indptr, shape=self.shape)
|
|
|
|
@api.jit
|
|
def matvec(self, v):
|
|
return csr_matvec(self.data, self.indices, self.indptr, v, shape=self.shape)
|
|
|
|
@api.jit
|
|
def matmat(self, B):
|
|
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape)
|
|
|
|
def transpose(self, axes=None):
|
|
assert axes is None
|
|
return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1])
|
|
|
|
def tree_flatten(self):
|
|
return (self.data, self.indices, self.indptr), {"shape": self.shape}
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
class CSC(JAXSparse):
|
|
"""Experimental CSC matrix implemented in JAX; API subject to change."""
|
|
data: jnp.ndarray
|
|
indices: jnp.ndarray
|
|
indptr: jnp.ndarray
|
|
nse = property(lambda self: self.data.size)
|
|
dtype = property(lambda self: self.data.dtype)
|
|
|
|
def __init__(self, args, *, shape):
|
|
self.data, self.indices, self.indptr = map(jnp.asarray, args)
|
|
super().__init__(args, shape=shape)
|
|
|
|
@classmethod
|
|
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
|
|
if nse is None:
|
|
nse = (mat != 0).sum()
|
|
return cls(csr_fromdense(mat.T, nse=nse, index_dtype=index_dtype), shape=mat.shape)
|
|
|
|
@api.jit
|
|
def todense(self):
|
|
return csr_todense(self.data, self.indices, self.indptr, shape=self.shape[::-1]).T
|
|
|
|
@api.jit
|
|
def matvec(self, v):
|
|
return csr_matvec(self.data, self.indices, self.indptr, v, shape=self.shape[::-1], transpose=True)
|
|
|
|
@api.jit
|
|
def matmat(self, B):
|
|
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape[::-1], transpose=True)
|
|
|
|
def transpose(self, axes=None):
|
|
assert axes is None
|
|
return CSR((self.data, self.indices, self.indptr), shape=self.shape[::-1])
|
|
|
|
def tree_flatten(self):
|
|
return (self.data, self.indices, self.indptr), {"shape": self.shape}
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
class COO(JAXSparse):
|
|
"""Experimental COO matrix implemented in JAX; API subject to change."""
|
|
data: jnp.ndarray
|
|
row: jnp.ndarray
|
|
col: jnp.ndarray
|
|
nse = property(lambda self: self.data.size)
|
|
dtype = property(lambda self: self.data.dtype)
|
|
|
|
def __init__(self, args, *, shape):
|
|
self.data, self.row, self.col = map(jnp.asarray, args)
|
|
super().__init__(args, shape=shape)
|
|
|
|
@classmethod
|
|
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
|
|
if nse is None:
|
|
nse = (mat != 0).sum()
|
|
return cls(coo_fromdense(mat, nse=nse, index_dtype=index_dtype), shape=mat.shape)
|
|
|
|
@api.jit
|
|
def todense(self):
|
|
return coo_todense(self.data, self.row, self.col, shape=self.shape)
|
|
|
|
@api.jit
|
|
def matvec(self, v):
|
|
return coo_matvec(self.data, self.row, self.col, v, shape=self.shape)
|
|
|
|
@api.jit
|
|
def matmat(self, B):
|
|
return coo_matmat(self.data, self.row, self.col, B, shape=self.shape)
|
|
|
|
def transpose(self, axes=None):
|
|
assert axes is None
|
|
return COO((self.data, self.col, self.row), shape=self.shape[::-1])
|
|
|
|
def tree_flatten(self):
|
|
return (self.data, self.row, self.col), {"shape": self.shape}
|
|
|
|
|
|
def _is_dummy(*args):
|
|
return all(type(arg) is object for arg in args) or all(arg is None for arg in args)
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
class BCOO(JAXSparse):
|
|
"""Experimental BCOO matrix implemented in JAX; API subject to change."""
|
|
data: jnp.ndarray
|
|
indices: jnp.ndarray
|
|
nse = property(lambda self: self.data.size)
|
|
dtype = property(lambda self: self.data.dtype)
|
|
n_batch = property(lambda self: self.indices.ndim - 2)
|
|
n_sparse = property(lambda self: self.indices.shape[-2])
|
|
n_dense = property(lambda self: self.data.ndim - 1 - self.n_batch)
|
|
shape = Tuple[int, ...]
|
|
|
|
@property
|
|
def _sparse_shape(self):
|
|
return tuple(self.shape[self.indices.ndim - 2:][:self.indices.shape[-2]])
|
|
|
|
def __init__(self, args, *, shape):
|
|
self.data, self.indices = args
|
|
super().__init__(args, shape=shape)
|
|
|
|
@classmethod
|
|
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32, n_dense=0, n_batch=0):
|
|
return cls(bcoo_fromdense(mat, nse=nse, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch), shape=mat.shape)
|
|
|
|
@api.jit
|
|
def todense(self):
|
|
return bcoo_todense(self.data, self.indices, shape=self.shape)
|
|
|
|
def __matmul__(self, other):
|
|
if isinstance(other, JAXSparse):
|
|
raise NotImplementedError("sparse-sparse matmul")
|
|
other = jnp.asarray(other)
|
|
if self.ndim == 0 or other.ndim == 0:
|
|
raise ValueError("matmul inputs cannot be zero-dimensional.")
|
|
if self.ndim > 2 or other.ndim > 2:
|
|
raise NotImplementedError("sparse matmul for dimensions larger than 2")
|
|
dtype = jnp.promote_types(self.dtype, other.dtype)
|
|
return bcoo_dot_general(self.data.astype(dtype), self.indices, other.astype(dtype),
|
|
lhs_shape=self.shape,
|
|
dimension_numbers=(([self.ndim - 1], [0]), ([], [])))
|
|
|
|
def __rmatmul__(self, other):
|
|
if isinstance(other, JAXSparse):
|
|
raise NotImplementedError("sparse-sparse matmul")
|
|
other = jnp.asarray(other)
|
|
if self.ndim == 0 or other.ndim == 0:
|
|
raise ValueError("matmul inputs cannot be zero-dimensional.")
|
|
if self.ndim > 2 or other.ndim > 2:
|
|
raise NotImplementedError("sparse matmul for dimensions larger than 2")
|
|
dtype = jnp.promote_types(self.dtype, other.dtype)
|
|
return bcoo_rdot_general(other.astype(dtype), self.data.astype(dtype), self.indices,
|
|
rhs_shape=self.shape,
|
|
dimension_numbers=(([other.ndim - 1], [0]), ([], [])))
|
|
|
|
def transpose(self, axes=None):
|
|
axes = np.arange(self.ndim)[::-1] if axes is None else axes
|
|
data_T, indices_T = bcoo_transpose(self.data, self.indices, shape=self.shape, permutation=axes)
|
|
shape_T = [self.shape[i] for i in axes]
|
|
return BCOO((data_T, indices_T), shape=shape_T)
|
|
|
|
def tree_flatten(self):
|
|
children = (self.data, self.indices)
|
|
# pytree sometimes creates dummy objects & we need to handle that.
|
|
sparse_shape = self.shape if _is_dummy(*children) else self._sparse_shape
|
|
# We serialize the sparse shape only to support batching.
|
|
return children, {"sparse_shape": sparse_shape}
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, aux_data, children):
|
|
data, indices = children
|
|
sparse_shape = aux_data["sparse_shape"]
|
|
# pytree sometimes creates dummy objects & we need to handle that.
|
|
if _is_dummy(data, indices):
|
|
shape = sparse_shape
|
|
else:
|
|
if np.ndim(indices) < 2 or len(sparse_shape) != np.shape(indices)[-2]:
|
|
raise ValueError(f"Invalid sparse representation: got indices.shape={np.shape(indices)}, "
|
|
f"data.shape={np.shape(data)}, sparse_shape={sparse_shape}")
|
|
n_batch = indices.ndim - 2
|
|
shape = (
|
|
tuple(np.maximum(data.shape[:n_batch], indices.shape[:n_batch]))
|
|
+ tuple(sparse_shape)
|
|
+ tuple(data.shape[n_batch + 1:]))
|
|
return cls(children, shape=shape)
|
|
|
|
# TODO(jakevdp): refactor to avoid circular imports - we can use the same strategy
|
|
# we use when adding methods to DeviceArray within lax_numpy.py
|
|
def __neg__(self):
|
|
from jax.experimental.sparse import sparsify
|
|
return sparsify(jnp.negative)(self)
|
|
|
|
def __mul__(self, other):
|
|
from jax.experimental.sparse import sparsify
|
|
return sparsify(jnp.multiply)(self, other)
|
|
|
|
def __rmul__(self, other):
|
|
from jax.experimental.sparse import sparsify
|
|
return sparsify(jnp.multiply)(other, self)
|
|
|
|
def __add__(self, other):
|
|
from jax.experimental.sparse import sparsify
|
|
return sparsify(jnp.add)(self, other)
|
|
|
|
def __radd__(self, other):
|
|
from jax.experimental.sparse import sparsify
|
|
return sparsify(jnp.add)(other, self)
|
|
|
|
def sum(self, *args, **kwargs):
|
|
from jax.experimental.sparse import sparsify
|
|
return sparsify(lambda x: x.sum(*args, **kwargs))(self)
|