2021-10-18 13:52:42 -07:00

1078 lines
48 KiB
Python

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BCOO (Bached coordinate format) matrix object and associated primitives."""
import functools
import operator
from typing import Any, NamedTuple, Sequence, Tuple
import numpy as np
import jax
from jax import core
from jax import dtypes
from jax import lax
from jax import tree_util
from jax import vmap
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
import jax.numpy as jnp
from jax.interpreters import ad
from jax.util import safe_zip
from jax._src.lax.lax import (
ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
DotDimensionNumbers)
from . import ops
Dtype = Any
Shape = Tuple[int, ...]
#----------------------------------------------------------------------
# BCOO primitives: batched extension of COO.
def _bcoo_nse(mat, n_batch=0, n_dense=0):
mat = jnp.asarray(mat)
mask = (mat != 0)
if n_dense > 0:
mask = mask.any([-(i + 1) for i in range(n_dense)])
mask = mask.sum(list(range(n_batch, mask.ndim)))
return mask.max()
def _dedupe_bcoo(data, indices, shape):
props = _validate_bcoo(data, indices, shape)
if indices.shape[:props.n_batch] != data.shape[:props.n_batch]:
# TODO: handle broadcasted dimensions.
raise NotImplementedError("dedupe_bcoo for broadcasted dimensions.")
f = functools.partial(_dedupe_bcoo_one,
shape=shape[props.n_batch:props.n_batch + props.n_sparse])
for _ in range(props.n_batch):
f = vmap(f)
return f(data, indices)
def _dedupe_bcoo_one(data, indices, *, shape):
nse, = data.shape
assert indices.shape == (nse, len(shape))
if indices.shape[1] == 0:
return data, indices
indices_unique, inv_idx = jnp.unique(indices, axis=0, return_inverse=True,
size=nse, fill_value=jnp.array(shape))
data_unique = jnp.zeros_like(data).at[inv_idx].add(data)
oob_mask = jnp.all(indices_unique == jnp.array(shape), 1)
data_unique = jnp.where(oob_mask, 0, data_unique)
return data_unique, indices_unique
def _unbatch_bcoo(data, indices, shape):
n_batch = _validate_bcoo(data, indices, shape).n_batch
if n_batch == 0:
return data, indices
data = jnp.broadcast_to(data, shape[:n_batch] + data.shape[n_batch:])
indices = jnp.broadcast_to(indices, shape[:n_batch] + indices.shape[n_batch:])
batch_indices = jnp.mgrid[tuple(slice(None, d) for d in indices.shape[:n_batch + 1])][:-1]
batch_indices = batch_indices.reshape(n_batch, -1).T
data = data.reshape(np.prod(data.shape[:n_batch + 1]), *data.shape[n_batch + 1:])
indices = indices.reshape(np.prod(indices.shape[:n_batch + 1]), *indices.shape[n_batch + 1:])
return data, jnp.hstack([batch_indices, indices])
class BCOOProperties(NamedTuple):
n_batch: int
n_sparse: int
n_dense: int
nse: int
def _validate_bcoo(data: jnp.ndarray, indices: jnp.ndarray, shape: Sequence[int]) -> BCOOProperties:
assert jnp.issubdtype(indices.dtype, jnp.integer)
shape = tuple(shape)
nse, n_sparse = indices.shape[-2:]
n_batch = indices.ndim - 2
n_dense = len(shape) - n_batch - n_sparse
assert n_dense >= 0
def _compatible(shape1, shape2):
return all(s1 in (1, s2) for s1, s2 in safe_zip(shape1, shape2))
if data is not None:
if not _compatible(data.shape[:n_batch], shape[:n_batch]):
raise ValueError("data batch dimensions not compatible for "
f"data.shape={data.shape}, shape={shape}")
if data.shape[-(n_dense + 1):] != (nse,) + shape[n_batch + n_sparse:]:
raise ValueError(f"Invalid data.shape={data.shape} for "
f"nse={nse}, n_batch={n_batch}, n_dense={n_dense}")
if not _compatible(indices.shape[:n_batch], shape[:n_batch]):
raise ValueError("indices batch dimensions not compatible for "
f"indices.shape={indices.shape}, shape={shape}")
if indices.shape[n_batch:] != (nse, n_sparse):
raise ValueError(f"Invalid indices.shape={indices.shape} for "
f"nse={nse}, n_batch={n_batch}, n_dense={n_dense}")
return BCOOProperties(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense, nse=nse)
#----------------------------------------------------------------------
# bcoo_todense
bcoo_todense_p = core.Primitive('bcoo_todense_p')
def bcoo_todense(data, indices, *, shape):
"""Convert batched sparse matrix to a dense matrix.
Args:
data : array of shape ``batch_dims + (nse,) + block_dims``.
indices : array of shape ``batch_dims + (n_sparse, nse)``
shape : tuple; the shape of the (batched) matrix. Equal to
``batch_dims + sparse_dims + block_dims``
where ``len(sparse_dims) == n_sparse``
Returns:
mat : array with specified shape and dtype matching ``data``
"""
return bcoo_todense_p.bind(jnp.asarray(data), jnp.asarray(indices), shape=tuple(shape))
@bcoo_todense_p.def_impl
def _bcoo_todense_impl(data, indices, *, shape):
n_batch, n_sparse, _, _ = _validate_bcoo(data, indices, shape)
ind_slices = tuple(np.zeros(s, int) if i_s == 1 else np.arange(s)
for s, i_s in zip(shape[:n_batch], indices.shape[:n_batch]))
grid = tuple(np.meshgrid(*ind_slices, indexing='ij', sparse=True))
sparse_ind = tuple(indices[grid + (slice(None), i)] for i in range(n_sparse))
batch_slices = tuple(np.arange(s) for s in shape[:n_batch])
grid = np.meshgrid(*batch_slices, np.arange(1), indexing='ij', sparse=True)
batch_ind = tuple(grid)[:-1]
if not sparse_ind:
data = data.sum(n_batch, keepdims=bool(batch_ind), dtype=data.dtype)
return jnp.zeros(shape, data.dtype).at[batch_ind + sparse_ind].add(data)
@bcoo_todense_p.def_abstract_eval
def _bcoo_todense_abstract_eval(data, indices, *, shape):
_validate_bcoo(data, indices, shape)
return core.ShapedArray(shape, data.dtype)
def _bcoo_todense_jvp(data_dot, data, indices, *, shape):
return bcoo_todense(data_dot, indices, shape=shape)
def _bcoo_todense_transpose(ct, data, indices, *, shape):
assert ad.is_undefined_primal(data)
if ad.is_undefined_primal(indices):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ct.shape == shape
assert ct.dtype == data.aval.dtype
return bcoo_extract(indices, ct), indices
def _bcoo_todense_batching_rule(batched_args, batch_dims, *, shape):
data, indices = batched_args
if any(b not in [0, None] for b in batch_dims):
raise NotImplementedError(f"batch_dims={batch_dims}. Only 0 and None are supported.")
if batch_dims[0] is None:
data = data[None, ...]
if batch_dims[1] is None:
indices = indices[None, ...]
return bcoo_todense(data, indices, shape=(max(data.shape[0], indices.shape[0]), *shape)), 0
ad.defjvp(bcoo_todense_p, _bcoo_todense_jvp, None)
ad.primitive_transposes[bcoo_todense_p] = _bcoo_todense_transpose
batching.primitive_batchers[bcoo_todense_p] = _bcoo_todense_batching_rule
xla.translations[bcoo_todense_p] = xla.lower_fun(
_bcoo_todense_impl, multiple_results=False)
#--------------------------------------------------------------------
# bcoo_fromdense
bcoo_fromdense_p = core.Primitive('bcoo_fromdense')
bcoo_fromdense_p.multiple_results = True
def bcoo_fromdense(mat, *, nse=None, n_batch=0, n_dense=0, index_dtype=jnp.int32):
"""Create COO-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to COO, with ``ndim = n_batch + n_sparse + n_dense``.
nse : number of specified elements in each batch
n_batch : number of batch dimensions (default: 0)
n_dense : number of block_dimensions (default: 0)
index_dtype : dtype of sparse indices (default: int32)
Returns:
data : array of shape ``mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]``
and dtype ``mat.dtype``
indices : array of shape ``mat.shape[:n_batch] + (n_sparse, nse)``
"""
mat = jnp.asarray(mat)
if nse is None:
nse = _bcoo_nse(mat, n_batch, n_dense)
nse = core.concrete_or_error(operator.index, nse, "nse argument of bcoo_fromdense")
return bcoo_fromdense_p.bind(mat, nse=nse, n_batch=n_batch, n_dense=n_dense,
index_dtype=index_dtype)
@bcoo_fromdense_p.def_impl
def _bcoo_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
mat = jnp.asarray(mat)
n_sparse = mat.ndim - n_dense - n_batch
mask = (mat != 0)
if n_dense > 0:
mask = mask.any([-(i + 1) for i in range(n_dense)])
def _nonzero(a):
if a.ndim:
return jnp.nonzero(a, size=nse, fill_value=a.shape[:n_sparse])
return ()
for _ in range(n_batch):
_nonzero = vmap(_nonzero, 0)
indices = _nonzero(mask)
if not indices:
indices = jnp.zeros(mask.shape[:n_batch] + (nse, 0), index_dtype)
else:
indices = jnp.moveaxis(jnp.array(indices, index_dtype), 0, n_batch + 1)
data = bcoo_extract(indices, mat)
true_nonzeros = jnp.arange(nse) < mask.sum(list(range(n_batch, mask.ndim)))[..., None]
true_nonzeros = true_nonzeros[(n_batch + 1) * (slice(None),) + n_dense * (None,)]
data = jnp.where(true_nonzeros, data, 0)
return data, indices
@bcoo_fromdense_p.def_abstract_eval
def _bcoo_fromdense_abstract_eval(mat, *, nse, n_batch, n_dense, index_dtype):
n_sparse = mat.ndim - n_batch - n_dense
data_shape = mat.shape[:n_batch] + (nse,) + mat.shape[n_batch + n_sparse:]
index_shape = mat.shape[:n_batch] + (nse, n_sparse)
return core.ShapedArray(data_shape, mat.dtype), core.ShapedArray(index_shape, index_dtype)
def _bcoo_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype):
M, = primals
Mdot, = tangents
primals_out = bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense, index_dtype=index_dtype)
data, indices = primals_out
if type(Mdot) is ad.Zero:
data_dot = ad.Zero.from_value(data)
else:
data_dot = bcoo_extract(indices, Mdot)
tangents_out = (data_dot, ad.Zero.from_value(indices))
return primals_out, tangents_out
def _bcoo_fromdense_transpose(ct, M, *, nse, n_batch, n_dense, index_dtype):
data, indices = ct
n_sparse = M.ndim = n_batch - n_dense
assert data.shape == M.shape[:n_batch] + (nse,) + M.shape[n_batch + n_sparse:]
assert indices.shape == M.shape[:n_batch] + (n_sparse, nse)
assert indices.dtype == index_dtype
if isinstance(indices, ad.Zero):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ad.is_undefined_primal(M)
return bcoo_todense(data, indices, shape=M.aval.shape)
def _bcoo_fromdense_batching_rule(batched_args, batch_dims, *, nse, n_batch, n_dense, index_dtype):
M, = batched_args
if batch_dims != (0,):
raise NotImplementedError(f"batch_dims={batch_dims}")
return bcoo_fromdense(M, nse=nse, n_batch=n_batch + 1, n_dense=n_dense, index_dtype=index_dtype), (0, 0)
ad.primitive_jvps[bcoo_fromdense_p] = _bcoo_fromdense_jvp
ad.primitive_transposes[bcoo_fromdense_p] = _bcoo_fromdense_transpose
batching.primitive_batchers[bcoo_fromdense_p] = _bcoo_fromdense_batching_rule
xla.translations[bcoo_fromdense_p] = xla.lower_fun(
_bcoo_fromdense_impl, multiple_results=True)
#----------------------------------------------------------------------
# bcoo_extract
bcoo_extract_p = core.Primitive('bcoo_extract')
def bcoo_extract(indices, mat):
"""Extract BCOO values from dense matrix `mat` at given BCOO indices."""
return bcoo_extract_p.bind(indices, mat)
@bcoo_extract_p.def_impl
def _bcoo_extract_impl(indices, mat):
mat = jnp.asarray(mat)
n_batch, n_sparse, _, _ = _validate_bcoo(None, indices, mat.shape)
ind_slices = tuple(np.zeros(s, int) if i_s == 1 else np.arange(s)
for s, i_s in zip(mat.shape[:n_batch], indices.shape[:n_batch]))
grid = tuple(np.meshgrid(*ind_slices, indexing='ij', sparse=True))
sparse_ind = tuple(indices[grid + (slice(None), i)] for i in range(n_sparse))
batch_slices = tuple(np.arange(s) for s in mat.shape[:n_batch])
grid = np.meshgrid(*batch_slices, np.arange(1), indexing='ij', sparse=True)
batch_ind = tuple(grid)[:-1]
if not sparse_ind + batch_ind:
return mat[None]
return mat.at[batch_ind + sparse_ind].get(mode='fill', fill_value=0)
@bcoo_extract_p.def_abstract_eval
def _bcoo_extract_abstract_eval(indices, mat):
n_batch, _, n_dense, nse = _validate_bcoo(None, indices, mat.shape)
out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]
return core.ShapedArray(out_shape, mat.dtype)
def _bcoo_extract_jvp(mat_dot, indices, mat):
assert mat_dot.shape == mat.shape
return bcoo_extract(indices, mat_dot)
def _bcoo_extract_transpose(ct, indices, mat):
assert ad.is_undefined_primal(mat)
if ad.is_undefined_primal(indices):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ct.dtype == mat.aval.dtype
return indices, bcoo_todense(ct, indices, shape=mat.aval.shape)
def _bcoo_extract_batching_rule(batched_args, batch_dims):
indices, mat = batched_args
assert any(b is not None for b in batch_dims)
if batch_dims[0] is None:
bdim = batch_dims[1]
indices = lax.expand_dims(indices, (bdim,))
elif batch_dims[1] is None:
bdim = batch_dims[0]
mat = lax.expand_dims(mat, (bdim,))
else:
assert batch_dims[0] == batch_dims[1]
bdim = batch_dims[0]
n_batch = indices.ndim - 2
if bdim >= n_batch:
raise ValueError(f"batch_dims={batch_dims} out of range for indices with n_batch={n_batch}")
return bcoo_extract(indices, mat), bdim
ad.defjvp(bcoo_extract_p, None, _bcoo_extract_jvp)
ad.primitive_transposes[bcoo_extract_p] = _bcoo_extract_transpose
batching.primitive_batchers[bcoo_extract_p] = _bcoo_extract_batching_rule
xla.translations[bcoo_extract_p] = xla.lower_fun(
_bcoo_extract_impl, multiple_results=False)
#----------------------------------------------------------------------
# bcoo_transpose
# transpose of a BCOO array
bcoo_transpose_p = core.Primitive('bcoo_transpose')
bcoo_transpose_p.multiple_results = True
def bcoo_transpose(data, indices, *, permutation, shape):
if tuple(permutation) == tuple(range(len(shape))):
return data, indices
else:
return bcoo_transpose_p.bind(data, indices, permutation=permutation, shape=shape)
def _validate_permutation(data, indices, permutation, shape):
if not isinstance(permutation, (tuple, list, np.ndarray)):
raise TypeError(f"transpose permutation must be a tuple/list/ndarray, got {type(permutation)}.")
if tuple(sorted(permutation)) != tuple(range(len(shape))):
raise TypeError("transpose permutation isn't a permutation of operand dimensions, "
f"got permutation {permutation} for shape {shape}.")
n_batch, n_sparse, n_dense, _ = _validate_bcoo(data, indices, shape)
batch_perm = permutation[:n_batch]
sparse_perm = [p - n_batch for p in permutation[n_batch: n_batch + n_sparse]]
dense_perm = [p - n_sparse - n_batch for p in permutation[n_batch + n_sparse:]]
if n_batch and tuple(sorted(batch_perm)) != tuple(range(n_batch)):
raise NotImplementedError("transpose permutation cannot permute batch axes with non-batch axes; "
f"got permutation {permutation}, with n_batch={n_batch}.")
if n_dense and tuple(sorted(dense_perm)) != tuple(range(n_dense)):
raise NotImplementedError("transpose permutation cannot permute dense axes with non-dense axes; "
f"got permutation {permutation}, with n_dense={n_dense}.")
return batch_perm, sparse_perm, dense_perm
@bcoo_transpose_p.def_impl
def _bcoo_transpose_impl(data, indices, *, permutation: Sequence[int], shape: Tuple[int]):
batch_perm, sparse_perm, dense_perm = _validate_permutation(data, indices, permutation, shape)
n_batch = len(batch_perm)
indices = indices[..., sparse_perm].transpose(*batch_perm, n_batch, n_batch + 1)
data = data.transpose(*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm))
return data, indices
@bcoo_transpose_p.def_abstract_eval
def _bcoo_transpose_abstract_eval(data, indices, *, permutation: Sequence[int], shape: Tuple[int]):
batch_perm, _, dense_perm = _validate_permutation(data, indices, permutation, shape)
n_batch = len(batch_perm)
indices_shape = np.array(indices.shape)[[*batch_perm, n_batch, n_batch + 1]]
data_shape = np.array(data.shape)[[*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm)]]
return core.ShapedArray(data_shape, data.dtype), core.ShapedArray(indices_shape, indices.dtype)
def _bcoo_transpose_jvp(primals, tangents, *, permutation, shape):
data, indices = primals
data_dot, _ = tangents
primals_out = bcoo_transpose(data, indices, permutation=permutation, shape=shape)
data_dot_out, _ = bcoo_transpose(data_dot, indices, permutation=permutation, shape=shape)
return primals_out, (data_dot_out, ad.Zero.from_value(indices))
def _bcoo_transpose_transpose(ct, data, indices, *, permutation, shape):
data_ct, indices_ct = ct
assert isinstance(indices_ct, ad.Zero)
if ad.is_undefined_primal(indices):
raise ValueError("Cannot transpose with respect to sparse indices")
assert data_ct.dtype == data.aval.dtype
ct_shape = tuple(shape[p] for p in permutation)
rev_permutation = np.argsort(permutation)
# TODO(jakevdp) avoid dummy indices?
dummy_indices = jnp.zeros([1 for i in range(indices.ndim - 2)] + list(indices.shape[-2:]), dtype=int)
data_trans, _ = bcoo_transpose(data_ct, dummy_indices, permutation=rev_permutation, shape=ct_shape)
return data_trans, indices_ct
def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation, shape):
data, indices = batched_args
batch_dims = list(batch_dims)
batch_size = max(0 if dim is None else arg.shape[dim]
for arg, dim in zip(batched_args, batch_dims))
if batch_dims[0] is None:
data = data[None]
else:
assert batch_dims[0] == 0
if batch_dims[1] is None:
indices = indices[None]
else:
assert batch_dims[1] == 0
batched_shape = (batch_size, *shape)
batched_permutation = (0, *(p + 1 for p in permutation))
data, indices = bcoo_transpose(data, indices, permutation=batched_permutation, shape=batched_shape)
if batch_dims[0] is None:
data = data[0]
if batch_dims[1] is None:
indices = indices[0]
return (data, indices), batch_dims
ad.primitive_jvps[bcoo_transpose_p] = _bcoo_transpose_jvp
ad.primitive_transposes[bcoo_transpose_p] = _bcoo_transpose_transpose
batching.primitive_batchers[bcoo_transpose_p] = _bcoo_transpose_batch_rule
xla.translations[bcoo_transpose_p] = xla.lower_fun(
_bcoo_transpose_impl, multiple_results=True)
#----------------------------------------------------------------------
# bcoo_dot_general
# (batched) general dot product of a BCOO sparse ND array and a dense ND array,
# returning a dense ND array.
bcoo_dot_general_p = core.Primitive('bcoo_dot_general')
def _dot_general_validated_shape(lhs_shape: Shape, rhs_shape: Shape, dimension_numbers: DotDimensionNumbers) -> Shape:
"""Validate the inputs and return the output shape."""
lhs = core.ShapedArray(lhs_shape, np.float32)
rhs = core.ShapedArray(rhs_shape, np.float32)
return _dot_general_shape_rule(
lhs, rhs, dimension_numbers=dimension_numbers,
precision=None, preferred_element_type=None)
def bcoo_dot_general(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
return bcoo_dot_general_p.bind(jnp.asarray(lhs_data), jnp.asarray(lhs_indices), jnp.asarray(rhs),
dimension_numbers=dimension_numbers, lhs_shape=tuple(lhs_shape))
def bcoo_rdot_general(lhs, rhs_data, rhs_indices, *, dimension_numbers, rhs_shape):
# TODO(jakevdp): perhaps this should be part of the bcoo_dot_general primitive?
result = bcoo_dot_general(rhs_data, rhs_indices, lhs, lhs_shape=rhs_shape,
dimension_numbers=[d[::-1] for d in dimension_numbers])
n_contract, n_batch = (len(d[0]) for d in dimension_numbers)
n_swap = len(rhs_shape) - n_contract
permutation = tuple([*range(n_batch), *range(n_swap, result.ndim), *range(n_batch, n_swap)])
return lax.transpose(result, permutation)
@bcoo_dot_general_p.def_impl
def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
lhs_data = jnp.asarray(lhs_data)
lhs_indices = jnp.asarray(lhs_indices)
rhs = jnp.asarray(rhs)
# Validate all inputs via abstract_eval
out_aval = _bcoo_dot_general_abstract_eval(lhs_data.aval, lhs_indices.aval, rhs.aval,
dimension_numbers=dimension_numbers,
lhs_shape=lhs_shape)
(lhs_contracting, rhs_contracting) , (lhs_batch, rhs_batch) = dimension_numbers
n_sparse = lhs_indices.shape[-1]
n_batch = lhs_indices.ndim - 2
# Move lhs batch dimensions to the front
if lhs_batch:
perm = list(lhs_batch) + remaining(range(n_batch), lhs_batch)
lhs_data = lhs_data.transpose(perm + list(range(n_batch, lhs_data.ndim)))
lhs_indices = lhs_indices.transpose(perm + list(range(n_batch, lhs_indices.ndim)))
# Move lhs contracting dimensions to the front of sparse dims, in order
n_contracting = len(lhs_contracting)
lhs_contracting = [d - n_batch for d in lhs_contracting]
perm = list(lhs_contracting) + remaining(range(n_sparse), lhs_contracting)
lhs_indices = lhs_indices[..., jnp.array(perm)]
# Move rhs batch dimensions then contracting dimensions to the front, in order
perm = (list(rhs_batch) + list(rhs_contracting) +
remaining(range(rhs.ndim), rhs_batch, rhs_contracting))
rhs = rhs.transpose(perm)
out_array = jnp.zeros(out_aval.shape, out_aval.dtype)
def result(out_array, lhs_data, lhs_indices, rhs):
idx = tuple(lhs_indices.T)
idx_right, idx_out = idx[:n_contracting], idx[n_contracting:]
ctc = [0] if n_contracting else []
prod = lax.dot_general(lhs_data, rhs.at[idx_right].get(mode='fill', fill_value=0),
(([], []), (ctc, ctc)))
return out_array.at[idx_out].add(prod) if idx_out else prod.sum(0, dtype=out_array.dtype)
for i in range(n_batch)[::-1]:
axes_in = [0, 0, 0, 0]
if lhs_data.shape[i] == 1:
lhs_data = lax.squeeze(lhs_data, (i,))
axes_in[1] = None
if lhs_indices.shape[i] == 1:
lhs_indices = lax.squeeze(lhs_indices, (i,))
axes_in[2] = None
if i >= len(lhs_batch):
axes_in[3] = None
result = vmap(result, tuple(axes_in))
return result(out_array, lhs_data, lhs_indices, rhs)
@bcoo_dot_general_p.def_abstract_eval
def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
n_batch, n_sparse, _, _ = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
out_shape = _dot_general_validated_shape(lhs_shape, rhs.shape, dimension_numbers)
if lhs_batch and max(lhs_batch) >= n_batch:
raise NotImplementedError(
"bcoo_dot_general batch dimensions must be among the batch dimensions in the sparse representtaion.\n"
f"got lhs_batch={lhs_batch}, n_batch={n_batch}")
# TODO: support constraction of batch dimensions?
if any(d < n_batch for d in lhs_contracting):
raise NotImplementedError("bcoo_dot_general: contracting over batch dimensions.")
# TODO: support contraction of dense dimensions?
if any(d >= n_batch + n_sparse for d in lhs_contracting):
raise NotImplementedError("bcoo_dot_general: contracting over dense dimensions.")
out_dtype = jnp.promote_types(lhs_data.dtype, rhs.dtype)
return core.ShapedArray(out_shape, out_dtype)
def _bcoo_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
return bcoo_dot_general(lhs_data_dot, lhs_indices, rhs, dimension_numbers=dimension_numbers, lhs_shape=lhs_shape)
def _bcoo_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
return bcoo_dot_general(lhs_data, lhs_indices, rhs_dot, dimension_numbers=dimension_numbers, lhs_shape=lhs_shape)
def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
assert not ad.is_undefined_primal(lhs_indices)
if type(ct) is ad.Zero:
return ad.Zero
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_ndim = len(lhs_shape)
rhs_ndim = rhs.aval.ndim if ad.is_undefined_primal(rhs) else rhs.ndim
lhs_kept = remaining(range(lhs_ndim), lhs_contract, lhs_batch)
rhs_kept = remaining(range(rhs_ndim), rhs_contract, rhs_batch)
ans_batch, ans_lhs, ans_rhs = map(list, ranges_like(lhs_batch, lhs_kept, rhs_kept))
if ad.is_undefined_primal(lhs_data):
dims = ((ans_rhs, rhs_kept), (ans_batch, rhs_batch))
lhs_contract_sorted_by_rhs = list(np.take(lhs_contract, np.argsort(rhs_contract)))
permutation = list(lhs_batch) + lhs_kept + lhs_contract_sorted_by_rhs
out_axes = np.argsort(permutation)
# What follows is essentially this, but computed in terms of dot_general_sampled:
# out_dense_T = lax.dot_general(ct, rhs, dimension_numbers=dims)
# out_dense = lax.transpose(out_dense_T, out_axes)
# result = bcoo_extract(lhs_indices, out_dense)
# Instead we (1) un-transpose indices, (2) compute SDDMM, (3) re-transpose result
dummy_data = jnp.ones([1 for i in range(lhs_indices.ndim - 2)] + [lhs_indices.shape[-2]])
dummy_shape = tuple(lhs_indices.shape[:-2]) + tuple(1 for i in range(lhs_indices.shape[-1]))
_, lhs_indices_T = bcoo_transpose(dummy_data, lhs_indices, permutation=permutation, shape=dummy_shape)
result_T = bcoo_dot_general_sampled(ct, rhs, lhs_indices_T, dimension_numbers=dims)
result, _ = bcoo_transpose(result_T, lhs_indices_T, permutation=out_axes, shape=dummy_shape)
return result, lhs_indices, rhs
else:
dims = ((lhs_kept, ans_lhs), (lhs_batch, ans_batch))
rhs_contract_sorted_by_lhs = list(np.take(rhs_contract, np.argsort(lhs_contract)))
out_axes = np.argsort(list(rhs_batch) + rhs_contract_sorted_by_lhs + rhs_kept)
result = bcoo_dot_general(lhs_data, lhs_indices, ct, lhs_shape=lhs_shape, dimension_numbers=dims)
return lhs_data, lhs_indices, lax.transpose(result, out_axes)
def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, lhs_shape):
lhs_data, lhs_indices, rhs = batched_args
batch_dims = list(batch_dims)
batch_size = max(0 if dim is None else arg.shape[dim]
for arg, dim in zip(batched_args, batch_dims))
if batch_dims[0] is None:
lhs_data = lhs_data[None]
batch_dims[0] = 0
if batch_dims[1] is None:
lhs_indices = lhs_indices[None]
batch_dims[1] = 0
# TODO: handle different batchings between lhs_data and lhs_indices?
assert batch_dims[0] == batch_dims[1] == 0
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
(len(lhs_shape), rhs.ndim), (batch_dims[0], batch_dims[2]), dimension_numbers)
new_shape = (batch_size, *lhs_shape)
batched_out = bcoo_dot_general(lhs_data, lhs_indices, rhs, lhs_shape=new_shape,
dimension_numbers=new_dimension_numbers)
return batched_out, result_batch_dim
ad.defjvp(bcoo_dot_general_p, _bcoo_dot_general_jvp_lhs, None, _bcoo_dot_general_jvp_rhs)
ad.primitive_transposes[bcoo_dot_general_p] = _bcoo_dot_general_transpose
batching.primitive_batchers[bcoo_dot_general_p] = _bcoo_dot_general_batch_rule
xla.translations[bcoo_dot_general_p] = xla.lower_fun(
_bcoo_dot_general_impl, multiple_results=False)
#----------------------------------------------------------------------
# bcoo_dot_general_sampled
# (batched) general sampled dot product of two dense ND arrays, with
# output computed only at a given set of sparse indices.
bcoo_dot_general_sampled_p = core.Primitive("bcoo_dot_general_sampled")
def bcoo_dot_general_sampled(A, B, indices, *, dimension_numbers):
return bcoo_dot_general_sampled_p.bind(A, B, indices, dimension_numbers=dimension_numbers)
@bcoo_dot_general_sampled_p.def_impl
def _bcoo_dot_general_sampled_impl(A, B, indices, *, dimension_numbers):
# TODO(jakevdp): use a more efficient implementation that avoids the full dot product.
dense_result = lax.dot_general(A, B, dimension_numbers=dimension_numbers)
return bcoo_extract(indices, dense_result)
@bcoo_dot_general_sampled_p.def_abstract_eval
def _bcoo_dot_general_sampled_abstract_eval(A, B, indices, *, dimension_numbers):
dense_result, = pe.abstract_eval_fun(lambda *args: [lax.dot_general(*args, dimension_numbers=dimension_numbers)], A, B)
sparse_result, = pe.abstract_eval_fun(lambda *args: [bcoo_extract(*args)], indices, dense_result)
return sparse_result
def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers):
A_shape = A.aval.shape if hasattr(A, 'aval') else A.shape
B_shape = B.aval.shape if hasattr(B, 'aval') else B.shape
mat_shape = _dot_general_validated_shape(A_shape, B_shape, dimension_numbers)
mat = ad.UndefinedPrimal(core.ShapedArray(mat_shape, ct.dtype))
indices, ct = _bcoo_extract_transpose(ct, indices, mat)
kwds = {'dimension_numbers': dimension_numbers,
'precision': None,
'preferred_element_type': None}
A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds)
return A, B, indices
def _bcoo_dot_general_sampled_jvp_A(A_dot, A, B, indices, *, dimension_numbers):
return bcoo_dot_general_sampled(A_dot, B, indices, dimension_numbers=dimension_numbers)
def _bcoo_dot_general_sampled_jvp_B(B_dot, A, B, indices, *, dimension_numbers):
return bcoo_dot_general_sampled(A, B_dot, indices, dimension_numbers=dimension_numbers)
def _bcoo_dot_general_sampled_batch_rule(batched_args, batch_dims, *, dimension_numbers):
def impl(A, B, indices):
return _bcoo_dot_general_sampled_impl(A, B, indices, dimension_numbers=dimension_numbers)
return vmap(impl, in_axes=batch_dims, out_axes=0)(*batched_args), 0
ad.defjvp(bcoo_dot_general_sampled_p, _bcoo_dot_general_sampled_jvp_A,
_bcoo_dot_general_sampled_jvp_B, None)
ad.primitive_transposes[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_transpose
batching.primitive_batchers[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_batch_rule
xla.translations[bcoo_dot_general_sampled_p] = xla.lower_fun(
_bcoo_dot_general_sampled_impl, multiple_results=False)
#----------------------------------------------------------------------
# bcoo_spdot_general
# (batched) general dot product of two BCOO sparse arrays returning a
# Dense ND array.
bcoo_spdot_general_p = core.Primitive('bcoo_spdot_general')
bcoo_spdot_general_p.multiple_results = True
def bcoo_spdot_general(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shape, rhs_shape, dimension_numbers):
return bcoo_spdot_general_p.bind(lhs_data, lhs_indices, rhs_data, rhs_indices,
lhs_shape=lhs_shape, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers)
def _bcoo_Mv(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shape, rhs_shape, dtype, lhs_contract):
"""Helper function to compute the dot product of a sparse array and a sparse vector."""
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
# Inputs should be unbatched; batching is handled by vmapping at the call site.
assert lhs.n_batch == rhs.n_batch == 0
assert lhs.n_dense == rhs.n_dense == 0
assert lhs.n_sparse >= 1
assert rhs.n_sparse == 1
assert (lhs_shape[lhs_contract],) == rhs_shape
rhs_data, rhs_indices = _dedupe_bcoo(rhs_data, rhs_indices, rhs_shape)
lhs_i = lhs_indices[:, lhs_contract]
rhs_i = rhs_indices[:, 0]
mask = jnp.isin(lhs_i, rhs_i, assume_unique=True)
lhs_i_inv = (lhs_i[None, :] == rhs_i[:, None]).argmax(0)
lhs_i_inv = jnp.where(lhs_i < rhs_shape[0], lhs_i_inv, rhs_shape[0])
rhs_data_at_lhs_indices = jnp.where(mask, rhs_data.at[lhs_i_inv].get(mode='fill', fill_value=0), 0)
out_data = lhs_data.at[jnp.arange(lhs.nse)].mul(rhs_data_at_lhs_indices)
out_indices = jnp.concatenate([lhs_indices[:, :lhs_contract], lhs_indices[:, lhs_contract + 1:]], axis=1)
return out_data, out_indices
@bcoo_spdot_general_p.def_impl
def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shape, rhs_shape, dimension_numbers):
out_shape = _dot_general_validated_shape(lhs_shape, rhs_shape, dimension_numbers)
data_aval, indices_aval = _bcoo_spdot_general_abstract_eval(
lhs_data.aval, lhs_indices.aval, rhs_data.aval, rhs_indices.aval,
lhs_shape=lhs_shape, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers)
_validate_bcoo(data_aval, indices_aval, out_shape)
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
# Move batch dimension to front
(lhs_contracting, _), (lhs_batch, rhs_batch) = dimension_numbers
lhs_perm = tuple(lhs_batch) + tuple(i for i in range(lhs.n_batch) if i not in lhs_batch)
rhs_perm = tuple(rhs_batch) + tuple(i for i in range(rhs.n_batch) if i not in rhs_batch)
lhs_indices = lhs_indices.transpose(lhs_perm + (lhs.n_batch, lhs.n_batch + 1))
rhs_indices = rhs_indices.transpose(rhs_perm + (rhs.n_batch, rhs.n_batch + 1))
lhs_data = lhs_data.transpose(lhs_perm + (lhs.n_batch,))
rhs_data = rhs_data.transpose(rhs_perm + (rhs.n_batch,))
# Implement batched dot product via vmap
func = functools.partial(_bcoo_Mv,
lhs_shape=lhs_shape[lhs.n_batch:], rhs_shape=rhs_shape[rhs.n_batch:],
dtype=data_aval.dtype, lhs_contract=lhs_contracting[0] - lhs.n_batch)
if rhs_data.shape[:rhs.n_batch] != rhs_indices.shape[:rhs.n_batch]:
raise NotImplementedError("unequal batches in rhs")
if lhs_data.shape[:lhs.n_batch] != lhs_indices.shape[:lhs.n_batch]:
raise NotImplementedError("unequal batches in lhs")
for dim in reversed(range(len(rhs_batch), rhs.n_batch)):
func = vmap(func, in_axes=(None, None, 0, 0))
for dim in reversed(range(len(lhs_batch), lhs.n_batch)):
func = vmap(func, in_axes=(0, 0, None, None))
for dim in range(len(lhs_batch)):
if lhs_data.shape[dim] != rhs_data.shape[dim]:
raise NotImplementedError("unequal batches in batched dims")
func = vmap(func, in_axes=0)
return func(lhs_data, lhs_indices, rhs_data, rhs_indices)
@bcoo_spdot_general_p.def_abstract_eval
def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shape, rhs_shape, dimension_numbers):
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
_ = _dot_general_validated_shape(lhs_shape, rhs_shape, dimension_numbers)
if not (lhs.n_dense == rhs.n_dense == 0):
# TODO(jakevdp): handle dense dimensions
raise NotImplementedError("bcoo_spdot_general with dense dimensions.")
if not (rhs.n_sparse == 1):
raise NotImplementedError("bcoo_spdot_general with n_sparse != 1 on the rhs.")
if max(lhs_batch, default=-1) >= lhs.n_batch or max(rhs_batch, default=-1) >= rhs.n_batch:
raise NotImplementedError("bcoo_spdot_general: batch_dims must correspond to batch dimensions of the sparse representation.")
if tuple(rhs_contracting) != (rhs.n_batch,) or lhs_contracting[0] not in range(lhs.n_batch, lhs.n_batch + lhs.n_sparse):
raise NotImplementedError("bcoo_spdot_general only supports contraction of sparse indices.")
if rhs.n_batch > len(rhs_batch) and lhs.n_sparse > len(lhs_contracting):
raise ValueError("Cannot have unused batch dims on rhs with unused sparse dims on lhs.")
data_shape = (
*(lhs_shape[dim] for dim in lhs_batch),
*(lhs_data.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch),
*(rhs_data.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch),
lhs.nse)
indices_shape = (
*(lhs_shape[dim] for dim in lhs_batch),
*(lhs_indices.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch),
*(rhs_indices.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch),
lhs.nse, lhs.n_sparse - len(lhs_contracting))
out_dtype = jnp.promote_types(lhs_data.dtype, rhs_data.dtype)
return core.ShapedArray(data_shape, out_dtype), core.ShapedArray(indices_shape, lhs_indices.dtype)
def _bcoo_spdot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, lhs_shape, rhs_shape):
lhs_data, lhs_indices, rhs_data, rhs_indices = batched_args
batch_dims = list(batch_dims)
batch_size = max(0 if dim is None else arg.shape[dim]
for arg, dim in zip(batched_args, batch_dims))
if batch_dims[0] is None:
lhs_data = lhs_data[None]
batch_dims[0] = 0
if batch_dims[1] is None:
lhs_indices = lhs_indices[None]
batch_dims[1] = 0
assert batch_dims[0] == batch_dims[1] == 0
if batch_dims[2] is None:
rhs_data = rhs_data[None]
batch_dims[2] = 0
if batch_dims[3] is None:
rhs_indices = rhs_indices[None]
batch_dims[3] = 0
if any(dim != 0 for dim in batch_dims):
raise NotImplementedError("batching along non-leading dimension.")
assert all(dim == 0 for dim in batch_dims)
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
(len(lhs_shape), len(rhs_shape)), (batch_dims[0], batch_dims[2]), dimension_numbers)
new_lhs_shape = (batch_size, *lhs_shape)
new_rhs_shape = (batch_size, *rhs_shape)
batched_out = bcoo_spdot_general(lhs_data, lhs_indices, rhs_data, rhs_indices,
dimension_numbers=new_dimension_numbers,
lhs_shape=new_lhs_shape, rhs_shape=new_rhs_shape)
return batched_out, (result_batch_dim, result_batch_dim)
# TODO(JVP): jvp, transpose
batching.primitive_batchers[bcoo_spdot_general_p] = _bcoo_spdot_general_batch_rule
xla.translations[bcoo_spdot_general_p] = xla.lower_fun(
_bcoo_spdot_general_impl, multiple_results=True)
#----------------------------------------------------------------------
# BCOO functions that maybe should be primitives?
def _tuple_replace(tup, ind, val):
return tuple(val if i == ind else t for i, t in enumerate(tup))
def bcoo_reduce_sum(data, indices, *, shape, axes):
assert all(0 <= a < len(shape) for a in axes)
n_batch, n_sparse, _, nse = _validate_bcoo(data, indices, shape)
axes = sorted(set(axes))
# Sum over dense dimensions -> sum over data
dense_axes = tuple(ax - n_sparse + 1 for ax in axes if ax >= n_batch + n_sparse)
data = data.sum(dense_axes)
if n_sparse:
# zero-out data corresponding to invalid indices.
sparse_shape = jnp.array(shape[n_batch: n_batch + n_sparse])
mask = jnp.all(indices < sparse_shape, -1)
if data.ndim > mask.ndim:
mask = lax.expand_dims(mask, tuple(range(mask.ndim, data.ndim)))
data = jnp.where(mask, data, 0)
# Sum over sparse dimensions -> drop index; sum is implicit
sparse_idx = [i for i in range(n_sparse) if i + n_batch not in axes]
if not sparse_idx:
indices = jnp.zeros(_tuple_replace(indices.shape, n_batch + 1, 0), indices.dtype)
else:
indices = indices[..., np.array(sparse_idx)]
# Sum over batch dimensions -> reshape into nse
batch_axes = {ax for ax in axes if ax < n_batch}
# First handle broadcasted batch dimensions
for ax in batch_axes:
if data.shape[ax] == 1:
if indices.shape[ax] == 1:
data = data * shape[ax]
else:
data = lax.broadcast_in_dim(data, _tuple_replace(data.shape, ax, shape[ax]), tuple(range(data.ndim)))
else:
if indices.shape[ax] == 1:
data = data.sum(ax)
assert data.shape[ax] == indices.shape[ax]
new_batch_dims = tuple(sorted(set(range(n_batch)) - batch_axes))
new_batch_shape = tuple(data.shape[i] for i in new_batch_dims)
new_nse = int(nse * np.prod([data.shape[i] for i in batch_axes]))
data = lax.reshape(data,
(*new_batch_shape, new_nse, *data.shape[n_batch + 1:]),
(*new_batch_dims, *batch_axes, *range(n_batch, data.ndim)))
indices = lax.reshape(indices,
(*new_batch_shape, new_nse, *indices.shape[n_batch + 1:]),
(*new_batch_dims, *batch_axes, *range(n_batch, indices.ndim)))
out_shape = tuple(shape[i] for i in range(len(shape)) if i not in axes)
return data, indices, out_shape
def _is_placeholder(*args):
return all(type(arg) is object for arg in args) or all(arg is None for arg in args)
def _asarray_or_float0(arg):
if isinstance(arg, np.ndarray) and arg.dtype == dtypes.float0:
return arg
return jnp.asarray(arg)
@tree_util.register_pytree_node_class
class BCOO(ops.JAXSparse):
"""Experimental batched COO matrix implemented in JAX
Args:
(data, indices) : data and indices in batched COO format.
shape : shape of sparse array.
Attributes:
data : ndarray of shape ``[*batch_dims, nse, *dense_dims]`` containing the
explicitly stored data within the sparse matrix.
indices : ndarray of shape ``[*batch_dims, nse, n_sparse]`` containing the
indices of the explicitly stored data. Duplicate entries will be summed.
Examples:
Create a sparse array from a dense array:
>>> M = jnp.array([[0., 2., 0.], [1., 0., 4.]])
>>> M_sp = BCOO.fromdense(M)
>>> M_sp
BCOO(float32[2, 3], nse=3)
Examine the internal representation:
>>> M_sp.data
DeviceArray([2., 1., 4.], dtype=float32)
>>> M_sp.indices
DeviceArray([[0, 1],
[1, 0],
[1, 2]], dtype=int32)
Create a dense array from a sparse array:
>>> M_sp.todense()
DeviceArray([[0., 2., 0.],
[1., 0., 4.]], dtype=float32)
Create a sparse array from COO data & indices:
>>> data = jnp.array([1., 3., 5.])
>>> indices = jnp.array([[0, 0],
... [1, 1],
... [2, 2]])
>>> mat = BCOO((data, indices), shape=(3, 3))
>>> mat
BCOO(float32[3, 3], nse=3)
>>> mat.todense()
DeviceArray([[1., 0., 0.],
[0., 3., 0.],
[0., 0., 5.]], dtype=float32)
"""
data: jnp.ndarray
indices: jnp.ndarray
shape: Shape
nse = property(lambda self: self.indices.shape[-2])
dtype = property(lambda self: self.data.dtype)
n_batch = property(lambda self: self.indices.ndim - 2)
n_sparse = property(lambda self: self.indices.shape[-1])
n_dense = property(lambda self: self.data.ndim - 1 - self.n_batch)
@property
def _sparse_shape(self):
return tuple(self.shape[self.n_batch:self.n_batch + self.n_sparse])
def __init__(self, args, *, shape):
# JAX transforms will sometimes instantiate pytrees with null values, so we
# must catch that in the initialization of inputs.
self.data, self.indices = args if _is_placeholder(*args) else map(_asarray_or_float0, args)
super().__init__(args, shape=shape)
@classmethod
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32, n_dense=0, n_batch=0):
"""Create a BCOO array from a (dense) :class:`DeviceArray`."""
return cls(bcoo_fromdense(mat, nse=nse, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch), shape=mat.shape)
@classmethod
def from_scipy_sparse(cls, mat, *, index_dtype=None, n_dense=0, n_batch=0):
"""Create a BCOO array from a :mod:`scipy.sparse` array."""
if n_dense != 0 or n_batch != 0:
raise NotImplementedError("BCOO.fromscipy with nonzero n_dense/n_batch")
mat = mat.tocoo()
data = jnp.asarray(mat.data)
indices = jnp.column_stack((mat.row, mat.col)).astype(index_dtype)
return cls((data, indices), shape=mat.shape)
def _unbatch(self):
"""Return an unbatched representation of the BCOO matrix."""
return BCOO(_unbatch_bcoo(self.data, self.indices, self.shape), shape=self.shape)
def _dedupe(self):
"""Return a de-duplicated representation of the BCOO matrix."""
return BCOO(_dedupe_bcoo(self.data, self.indices, self.shape), shape=self.shape)
@jax.jit
def todense(self):
"""Create a dense version of the array."""
return bcoo_todense(self.data, self.indices, shape=self.shape)
def __matmul__(self, other):
if isinstance(other, BCOO):
dtype = jnp.promote_types(self.dtype, other.dtype)
dimension_numbers = (([self.ndim - 1], [0]), ([], []))
data, indices = bcoo_spdot_general(self.data.astype(dtype), self.indices,
other.data.astype(dtype), other.indices,
lhs_shape=self.shape, rhs_shape=other.shape,
dimension_numbers=dimension_numbers)
shape = _dot_general_validated_shape(self.shape, other.shape, dimension_numbers)
return BCOO((data, indices), shape=shape)
elif isinstance(other, ops.JAXSparse):
raise NotImplementedError("sparse-sparse matmul")
other = jnp.asarray(other)
if self.ndim == 0 or other.ndim == 0:
raise ValueError("matmul inputs cannot be zero-dimensional.")
if self.ndim > 2 or other.ndim > 2:
raise NotImplementedError("sparse matmul for dimensions larger than 2")
dtype = jnp.promote_types(self.dtype, other.dtype)
return bcoo_dot_general(self.data.astype(dtype), self.indices, other.astype(dtype),
lhs_shape=self.shape,
dimension_numbers=(([self.ndim - 1], [0]), ([], [])))
def __rmatmul__(self, other):
if isinstance(other, ops.JAXSparse):
raise NotImplementedError("sparse-sparse matmul")
other = jnp.asarray(other)
if self.ndim == 0 or other.ndim == 0:
raise ValueError("matmul inputs cannot be zero-dimensional.")
if self.ndim > 2 or other.ndim > 2:
raise NotImplementedError("sparse matmul for dimensions larger than 2")
dtype = jnp.promote_types(self.dtype, other.dtype)
return bcoo_rdot_general(other.astype(dtype), self.data.astype(dtype), self.indices,
rhs_shape=self.shape,
dimension_numbers=(([other.ndim - 1], [0]), ([], [])))
def transpose(self, axes=None):
"""Create a new array containing the transpose."""
axes = np.arange(self.ndim)[::-1] if axes is None else axes
data_T, indices_T = bcoo_transpose(self.data, self.indices, shape=self.shape, permutation=axes)
shape_T = [self.shape[i] for i in axes]
return BCOO((data_T, indices_T), shape=shape_T)
def tree_flatten(self):
children = (self.data, self.indices)
# pytree sometimes creates placeholder objects & we need to handle that.
sparse_shape = self.shape if _is_placeholder(*children) else self._sparse_shape
# We serialize the sparse shape only to support batching.
return children, {"sparse_shape": sparse_shape}
@classmethod
def tree_unflatten(cls, aux_data, children):
data, indices = children
sparse_shape = aux_data["sparse_shape"]
# pytree sometimes creates placeholder objects & we need to handle that.
if _is_placeholder(data, indices):
shape = sparse_shape
else:
if np.ndim(indices) < 2 or len(sparse_shape) != np.shape(indices)[-1]:
raise ValueError(f"Invalid sparse representation: got indices.shape={np.shape(indices)}, "
f"data.shape={np.shape(data)}, sparse_shape={sparse_shape}")
n_batch = indices.ndim - 2
shape = (
tuple(np.maximum(data.shape[:n_batch], indices.shape[:n_batch]))
+ tuple(sparse_shape)
+ tuple(data.shape[n_batch + 1:]))
return cls(children, shape=shape)
# TODO(jakevdp): refactor to avoid circular imports - we can use the same strategy
# we use when adding methods to DeviceArray within lax_numpy.py
def __neg__(self):
from jax.experimental.sparse import sparsify
return sparsify(jnp.negative)(self)
def __mul__(self, other):
from jax.experimental.sparse import sparsify
return sparsify(jnp.multiply)(self, other)
def __rmul__(self, other):
from jax.experimental.sparse import sparsify
return sparsify(jnp.multiply)(other, self)
def __add__(self, other):
from jax.experimental.sparse import sparsify
return sparsify(jnp.add)(self, other)
def __radd__(self, other):
from jax.experimental.sparse import sparsify
return sparsify(jnp.add)(other, self)
def sum(self, *args, **kwargs):
"""Sum array along axis."""
from jax.experimental.sparse import sparsify
return sparsify(lambda x: x.sum(*args, **kwargs))(self)