Merge pull request #16740 from jakevdp:spdot-general-args

PiperOrigin-RevId: 548773744
This commit is contained in:
jax authors 2023-07-17 12:52:33 -07:00
commit 68ea651ae4
4 changed files with 141 additions and 74 deletions

View File

@ -625,52 +625,65 @@ def bcoo_dot_general(lhs: Union[BCOO, Array], rhs: Union[BCOO, Array], *, dimens
the result will be dense, of type ndarray.
"""
# TODO(jakevdp) make use of these?
del precision, preferred_element_type # unused
del precision # unused
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
dimension_numbers)
bufs = _bcoo_spdot_general(lhs.data, lhs.indices, rhs.data, rhs.indices,
lhs_spinfo=lhs._info, rhs_spinfo=rhs._info,
dimension_numbers=dimension_numbers)
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type)
return BCOO(bufs, shape=shape)
elif isinstance(lhs, BCOO):
return _bcoo_dot_general(lhs.data, lhs.indices, rhs, dimension_numbers=dimension_numbers, # type: ignore[arg-type]
preferred_element_type=preferred_element_type,
lhs_spinfo=lhs._info)
elif isinstance(rhs, BCOO):
return _bcoo_rdot_general(lhs, rhs.data, rhs.indices, dimension_numbers=dimension_numbers, # type: ignore[arg-type]
preferred_element_type=preferred_element_type,
rhs_spinfo=rhs._info)
else:
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type)
def _bcoo_dot_general(lhs_data: Array, lhs_indices: Array, rhs: Array, *,
dimension_numbers: DotDimensionNumbers, lhs_spinfo: SparseInfo) -> Array:
dimension_numbers: DotDimensionNumbers,
preferred_element_type: Any,
lhs_spinfo: SparseInfo) -> Array:
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
api_util._ensure_index_tuple(rhs_contract))
bdims = (api_util._ensure_index_tuple(lhs_batch),
api_util._ensure_index_tuple(rhs_batch))
if preferred_element_type is not None:
preferred_element_type = np.dtype(preferred_element_type)
return bcoo_dot_general_p.bind(jnp.asarray(lhs_data), jnp.asarray(lhs_indices), jnp.asarray(rhs),
dimension_numbers=(cdims, bdims),
preferred_element_type=preferred_element_type,
lhs_spinfo=lhs_spinfo)
def _bcoo_rdot_general(lhs: Array, rhs_data: Array, rhs_indices: Array, *,
dimension_numbers: DotDimensionNumbers, rhs_spinfo: SparseInfo) -> Array:
dimension_numbers: DotDimensionNumbers,
preferred_element_type: Any, rhs_spinfo: SparseInfo) -> Array:
# TODO(jakevdp): perhaps this should be part of the bcoo_dot_general primitive?
dimension_numbers_reversed: DotDimensionNumbers = tuple(d[::-1] for d in dimension_numbers) # type: ignore[assignment]
result = _bcoo_dot_general(rhs_data, rhs_indices, lhs, lhs_spinfo=rhs_spinfo,
dimension_numbers=dimension_numbers_reversed)
dimension_numbers=dimension_numbers_reversed,
preferred_element_type=preferred_element_type)
n_contract, n_batch = (len(d[0]) for d in dimension_numbers)
n_swap = len(rhs_spinfo.shape) - n_contract
permutation = tuple([*range(n_batch), *range(n_swap, result.ndim), *range(n_batch, n_swap)])
return lax.transpose(result, permutation)
def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: SparseInfo):
def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers,
preferred_element_type, lhs_spinfo: SparseInfo):
lhs_data = jnp.asarray(lhs_data)
lhs_indices = jnp.asarray(lhs_indices)
rhs = jnp.asarray(rhs)
# Validate all inputs via abstract_eval
out_aval = _bcoo_dot_general_abstract_eval(lhs_data.aval, lhs_indices.aval, rhs.aval,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type,
lhs_spinfo=lhs_spinfo)
n_sparse = lhs_indices.shape[-1]
n_batch = lhs_indices.ndim - 2
@ -720,16 +733,17 @@ def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs
return result(out_array, lhs_data, lhs_indices, rhs)
@bcoo_dot_general_p.def_abstract_eval
def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: SparseInfo):
if lhs_data.dtype != rhs.dtype:
raise ValueError("bcoo_dot_general requires arguments to have matching dtypes; "
f"got lhs.dtype={lhs_data.dtype}, rhs.dtype={rhs.dtype}")
def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers,
preferred_element_type, lhs_spinfo: SparseInfo):
out_aval = jax.eval_shape(
partial(lax.dot_general,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type),
jax.ShapeDtypeStruct(lhs_spinfo.shape, lhs_data.dtype),
jax.ShapeDtypeStruct(rhs.shape, rhs.dtype))
(lhs_contracting, _), (lhs_batch, _) = dimension_numbers
n_batch, n_sparse, _, _ = _validate_bcoo(lhs_data, lhs_indices, lhs_spinfo.shape)
out_shape = _dot_general_validated_shape(lhs_spinfo.shape, rhs.shape,
dimension_numbers)
if lhs_batch and max(lhs_batch) >= n_batch:
raise NotImplementedError(
"bcoo_dot_general batch dimensions must be among the batch dimensions in the sparse representtaion.\n"
@ -739,7 +753,7 @@ def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_num
if any(d >= n_batch + n_sparse for d in lhs_contracting):
raise NotImplementedError("bcoo_dot_general: contracting over dense dimensions.")
return core.ShapedArray(out_shape, lhs_data.dtype)
return core.ShapedArray(out_aval.shape, out_aval.dtype)
_bcoo_dot_general_default_lowering = mlir.lower_fun(
_bcoo_dot_general_impl, multiple_results=False)
@ -766,16 +780,34 @@ def _bcoo_dot_general_fallback(data, indices, spinfo):
return False
def _bcoo_dot_general_gpu_impl(lhs_data, lhs_indices, rhs, *,
dimension_numbers, lhs_spinfo):
dimension_numbers, preferred_element_type,
lhs_spinfo):
if not config.jax_bcoo_cusparse_lowering:
return _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type,
lhs_spinfo=lhs_spinfo)
(lhs_contract, rhs_contract), (lhs_batch, _) = dimension_numbers
n_batch, n_sparse, n_dense, _ = _validate_bcoo(
lhs_data, lhs_indices, lhs_spinfo.shape)
coo_matmul_p = coo_spmv_p if rhs.ndim == 1 else coo_spmm_p
out_aval = _bcoo_dot_general_abstract_eval(
lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type,
lhs_spinfo=lhs_spinfo)
if out_aval.dtype not in CUSPARSE_DATA_DTYPES:
return _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type,
lhs_spinfo=lhs_spinfo)
lhs_data = lhs_data.astype(out_aval.dtype)
rhs = rhs.astype(out_aval.dtype)
# TODO(jakevdp, tianjianlu): add support for batched lowerings
if (len(lhs_contract) == 1 and len(lhs_batch) == 0 and rhs.ndim in (1, 2)
and (n_batch, n_sparse, n_dense) == (0, 1, 0)
@ -801,18 +833,24 @@ def _bcoo_dot_general_gpu_impl(lhs_data, lhs_indices, rhs, *,
return out[:-1]
else:
return _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo,
preferred_element_type=preferred_element_type)
_bcoo_dot_general_gpu_lowering = mlir.lower_fun(
_bcoo_dot_general_gpu_impl, multiple_results=False)
def _bcoo_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: SparseInfo):
return _bcoo_dot_general(lhs_data_dot, lhs_indices, rhs, dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
def _bcoo_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers,
preferred_element_type, lhs_spinfo: SparseInfo):
return _bcoo_dot_general(lhs_data_dot, lhs_indices, rhs, dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type, lhs_spinfo=lhs_spinfo)
def _bcoo_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: SparseInfo):
return _bcoo_dot_general(lhs_data, lhs_indices, rhs_dot, dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
def _bcoo_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers,
preferred_element_type, lhs_spinfo: SparseInfo):
return _bcoo_dot_general(lhs_data, lhs_indices, rhs_dot, dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type, lhs_spinfo=lhs_spinfo)
def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: SparseInfo):
def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_numbers,
preferred_element_type, lhs_spinfo: SparseInfo):
assert not ad.is_undefined_primal(lhs_indices)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_ndim = len(lhs_spinfo.shape)
@ -857,10 +895,13 @@ def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_num
dims = ((lhs_kept, ans_lhs), (lhs_batch, ans_batch)) # type: ignore[assignment]
rhs_contract_sorted_by_lhs = list(np.take(rhs_contract, np.argsort(lhs_contract)))
out_axes = list(np.argsort(list(rhs_batch) + rhs_contract_sorted_by_lhs + rhs_kept))
result = _bcoo_dot_general(lhs_data, lhs_indices, ct, lhs_spinfo=lhs_spinfo, dimension_numbers=dims)
result = _bcoo_dot_general(lhs_data, lhs_indices, ct, lhs_spinfo=lhs_spinfo,
preferred_element_type=preferred_element_type,
dimension_numbers=dims)
return lhs_data, lhs_indices, lax.transpose(result, out_axes)
def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, lhs_spinfo: SparseInfo):
def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
preferred_element_type, lhs_spinfo: SparseInfo):
_, _, rhs = batched_args
_, _, rhs_bdim = batch_dims
new_lhs_data, new_lhs_indices, new_lhs_spinfo = _bcoo_batch_dims_to_front(
@ -869,6 +910,7 @@ def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
(len(lhs_spinfo.shape), rhs.ndim), (0, rhs_bdim), dimension_numbers)
batched_out = _bcoo_dot_general(new_lhs_data, new_lhs_indices, rhs, lhs_spinfo=new_lhs_spinfo,
preferred_element_type=preferred_element_type,
dimension_numbers=new_dimension_numbers)
return batched_out, result_batch_dim
@ -1041,7 +1083,8 @@ bcoo_spdot_general_p = core.Primitive('bcoo_spdot_general')
bcoo_spdot_general_p.multiple_results = True
def _bcoo_spdot_general(lhs_data: Array, lhs_indices: Array, rhs_data: Array, rhs_indices: Array, *,
lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers: DotDimensionNumbers) -> tuple[Array, Array]:
lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers: DotDimensionNumbers,
preferred_element_type: Any) -> tuple[Array, Array]:
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
api_util._ensure_index_tuple(rhs_contract))
@ -1049,7 +1092,8 @@ def _bcoo_spdot_general(lhs_data: Array, lhs_indices: Array, rhs_data: Array, rh
api_util._ensure_index_tuple(rhs_batch))
return bcoo_spdot_general_p.bind(lhs_data, lhs_indices, rhs_data, rhs_indices,
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo,
dimension_numbers=(cdims, bdims))
dimension_numbers=(cdims, bdims),
preferred_element_type=preferred_element_type)
def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo, rhs_spinfo, lhs_contracting, rhs_contracting, out_nse):
lhs_shape = lhs_spinfo.shape
@ -1098,7 +1142,8 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices,
return _bcoo_sum_duplicates(out_data, out_indices, spinfo=SparseInfo(shape=out_shape), nse=out_nse)
@bcoo_spdot_general_p.def_impl
def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers):
def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo,
dimension_numbers, preferred_element_type):
lhs_shape = lhs_spinfo.shape
rhs_shape = rhs_spinfo.shape
@ -1107,7 +1152,8 @@ def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lh
assert lhs.n_dense == rhs.n_dense == 0
data_aval, _ = _bcoo_spdot_general_abstract_eval(
lhs_data.aval, lhs_indices.aval, rhs_data.aval, rhs_indices.aval,
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo, dimension_numbers=dimension_numbers)
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo, dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type)
out_nse = data_aval.shape[-1]
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
@ -1134,19 +1180,20 @@ def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lh
return func(lhs_data, lhs_indices, rhs_data, rhs_indices)
@bcoo_spdot_general_p.def_abstract_eval
def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers):
def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo,
dimension_numbers, preferred_element_type):
lhs_shape = lhs_spinfo.shape
rhs_shape = rhs_spinfo.shape
out_shape = _dot_general_validated_shape(lhs_shape, rhs_shape,
dimension_numbers)
out_aval = jax.eval_shape(
partial(lax.dot_general,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type),
jax.ShapeDtypeStruct(lhs_shape, lhs_data.dtype),
jax.ShapeDtypeStruct(rhs_shape, rhs_data.dtype))
if lhs_data.dtype != rhs_data.dtype:
raise ValueError("bcoo_spdot_general requires inputs to have matching dtypes; "
f"got lhs.dtype={lhs_data.dtype}, rhs.dtype={rhs_data.dtype}")
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
_ = _dot_general_validated_shape(lhs_shape, rhs_shape, dimension_numbers)
if lhs.n_dense or rhs.n_dense:
# TODO(jakevdp): handle dense dimensions
@ -1172,7 +1219,7 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic
# Ensure we're not storing more output elements than necessary.
# TODO(jakevdp): should we warn here if output is effectively dense?
out_n_batch = lhs.n_batch + rhs.n_batch - len(lhs_batch)
out_nse = min(out_nse, math.prod(out_shape[out_n_batch:]))
out_nse = min(out_nse, math.prod(out_aval.shape[out_n_batch:]))
data_shape = (
*(lhs_shape[dim] for dim in lhs_batch),
@ -1185,13 +1232,14 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic
*(rhs_indices.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch),
out_nse, lhs.n_sparse + rhs.n_sparse - 2 * len(lhs_contracting))
data_aval = core.ShapedArray(data_shape, lhs_data.dtype)
data_aval = core.ShapedArray(data_shape, out_aval.dtype)
indices_aval = core.ShapedArray(indices_shape, lhs_indices.dtype)
_validate_bcoo(data_aval, indices_aval, out_shape) # pytype: disable=wrong-arg-types # always-use-return-annotations
_validate_bcoo(data_aval, indices_aval, out_aval.shape) # pytype: disable=wrong-arg-types # always-use-return-annotations
return data_aval, indices_aval
def _bcoo_spdot_general_batch_rule(batched_args, batch_dims, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers):
def _bcoo_spdot_general_batch_rule(batched_args, batch_dims, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo,
preferred_element_type, dimension_numbers):
lhs_ndim = len(lhs_spinfo.shape)
rhs_ndim = len(rhs_spinfo.shape)
batch_size = max(arg.shape[dim] for arg, dim in zip(batched_args, batch_dims) if dim is not None)
@ -1203,7 +1251,8 @@ def _bcoo_spdot_general_batch_rule(batched_args, batch_dims, *, lhs_spinfo: Spar
(lhs_ndim, rhs_ndim), (0, 0), dimension_numbers)
batched_out = _bcoo_spdot_general(lhs_data, lhs_indices, rhs_data, rhs_indices,
dimension_numbers=dimension_numbers,
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo)
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo,
preferred_element_type=preferred_element_type)
return batched_out, (result_batch_dim, result_batch_dim)

View File

@ -19,7 +19,7 @@ from functools import partial
import operator
import warnings
from typing import NamedTuple, Optional, Sequence, Union
from typing import Any, NamedTuple, Optional, Sequence, Union
import numpy as np
@ -479,15 +479,17 @@ def bcsr_dot_general(lhs: Union[BCSR, Array], rhs: Array, *,
are sparse, the result will be sparse, of type BCSR. If either input is
dense, the result will be dense, of type ndarray.
"""
del precision, preferred_element_type # unused
del precision # unused
if isinstance(rhs, (np.ndarray, jax.Array)):
if isinstance(lhs, (np.ndarray, jax.Array)):
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type)
if isinstance(lhs, BCSR):
lhs_data, lhs_indices, lhs_indptr = lhs._bufs
return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type,
lhs_spinfo=lhs._info)
raise NotImplementedError("bcsr_dot_general currently implemented for BCSR "
@ -497,6 +499,7 @@ def bcsr_dot_general(lhs: Union[BCSR, Array], rhs: Array, *,
def _bcsr_dot_general(lhs_data: jax.Array, lhs_indices: jax.Array,
lhs_indptr: jax.Array, rhs: Array, *,
dimension_numbers: DotDimensionNumbers,
preferred_element_type: Any,
lhs_spinfo: SparseInfo) -> Array:
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
@ -507,11 +510,12 @@ def _bcsr_dot_general(lhs_data: jax.Array, lhs_indices: jax.Array,
jnp.asarray(lhs_indices),
jnp.asarray(lhs_indptr), jnp.asarray(rhs),
dimension_numbers=(cdims, bdims),
preferred_element_type=preferred_element_type,
lhs_spinfo=lhs_spinfo)
def _bcsr_dot_general_impl(lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
dimension_numbers, preferred_element_type, lhs_spinfo):
lhs_data = jnp.asarray(lhs_data)
lhs_bcsr_indices = jnp.asarray(lhs_indices)
lhs_bcsr_indptr = jnp.asarray(lhs_indptr)
@ -520,21 +524,21 @@ def _bcsr_dot_general_impl(lhs_data, lhs_indices, lhs_indptr, rhs, *,
shape=lhs_spinfo.shape)
return bcoo._bcoo_dot_general_impl(lhs_data, lhs_bcoo_indices, rhs,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type,
lhs_spinfo=lhs_spinfo)
@bcsr_dot_general_p.def_abstract_eval
def _bcsr_dot_general_abstract_eval(lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
if lhs_data.dtype != rhs.dtype:
raise ValueError("bcsr_dot_general requires arguments to have matching "
f"dtypes; got lhs.dtype={lhs_data.dtype}, "
f"rhs.dtype={rhs.dtype}")
dimension_numbers, preferred_element_type, lhs_spinfo):
(lhs_contracting, _), (lhs_batch, _) = dimension_numbers
props = _validate_bcsr_indices(lhs_indices, lhs_indptr, lhs_spinfo.shape)
out_shape = _dot_general_validated_shape(lhs_spinfo.shape, rhs.shape,
dimension_numbers)
out_aval = jax.eval_shape(
partial(lax.dot_general,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type),
jax.ShapeDtypeStruct(lhs_spinfo.shape, lhs_data.dtype),
jax.ShapeDtypeStruct(rhs.shape, rhs.dtype))
if lhs_batch and max(lhs_batch) >= props.n_batch:
raise NotImplementedError(
@ -545,38 +549,41 @@ def _bcsr_dot_general_abstract_eval(lhs_data, lhs_indices, lhs_indptr, rhs, *,
if any(d >= props.n_batch + 2 for d in lhs_contracting):
raise NotImplementedError("bcsr_dot_general: contracting over dense dimensions.")
return core.ShapedArray(out_shape, lhs_data.dtype)
return core.ShapedArray(out_aval.shape, out_aval.dtype)
def _bcsr_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
dimension_numbers, preferred_element_type, lhs_spinfo):
del lhs_data
return _bcsr_dot_general(lhs_data_dot, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type,
lhs_spinfo=lhs_spinfo)
def _bcsr_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
dimension_numbers, preferred_element_type, lhs_spinfo):
del rhs
return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs_dot,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type,
lhs_spinfo=lhs_spinfo)
def _bcsr_dot_general_transpose(ct, lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
dimension_numbers, preferred_element_type, lhs_spinfo):
# TODO(jakevdp): implement this in terms of bcsr_dot_general
lhs_bcoo_indices = _bcsr_to_bcoo(
lhs_indices, lhs_indptr, shape=lhs_spinfo.shape)
data_out, _, rhs_out = bcoo._bcoo_dot_general_transpose(
ct, lhs_data, lhs_bcoo_indices, rhs, dimension_numbers=dimension_numbers,
lhs_spinfo=lhs_spinfo)
preferred_element_type=preferred_element_type, lhs_spinfo=lhs_spinfo)
return data_out, lhs_indices, lhs_indptr, rhs_out
def _bcsr_dot_general_batch_rule(batched_args, batch_dims, *,
dimension_numbers, lhs_spinfo):
dimension_numbers, preferred_element_type,
lhs_spinfo):
*lhs_args, rhs = batched_args
*lhs_dims, rhs_bdim = batch_dims
*new_lhs_args, new_lhs_spinfo = _bcsr_batch_dims_to_front(
@ -585,7 +592,8 @@ def _bcsr_dot_general_batch_rule(batched_args, batch_dims, *,
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
(len(lhs_spinfo.shape), rhs.ndim), (0, rhs_bdim), dimension_numbers)
batched_out = _bcsr_dot_general(*new_lhs_args, rhs, lhs_spinfo=new_lhs_spinfo,
dimension_numbers=new_dimension_numbers)
dimension_numbers=new_dimension_numbers,
preferred_element_type=preferred_element_type)
return batched_out, result_batch_dim
@ -613,12 +621,14 @@ _bcsr_correct_out_of_bound_indices_lowered = mlir.lower_fun(
def _bcsr_dot_general_gpu_lowering(
csr_matvec_lowering, csr_matmat_lowering,
ctx, lhs_data, lhs_indices, lhs_indptr, rhs, *, dimension_numbers,
lhs_spinfo: SparseInfo):
preferred_element_type, lhs_spinfo: SparseInfo):
if not config.jax_bcoo_cusparse_lowering:
return _bcsr_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type,
lhs_spinfo=lhs_spinfo)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, rhs_aval = ctx.avals_in
@ -631,6 +641,10 @@ def _bcsr_dot_general_gpu_lowering(
if lhs_batch or rhs_batch:
# batch dimensions in dot_general are not supported
use_default_lowering = True
elif (lhs_data_aval.dtype != rhs_aval.dtype):
use_default_lowering = True
elif preferred_element_type is not None and preferred_element_type != lhs_data_aval.dtype:
use_default_lowering = True
elif len(lhs_spinfo.shape) != 2 or rhs_aval.ndim not in [1, 2]:
# only matmat / matvec supported
use_default_lowering = True
@ -650,7 +664,9 @@ def _bcsr_dot_general_gpu_lowering(
if use_default_lowering:
return _bcsr_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type,
lhs_spinfo=lhs_spinfo)
# Account for a bug in cusparse: it references indices and data beyond
# the extent of indptr.

View File

@ -1289,7 +1289,7 @@ class BCOOTest(sptu.SparseTestCase):
def f_sparse(data, indices, Y):
return sparse_bcoo._bcoo_dot_general(data, indices, Y, lhs_spinfo=sparse_util.SparseInfo(X.shape),
dimension_numbers=dimension_numbers)
dimension_numbers=dimension_numbers, preferred_element_type=None)
for data, indices in itertools.product([data, data[:1]], [indices, indices[:1]]):
X = sparse_bcoo._bcoo_todense(data, indices, spinfo=sparse_util.SparseInfo(X.shape))

View File

@ -162,27 +162,29 @@ class SparsifyTest(jtu.JaxTestCase):
self.assertAllClose(result_sparse.todense(), result_dense)
@jax.numpy_dtype_promotion('standard')
def testSparseMatmul(self):
X = jnp.arange(16.0).reshape(4, 4)
X = jnp.arange(16.0, dtype='float32').reshape(4, 4)
Xsp = BCOO.fromdense(X)
Y = jnp.ones(4)
Y = jnp.ones(4, dtype='int32')
Ysp = BCOO.fromdense(Y)
func = self.sparsify(operator.matmul)
# Note: deliberately testing with mixed precision
assert Xsp.dtype != Ysp.dtype
# dot_general
result_sparse = func(Xsp, Y)
result_dense = operator.matmul(X, Y)
result_sparse = self.sparsify(lax.dot)(Xsp, Y)
result_dense = lax.dot(X, Y)
self.assertAllClose(result_sparse, result_dense)
# rdot_general
result_sparse = func(Y, Xsp)
result_dense = operator.matmul(Y, X)
result_sparse = self.sparsify(lax.dot)(Y, Xsp)
result_dense = lax.dot(Y, X)
self.assertAllClose(result_sparse, result_dense)
# spdot_general
result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp)
result_dense = operator.matmul(X, Y)
# spdot_general
result_sparse = self.sparsify(lax.dot)(Xsp, Ysp)
result_dense = lax.dot(X, Y)
self.assertAllClose(result_sparse.todense(), result_dense)
def testSparseAdd(self):