[sparse] cleanup validation function

This commit is contained in:
Jake VanderPlas 2021-06-22 10:39:34 -07:00
parent afe2d908ce
commit 98ce6c5892

View File

@ -609,6 +609,8 @@ def _validate_bcoo(data, indices, shape):
assert _compatible(indices.shape[:n_batch], shape[:n_batch])
assert indices.shape[n_batch:] == (n_sparse, nse)
return n_batch, n_sparse, n_dense
#----------------------------------------------------------------------
# bcoo_todense
@ -632,9 +634,7 @@ def bcoo_todense(data, indices, *, shape):
@bcoo_todense_p.def_impl
def _bcoo_todense_impl(data, indices, *, shape):
_validate_bcoo(data, indices, shape)
n_sparse = indices.shape[-2]
n_batch = indices.ndim - 2
n_batch, n_sparse, _ = _validate_bcoo(data, indices, shape)
batch_slices = tuple(slice(s) for s in shape[:n_batch])
sparse_ind = tuple(indices[tuple(np.mgrid[batch_slices]) + (i,)] for i in range(n_sparse))
batch_ind = tuple(np.mgrid[batch_slices + (slice(1),)])[:-1]
@ -907,9 +907,7 @@ def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs
@bcoo_dot_general_p.def_abstract_eval
def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
n_sparse = lhs_indices.shape[-2]
n_batch = lhs_indices.ndim - 2
_validate_bcoo(lhs_data, lhs_indices, lhs_shape)
n_batch, n_sparse, _ = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
# Check for proper dimension_numbers
for dims in [lhs_contracting, rhs_contracting, lhs_batch, rhs_batch]: