Peter Hawkins 0150d15cb2 Increase minimum jaxlib version to 0.3.7.
Drop backwards compatibility with older jaxlib versions.
2022-04-18 08:09:50 -04:00

1889 lines
82 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BCOO (Bached coordinate format) matrix object and associated primitives."""
import functools
import operator
from typing import Any, NamedTuple, Sequence, Tuple
import warnings
import numpy as np
from jax import core
from jax import lax
from jax import tree_util
from jax import vmap
from jax.config import config
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _safe_asarray, CuSparseEfficiencyWarning
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
import jax.numpy as jnp
from jax.interpreters import ad
from jax.util import safe_zip, unzip2, split_list
from jax._src import api_util
from jax._src.api_util import flatten_axes
from jax._src.lax.lax import (
ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
DotDimensionNumbers)
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import xla_client as xc
from jax._src.numpy.setops import _unique
from jax._src.lib import sparse_apis
xops = xc._xla.ops
Dtype = Any
Shape = Tuple[int, ...]
#----------------------------------------------------------------------
# General utilities...
def broadcasting_vmap(fun, in_axes=0, out_axes=0):
@functools.wraps(fun)
def batched_fun(*args):
args_flat, in_tree = tree_util.tree_flatten(args)
in_axes_flat = flatten_axes("vmap in_axes", in_tree, in_axes, kws=False)
size = max(arg.shape[i] for arg, i in safe_zip(args_flat, in_axes_flat) if i is not None)
if size > 1:
if any(i is not None and arg.shape[i] not in (1, size)
for arg, i in safe_zip(args_flat, in_axes_flat)):
raise ValueError("broadcasting_vmap: mismatched input shapes")
args_flat, in_axes_flat = zip(*(
(arg, None) if i is None else (lax.squeeze(arg, (i,)), None) if arg.shape[i] == 1 else (arg, i)
for arg, i in zip(args_flat, in_axes_flat)
))
new_args = tree_util.tree_unflatten(in_tree, args_flat)
new_in_axes = tree_util.tree_unflatten(in_tree, in_axes_flat)
return vmap(fun, in_axes=new_in_axes, out_axes=out_axes)(*new_args)
return batched_fun
#----------------------------------------------------------------------
# BCOO primitives: batched extension of COO.
def _bcoo_nse(mat, n_batch=0, n_dense=0):
mat = jnp.asarray(mat)
mask = (mat != 0)
if n_dense > 0:
mask = mask.any([-(i + 1) for i in range(n_dense)])
mask = mask.sum(list(range(n_batch, mask.ndim)))
return mask.max()
# TODO(jakevdp): add a custom autodiff rule that errors if remove_zeros=True, because
# it produces wrong values. See https://github.com/google/jax/issues/10163
def _bcoo_sum_duplicates(data, indices, shape, nse=None, remove_zeros=True):
if nse is None and isinstance(jnp.array(0), core.Tracer):
raise ValueError("When used with JIT, vmap, or another transform, sum_duplicates() "
"requires passing a non-None value for the nse argument.")
props = _validate_bcoo(data, indices, shape)
f = functools.partial(_bcoo_sum_duplicates_unbatched, shape=shape[props.n_batch:],
nse=nse, remove_zeros=remove_zeros)
for _ in range(props.n_batch):
f = broadcasting_vmap(f)
data_unique, indices_unique, nse_out = f(data, indices)
if nse is None:
nse = jnp.max(nse_out)
data_unique = lax.slice_in_dim(data_unique, 0, nse, axis=props.n_batch)
indices_unique = lax.slice_in_dim(indices_unique, 0, nse, axis=props.n_batch)
return data_unique, indices_unique
def _bcoo_sum_duplicates_unbatched(data, indices, *, shape, nse, remove_zeros):
props = _validate_bcoo(data, indices, shape)
assert props.n_batch == 0
if not props.n_sparse:
nse = 1 if nse is None else nse
data_unique = jnp.zeros_like(data, shape=(nse, *data.shape[1:])).at[0].set(data.sum(0))
indices_unique = jnp.zeros_like(indices, shape=(nse, 0))
return data_unique, indices_unique, nse
fill_value = jnp.expand_dims(jnp.array(shape[:props.n_sparse], dtype=indices.dtype),
range(indices.ndim - 1))
out_of_bounds = (indices >= fill_value).any(-1, keepdims=True)
if remove_zeros:
data_all_zero = (data == 0).all(range(props.n_batch + 1, data.ndim))[:, None]
out_of_bounds = out_of_bounds | data_all_zero
indices = jnp.where(out_of_bounds, fill_value, indices)
if nse is None:
indices_unique, inv_idx, nse = _unique(
indices, axis=0, return_inverse=True, return_true_size=True,
size=props.nse, fill_value=fill_value)
nse = nse - (indices == fill_value).any()
else:
indices_unique, inv_idx = jnp.unique(
indices, axis=0, return_inverse=True,
size=nse, fill_value=fill_value)
data_shape = [indices_unique.shape[0], *data.shape[1:]]
data_unique = jnp.zeros(data_shape, data.dtype).at[inv_idx].add(data)
oob_mask = jnp.all(indices_unique == fill_value, 1)
data_unique = jnp.where(oob_mask[(...,) + props.n_dense * (None,)], 0, data_unique)
return data_unique, indices_unique, nse
def _unbatch_bcoo(data, indices, shape):
n_batch = _validate_bcoo(data, indices, shape).n_batch
if n_batch == 0:
return data, indices
data = jnp.broadcast_to(data, shape[:n_batch] + data.shape[n_batch:])
indices = jnp.broadcast_to(indices, shape[:n_batch] + indices.shape[n_batch:])
batch_indices = jnp.mgrid[tuple(slice(None, d) for d in indices.shape[:n_batch + 1])][:-1]
batch_indices = batch_indices.reshape(n_batch, -1).T
data = data.reshape(np.prod(data.shape[:n_batch + 1]), *data.shape[n_batch + 1:])
indices = indices.reshape(np.prod(indices.shape[:n_batch + 1]), *indices.shape[n_batch + 1:])
return data, jnp.hstack([batch_indices, indices])
class BCOOProperties(NamedTuple):
n_batch: int
n_sparse: int
n_dense: int
nse: int
class BCOOInfo(NamedTuple):
shape: Shape
def _validate_bcoo(data: jnp.ndarray, indices: jnp.ndarray, shape: Sequence[int]) -> BCOOProperties:
props = _validate_bcoo_indices(indices, shape)
n_batch, n_sparse, n_dense, nse = props
shape = tuple(shape)
if any(s1 not in (1, s2) for s1, s2 in safe_zip(data.shape[:n_batch], shape[:n_batch])):
raise ValueError("data batch dimensions not compatible for "
f"data.shape={data.shape}, shape={shape}")
if data.shape[n_batch:] != (nse,) + shape[n_batch + n_sparse:]:
raise ValueError(f"Invalid data.shape={data.shape} for "
f"nse={nse}, n_batch={n_batch}, n_dense={n_dense}")
return props
def _validate_bcoo_indices(indices: jnp.ndarray, shape: Sequence[int]) -> BCOOProperties:
assert jnp.issubdtype(indices.dtype, jnp.integer)
shape = tuple(shape)
nse, n_sparse = indices.shape[-2:]
n_batch = indices.ndim - 2
n_dense = len(shape) - n_batch - n_sparse
assert n_dense >= 0
if any(s1 not in (1, s2) for s1, s2 in safe_zip(indices.shape[:n_batch], shape[:n_batch])):
raise ValueError("indices batch dimensions not compatible for "
f"indices.shape={indices.shape}, shape={shape}")
if indices.shape[n_batch:] != (nse, n_sparse):
raise ValueError(f"Invalid indices.shape={indices.shape} for "
f"nse={nse}, n_batch={n_batch}, n_dense={n_dense}")
return BCOOProperties(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense, nse=nse)
#----------------------------------------------------------------------
# bcoo_todense
bcoo_todense_p = core.Primitive('bcoo_todense')
def bcoo_todense(mat):
"""Convert batched sparse matrix to a dense matrix.
Args:
mat: BCOO matrix.
Returns:
mat_dense: dense version of ``mat``.
"""
return _bcoo_todense(mat.data, mat.indices, spinfo=mat._info)
def _bcoo_todense(data, indices, *, spinfo):
"""Convert batched sparse matrix to a dense matrix.
Args:
data : array of shape ``batch_dims + (nse,) + block_dims``.
indices : array of shape ``batch_dims + (n_sparse, nse)``
spinfo : BCOOInfo. In particular, this includes the shape
of the matrix, which is equal to ``batch_dims + sparse_dims + block_dims``
where ``len(sparse_dims) == n_sparse``
Returns:
mat : array with specified shape and dtype matching ``data``
"""
return bcoo_todense_p.bind(jnp.asarray(data), jnp.asarray(indices), spinfo=spinfo)
@bcoo_todense_p.def_impl
def _bcoo_todense_impl(data, indices, *, spinfo):
shape = spinfo.shape
n_batch, n_sparse, _, _ = _validate_bcoo(data, indices, shape)
ind_slices = tuple(np.zeros(s, int) if i_s == 1 else np.arange(s)
for s, i_s in zip(shape[:n_batch], indices.shape[:n_batch]))
grid = tuple(np.meshgrid(*ind_slices, indexing='ij', sparse=True))
sparse_ind = tuple(indices[grid + (slice(None), i)] for i in range(n_sparse))
batch_slices = tuple(np.arange(s) for s in shape[:n_batch])
grid = np.meshgrid(*batch_slices, np.arange(1), indexing='ij', sparse=True)
batch_ind = tuple(grid)[:-1]
if not sparse_ind:
data = data.sum(n_batch, keepdims=bool(batch_ind), dtype=data.dtype)
return jnp.zeros(shape, data.dtype).at[batch_ind + sparse_ind].add(data)
@bcoo_todense_p.def_abstract_eval
def _bcoo_todense_abstract_eval(data, indices, *, spinfo):
shape = spinfo.shape
_validate_bcoo(data, indices, shape)
return core.ShapedArray(shape, data.dtype)
def _bcoo_todense_jvp(data_dot, data, indices, *, spinfo):
return _bcoo_todense(data_dot, indices, spinfo=spinfo)
def _bcoo_todense_transpose(ct, data, indices, *, spinfo):
shape = spinfo.shape
assert ad.is_undefined_primal(data)
if ad.is_undefined_primal(indices):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ct.shape == shape
assert ct.dtype == data.aval.dtype
return bcoo_extract(indices, ct), indices
def _bcoo_todense_batching_rule(batched_args, batch_dims, *, spinfo):
data, indices = batched_args
if any(b not in [0, None] for b in batch_dims):
raise NotImplementedError(f"batch_dims={batch_dims}. Only 0 and None are supported.")
if batch_dims[0] is None:
data = data[None, ...]
if batch_dims[1] is None:
indices = indices[None, ...]
new_spinfo = BCOOInfo(
shape=(max(data.shape[0], indices.shape[0]), *spinfo.shape))
return _bcoo_todense(data, indices, spinfo=new_spinfo), 0
ad.defjvp(bcoo_todense_p, _bcoo_todense_jvp, None)
ad.primitive_transposes[bcoo_todense_p] = _bcoo_todense_transpose
batching.primitive_batchers[bcoo_todense_p] = _bcoo_todense_batching_rule
xla.register_translation(bcoo_todense_p, xla.lower_fun(
_bcoo_todense_impl, multiple_results=False, new_style=True))
mlir.register_lowering(bcoo_todense_p, mlir.lower_fun(
_bcoo_todense_impl, multiple_results=False))
#--------------------------------------------------------------------
# bcoo_fromdense
bcoo_fromdense_p = core.Primitive('bcoo_fromdense')
bcoo_fromdense_p.multiple_results = True
_TRACED_NSE_ERROR = """
The error arose for the nse argument of bcoo_fromdense. In order for BCOO.fromdense()
to be used in traced/compiled code, you must pass a concrete value to the nse
(number of specified elements) argument.
"""
def bcoo_fromdense(mat, *, nse=None, n_batch=0, n_dense=0, index_dtype=jnp.int32):
"""Create BCOO-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to BCOO.
nse : number of specified elements in each batch
n_batch : number of batch dimensions (default: 0)
n_dense : number of block_dimensions (default: 0)
index_dtype : dtype of sparse indices (default: int32)
Returns:
mat_bcoo: BCOO representation of the matrix.
"""
mat = jnp.asarray(mat)
if nse is None:
nse = _bcoo_nse(mat, n_batch, n_dense)
nse = core.concrete_or_error(operator.index, nse, _TRACED_NSE_ERROR)
return BCOO(_bcoo_fromdense(mat, nse=nse, n_batch=n_batch, n_dense=n_dense,
index_dtype=index_dtype),
shape=mat.shape)
def _bcoo_fromdense(mat, *, nse, n_batch=0, n_dense=0, index_dtype=jnp.int32):
"""Create BCOO-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to BCOO, with ``ndim = n_batch + n_sparse + n_dense``.
nse : number of specified elements in each batch
n_batch : number of batch dimensions (default: 0)
n_dense : number of block_dimensions (default: 0)
index_dtype : dtype of sparse indices (default: int32)
Returns:
data : array of shape ``mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]``
and dtype ``mat.dtype``
indices : array of shape ``mat.shape[:n_batch] + (n_sparse, nse)``
"""
mat = jnp.asarray(mat)
nse = core.concrete_or_error(operator.index, nse, _TRACED_NSE_ERROR)
return bcoo_fromdense_p.bind(mat, nse=nse, n_batch=n_batch, n_dense=n_dense,
index_dtype=index_dtype)
@bcoo_fromdense_p.def_impl
def _bcoo_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
mat = jnp.asarray(mat)
n_sparse = mat.ndim - n_dense - n_batch
mask = (mat != 0)
if n_dense > 0:
mask = mask.any([-(i + 1) for i in range(n_dense)])
def _nonzero(a):
if a.ndim:
return jnp.nonzero(a, size=nse, fill_value=a.shape[:n_sparse])
return ()
for _ in range(n_batch):
_nonzero = vmap(_nonzero, 0)
indices = _nonzero(mask)
if not indices:
indices = jnp.zeros(mask.shape[:n_batch] + (nse, 0), index_dtype)
else:
indices = jnp.moveaxis(jnp.array(indices, index_dtype), 0, n_batch + 1)
data = bcoo_extract(indices, mat)
true_nonzeros = (lax.broadcasted_iota(jnp.int32, (1,) * n_batch + (nse,), n_batch) <
mask.sum(list(range(n_batch, mask.ndim)))[..., None])
true_nonzeros = true_nonzeros[(n_batch + 1) * (slice(None),) + n_dense * (None,)]
data = jnp.where(true_nonzeros, data, 0)
return data, indices
@bcoo_fromdense_p.def_abstract_eval
def _bcoo_fromdense_abstract_eval(mat, *, nse, n_batch, n_dense, index_dtype):
n_sparse = mat.ndim - n_batch - n_dense
data_shape = mat.shape[:n_batch] + (nse,) + mat.shape[n_batch + n_sparse:]
index_shape = mat.shape[:n_batch] + (nse, n_sparse)
return core.ShapedArray(data_shape, mat.dtype), core.ShapedArray(index_shape, index_dtype)
def _bcoo_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype):
M, = primals
Mdot, = tangents
primals_out = _bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense, index_dtype=index_dtype)
data, indices = primals_out
if type(Mdot) is ad.Zero:
data_dot = ad.Zero.from_value(data)
else:
data_dot = bcoo_extract(indices, Mdot)
tangents_out = (data_dot, ad.Zero.from_value(indices))
return primals_out, tangents_out
def _bcoo_fromdense_transpose(ct, M, *, nse, n_batch, n_dense, index_dtype):
data, indices = ct
n_sparse = M.ndim = n_batch - n_dense
assert data.shape == M.shape[:n_batch] + (nse,) + M.shape[n_batch + n_sparse:]
assert indices.shape == M.shape[:n_batch] + (n_sparse, nse)
assert indices.dtype == index_dtype
if isinstance(indices, ad.Zero):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ad.is_undefined_primal(M)
return _bcoo_todense(data, indices, spinfo=BCOOInfo(M.aval.shape))
def _bcoo_fromdense_batching_rule(batched_args, batch_dims, *, nse, n_batch, n_dense, index_dtype):
M, = batched_args
if batch_dims != (0,):
raise NotImplementedError(f"batch_dims={batch_dims}")
return _bcoo_fromdense(M, nse=nse, n_batch=n_batch + 1, n_dense=n_dense, index_dtype=index_dtype), (0, 0)
ad.primitive_jvps[bcoo_fromdense_p] = _bcoo_fromdense_jvp
ad.primitive_transposes[bcoo_fromdense_p] = _bcoo_fromdense_transpose
batching.primitive_batchers[bcoo_fromdense_p] = _bcoo_fromdense_batching_rule
xla.register_translation(bcoo_fromdense_p, xla.lower_fun(
_bcoo_fromdense_impl, multiple_results=True, new_style=True))
mlir.register_lowering(bcoo_fromdense_p, mlir.lower_fun(
_bcoo_fromdense_impl, multiple_results=True))
#----------------------------------------------------------------------
# bcoo_extract
bcoo_extract_p = core.Primitive('bcoo_extract')
def bcoo_extract(indices, mat):
"""Extract BCOO data values from a dense matrix at given BCOO indices.
Args:
indices: An ndarray; see BCOO indices.
mat: A dense matrix.
Returns:
An ndarray; see BCOO data.
"""
return bcoo_extract_p.bind(indices, mat)
@bcoo_extract_p.def_impl
def _bcoo_extract_impl(indices, mat):
mat = jnp.asarray(mat)
n_batch, n_sparse, _, _ = _validate_bcoo_indices(indices, mat.shape)
ind_slices = tuple(np.zeros(s, int) if i_s == 1 else np.arange(s)
for s, i_s in zip(mat.shape[:n_batch], indices.shape[:n_batch]))
grid = tuple(np.meshgrid(*ind_slices, indexing='ij', sparse=True))
sparse_ind = tuple(indices[grid + (slice(None), i)] for i in range(n_sparse))
batch_slices = tuple(np.arange(s) for s in mat.shape[:n_batch])
grid = np.meshgrid(*batch_slices, np.arange(1), indexing='ij', sparse=True)
batch_ind = tuple(grid)[:-1]
if not sparse_ind + batch_ind:
return mat[None]
return mat.at[batch_ind + sparse_ind].get(mode='fill', fill_value=0)
@bcoo_extract_p.def_abstract_eval
def _bcoo_extract_abstract_eval(indices, mat):
n_batch, _, n_dense, nse = _validate_bcoo_indices(indices, mat.shape)
out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]
return core.ShapedArray(out_shape, mat.dtype)
def _bcoo_extract_jvp(mat_dot, indices, mat):
assert mat_dot.shape == mat.shape
return bcoo_extract(indices, mat_dot)
def _bcoo_extract_transpose(ct, indices, mat):
assert ad.is_undefined_primal(mat)
if ad.is_undefined_primal(indices):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ct.dtype == mat.aval.dtype
return indices, _bcoo_todense(ct, indices, spinfo=BCOOInfo(mat.aval.shape))
def _bcoo_extract_batching_rule(batched_args, batch_dims):
indices, mat = batched_args
assert any(b is not None for b in batch_dims)
if batch_dims[0] is None:
bdim = batch_dims[1]
indices = lax.expand_dims(indices, (bdim,))
elif batch_dims[1] is None:
# TODO(jakevdp) can we handle this case without explicit broadcasting?
bdim = batch_dims[0]
result_shape = list(mat.shape)
result_shape.insert(bdim, indices.shape[bdim])
mat = lax.broadcast_in_dim(mat, result_shape, (bdim,))
else:
if batch_dims[0] != batch_dims[1]:
raise NotImplementedError("bcoo_extract with unequal batch dimensions.")
bdim = batch_dims[0]
n_batch = indices.ndim - 2
if bdim >= n_batch:
raise ValueError(f"batch_dims={batch_dims} out of range for indices with n_batch={n_batch}")
return bcoo_extract(indices, mat), bdim
ad.defjvp(bcoo_extract_p, None, _bcoo_extract_jvp)
ad.primitive_transposes[bcoo_extract_p] = _bcoo_extract_transpose
batching.primitive_batchers[bcoo_extract_p] = _bcoo_extract_batching_rule
xla.register_translation(bcoo_extract_p, xla.lower_fun(
_bcoo_extract_impl, multiple_results=False, new_style=True))
mlir.register_lowering(bcoo_extract_p, mlir.lower_fun(
_bcoo_extract_impl, multiple_results=False))
#----------------------------------------------------------------------
# bcoo_transpose
# transpose of a BCOO array
bcoo_transpose_p = core.Primitive('bcoo_transpose')
bcoo_transpose_p.multiple_results = True
def bcoo_transpose(mat, *, permutation: Sequence[int]):
"""Transpose a BCOO-format array.
Args:
mat: A BCOO-format array.
permutation: A tuple or list or ndarray which contains a permutation of
[0,1,..,N-1] where N is the number of axes of ``mat`` in the order of
batch, sparse, and dense dimensions. The ith axis of the returned array
corresponds to the axis numbered permutation[i] of ``mat``. Transpose
permutation currently does not support permuting batch axes with non-batch
axes nor permutating dense axes with non-dense axes.
Returns:
A BCOO-format array.
"""
return BCOO(_bcoo_transpose(mat.data, mat.indices, permutation=permutation, spinfo=mat._info),
shape=mat._info.shape)
def _bcoo_transpose(data, indices, *, permutation: Sequence[int], spinfo: BCOOInfo):
permutation = tuple(permutation)
if permutation == tuple(range(len(spinfo.shape))):
return data, indices
else:
return bcoo_transpose_p.bind(data, indices, permutation=permutation,
spinfo=spinfo)
def _validate_permutation(data, indices, permutation, shape):
if not isinstance(permutation, (tuple, list, np.ndarray)):
raise TypeError(f"transpose permutation must be a tuple/list/ndarray, got {type(permutation)}.")
if tuple(sorted(permutation)) != tuple(range(len(shape))):
raise TypeError("transpose permutation isn't a permutation of operand dimensions, "
f"got permutation {permutation} for shape {shape}.")
n_batch, n_sparse, n_dense, _ = _validate_bcoo(data, indices, shape)
batch_perm = permutation[:n_batch]
sparse_perm = [p - n_batch for p in permutation[n_batch: n_batch + n_sparse]]
dense_perm = [p - n_sparse - n_batch for p in permutation[n_batch + n_sparse:]]
if n_batch and tuple(sorted(batch_perm)) != tuple(range(n_batch)):
raise NotImplementedError("transpose permutation cannot permute batch axes with non-batch axes; "
f"got permutation {permutation}, with n_batch={n_batch}.")
if n_dense and tuple(sorted(dense_perm)) != tuple(range(n_dense)):
raise NotImplementedError("transpose permutation cannot permute dense axes with non-dense axes; "
f"got permutation {permutation}, with n_dense={n_dense}.")
return batch_perm, sparse_perm, dense_perm
@bcoo_transpose_p.def_impl
def _bcoo_transpose_impl(data, indices, *, permutation: Sequence[int], spinfo: BCOOInfo):
batch_perm, sparse_perm, dense_perm = _validate_permutation(data, indices, permutation, spinfo.shape)
n_batch = len(batch_perm)
indices = indices[..., sparse_perm].transpose(*batch_perm, n_batch, n_batch + 1)
data = data.transpose(*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm))
return data, indices
@bcoo_transpose_p.def_abstract_eval
def _bcoo_transpose_abstract_eval(data, indices, *, permutation: Sequence[int], spinfo: BCOOInfo):
batch_perm, _, dense_perm = _validate_permutation(data, indices, permutation, spinfo.shape)
n_batch = len(batch_perm)
indices_shape = np.array(indices.shape)[[*batch_perm, n_batch, n_batch + 1]]
data_shape = np.array(data.shape)[[*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm)]]
return core.ShapedArray(data_shape, data.dtype), core.ShapedArray(indices_shape, indices.dtype)
def _bcoo_transpose_jvp(primals, tangents, *, permutation: Sequence[int], spinfo: BCOOInfo):
data, indices = primals
data_dot, _ = tangents
primals_out = _bcoo_transpose(data, indices, permutation=permutation, spinfo=spinfo)
data_dot_out, _ = _bcoo_transpose(data_dot, indices, permutation=permutation, spinfo=spinfo)
return primals_out, (data_dot_out, ad.Zero.from_value(indices))
def _bcoo_transpose_transpose(ct, data, indices, *, permutation: Sequence[int], spinfo: BCOOInfo):
data_ct, indices_ct = ct
assert isinstance(indices_ct, ad.Zero)
if ad.is_undefined_primal(indices):
raise ValueError("Cannot transpose with respect to sparse indices")
assert data_ct.dtype == data.aval.dtype
ct_spinfo = BCOOInfo(tuple(spinfo.shape[p] for p in permutation))
rev_permutation = list(np.argsort(permutation))
# TODO(jakevdp) avoid dummy indices?
dummy_indices = jnp.zeros([1 for i in range(indices.ndim - 2)] + list(indices.shape[-2:]), dtype=int)
data_trans, _ = _bcoo_transpose(data_ct, dummy_indices, permutation=rev_permutation, spinfo=ct_spinfo)
return data_trans, indices_ct
def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation: Sequence[int], spinfo: BCOOInfo):
data, indices = batched_args
batch_dims = list(batch_dims)
batch_size = max(0 if dim is None else arg.shape[dim]
for arg, dim in zip(batched_args, batch_dims))
if batch_dims[0] is None:
data = data[None]
else:
assert batch_dims[0] == 0
if batch_dims[1] is None:
indices = indices[None]
else:
assert batch_dims[1] == 0
batched_spinfo = BCOOInfo((batch_size, *spinfo.shape))
batched_permutation = (0, *(p + 1 for p in permutation))
data, indices = _bcoo_transpose(data, indices, permutation=batched_permutation, spinfo=batched_spinfo)
if batch_dims[0] is None:
data = data[0]
if batch_dims[1] is None:
indices = indices[0]
return (data, indices), batch_dims
ad.primitive_jvps[bcoo_transpose_p] = _bcoo_transpose_jvp
ad.primitive_transposes[bcoo_transpose_p] = _bcoo_transpose_transpose
batching.primitive_batchers[bcoo_transpose_p] = _bcoo_transpose_batch_rule
xla.register_translation(bcoo_transpose_p, xla.lower_fun(
_bcoo_transpose_impl, multiple_results=True, new_style=True))
mlir.register_lowering(bcoo_transpose_p, mlir.lower_fun(
_bcoo_transpose_impl, multiple_results=True))
#----------------------------------------------------------------------
# bcoo_dot_general
# (batched) general dot product of a BCOO sparse ND array and a dense ND array,
# returning a dense ND array.
bcoo_dot_general_p = core.Primitive('bcoo_dot_general')
def _dot_general_validated_shape(lhs_shape: Shape, rhs_shape: Shape, dimension_numbers: DotDimensionNumbers) -> Shape:
"""Validate the inputs and return the output shape."""
lhs = core.ShapedArray(lhs_shape, np.float32)
rhs = core.ShapedArray(rhs_shape, np.float32)
return _dot_general_shape_rule(
lhs, rhs, dimension_numbers=dimension_numbers,
precision=None, preferred_element_type=None)
def bcoo_dot_general(lhs, rhs, *, dimension_numbers):
"""A general contraction operation.
Args:
lhs: An ndarray or BCOO-format sparse array.
rhs: An ndarray or BCOO-format sparse array..
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`.
Returns:
An ndarray or BCOO-format sparse array containing the result. If both inputs
are sparse, the result will be sparse, of type BCOO. If either input is dense,
the result will be dense, of type ndarray.
"""
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers)
bufs = _bcoo_spdot_general(lhs.data, lhs.indices, rhs.data, rhs.indices,
lhs_spinfo=lhs._info, rhs_spinfo=rhs._info,
dimension_numbers=dimension_numbers)
return BCOO(bufs, shape=shape)
elif isinstance(lhs, BCOO):
return _bcoo_dot_general(*lhs._bufs, rhs, dimension_numbers=dimension_numbers,
lhs_spinfo=lhs._info)
elif isinstance(rhs, BCOO):
return _bcoo_rdot_general(lhs, *rhs._bufs, dimension_numbers=dimension_numbers,
rhs_spinfo=rhs._info)
else:
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
def _bcoo_dot_general(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
api_util._ensure_index_tuple(rhs_contract))
bdims = (api_util._ensure_index_tuple(lhs_batch),
api_util._ensure_index_tuple(rhs_batch))
return bcoo_dot_general_p.bind(jnp.asarray(lhs_data), jnp.asarray(lhs_indices), jnp.asarray(rhs),
dimension_numbers=(cdims, bdims),
lhs_spinfo=lhs_spinfo)
def _bcoo_rdot_general(lhs, rhs_data, rhs_indices, *, dimension_numbers: DotDimensionNumbers, rhs_spinfo: BCOOInfo):
# TODO(jakevdp): perhaps this should be part of the bcoo_dot_general primitive?
result = _bcoo_dot_general(rhs_data, rhs_indices, lhs, lhs_spinfo=rhs_spinfo,
dimension_numbers=[d[::-1] for d in dimension_numbers])
n_contract, n_batch = (len(d[0]) for d in dimension_numbers)
n_swap = len(rhs_spinfo.shape) - n_contract
permutation = tuple([*range(n_batch), *range(n_swap, result.ndim), *range(n_batch, n_swap)])
return lax.transpose(result, permutation)
@bcoo_dot_general_p.def_impl
def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
lhs_data = jnp.asarray(lhs_data)
lhs_indices = jnp.asarray(lhs_indices)
rhs = jnp.asarray(rhs)
# Validate all inputs via abstract_eval
out_aval = _bcoo_dot_general_abstract_eval(lhs_data.aval, lhs_indices.aval, rhs.aval,
dimension_numbers=dimension_numbers,
lhs_spinfo=lhs_spinfo)
n_sparse = lhs_indices.shape[-1]
n_batch = lhs_indices.ndim - 2
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
lhs_contracting_b, rhs_contracting_b = unzip2([
(l, r) for l, r in safe_zip(lhs_contracting, rhs_contracting) if l < n_batch])
lhs_contracting_s, rhs_contracting_s = unzip2([
(l, r) for l, r in safe_zip(lhs_contracting, rhs_contracting) if l >= n_batch])
# Reorder lhs batch dimensions
if lhs_batch or lhs_contracting_b:
batch_perm = [*lhs_batch, *remaining(range(n_batch), lhs_batch, lhs_contracting_b), *lhs_contracting_b]
lhs_data = lhs_data.transpose([*batch_perm, *range(n_batch, lhs_data.ndim)])
lhs_indices = lhs_indices.transpose([*batch_perm, *range(n_batch, lhs_indices.ndim)])
# Reorder lhs sparse dimensions
if lhs_contracting_s:
lhs_contracting_s = [d - n_batch for d in lhs_contracting_s]
sparse_perm = jnp.array([*lhs_contracting_s, *remaining(range(n_sparse), lhs_contracting_s)])
lhs_indices = lhs_indices[..., sparse_perm]
# Reorder rhs dimensions
rhs_perm = [*rhs_batch, *rhs_contracting_b, *rhs_contracting_s,
*remaining(range(rhs.ndim), rhs_batch, rhs_contracting)]
rhs = rhs.transpose(rhs_perm)
def result(out_array, lhs_data, lhs_indices, rhs):
idx = tuple(lhs_indices[..., i] for i in range(n_sparse))
idx_right = idx[:len(lhs_contracting_s)]
idx_out = idx[len(lhs_contracting_s):]
if idx_right and lhs_indices.ndim > 2:
idx_batch = jnp.meshgrid(
*(jnp.arange(n) for n in lhs_indices.shape[:-1]),
indexing='ij')[:lhs_indices.ndim - 2]
idx_right = (*idx_batch, *idx_right)
batch_dims = list(range(len(lhs_contracting_b) + bool(lhs_contracting_s)))
prod = lax.dot_general(lhs_data, rhs.at[idx_right].get(mode='fill', fill_value=0),
(([], []), (batch_dims, batch_dims)))
if idx_out:
return out_array.at[idx_out].add(prod)
else:
return prod.sum(tuple(range(prod.ndim - out_array.ndim)), dtype=out_array.dtype)
for _ in range(n_batch - len(lhs_contracting_b)):
result = broadcasting_vmap(result)
rhs = lax.expand_dims(rhs, range(len(rhs_batch), n_batch - len(lhs_contracting_b)))
out_array = jnp.zeros(out_aval.shape, out_aval.dtype)
return result(out_array, lhs_data, lhs_indices, rhs)
@bcoo_dot_general_p.def_abstract_eval
def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
if lhs_data.dtype != rhs.dtype:
raise ValueError("bcoo_dot_general requires arguments to have matching dtypes; "
f"got lhs.dtype={lhs_data.dtype}, rhs.dtype={rhs.dtype}")
(lhs_contracting, _), (lhs_batch, _) = dimension_numbers
n_batch, n_sparse, _, _ = _validate_bcoo(lhs_data, lhs_indices, lhs_spinfo.shape)
out_shape = _dot_general_validated_shape(lhs_spinfo.shape, rhs.shape, dimension_numbers)
if lhs_batch and max(lhs_batch) >= n_batch:
raise NotImplementedError(
"bcoo_dot_general batch dimensions must be among the batch dimensions in the sparse representtaion.\n"
f"got lhs_batch={lhs_batch}, n_batch={n_batch}")
# TODO: support contraction of dense dimensions?
if any(d >= n_batch + n_sparse for d in lhs_contracting):
raise NotImplementedError("bcoo_dot_general: contracting over dense dimensions.")
return core.ShapedArray(out_shape, lhs_data.dtype)
_bcoo_dot_general_default_translation_rule = xla.lower_fun(
_bcoo_dot_general_impl, multiple_results=False, new_style=True)
def _bcoo_dot_general_cuda_translation_rule(
ctx, avals_in, avals_out, lhs_data, lhs_indices, rhs, *, dimension_numbers,
lhs_spinfo: BCOOInfo):
c = ctx.builder
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_data_aval, lhs_indices_aval, rhs_aval, = avals_in
props = _validate_bcoo_indices(lhs_indices_aval, lhs_spinfo.shape)
rhs_ndim = len(c.get_shape(rhs).dimensions())
# Checks the shapes of lhs and rhs.
assert props.n_dense == 0
assert props.n_batch == 0
assert props.n_sparse in [1, 2]
assert rhs_ndim in [1, 2]
# Checks the operation dimensions.
assert len(lhs_batch) == 0
assert len(rhs_batch) == 0
assert len(lhs_contract) == 1
# Checks the dtype.
assert lhs_data_aval.dtype in [np.float32, np.float64, np.complex64, np.complex128]
assert lhs_data_aval.dtype == rhs_aval.dtype
assert lhs_indices_aval.dtype == np.int32
assert sparse_apis is not None
if rhs_ndim == 1:
bcoo_dot_general_fn = sparse_apis.coo_matvec
elif rhs_ndim == 2:
bcoo_dot_general_fn = sparse_apis.coo_matmat
if rhs_contract[0] == 1:
rhs = xops.Transpose(rhs, permutation=[1, 0])
else:
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_ndim}d.")
lhs_transpose = False
if props.n_sparse == 1:
# Converts lhs to a row vector.
col = xops.Collapse(lhs_indices, dimensions=[0, 1])
row = xops.Broadcast(xops.Constant(c, jnp.array(0, dtype=jnp.int32)),
c.get_shape(col).dimensions())
lhs_shape = (1, lhs_spinfo.shape[0])
dot_product = bcoo_dot_general_fn(
c, lhs_data, row, col, rhs, shape=lhs_shape, transpose=lhs_transpose)
if rhs_ndim == 1:
# Transforms a single-element array to a scalar.
return [xops.Reshape(dot_product, dimensions=[0], new_sizes=[])]
else:
return [xops.Collapse(dot_product, dimensions=[0, 1])]
elif props.n_sparse == 2:
row = xops.Collapse(
xops.Slice(lhs_indices,
start_indices=[0, 0],
limit_indices=[c.get_shape(lhs_indices).dimensions()[0], 1],
strides=[1, 1]),
dimensions=[0, 1])
col = xops.Collapse(
xops.Slice(lhs_indices,
start_indices=[0, 1],
limit_indices=[c.get_shape(lhs_indices).dimensions()[0], 2],
strides=[1, 1]),
dimensions=[0, 1])
if lhs_contract[0] == 0:
lhs_transpose = True
return [bcoo_dot_general_fn(
c, lhs_data, row, col, rhs, shape=lhs_spinfo.shape,
transpose=lhs_transpose)]
else:
raise ValueError(f"lhs has to be 1d or 2d; get {props.n_sparse}d.")
def _bcoo_dot_general_gpu_translation_rule(
ctx, avals_in, avals_out, lhs_data, lhs_indices, rhs, *, dimension_numbers,
lhs_spinfo: BCOOInfo):
if not config.jax_bcoo_cusparse_lowering:
return _bcoo_dot_general_default_translation_rule(
ctx, avals_in, avals_out, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_data_aval, lhs_indices_aval, rhs_aval, = avals_in
n_batch, n_sparse, n_dense, nse = _validate_bcoo(
lhs_data_aval, lhs_indices_aval, lhs_spinfo.shape)
dtype = lhs_data_aval.dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f'bcoo_dot_general cusparse/hipsparse lowering not available for '
f'dtype={dtype}. Falling back to default implementation.',
CuSparseEfficiencyWarning)
return _bcoo_dot_general_default_translation_rule(
ctx, avals_in, avals_out, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
if (n_batch or n_dense or
n_sparse not in [1, 2] or rhs_aval.ndim not in [1, 2] or
lhs_batch or rhs_batch or len(lhs_contract) != 1):
return _bcoo_dot_general_default_translation_rule(
ctx, avals_in, avals_out, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
else:
# Sorts lhs by row indices.
lhs_data, lhs_indices = _bcoo_sort_indices_rule(
ctx, avals_in[:2], avals_in[:2], lhs_data, lhs_indices, spinfo=lhs_spinfo)
return _bcoo_dot_general_cuda_translation_rule(
ctx, avals_in, avals_out, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
_bcoo_dot_general_default_lowering = mlir.lower_fun(
_bcoo_dot_general_impl, multiple_results=False)
def _collapse_mhlo(x, start, end):
x_type = ir.RankedTensorType(x.type)
shape = x_type.shape
shape = (shape[:start]
+ [functools.reduce(operator.mul, shape[start:end + 1])]
+ shape[end + 1:])
return mhlo.ReshapeOp(
ir.RankedTensorType.get(shape, x_type.element_type), x).result
def _bcoo_dot_general_cuda_lowering(
ctx, lhs_data, lhs_indices, rhs, *, dimension_numbers,
lhs_spinfo: BCOOInfo):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_data_aval, lhs_indices_aval, rhs_aval, = ctx.avals_in
props = _validate_bcoo_indices(lhs_indices_aval, lhs_spinfo.shape)
rhs_ndim = len(ir.RankedTensorType(rhs.type).shape)
# Checks the shapes of lhs and rhs.
assert props.n_dense == 0
assert props.n_batch == 0
assert props.n_sparse in [1, 2]
assert rhs_ndim in [1, 2]
# Checks the operation dimensions.
assert len(lhs_batch) == 0
assert len(rhs_batch) == 0
assert len(lhs_contract) == 1
# Checks the dtype.
assert lhs_data_aval.dtype in [np.float32, np.float64, np.complex64,
np.complex128]
assert lhs_data_aval.dtype == rhs_aval.dtype
assert lhs_indices_aval.dtype == np.int32
assert sparse_apis is not None
if rhs_ndim == 1:
bcoo_dot_general_fn = sparse_apis.coo_matvec_mhlo
elif rhs_ndim == 2:
bcoo_dot_general_fn = sparse_apis.coo_matmat_mhlo
if rhs_contract[0] == 1:
rhs = mhlo.TransposeOp(
rhs, permutation=mlir.dense_int_elements([1, 0])).result
else:
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_ndim}d.")
lhs_transpose = False
if props.n_sparse == 1:
# Converts lhs to a row vector.
col = _collapse_mhlo(lhs_indices, start=0, end=1)
row = mlir.full_like_aval(
0, core.ShapedArray(ir.RankedTensorType(col.type).shape,
np.dtype(np.int32)))
lhs_shape = (1, lhs_spinfo.shape[0])
dot_product = bcoo_dot_general_fn(
lhs_data, row, col, rhs, shape=lhs_shape, transpose=lhs_transpose,
data_dtype=lhs_data_aval.dtype, index_dtype=lhs_indices_aval.dtype,
x_dtype=rhs_aval.dtype)
if rhs_ndim == 1:
# Transforms a single-element array to a scalar.
return [mhlo.ReshapeOp(
ir.RankedTensorType(
[], ir.RankedTensorType(dot_product.type).element_type),
dot_product).result]
else:
return [_collapse_mhlo(dot_product, start=0, end=1)]
elif props.n_sparse == 2:
lhs_indices_shape = ir.RankedTensorType(lhs_indices.type).shape
row = _collapse_mhlo(
mhlo.SliceOp(
lhs_indices,
start_indices=mlir.dense_int_elements([0, 0]),
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 1]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)
col = _collapse_mhlo(
mhlo.SliceOp(
lhs_indices,
start_indices=mlir.dense_int_elements([0, 1]),
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 2]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)
if lhs_contract[0] == 0:
lhs_transpose = True
return [bcoo_dot_general_fn(
lhs_data, row, col, rhs, shape=lhs_spinfo.shape,
transpose=lhs_transpose, data_dtype=lhs_data_aval.dtype,
index_dtype=lhs_indices_aval.dtype,
x_dtype=rhs_aval.dtype)]
else:
raise ValueError(f"lhs has to be 1d or 2d; get {props.n_sparse}d.")
def _bcoo_dot_general_gpu_lowering(
ctx, lhs_data, lhs_indices, rhs, *, dimension_numbers,
lhs_spinfo: BCOOInfo):
if not config.jax_bcoo_cusparse_lowering:
return _bcoo_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_data_aval, lhs_indices_aval, rhs_aval, = ctx.avals_in
n_batch, n_sparse, n_dense, nse = _validate_bcoo(
lhs_data_aval, lhs_indices_aval, lhs_spinfo.shape)
dtype = lhs_data_aval.dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f'bcoo_dot_general cusparse/hipsparse lowering not available '
f'for dtype={dtype}. Falling back to default implementation.',
CuSparseEfficiencyWarning)
return _bcoo_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
if (n_batch or n_dense or
n_sparse not in [1, 2] or rhs_aval.ndim not in [1, 2] or
lhs_batch or rhs_batch or len(lhs_contract) != 1):
return _bcoo_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
else:
# Sorts lhs by row indices.
sub_ctx = mlir.LoweringRuleContext(module_context=ctx.module_context,
primitive=None,
avals_in=ctx.avals_in[:2],
avals_out=ctx.avals_in[:2])
(lhs_data,), (lhs_indices,) = _bcoo_sort_indices_mhlo(
sub_ctx, lhs_data, lhs_indices, spinfo=lhs_spinfo)
return _bcoo_dot_general_cuda_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
def _bcoo_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
return _bcoo_dot_general(lhs_data_dot, lhs_indices, rhs, dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
def _bcoo_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
return _bcoo_dot_general(lhs_data, lhs_indices, rhs_dot, dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
assert not ad.is_undefined_primal(lhs_indices)
if type(ct) is ad.Zero:
return ad.Zero
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_ndim = len(lhs_spinfo.shape)
rhs_ndim = rhs.aval.ndim if ad.is_undefined_primal(rhs) else rhs.ndim
lhs_kept = remaining(range(lhs_ndim), lhs_contract, lhs_batch)
rhs_kept = remaining(range(rhs_ndim), rhs_contract, rhs_batch)
ans_batch, ans_lhs, ans_rhs = map(list, ranges_like(lhs_batch, lhs_kept, rhs_kept))
if ad.is_undefined_primal(lhs_data):
dims = ((ans_rhs, rhs_kept), (ans_batch, rhs_batch))
lhs_contract_sorted_by_rhs = list(np.take(lhs_contract, np.argsort(rhs_contract)))
permutation = list(lhs_batch) + lhs_kept + lhs_contract_sorted_by_rhs
out_axes = list(np.argsort(permutation))
# What follows is essentially this, but computed in terms of dot_general_sampled:
# out_dense_T = lax.dot_general(ct, rhs, dimension_numbers=dims)
# out_dense = lax.transpose(out_dense_T, out_axes)
# result = bcoo_extract(lhs_indices, out_dense)
# Instead we (1) un-transpose indices, (2) compute SDDMM, (3) re-transpose result
dummy_data = jnp.ones([1 for i in range(lhs_indices.ndim - 2)] + [lhs_indices.shape[-2]])
dummy_spinfo = BCOOInfo(tuple(lhs_indices.shape[:-2]) + tuple(1 for i in range(lhs_indices.shape[-1])))
_, lhs_indices_T = _bcoo_transpose(dummy_data, lhs_indices, permutation=permutation, spinfo=dummy_spinfo)
result_T = bcoo_dot_general_sampled(ct, rhs, lhs_indices_T, dimension_numbers=dims)
result, _ = _bcoo_transpose(result_T, lhs_indices_T, permutation=out_axes, spinfo=dummy_spinfo)
return result, lhs_indices, rhs
else:
dims = ((lhs_kept, ans_lhs), (lhs_batch, ans_batch))
rhs_contract_sorted_by_lhs = list(np.take(rhs_contract, np.argsort(lhs_contract)))
out_axes = list(np.argsort(list(rhs_batch) + rhs_contract_sorted_by_lhs + rhs_kept))
result = _bcoo_dot_general(lhs_data, lhs_indices, ct, lhs_spinfo=lhs_spinfo, dimension_numbers=dims)
return lhs_data, lhs_indices, lax.transpose(result, out_axes)
def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, lhs_spinfo: BCOOInfo):
lhs_data, lhs_indices, rhs = batched_args
batch_dims = list(batch_dims)
batch_size = max(0 if dim is None else arg.shape[dim]
for arg, dim in zip(batched_args, batch_dims))
if batch_dims[0] is None:
lhs_data = lhs_data[None]
batch_dims[0] = 0
if batch_dims[1] is None:
lhs_indices = lhs_indices[None]
batch_dims[1] = 0
# TODO: handle different batchings between lhs_data and lhs_indices?
assert batch_dims[0] == batch_dims[1] == 0
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
(len(lhs_spinfo.shape), rhs.ndim), (batch_dims[0], batch_dims[2]), dimension_numbers)
new_shape = (batch_size, *lhs_spinfo.shape)
batched_out = _bcoo_dot_general(lhs_data, lhs_indices, rhs, lhs_spinfo=BCOOInfo(new_shape),
dimension_numbers=new_dimension_numbers)
return batched_out, result_batch_dim
ad.defjvp(bcoo_dot_general_p, _bcoo_dot_general_jvp_lhs, None, _bcoo_dot_general_jvp_rhs)
ad.primitive_transposes[bcoo_dot_general_p] = _bcoo_dot_general_transpose
batching.primitive_batchers[bcoo_dot_general_p] = _bcoo_dot_general_batch_rule
xla.register_translation(
bcoo_dot_general_p, _bcoo_dot_general_default_translation_rule)
mlir.register_lowering(
bcoo_dot_general_p, _bcoo_dot_general_default_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(bcoo_dot_general_p,
_bcoo_dot_general_gpu_translation_rule,
platform='gpu')
mlir.register_lowering(bcoo_dot_general_p,
_bcoo_dot_general_gpu_lowering,
platform='gpu')
#----------------------------------------------------------------------
# bcoo_dot_general_sampled
# (batched) general sampled dot product of two dense ND arrays, with
# output computed only at a given set of sparse indices.
bcoo_dot_general_sampled_p = core.Primitive("bcoo_dot_general_sampled")
def bcoo_dot_general_sampled(A, B, indices, *, dimension_numbers):
"""A contraction operation with output computed at given sparse indices.
Args:
lhs: An ndarray.
rhs: An ndarray.
indices: BCOO indices.
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`.
Returns:
BCOO data, an ndarray containing the result.
"""
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
api_util._ensure_index_tuple(rhs_contract))
bdims = (api_util._ensure_index_tuple(lhs_batch),
api_util._ensure_index_tuple(rhs_batch))
return bcoo_dot_general_sampled_p.bind(A, B, indices,
dimension_numbers=(cdims, bdims))
@bcoo_dot_general_sampled_p.def_impl
def _bcoo_dot_general_sampled_impl(A, B, indices, *, dimension_numbers):
# TODO(jakevdp): use a more efficient implementation that avoids the full dot product.
dense_result = lax.dot_general(A, B, dimension_numbers=dimension_numbers)
return bcoo_extract(indices, dense_result)
@bcoo_dot_general_sampled_p.def_abstract_eval
def _bcoo_dot_general_sampled_abstract_eval(A, B, indices, *, dimension_numbers):
dense_result, = pe.abstract_eval_fun(lambda *args: [lax.dot_general(*args, dimension_numbers=dimension_numbers)], A, B)
sparse_result, = pe.abstract_eval_fun(lambda *args: [bcoo_extract(*args)], indices, dense_result)
return sparse_result
def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers):
A_shape = A.aval.shape if hasattr(A, 'aval') else A.shape
B_shape = B.aval.shape if hasattr(B, 'aval') else B.shape
mat_shape = _dot_general_validated_shape(A_shape, B_shape, dimension_numbers)
mat = ad.UndefinedPrimal(core.ShapedArray(mat_shape, ct.dtype))
indices, ct = _bcoo_extract_transpose(ct, indices, mat)
kwds = {'dimension_numbers': dimension_numbers,
'precision': None,
'preferred_element_type': None}
A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds)
return A, B, indices
def _bcoo_dot_general_sampled_jvp_A(A_dot, A, B, indices, *, dimension_numbers):
return bcoo_dot_general_sampled(A_dot, B, indices, dimension_numbers=dimension_numbers)
def _bcoo_dot_general_sampled_jvp_B(B_dot, A, B, indices, *, dimension_numbers):
return bcoo_dot_general_sampled(A, B_dot, indices, dimension_numbers=dimension_numbers)
def _bcoo_dot_general_sampled_batch_rule(batched_args, batch_dims, *, dimension_numbers):
def impl(A, B, indices):
return _bcoo_dot_general_sampled_impl(A, B, indices, dimension_numbers=dimension_numbers)
return vmap(impl, in_axes=batch_dims, out_axes=0)(*batched_args), 0
ad.defjvp(bcoo_dot_general_sampled_p, _bcoo_dot_general_sampled_jvp_A,
_bcoo_dot_general_sampled_jvp_B, None)
ad.primitive_transposes[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_transpose
batching.primitive_batchers[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_batch_rule
xla.register_translation(bcoo_dot_general_sampled_p, xla.lower_fun(
_bcoo_dot_general_sampled_impl, multiple_results=False, new_style=True))
mlir.register_lowering(
bcoo_dot_general_sampled_p,
mlir.lower_fun(_bcoo_dot_general_sampled_impl, multiple_results=False))
#----------------------------------------------------------------------
# bcoo_spdot_general
# (batched) general dot product of two BCOO sparse arrays returning a
# Dense ND array.
bcoo_spdot_general_p = core.Primitive('bcoo_spdot_general')
bcoo_spdot_general_p.multiple_results = True
def _bcoo_spdot_general(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: BCOOInfo, rhs_spinfo: BCOOInfo, dimension_numbers: DotDimensionNumbers):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
api_util._ensure_index_tuple(rhs_contract))
bdims = (api_util._ensure_index_tuple(lhs_batch),
api_util._ensure_index_tuple(rhs_batch))
return bcoo_spdot_general_p.bind(lhs_data, lhs_indices, rhs_data, rhs_indices,
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo,
dimension_numbers=(cdims, bdims))
def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo, rhs_spinfo, lhs_contracting, rhs_contracting):
lhs_shape = lhs_spinfo.shape
rhs_shape = rhs_spinfo.shape
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
assert lhs.n_batch == rhs.n_batch == 0
assert lhs.n_dense == rhs.n_dense == 0
assert [lhs_shape[d] for d in lhs_contracting] == [rhs_shape[d] for d in rhs_contracting]
assert max(lhs_contracting, default=-1) < lhs.n_sparse
assert max(rhs_contracting, default=-1) < rhs.n_sparse
out_shape = (
[s for i, s in enumerate(lhs_shape) if i not in lhs_contracting] +
[s for i, s in enumerate(rhs_shape) if i not in rhs_contracting])
lhs_i = lhs_indices[:, jnp.array(lhs_contracting, dtype=int)]
rhs_i = rhs_indices[:, jnp.array(rhs_contracting, dtype=int)]
lhs_j = lhs_indices[:, jnp.array(remaining(range(lhs.n_sparse), lhs_contracting), dtype=int)]
rhs_j = rhs_indices[:, jnp.array(remaining(range(rhs.n_sparse), rhs_contracting), dtype=int)]
# TODO(jakevdp): can we do this more efficiently than using an outer product? Note that
# jnp.isin() currently doesn't help much, because it also does all() over an outer
# comparison.
overlap = (lhs_i[:, None] == rhs_i[None, :]).all(-1)
lhs_fill_value = jnp.expand_dims(
jnp.array([lhs_shape[d] for d in lhs_contracting]), range(lhs_i.ndim - 1))
rhs_fill_value = jnp.expand_dims(
jnp.array([rhs_shape[d] for d in rhs_contracting]), range(rhs_i.ndim - 1))
lhs_valid = (lhs_i < lhs_fill_value).all(-1)
rhs_valid = (rhs_i < rhs_fill_value).all(-1)
out_data = jnp.where(overlap & lhs_valid[:, None] & rhs_valid[None, :],
lhs_data[:, None] * rhs_data[None, :], 0).ravel()
out_indices = jnp.empty([lhs.nse, rhs.nse, lhs_j.shape[-1] + rhs_j.shape[-1]],
dtype=jnp.result_type(lhs_indices, rhs_indices))
out_indices = out_indices.at[:, :, :lhs_j.shape[-1]].set(lhs_j[:, None])
out_indices = out_indices.at[:, :, lhs_j.shape[-1]:].set(rhs_j[None, :])
out_indices = out_indices.reshape(len(out_data), out_indices.shape[-1])
out_nse = (lhs.nse if lhs_j.shape[1] else 1) * (rhs.nse if rhs_j.shape[1] else 1)
# Note: remove_zeros=True is incompatible with autodiff.
return _bcoo_sum_duplicates(out_data, out_indices, out_shape, nse=out_nse, remove_zeros=False)
@bcoo_spdot_general_p.def_impl
def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: BCOOInfo, rhs_spinfo: BCOOInfo, dimension_numbers):
lhs_shape = lhs_spinfo.shape
rhs_shape = rhs_spinfo.shape
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
assert lhs.n_dense == rhs.n_dense == 0
data_aval, indices_aval = _bcoo_spdot_general_abstract_eval(
lhs_data.aval, lhs_indices.aval, rhs_data.aval, rhs_indices.aval,
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo, dimension_numbers=dimension_numbers)
out_shape = _dot_general_validated_shape(lhs_shape, rhs_shape, dimension_numbers)
_validate_bcoo(data_aval, indices_aval, out_shape)
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
# Move batch dimensions to front of each array.
lhs_batch_perm = [*lhs_batch, *remaining(range(lhs.n_batch), lhs_batch)]
rhs_batch_perm = [*rhs_batch, *remaining(range(rhs.n_batch), rhs_batch)]
lhs_data = lhs_data.transpose([*lhs_batch_perm, *range(lhs.n_batch, lhs_data.ndim)])
rhs_data = rhs_data.transpose([*rhs_batch_perm, *range(rhs.n_batch, rhs_data.ndim)])
lhs_indices = lhs_indices.transpose([*lhs_batch_perm, *range(lhs.n_batch, lhs_indices.ndim)])
rhs_indices = rhs_indices.transpose([*rhs_batch_perm, *range(rhs.n_batch, rhs_indices.ndim)])
# Implement batched dot product via vmap
func = functools.partial(_bcoo_spdot_general_unbatched,
lhs_spinfo=BCOOInfo(lhs_shape[lhs.n_batch:]),
rhs_spinfo=BCOOInfo(rhs_shape[rhs.n_batch:]),
lhs_contracting=[d - lhs.n_batch for d in lhs_contracting],
rhs_contracting=[d - rhs.n_batch for d in rhs_contracting])
for _ in reversed(range(len(rhs_batch), rhs.n_batch)):
func = broadcasting_vmap(func, in_axes=(None, None, 0, 0))
for _ in reversed(range(len(lhs_batch), lhs.n_batch)):
func = broadcasting_vmap(func, in_axes=(0, 0, None, None))
for _ in range(len(lhs_batch)):
func = broadcasting_vmap(func, in_axes=0)
return func(lhs_data, lhs_indices, rhs_data, rhs_indices)
@bcoo_spdot_general_p.def_abstract_eval
def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: BCOOInfo, rhs_spinfo: BCOOInfo, dimension_numbers):
lhs_shape = lhs_spinfo.shape
rhs_shape = rhs_spinfo.shape
if lhs_data.dtype != rhs_data.dtype:
raise ValueError("bcoo_spdot_general requires inputs to have matching dtypes; "
f"got lhs.dtype={lhs_data.dtype}, rhs.dtype={rhs_data.dtype}")
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
_ = _dot_general_validated_shape(lhs_shape, rhs_shape, dimension_numbers)
if lhs.n_dense or rhs.n_dense:
# TODO(jakevdp): handle dense dimensions
raise NotImplementedError("bcoo_spdot_general with dense dimensions.")
if (lhs_batch and max(lhs_batch) >= lhs.n_batch) or (rhs_batch and max(rhs_batch) >= rhs.n_batch):
raise NotImplementedError("bcoo_spdot_general: batch_dims must correspond to batch dimensions of the sparse representation.")
if lhs_contracting and (min(lhs_contracting) < lhs.n_batch or max(lhs_contracting) >= lhs.n_batch + lhs.n_sparse):
raise NotImplementedError("bcoo_spdot_general only supports contraction of sparse indices.")
if rhs_contracting and (min(rhs_contracting) < rhs.n_batch or max(rhs_contracting) >= rhs.n_batch + rhs.n_sparse):
raise NotImplementedError("bcoo_spdot_general only supports contraction of sparse indices.")
if rhs.n_batch > len(rhs_batch) and lhs.n_sparse > len(lhs_contracting):
raise ValueError("bcoo_spdot_general: cannot have unused batch dims on rhs with unused sparse dims on lhs.")
out_nse = (
(lhs.nse if lhs.n_sparse > len(lhs_contracting) else 1) *
(rhs.nse if rhs.n_sparse > len(rhs_contracting) else 1)
)
data_shape = (
*(lhs_shape[dim] for dim in lhs_batch),
*(lhs_data.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch),
*(rhs_data.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch),
out_nse)
indices_shape = (
*(lhs_shape[dim] for dim in lhs_batch),
*(lhs_indices.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch),
*(rhs_indices.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch),
out_nse, lhs.n_sparse + rhs.n_sparse - 2 * len(lhs_contracting))
return core.ShapedArray(data_shape, lhs_data.dtype), core.ShapedArray(indices_shape, lhs_indices.dtype)
def _bcoo_spdot_general_batch_rule(batched_args, batch_dims, *, lhs_spinfo: BCOOInfo, rhs_spinfo: BCOOInfo, dimension_numbers):
lhs_shape = lhs_spinfo.shape
rhs_shape = rhs_spinfo.shape
lhs_data, lhs_indices, rhs_data, rhs_indices = batched_args
batch_dims = list(batch_dims)
batch_size = max(0 if dim is None else arg.shape[dim]
for arg, dim in zip(batched_args, batch_dims))
if batch_dims[0] is None:
lhs_data = lhs_data[None]
batch_dims[0] = 0
if batch_dims[1] is None:
lhs_indices = lhs_indices[None]
batch_dims[1] = 0
assert batch_dims[0] == batch_dims[1] == 0
if batch_dims[2] is None:
rhs_data = rhs_data[None]
batch_dims[2] = 0
if batch_dims[3] is None:
rhs_indices = rhs_indices[None]
batch_dims[3] = 0
if any(dim != 0 for dim in batch_dims):
raise NotImplementedError("batching along non-leading dimension.")
assert all(dim == 0 for dim in batch_dims)
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
(len(lhs_shape), len(rhs_shape)), (batch_dims[0], batch_dims[2]), dimension_numbers)
new_lhs_shape = (batch_size, *lhs_shape)
new_rhs_shape = (batch_size, *rhs_shape)
batched_out = _bcoo_spdot_general(lhs_data, lhs_indices, rhs_data, rhs_indices,
dimension_numbers=new_dimension_numbers,
lhs_spinfo=BCOOInfo(new_lhs_shape),
rhs_spinfo=BCOOInfo(new_rhs_shape))
return batched_out, (result_batch_dim, result_batch_dim)
def _bcoo_spdot_general_jvp(primals, tangents, **kwds):
lhs_data, lhs_indices, rhs_data, rhs_indices = primals
lhs_data_dot, lhs_indices_dot, rhs_data_dot, rhs_indices_dot = tangents
primals_out = _bcoo_spdot_general(*primals, **kwds)
assert type(lhs_indices_dot) is ad.Zero
assert type(rhs_indices_dot) is ad.Zero
data_dot_out = 0
if type(lhs_data_dot) is not ad.Zero:
data_dot_out += _bcoo_spdot_general(lhs_data_dot, lhs_indices, rhs_data, rhs_indices, **kwds)[0]
if type(rhs_data_dot) is not ad.Zero:
data_dot_out += _bcoo_spdot_general(lhs_data, lhs_indices, rhs_data_dot, rhs_indices, **kwds)[0]
return primals_out, [data_dot_out, ad.Zero.from_value(primals_out[1])]
# TODO(JVP): transpose rule
batching.primitive_batchers[bcoo_spdot_general_p] = _bcoo_spdot_general_batch_rule
ad.primitive_jvps[bcoo_spdot_general_p] = _bcoo_spdot_general_jvp
xla.register_translation(bcoo_spdot_general_p, xla.lower_fun(
_bcoo_spdot_general_impl, multiple_results=True, new_style=True))
mlir.register_lowering(bcoo_spdot_general_p, mlir.lower_fun(
_bcoo_spdot_general_impl, multiple_results=True))
#----------------------------------------------------------------------
# bcoo_sort_indices
# Utility to sort the indices of a BCOO representation. This primitive
# does not support deduplication or removing of zeros; see bcoo_sum_duplicates.
bcoo_sort_indices_p = core.Primitive("bcoo_sort_indices")
bcoo_sort_indices_p.multiple_results = True
def bcoo_sort_indices(mat):
"""Sort indices of a BCOO array, and optionally sum duplicates & eliminate zeros.
Args:
mat : BCOO array
Returns:
mat_out : BCOO array with sorted indices.
"""
data, indices = bcoo_sort_indices_p.bind(*mat._bufs, spinfo=mat._info)
return BCOO((data, indices), shape=mat.shape)
@bcoo_sort_indices_p.def_impl
def _bcoo_sort_indices_impl(data, indices, *, spinfo):
props = _validate_bcoo(data, indices, spinfo.shape)
if props.n_sparse == 0:
return data, indices
f = _bcoo_sort_indices_unbatched
for _ in range(props.n_batch):
f = vmap(f)
indices, perm = f(indices)
permute = lambda d, p: d[p]
for _ in range(props.n_batch):
permute = broadcasting_vmap(permute)
data = permute(data, perm)
return data, indices
def _bcoo_sort_indices_unbatched(indices):
# sort indices without summing duplicates
nse, N = indices.shape
idx_cols = (indices[:, i] for i in range(N))
*indices, perm = lax.sort((*idx_cols, lax.iota(indices.dtype, nse)), num_keys=N)
return jnp.column_stack(indices), perm
@bcoo_sort_indices_p.def_abstract_eval
def _bcoo_sort_indices_abstract_eval(data, indices, *, spinfo):
props = _validate_bcoo(data, indices, spinfo.shape)
if props.n_sparse == 0:
return data, indices
data_out = core.ShapedArray(
(*map(max, indices.shape[:props.n_batch], data.shape[:props.n_batch]),
props.nse, *data.shape[props.n_batch + 1:]), data.dtype, weak_type=data.weak_type)
return data_out, indices
def _bcoo_sort_indices_batching_rule(batched_args, batch_dims, *, spinfo):
data, indices = batched_args
if any(b not in [0, None] for b in batch_dims):
raise NotImplementedError(f"batch_dims={batch_dims}. Only 0 and None are supported.")
if batch_dims[0] is None:
data = data[None, ...]
if batch_dims[1] is None:
indices = indices[None, ...]
new_spinfo = BCOOInfo(shape=(max(data.shape[0], indices.shape[0]), *spinfo.shape))
return bcoo_sort_indices_p.bind(data, indices, spinfo=new_spinfo), (0, 0)
def _bcoo_sort_indices_jvp(primals, tangents, *, spinfo):
props = _validate_bcoo(*primals, spinfo.shape)
if props.n_sparse == 0:
return primals, tangents
data, indices = primals
data_dot, _ = tangents
f = _bcoo_sort_indices_unbatched
for _ in range(props.n_batch):
f = broadcasting_vmap(f)
indices_out, perm = f(indices)
permute = lambda d, p: d[p]
for _ in range(props.n_batch):
permute = broadcasting_vmap(permute)
data_out = permute(data, perm)
indices_dot_out = ad.Zero.from_value(indices)
data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot, perm)
return (data_out, indices_out), (data_dot_out, indices_dot_out)
_bcoo_sort_indices_rule = xla.lower_fun(
_bcoo_sort_indices_impl, multiple_results=True, new_style=True)
_bcoo_sort_indices_mhlo = mlir.lower_fun(
_bcoo_sort_indices_impl, multiple_results=True)
ad.primitive_jvps[bcoo_sort_indices_p] = _bcoo_sort_indices_jvp
batching.primitive_batchers[bcoo_sort_indices_p] = _bcoo_sort_indices_batching_rule
xla.register_translation(bcoo_sort_indices_p, _bcoo_sort_indices_rule)
mlir.register_lowering(bcoo_sort_indices_p, _bcoo_sort_indices_mhlo)
#----------------------------------------------------------------------
# BCOO functions that maybe should be primitives?
def bcoo_add_batch_dim(M):
"""Convert a sparse dimension to a batch dimension
Please note that this function may result in a far less efficient storage scheme
for the matrix (storage required will increase by a factor of `M.shape[0] * M.nse`).
This utility is provided for convenience, e.g. to allow vmapping over non-batched
matrices.
Args:
M: BCOO matrix
Returns:
M2: BCOO matrix with n_batch = M.n_batch + 1 and n_sparse = M.n_sparse - 1
"""
# TODO(jakevdp): allow user-specified nse?
if M.n_sparse == 0:
raise ValueError("Cannot add a batch dimension to a matrix with n_sparse=0")
f = _add_batch_dim
for _ in range(M.n_batch):
f = vmap(f)
return f(M)
def _add_batch_dim(M):
assert M.n_batch == 0
assert M.n_sparse > 0
data = jnp.zeros_like(M.data, shape=(M.shape[0], *M.data.shape))
data = data.at[M.indices[:, 0], jnp.arange(M.nse)].set(M.data)
indices_shape = (M.shape[0], M.nse, M.n_sparse - 1)
if M.n_sparse > 1:
fill_value = jnp.array(M.shape[M.n_batch + 1: M.n_batch + M.n_sparse])
indices = jnp.full_like(M.indices, shape=indices_shape, fill_value=fill_value)
indices = indices.at[M.indices[:, 0], jnp.arange(M.nse)].set(M.indices[:, 1:])
else:
indices = jnp.empty_like(M.indices, shape=indices_shape)
return BCOO((data, indices), shape=M.shape)
def bcoo_broadcast_in_dim(mat, *, shape, broadcast_dimensions):
"""Expand the size and rank of a BCOO array by duplicating the data.
A BCOO equivalence to jax.lax.broadcast_in_dim.
Args:
mat: A BCOO-format array.
shape: The shape of the target array.
broadcast_dimensions: The dimension in the shape of the target array which
each dimension of the operand (``mat``) shape corresponds to.
Returns:
A BCOO-format array containing the target array.
"""
return BCOO(_bcoo_broadcast_in_dim(mat.data, mat.indices, spinfo=mat._info,
shape=shape,
broadcast_dimensions=broadcast_dimensions),
shape=shape)
def _bcoo_broadcast_in_dim(data, indices, *, spinfo, shape, broadcast_dimensions):
"""BCOO equivalent of lax.broadcast_in_dim"""
if len(spinfo.shape) != len(broadcast_dimensions):
raise ValueError(f"spinfo.shape={spinfo.shape} and broadcast_dimensions={broadcast_dimensions} must have the same length")
props = _validate_bcoo(data, indices, spinfo.shape)
batch_dims, sparse_dims, dense_dims = split_list(broadcast_dimensions, [props.n_batch, props.n_sparse])
if max(batch_dims, default=0) > min(sparse_dims, default=len(shape)):
raise ValueError("Cannot mix batch and sparse dimensions during broadcast_in_dim")
if max(sparse_dims, default=0) > min(dense_dims, default=len(shape)):
raise ValueError("Cannot mix sparse and dense dimensions during broadcast_in_dim")
new_n_batch = props.n_batch and 1 + max(broadcast_dimensions[:props.n_batch])
new_n_dense = props.n_dense and len(shape) - min(broadcast_dimensions[-props.n_dense:])
new_n_sparse = len(shape) - new_n_batch - new_n_dense
if np.prod(spinfo.shape[props.n_batch: props.n_batch + props.n_sparse]) != np.prod(shape[new_n_batch:new_n_batch + new_n_sparse]):
raise NotImplementedError("Adding sparse dimensions with lengths != 1")
nse = props.nse
# batch & dense dimensions
new_data = lax.broadcast_in_dim(data,
shape=(*shape[:new_n_batch], nse, *shape[new_n_batch + new_n_sparse:]),
broadcast_dimensions=(*batch_dims, new_n_batch, *(b + 1 - new_n_sparse for b in dense_dims)))
new_indices = lax.broadcast_in_dim(indices,
shape=(*shape[:new_n_batch], nse, props.n_sparse),
broadcast_dimensions=(*batch_dims, new_n_batch, new_n_batch + 1))
# sparse dimensions
new_indices = (jnp.zeros_like(new_indices, shape=(*shape[:new_n_batch], nse, new_n_sparse))
.at[..., jnp.array(sparse_dims, int) - new_n_batch].set(new_indices))
return new_data, new_indices
def _tuple_replace(tup, ind, val):
return tuple(val if i == ind else t for i, t in enumerate(tup))
def bcoo_reduce_sum(mat, *, axes):
"""Sum array element over given axes.
Args:
mat: A BCOO-format array.
shape: The shape of the target array.
axes: A tuple or list or ndarray which contains axes of ``mat`` over which
sum is perfomed.
Returns:
A BCOO-format array containing the result.
"""
out_data, out_indices, out_shape = _bcoo_reduce_sum(
mat.data, mat.indices, spinfo=mat._info, axes=axes)
return BCOO((out_data, out_indices), shape=out_shape)
def _bcoo_reduce_sum(data, indices, *, spinfo, axes):
shape = spinfo.shape
assert all(0 <= a < len(shape) for a in axes)
n_batch, n_sparse, _, nse = _validate_bcoo(data, indices, shape)
axes = sorted(set(axes))
# Sum over dense dimensions -> sum over data
dense_axes = tuple(ax - n_sparse + 1 for ax in axes if ax >= n_batch + n_sparse)
data = data.sum(dense_axes)
if n_sparse:
# zero-out data corresponding to invalid indices.
fill_value = jnp.expand_dims(
jnp.array(shape[n_batch: n_batch + n_sparse]), range(indices.ndim - 1))
mask = jnp.all(indices < fill_value, -1)
if data.ndim > mask.ndim:
mask = lax.expand_dims(mask, tuple(range(mask.ndim, data.ndim)))
data = jnp.where(mask, data, 0)
# Sum over sparse dimensions -> drop index; sum is implicit
sparse_idx = [i for i in range(n_sparse) if i + n_batch not in axes]
if not sparse_idx:
indices = jnp.zeros(_tuple_replace(indices.shape, n_batch + 1, 0), indices.dtype)
else:
indices = indices[..., np.array(sparse_idx)]
# Sum over batch dimensions -> reshape into nse
batch_axes = {ax for ax in axes if ax < n_batch}
# First handle broadcasted batch dimensions
for ax in batch_axes:
if data.shape[ax] == 1:
if indices.shape[ax] == 1:
data = data * shape[ax]
else:
data = lax.broadcast_in_dim(data, _tuple_replace(data.shape, ax, shape[ax]), tuple(range(data.ndim)))
else:
if indices.shape[ax] == 1:
data = data.sum(ax)
assert data.shape[ax] == indices.shape[ax]
new_batch_dims = tuple(sorted(set(range(n_batch)) - batch_axes))
new_batch_shape = tuple(data.shape[i] for i in new_batch_dims)
new_nse = int(nse * np.prod([data.shape[i] for i in batch_axes]))
data = lax.reshape(data,
(*new_batch_shape, new_nse, *data.shape[n_batch + 1:]),
(*new_batch_dims, *batch_axes, *range(n_batch, data.ndim)))
indices = lax.reshape(indices,
(*new_batch_shape, new_nse, *indices.shape[n_batch + 1:]),
(*new_batch_dims, *batch_axes, *range(n_batch, indices.ndim)))
out_shape = tuple(shape[i] for i in range(len(shape)) if i not in axes)
return data, indices, out_shape
def bcoo_multiply_sparse(lhs, rhs):
"""An element-wise multiplication of two sparse arrays.
Args:
lhs: A BCOO-format array.
rhs: A BCOO-format array.
Returns:
An BCOO-format array containing the result.
"""
out_data, out_indices, out_shape = _bcoo_multiply_sparse(
lhs.data, lhs.indices, rhs.data, rhs.indices, lhs_spinfo=lhs._info,
rhs_spinfo=rhs._info)
return BCOO((out_data, out_indices), shape=out_shape)
def _bcoo_multiply_sparse(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo, rhs_spinfo):
lhs_shape = lhs_spinfo.shape
rhs_shape = rhs_spinfo.shape
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
if len(lhs_shape) != len(rhs_shape):
# Similar requirement as lax.mul:
raise TypeError("bcoo_multiply_sparse: arrays must have same number of dimensions, "
f"got {lhs_shape}, {rhs_shape}")
if lhs.n_dense != rhs.n_dense:
raise NotImplementedError("bcoo_multiply_sparse: arrays with differing numbers of "
f"dense dimensions: {lhs}, {rhs}")
n_batch = min(lhs.n_batch, rhs.n_batch)
_mul = functools.partial(_bcoo_multiply_sparse_unbatched,
lhs_shape=lhs_shape[n_batch:],
rhs_shape=rhs_shape[n_batch:])
for _ in range(n_batch):
_mul = broadcasting_vmap(_mul)
data, indices = _mul(lhs_data, lhs_indices, rhs_data, rhs_indices)
return data, indices, jnp.broadcast_shapes(lhs_shape, rhs_shape)
def _bcoo_multiply_sparse_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shape, rhs_shape):
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
assert (lhs.n_batch == 0) or (rhs.n_batch == 0) # Ensured at call site above
# TODO(jakevdp): this can be made more efficient by utilizing batch structure.
if lhs.n_batch:
lhs_data, lhs_indices = _unbatch_bcoo(lhs_data, lhs_indices, lhs_shape)
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
elif rhs.n_batch:
rhs_data, rhs_indices = _unbatch_bcoo(rhs_data, rhs_indices, rhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
dims = jnp.array([i for i, (s1, s2) in enumerate(safe_zip(lhs_shape[:lhs.n_sparse], rhs_shape[:rhs.n_sparse]))
if s1 != 1 and s2 != 1], dtype=int)
# TODO(jakevdp): this nse can be tightened to min(lhs.nse, rhs.nse) if there
# is no broadcasting and indices are unique.
nse = lhs.nse * rhs.nse
# TODO(jakevdp): this is pretty inefficient. Can we do this membership check
# without constructing the full (lhs.nse, rhs.nse) masking matrix?
mask = jnp.all(lhs_indices[:, None, dims] == rhs_indices[None, :, dims], -1)
i_lhs, i_rhs = jnp.nonzero(mask, size=nse, fill_value=(lhs.nse, rhs.nse))
data = (lhs_data.at[i_lhs].get(mode='fill', fill_value=0) *
rhs_data.at[i_rhs].get(mode='fill', fill_value=0))
indices = jnp.maximum(
lhs_indices.at[i_lhs].get(mode='fill', fill_value=max(lhs_shape, default=0)),
rhs_indices.at[i_rhs].get(mode='fill', fill_value=max(rhs_shape, default=0)))
return data, indices
def bcoo_multiply_dense(sp_mat, v):
"""An element-wise multiplication between a sparse and a dense array.
Args:
lhs: A BCOO-format array.
rhs: An ndarray.
Returns:
An ndarray containing the result.
"""
return _bcoo_multiply_dense(*sp_mat._bufs, v, spinfo=sp_mat._info)
def _bcoo_multiply_dense(data, indices, v, *, spinfo):
"""Broadcasted elementwise multiplication between a BCOO array and a dense array."""
# TODO(jakevdp): the logic here is similar to bcoo_extract... can we reuse that?
shape = spinfo.shape
if v.ndim == 0:
return lax.mul(data, v)
if shape == v.shape:
# Note: due to distributive property, no deduplication necessary!
return lax.mul(data, bcoo_extract(indices, v))
if lax.broadcast_shapes(v.shape, shape) != shape:
raise NotImplementedError(
"multiplication between sparse and dense is only implemented for cases "
"where the output shape matches the sparse matrix shape. Got "
f"shape={shape}, v.shape={v.shape}")
v = lax.expand_dims(v, range(len(shape) - v.ndim))
props = _validate_bcoo(data, indices, shape)
def _mul(data, indices, v):
assert indices.shape[1] == v.ndim - props.n_dense
ind = tuple(indices[:, i] for i in range(indices.shape[1]))
ind = tuple(i if s != 1 else 0 for i, s in zip(ind, v.shape))
return data * v[ind]
for _ in range(props.n_batch):
_mul = broadcasting_vmap(_mul)
return _mul(data, indices, v)
@tree_util.register_pytree_node_class
class BCOO(JAXSparse):
"""Experimental batched COO matrix implemented in JAX
Args:
(data, indices) : data and indices in batched COO format.
shape : shape of sparse array.
Attributes:
data : ndarray of shape ``[*batch_dims, nse, *dense_dims]`` containing the
explicitly stored data within the sparse matrix.
indices : ndarray of shape ``[*batch_dims, nse, n_sparse]`` containing the
indices of the explicitly stored data. Duplicate entries will be summed.
Examples:
Create a sparse array from a dense array:
>>> M = jnp.array([[0., 2., 0.], [1., 0., 4.]])
>>> M_sp = BCOO.fromdense(M)
>>> M_sp
BCOO(float32[2, 3], nse=3)
Examine the internal representation:
>>> M_sp.data
DeviceArray([2., 1., 4.], dtype=float32)
>>> M_sp.indices
DeviceArray([[0, 1],
[1, 0],
[1, 2]], dtype=int32)
Create a dense array from a sparse array:
>>> M_sp.todense()
DeviceArray([[0., 2., 0.],
[1., 0., 4.]], dtype=float32)
Create a sparse array from COO data & indices:
>>> data = jnp.array([1., 3., 5.])
>>> indices = jnp.array([[0, 0],
... [1, 1],
... [2, 2]])
>>> mat = BCOO((data, indices), shape=(3, 3))
>>> mat
BCOO(float32[3, 3], nse=3)
>>> mat.todense()
DeviceArray([[1., 0., 0.],
[0., 3., 0.],
[0., 0., 5.]], dtype=float32)
"""
# Note: additional BCOO methods are defined in transform.py
data: jnp.ndarray
indices: jnp.ndarray
shape: Shape
nse = property(lambda self: self.indices.shape[-2])
dtype = property(lambda self: self.data.dtype)
n_batch = property(lambda self: self.indices.ndim - 2)
n_sparse = property(lambda self: self.indices.shape[-1])
n_dense = property(lambda self: self.data.ndim - 1 - self.n_batch)
_info = property(lambda self: BCOOInfo(self.shape))
_bufs = property(lambda self: (self.data, self.indices))
def __init__(self, args, *, shape):
# JAX transforms will sometimes instantiate pytrees with null values, so we
# must catch that in the initialization of inputs.
self.data, self.indices = _safe_asarray(args)
super().__init__(args, shape=shape)
@classmethod
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32, n_dense=0, n_batch=0):
"""Create a BCOO array from a (dense) :class:`DeviceArray`."""
return bcoo_fromdense(
mat, nse=nse, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch)
@classmethod
def from_scipy_sparse(cls, mat, *, index_dtype=None, n_dense=0, n_batch=0):
"""Create a BCOO array from a :mod:`scipy.sparse` array."""
if n_dense != 0 or n_batch != 0:
raise NotImplementedError("BCOO.fromscipy with nonzero n_dense/n_batch")
mat = mat.tocoo()
data = jnp.asarray(mat.data)
indices = jnp.column_stack((mat.row, mat.col)).astype(index_dtype or jnp.int32)
return cls((data, indices), shape=mat.shape)
@classmethod
def _empty(cls, shape, *, dtype=None, index_dtype='int32', n_dense=0, n_batch=0, nse=0):
"""Create an empty BCOO instance. Public method is sparse.empty()."""
shape = tuple(shape)
n_sparse = len(shape) - n_dense - n_batch
if n_sparse < 0 or n_dense < 0 or n_batch < 0 or nse < 0:
raise ValueError(f"Invalid inputs: shape={shape}, n_dense={n_dense}, n_batch={n_batch}, nse={nse}")
batch_shape, sparse_shape, dense_shape = split_list(shape, [n_batch, n_sparse])
data = jnp.zeros((*batch_shape, nse, *dense_shape), dtype)
indices = jnp.full((*batch_shape, nse, n_sparse), jnp.array(sparse_shape), index_dtype)
return cls((data, indices), shape=shape)
def _unbatch(self):
"""Return an unbatched representation of the BCOO matrix."""
return BCOO(_unbatch_bcoo(self.data, self.indices, self.shape), shape=self.shape)
def _dedupe(self):
warnings.warn("_dedupe() is deprecated. Use sum_duplicates() instead.", FutureWarning)
return self.sum_duplicates(nse=self.nse)
def sum_duplicates(self, nse=None, remove_zeros=True):
"""Return a copy of the array with duplicate indices summed.
Additionally, this operation will result in explicit zero entries removed, and
indices being sorted in lexicographic order.
Because the size of the resulting representation depends on the values in the
arrays, this operation is not compatible with JIT or other transforms. To use
``sum_duplicates`` in such cases, you may pass a value to `nse` to specify the
desired size of the output representation.
Args:
nse : integer (optional), if specified, gives the number of specified elements in
the output sparse representation; if it is larger than the number required, data
will be padded with zeros and indices will be padded with out-of-bounds values.
If it is smaller than the number required, data will be silently discarded.
remove_zeros : bool (default=True). If True, remove explicit zeros from the data
as part of summing duplicates. If False, then explicit zeros at unique indices
will remain among the specified elements. Note: remove_zeros=True is incompatible
with autodiff.
"""
data, indices = _bcoo_sum_duplicates(self.data, self.indices, self.shape,
nse=nse, remove_zeros=remove_zeros)
return BCOO((data, indices), shape=self.shape)
def sort_indices(self):
"""Return a copy of the matrix with indices sorted."""
return bcoo_sort_indices(self)
def todense(self):
"""Create a dense version of the array."""
return bcoo_todense(self)
def transpose(self, axes=None):
"""Create a new array containing the transpose."""
axes = np.arange(self.ndim)[::-1] if axes is None else axes
mat_T = bcoo_transpose(self, permutation=axes)
shape_T = tuple(self.shape[i] for i in axes)
return BCOO((mat_T.data, mat_T.indices), shape=shape_T)
def tree_flatten(self):
return (self.data, self.indices), self._info._asdict()
# vmappable handlers
def _bcoo_to_elt(cont, _, val, axis):
if axis is None:
return val
if axis >= val.n_batch:
raise ValueError(f"Cannot map in_axis={axis} for BCOO array with n_batch={val.n_batch}. "
"in_axes for batched BCOO operations must correspond to a batch dimension.")
return BCOO((cont(val.data, axis), cont(val.indices, axis)),
shape= val.shape[:axis] + val.shape[axis + 1:])
def _bcoo_from_elt(cont, axis_size, elt, axis):
if axis > elt.n_batch:
raise ValueError(f"BCOO: cannot add out_axis={axis} for BCOO array with n_batch={elt.n_batch}. "
"BCOO batch axes must be a contiguous block of leading dimensions.")
return BCOO((cont(axis_size, elt.data, axis), cont(axis_size, elt.indices, axis)),
shape=elt.shape[:axis] + (axis_size,) + elt.shape[axis:])
batching.register_vmappable(BCOO, int, int, _bcoo_to_elt, _bcoo_from_elt, None)