mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #16740 from jakevdp:spdot-general-args
PiperOrigin-RevId: 548773744
This commit is contained in:
commit
68ea651ae4
@ -625,52 +625,65 @@ def bcoo_dot_general(lhs: Union[BCOO, Array], rhs: Union[BCOO, Array], *, dimens
|
||||
the result will be dense, of type ndarray.
|
||||
"""
|
||||
# TODO(jakevdp) make use of these?
|
||||
del precision, preferred_element_type # unused
|
||||
del precision # unused
|
||||
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
|
||||
shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
|
||||
dimension_numbers)
|
||||
bufs = _bcoo_spdot_general(lhs.data, lhs.indices, rhs.data, rhs.indices,
|
||||
lhs_spinfo=lhs._info, rhs_spinfo=rhs._info,
|
||||
dimension_numbers=dimension_numbers)
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type)
|
||||
return BCOO(bufs, shape=shape)
|
||||
elif isinstance(lhs, BCOO):
|
||||
return _bcoo_dot_general(lhs.data, lhs.indices, rhs, dimension_numbers=dimension_numbers, # type: ignore[arg-type]
|
||||
preferred_element_type=preferred_element_type,
|
||||
lhs_spinfo=lhs._info)
|
||||
elif isinstance(rhs, BCOO):
|
||||
return _bcoo_rdot_general(lhs, rhs.data, rhs.indices, dimension_numbers=dimension_numbers, # type: ignore[arg-type]
|
||||
preferred_element_type=preferred_element_type,
|
||||
rhs_spinfo=rhs._info)
|
||||
else:
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
||||
def _bcoo_dot_general(lhs_data: Array, lhs_indices: Array, rhs: Array, *,
|
||||
dimension_numbers: DotDimensionNumbers, lhs_spinfo: SparseInfo) -> Array:
|
||||
dimension_numbers: DotDimensionNumbers,
|
||||
preferred_element_type: Any,
|
||||
lhs_spinfo: SparseInfo) -> Array:
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
cdims = (api_util._ensure_index_tuple(lhs_contract),
|
||||
api_util._ensure_index_tuple(rhs_contract))
|
||||
bdims = (api_util._ensure_index_tuple(lhs_batch),
|
||||
api_util._ensure_index_tuple(rhs_batch))
|
||||
if preferred_element_type is not None:
|
||||
preferred_element_type = np.dtype(preferred_element_type)
|
||||
return bcoo_dot_general_p.bind(jnp.asarray(lhs_data), jnp.asarray(lhs_indices), jnp.asarray(rhs),
|
||||
dimension_numbers=(cdims, bdims),
|
||||
preferred_element_type=preferred_element_type,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
|
||||
def _bcoo_rdot_general(lhs: Array, rhs_data: Array, rhs_indices: Array, *,
|
||||
dimension_numbers: DotDimensionNumbers, rhs_spinfo: SparseInfo) -> Array:
|
||||
dimension_numbers: DotDimensionNumbers,
|
||||
preferred_element_type: Any, rhs_spinfo: SparseInfo) -> Array:
|
||||
# TODO(jakevdp): perhaps this should be part of the bcoo_dot_general primitive?
|
||||
dimension_numbers_reversed: DotDimensionNumbers = tuple(d[::-1] for d in dimension_numbers) # type: ignore[assignment]
|
||||
result = _bcoo_dot_general(rhs_data, rhs_indices, lhs, lhs_spinfo=rhs_spinfo,
|
||||
dimension_numbers=dimension_numbers_reversed)
|
||||
dimension_numbers=dimension_numbers_reversed,
|
||||
preferred_element_type=preferred_element_type)
|
||||
n_contract, n_batch = (len(d[0]) for d in dimension_numbers)
|
||||
n_swap = len(rhs_spinfo.shape) - n_contract
|
||||
permutation = tuple([*range(n_batch), *range(n_swap, result.ndim), *range(n_batch, n_swap)])
|
||||
return lax.transpose(result, permutation)
|
||||
|
||||
def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: SparseInfo):
|
||||
def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers,
|
||||
preferred_element_type, lhs_spinfo: SparseInfo):
|
||||
lhs_data = jnp.asarray(lhs_data)
|
||||
lhs_indices = jnp.asarray(lhs_indices)
|
||||
rhs = jnp.asarray(rhs)
|
||||
# Validate all inputs via abstract_eval
|
||||
out_aval = _bcoo_dot_general_abstract_eval(lhs_data.aval, lhs_indices.aval, rhs.aval,
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
n_sparse = lhs_indices.shape[-1]
|
||||
n_batch = lhs_indices.ndim - 2
|
||||
@ -720,16 +733,17 @@ def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs
|
||||
return result(out_array, lhs_data, lhs_indices, rhs)
|
||||
|
||||
@bcoo_dot_general_p.def_abstract_eval
|
||||
def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: SparseInfo):
|
||||
if lhs_data.dtype != rhs.dtype:
|
||||
raise ValueError("bcoo_dot_general requires arguments to have matching dtypes; "
|
||||
f"got lhs.dtype={lhs_data.dtype}, rhs.dtype={rhs.dtype}")
|
||||
def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers,
|
||||
preferred_element_type, lhs_spinfo: SparseInfo):
|
||||
out_aval = jax.eval_shape(
|
||||
partial(lax.dot_general,
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type),
|
||||
jax.ShapeDtypeStruct(lhs_spinfo.shape, lhs_data.dtype),
|
||||
jax.ShapeDtypeStruct(rhs.shape, rhs.dtype))
|
||||
|
||||
(lhs_contracting, _), (lhs_batch, _) = dimension_numbers
|
||||
n_batch, n_sparse, _, _ = _validate_bcoo(lhs_data, lhs_indices, lhs_spinfo.shape)
|
||||
out_shape = _dot_general_validated_shape(lhs_spinfo.shape, rhs.shape,
|
||||
dimension_numbers)
|
||||
|
||||
if lhs_batch and max(lhs_batch) >= n_batch:
|
||||
raise NotImplementedError(
|
||||
"bcoo_dot_general batch dimensions must be among the batch dimensions in the sparse representtaion.\n"
|
||||
@ -739,7 +753,7 @@ def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_num
|
||||
if any(d >= n_batch + n_sparse for d in lhs_contracting):
|
||||
raise NotImplementedError("bcoo_dot_general: contracting over dense dimensions.")
|
||||
|
||||
return core.ShapedArray(out_shape, lhs_data.dtype)
|
||||
return core.ShapedArray(out_aval.shape, out_aval.dtype)
|
||||
|
||||
_bcoo_dot_general_default_lowering = mlir.lower_fun(
|
||||
_bcoo_dot_general_impl, multiple_results=False)
|
||||
@ -766,16 +780,34 @@ def _bcoo_dot_general_fallback(data, indices, spinfo):
|
||||
return False
|
||||
|
||||
def _bcoo_dot_general_gpu_impl(lhs_data, lhs_indices, rhs, *,
|
||||
dimension_numbers, lhs_spinfo):
|
||||
dimension_numbers, preferred_element_type,
|
||||
lhs_spinfo):
|
||||
if not config.jax_bcoo_cusparse_lowering:
|
||||
return _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
|
||||
(lhs_contract, rhs_contract), (lhs_batch, _) = dimension_numbers
|
||||
n_batch, n_sparse, n_dense, _ = _validate_bcoo(
|
||||
lhs_data, lhs_indices, lhs_spinfo.shape)
|
||||
coo_matmul_p = coo_spmv_p if rhs.ndim == 1 else coo_spmm_p
|
||||
|
||||
out_aval = _bcoo_dot_general_abstract_eval(
|
||||
lhs_data, lhs_indices, rhs,
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
|
||||
if out_aval.dtype not in CUSPARSE_DATA_DTYPES:
|
||||
return _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs,
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
|
||||
lhs_data = lhs_data.astype(out_aval.dtype)
|
||||
rhs = rhs.astype(out_aval.dtype)
|
||||
|
||||
# TODO(jakevdp, tianjianlu): add support for batched lowerings
|
||||
if (len(lhs_contract) == 1 and len(lhs_batch) == 0 and rhs.ndim in (1, 2)
|
||||
and (n_batch, n_sparse, n_dense) == (0, 1, 0)
|
||||
@ -801,18 +833,24 @@ def _bcoo_dot_general_gpu_impl(lhs_data, lhs_indices, rhs, *,
|
||||
return out[:-1]
|
||||
else:
|
||||
return _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo,
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
||||
_bcoo_dot_general_gpu_lowering = mlir.lower_fun(
|
||||
_bcoo_dot_general_gpu_impl, multiple_results=False)
|
||||
|
||||
def _bcoo_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: SparseInfo):
|
||||
return _bcoo_dot_general(lhs_data_dot, lhs_indices, rhs, dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
def _bcoo_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers,
|
||||
preferred_element_type, lhs_spinfo: SparseInfo):
|
||||
return _bcoo_dot_general(lhs_data_dot, lhs_indices, rhs, dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type, lhs_spinfo=lhs_spinfo)
|
||||
|
||||
def _bcoo_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: SparseInfo):
|
||||
return _bcoo_dot_general(lhs_data, lhs_indices, rhs_dot, dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
def _bcoo_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, rhs, *, dimension_numbers,
|
||||
preferred_element_type, lhs_spinfo: SparseInfo):
|
||||
return _bcoo_dot_general(lhs_data, lhs_indices, rhs_dot, dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type, lhs_spinfo=lhs_spinfo)
|
||||
|
||||
def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: SparseInfo):
|
||||
def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_numbers,
|
||||
preferred_element_type, lhs_spinfo: SparseInfo):
|
||||
assert not ad.is_undefined_primal(lhs_indices)
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_ndim = len(lhs_spinfo.shape)
|
||||
@ -857,10 +895,13 @@ def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_num
|
||||
dims = ((lhs_kept, ans_lhs), (lhs_batch, ans_batch)) # type: ignore[assignment]
|
||||
rhs_contract_sorted_by_lhs = list(np.take(rhs_contract, np.argsort(lhs_contract)))
|
||||
out_axes = list(np.argsort(list(rhs_batch) + rhs_contract_sorted_by_lhs + rhs_kept))
|
||||
result = _bcoo_dot_general(lhs_data, lhs_indices, ct, lhs_spinfo=lhs_spinfo, dimension_numbers=dims)
|
||||
result = _bcoo_dot_general(lhs_data, lhs_indices, ct, lhs_spinfo=lhs_spinfo,
|
||||
preferred_element_type=preferred_element_type,
|
||||
dimension_numbers=dims)
|
||||
return lhs_data, lhs_indices, lax.transpose(result, out_axes)
|
||||
|
||||
def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, lhs_spinfo: SparseInfo):
|
||||
def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
preferred_element_type, lhs_spinfo: SparseInfo):
|
||||
_, _, rhs = batched_args
|
||||
_, _, rhs_bdim = batch_dims
|
||||
new_lhs_data, new_lhs_indices, new_lhs_spinfo = _bcoo_batch_dims_to_front(
|
||||
@ -869,6 +910,7 @@ def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
|
||||
(len(lhs_spinfo.shape), rhs.ndim), (0, rhs_bdim), dimension_numbers)
|
||||
batched_out = _bcoo_dot_general(new_lhs_data, new_lhs_indices, rhs, lhs_spinfo=new_lhs_spinfo,
|
||||
preferred_element_type=preferred_element_type,
|
||||
dimension_numbers=new_dimension_numbers)
|
||||
return batched_out, result_batch_dim
|
||||
|
||||
@ -1041,7 +1083,8 @@ bcoo_spdot_general_p = core.Primitive('bcoo_spdot_general')
|
||||
bcoo_spdot_general_p.multiple_results = True
|
||||
|
||||
def _bcoo_spdot_general(lhs_data: Array, lhs_indices: Array, rhs_data: Array, rhs_indices: Array, *,
|
||||
lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers: DotDimensionNumbers) -> tuple[Array, Array]:
|
||||
lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers: DotDimensionNumbers,
|
||||
preferred_element_type: Any) -> tuple[Array, Array]:
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
cdims = (api_util._ensure_index_tuple(lhs_contract),
|
||||
api_util._ensure_index_tuple(rhs_contract))
|
||||
@ -1049,7 +1092,8 @@ def _bcoo_spdot_general(lhs_data: Array, lhs_indices: Array, rhs_data: Array, rh
|
||||
api_util._ensure_index_tuple(rhs_batch))
|
||||
return bcoo_spdot_general_p.bind(lhs_data, lhs_indices, rhs_data, rhs_indices,
|
||||
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo,
|
||||
dimension_numbers=(cdims, bdims))
|
||||
dimension_numbers=(cdims, bdims),
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
||||
def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo, rhs_spinfo, lhs_contracting, rhs_contracting, out_nse):
|
||||
lhs_shape = lhs_spinfo.shape
|
||||
@ -1098,7 +1142,8 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices,
|
||||
return _bcoo_sum_duplicates(out_data, out_indices, spinfo=SparseInfo(shape=out_shape), nse=out_nse)
|
||||
|
||||
@bcoo_spdot_general_p.def_impl
|
||||
def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers):
|
||||
def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo,
|
||||
dimension_numbers, preferred_element_type):
|
||||
lhs_shape = lhs_spinfo.shape
|
||||
rhs_shape = rhs_spinfo.shape
|
||||
|
||||
@ -1107,7 +1152,8 @@ def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lh
|
||||
assert lhs.n_dense == rhs.n_dense == 0
|
||||
data_aval, _ = _bcoo_spdot_general_abstract_eval(
|
||||
lhs_data.aval, lhs_indices.aval, rhs_data.aval, rhs_indices.aval,
|
||||
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo, dimension_numbers=dimension_numbers)
|
||||
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo, dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type)
|
||||
out_nse = data_aval.shape[-1]
|
||||
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
@ -1134,19 +1180,20 @@ def _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lh
|
||||
return func(lhs_data, lhs_indices, rhs_data, rhs_indices)
|
||||
|
||||
@bcoo_spdot_general_p.def_abstract_eval
|
||||
def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers):
|
||||
def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo,
|
||||
dimension_numbers, preferred_element_type):
|
||||
lhs_shape = lhs_spinfo.shape
|
||||
rhs_shape = rhs_spinfo.shape
|
||||
out_shape = _dot_general_validated_shape(lhs_shape, rhs_shape,
|
||||
dimension_numbers)
|
||||
out_aval = jax.eval_shape(
|
||||
partial(lax.dot_general,
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type),
|
||||
jax.ShapeDtypeStruct(lhs_shape, lhs_data.dtype),
|
||||
jax.ShapeDtypeStruct(rhs_shape, rhs_data.dtype))
|
||||
|
||||
if lhs_data.dtype != rhs_data.dtype:
|
||||
raise ValueError("bcoo_spdot_general requires inputs to have matching dtypes; "
|
||||
f"got lhs.dtype={lhs_data.dtype}, rhs.dtype={rhs_data.dtype}")
|
||||
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
|
||||
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
_ = _dot_general_validated_shape(lhs_shape, rhs_shape, dimension_numbers)
|
||||
|
||||
if lhs.n_dense or rhs.n_dense:
|
||||
# TODO(jakevdp): handle dense dimensions
|
||||
@ -1172,7 +1219,7 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic
|
||||
# Ensure we're not storing more output elements than necessary.
|
||||
# TODO(jakevdp): should we warn here if output is effectively dense?
|
||||
out_n_batch = lhs.n_batch + rhs.n_batch - len(lhs_batch)
|
||||
out_nse = min(out_nse, math.prod(out_shape[out_n_batch:]))
|
||||
out_nse = min(out_nse, math.prod(out_aval.shape[out_n_batch:]))
|
||||
|
||||
data_shape = (
|
||||
*(lhs_shape[dim] for dim in lhs_batch),
|
||||
@ -1185,13 +1232,14 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic
|
||||
*(rhs_indices.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch),
|
||||
out_nse, lhs.n_sparse + rhs.n_sparse - 2 * len(lhs_contracting))
|
||||
|
||||
data_aval = core.ShapedArray(data_shape, lhs_data.dtype)
|
||||
data_aval = core.ShapedArray(data_shape, out_aval.dtype)
|
||||
indices_aval = core.ShapedArray(indices_shape, lhs_indices.dtype)
|
||||
_validate_bcoo(data_aval, indices_aval, out_shape) # pytype: disable=wrong-arg-types # always-use-return-annotations
|
||||
_validate_bcoo(data_aval, indices_aval, out_aval.shape) # pytype: disable=wrong-arg-types # always-use-return-annotations
|
||||
|
||||
return data_aval, indices_aval
|
||||
|
||||
def _bcoo_spdot_general_batch_rule(batched_args, batch_dims, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo, dimension_numbers):
|
||||
def _bcoo_spdot_general_batch_rule(batched_args, batch_dims, *, lhs_spinfo: SparseInfo, rhs_spinfo: SparseInfo,
|
||||
preferred_element_type, dimension_numbers):
|
||||
lhs_ndim = len(lhs_spinfo.shape)
|
||||
rhs_ndim = len(rhs_spinfo.shape)
|
||||
batch_size = max(arg.shape[dim] for arg, dim in zip(batched_args, batch_dims) if dim is not None)
|
||||
@ -1203,7 +1251,8 @@ def _bcoo_spdot_general_batch_rule(batched_args, batch_dims, *, lhs_spinfo: Spar
|
||||
(lhs_ndim, rhs_ndim), (0, 0), dimension_numbers)
|
||||
batched_out = _bcoo_spdot_general(lhs_data, lhs_indices, rhs_data, rhs_indices,
|
||||
dimension_numbers=dimension_numbers,
|
||||
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo)
|
||||
lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo,
|
||||
preferred_element_type=preferred_element_type)
|
||||
return batched_out, (result_batch_dim, result_batch_dim)
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ from functools import partial
|
||||
import operator
|
||||
import warnings
|
||||
|
||||
from typing import NamedTuple, Optional, Sequence, Union
|
||||
from typing import Any, NamedTuple, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -479,15 +479,17 @@ def bcsr_dot_general(lhs: Union[BCSR, Array], rhs: Array, *,
|
||||
are sparse, the result will be sparse, of type BCSR. If either input is
|
||||
dense, the result will be dense, of type ndarray.
|
||||
"""
|
||||
del precision, preferred_element_type # unused
|
||||
del precision # unused
|
||||
if isinstance(rhs, (np.ndarray, jax.Array)):
|
||||
if isinstance(lhs, (np.ndarray, jax.Array)):
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
||||
if isinstance(lhs, BCSR):
|
||||
lhs_data, lhs_indices, lhs_indptr = lhs._bufs
|
||||
return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs,
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
lhs_spinfo=lhs._info)
|
||||
|
||||
raise NotImplementedError("bcsr_dot_general currently implemented for BCSR "
|
||||
@ -497,6 +499,7 @@ def bcsr_dot_general(lhs: Union[BCSR, Array], rhs: Array, *,
|
||||
def _bcsr_dot_general(lhs_data: jax.Array, lhs_indices: jax.Array,
|
||||
lhs_indptr: jax.Array, rhs: Array, *,
|
||||
dimension_numbers: DotDimensionNumbers,
|
||||
preferred_element_type: Any,
|
||||
lhs_spinfo: SparseInfo) -> Array:
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
cdims = (api_util._ensure_index_tuple(lhs_contract),
|
||||
@ -507,11 +510,12 @@ def _bcsr_dot_general(lhs_data: jax.Array, lhs_indices: jax.Array,
|
||||
jnp.asarray(lhs_indices),
|
||||
jnp.asarray(lhs_indptr), jnp.asarray(rhs),
|
||||
dimension_numbers=(cdims, bdims),
|
||||
preferred_element_type=preferred_element_type,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
|
||||
|
||||
def _bcsr_dot_general_impl(lhs_data, lhs_indices, lhs_indptr, rhs, *,
|
||||
dimension_numbers, lhs_spinfo):
|
||||
dimension_numbers, preferred_element_type, lhs_spinfo):
|
||||
lhs_data = jnp.asarray(lhs_data)
|
||||
lhs_bcsr_indices = jnp.asarray(lhs_indices)
|
||||
lhs_bcsr_indptr = jnp.asarray(lhs_indptr)
|
||||
@ -520,21 +524,21 @@ def _bcsr_dot_general_impl(lhs_data, lhs_indices, lhs_indptr, rhs, *,
|
||||
shape=lhs_spinfo.shape)
|
||||
return bcoo._bcoo_dot_general_impl(lhs_data, lhs_bcoo_indices, rhs,
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
|
||||
|
||||
@bcsr_dot_general_p.def_abstract_eval
|
||||
def _bcsr_dot_general_abstract_eval(lhs_data, lhs_indices, lhs_indptr, rhs, *,
|
||||
dimension_numbers, lhs_spinfo):
|
||||
if lhs_data.dtype != rhs.dtype:
|
||||
raise ValueError("bcsr_dot_general requires arguments to have matching "
|
||||
f"dtypes; got lhs.dtype={lhs_data.dtype}, "
|
||||
f"rhs.dtype={rhs.dtype}")
|
||||
|
||||
dimension_numbers, preferred_element_type, lhs_spinfo):
|
||||
(lhs_contracting, _), (lhs_batch, _) = dimension_numbers
|
||||
props = _validate_bcsr_indices(lhs_indices, lhs_indptr, lhs_spinfo.shape)
|
||||
out_shape = _dot_general_validated_shape(lhs_spinfo.shape, rhs.shape,
|
||||
dimension_numbers)
|
||||
out_aval = jax.eval_shape(
|
||||
partial(lax.dot_general,
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type),
|
||||
jax.ShapeDtypeStruct(lhs_spinfo.shape, lhs_data.dtype),
|
||||
jax.ShapeDtypeStruct(rhs.shape, rhs.dtype))
|
||||
|
||||
if lhs_batch and max(lhs_batch) >= props.n_batch:
|
||||
raise NotImplementedError(
|
||||
@ -545,38 +549,41 @@ def _bcsr_dot_general_abstract_eval(lhs_data, lhs_indices, lhs_indptr, rhs, *,
|
||||
if any(d >= props.n_batch + 2 for d in lhs_contracting):
|
||||
raise NotImplementedError("bcsr_dot_general: contracting over dense dimensions.")
|
||||
|
||||
return core.ShapedArray(out_shape, lhs_data.dtype)
|
||||
return core.ShapedArray(out_aval.shape, out_aval.dtype)
|
||||
|
||||
|
||||
def _bcsr_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, lhs_indptr, rhs, *,
|
||||
dimension_numbers, lhs_spinfo):
|
||||
dimension_numbers, preferred_element_type, lhs_spinfo):
|
||||
del lhs_data
|
||||
return _bcsr_dot_general(lhs_data_dot, lhs_indices, lhs_indptr, rhs,
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
|
||||
|
||||
def _bcsr_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, lhs_indptr, rhs, *,
|
||||
dimension_numbers, lhs_spinfo):
|
||||
dimension_numbers, preferred_element_type, lhs_spinfo):
|
||||
del rhs
|
||||
return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs_dot,
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
|
||||
|
||||
def _bcsr_dot_general_transpose(ct, lhs_data, lhs_indices, lhs_indptr, rhs, *,
|
||||
dimension_numbers, lhs_spinfo):
|
||||
dimension_numbers, preferred_element_type, lhs_spinfo):
|
||||
# TODO(jakevdp): implement this in terms of bcsr_dot_general
|
||||
lhs_bcoo_indices = _bcsr_to_bcoo(
|
||||
lhs_indices, lhs_indptr, shape=lhs_spinfo.shape)
|
||||
data_out, _, rhs_out = bcoo._bcoo_dot_general_transpose(
|
||||
ct, lhs_data, lhs_bcoo_indices, rhs, dimension_numbers=dimension_numbers,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
preferred_element_type=preferred_element_type, lhs_spinfo=lhs_spinfo)
|
||||
return data_out, lhs_indices, lhs_indptr, rhs_out
|
||||
|
||||
|
||||
def _bcsr_dot_general_batch_rule(batched_args, batch_dims, *,
|
||||
dimension_numbers, lhs_spinfo):
|
||||
dimension_numbers, preferred_element_type,
|
||||
lhs_spinfo):
|
||||
*lhs_args, rhs = batched_args
|
||||
*lhs_dims, rhs_bdim = batch_dims
|
||||
*new_lhs_args, new_lhs_spinfo = _bcsr_batch_dims_to_front(
|
||||
@ -585,7 +592,8 @@ def _bcsr_dot_general_batch_rule(batched_args, batch_dims, *,
|
||||
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
|
||||
(len(lhs_spinfo.shape), rhs.ndim), (0, rhs_bdim), dimension_numbers)
|
||||
batched_out = _bcsr_dot_general(*new_lhs_args, rhs, lhs_spinfo=new_lhs_spinfo,
|
||||
dimension_numbers=new_dimension_numbers)
|
||||
dimension_numbers=new_dimension_numbers,
|
||||
preferred_element_type=preferred_element_type)
|
||||
return batched_out, result_batch_dim
|
||||
|
||||
|
||||
@ -613,12 +621,14 @@ _bcsr_correct_out_of_bound_indices_lowered = mlir.lower_fun(
|
||||
def _bcsr_dot_general_gpu_lowering(
|
||||
csr_matvec_lowering, csr_matmat_lowering,
|
||||
ctx, lhs_data, lhs_indices, lhs_indptr, rhs, *, dimension_numbers,
|
||||
lhs_spinfo: SparseInfo):
|
||||
preferred_element_type, lhs_spinfo: SparseInfo):
|
||||
|
||||
if not config.jax_bcoo_cusparse_lowering:
|
||||
return _bcsr_dot_general_default_lowering(
|
||||
ctx, lhs_data, lhs_indices, lhs_indptr, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, rhs_aval = ctx.avals_in
|
||||
@ -631,6 +641,10 @@ def _bcsr_dot_general_gpu_lowering(
|
||||
if lhs_batch or rhs_batch:
|
||||
# batch dimensions in dot_general are not supported
|
||||
use_default_lowering = True
|
||||
elif (lhs_data_aval.dtype != rhs_aval.dtype):
|
||||
use_default_lowering = True
|
||||
elif preferred_element_type is not None and preferred_element_type != lhs_data_aval.dtype:
|
||||
use_default_lowering = True
|
||||
elif len(lhs_spinfo.shape) != 2 or rhs_aval.ndim not in [1, 2]:
|
||||
# only matmat / matvec supported
|
||||
use_default_lowering = True
|
||||
@ -650,7 +664,9 @@ def _bcsr_dot_general_gpu_lowering(
|
||||
if use_default_lowering:
|
||||
return _bcsr_dot_general_default_lowering(
|
||||
ctx, lhs_data, lhs_indices, lhs_indptr, rhs,
|
||||
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
lhs_spinfo=lhs_spinfo)
|
||||
|
||||
# Account for a bug in cusparse: it references indices and data beyond
|
||||
# the extent of indptr.
|
||||
|
@ -1289,7 +1289,7 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
|
||||
def f_sparse(data, indices, Y):
|
||||
return sparse_bcoo._bcoo_dot_general(data, indices, Y, lhs_spinfo=sparse_util.SparseInfo(X.shape),
|
||||
dimension_numbers=dimension_numbers)
|
||||
dimension_numbers=dimension_numbers, preferred_element_type=None)
|
||||
|
||||
for data, indices in itertools.product([data, data[:1]], [indices, indices[:1]]):
|
||||
X = sparse_bcoo._bcoo_todense(data, indices, spinfo=sparse_util.SparseInfo(X.shape))
|
||||
|
@ -162,27 +162,29 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(result_sparse.todense(), result_dense)
|
||||
|
||||
@jax.numpy_dtype_promotion('standard')
|
||||
def testSparseMatmul(self):
|
||||
X = jnp.arange(16.0).reshape(4, 4)
|
||||
X = jnp.arange(16.0, dtype='float32').reshape(4, 4)
|
||||
Xsp = BCOO.fromdense(X)
|
||||
Y = jnp.ones(4)
|
||||
Y = jnp.ones(4, dtype='int32')
|
||||
Ysp = BCOO.fromdense(Y)
|
||||
|
||||
func = self.sparsify(operator.matmul)
|
||||
# Note: deliberately testing with mixed precision
|
||||
assert Xsp.dtype != Ysp.dtype
|
||||
|
||||
# dot_general
|
||||
result_sparse = func(Xsp, Y)
|
||||
result_dense = operator.matmul(X, Y)
|
||||
result_sparse = self.sparsify(lax.dot)(Xsp, Y)
|
||||
result_dense = lax.dot(X, Y)
|
||||
self.assertAllClose(result_sparse, result_dense)
|
||||
|
||||
# rdot_general
|
||||
result_sparse = func(Y, Xsp)
|
||||
result_dense = operator.matmul(Y, X)
|
||||
result_sparse = self.sparsify(lax.dot)(Y, Xsp)
|
||||
result_dense = lax.dot(Y, X)
|
||||
self.assertAllClose(result_sparse, result_dense)
|
||||
|
||||
# spdot_general
|
||||
result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp)
|
||||
result_dense = operator.matmul(X, Y)
|
||||
# spdot_general
|
||||
result_sparse = self.sparsify(lax.dot)(Xsp, Ysp)
|
||||
result_dense = lax.dot(X, Y)
|
||||
self.assertAllClose(result_sparse.todense(), result_dense)
|
||||
|
||||
def testSparseAdd(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user