mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] cleanup validation function
This commit is contained in:
parent
afe2d908ce
commit
98ce6c5892
@ -609,6 +609,8 @@ def _validate_bcoo(data, indices, shape):
|
||||
assert _compatible(indices.shape[:n_batch], shape[:n_batch])
|
||||
assert indices.shape[n_batch:] == (n_sparse, nse)
|
||||
|
||||
return n_batch, n_sparse, n_dense
|
||||
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_todense
|
||||
@ -632,9 +634,7 @@ def bcoo_todense(data, indices, *, shape):
|
||||
|
||||
@bcoo_todense_p.def_impl
|
||||
def _bcoo_todense_impl(data, indices, *, shape):
|
||||
_validate_bcoo(data, indices, shape)
|
||||
n_sparse = indices.shape[-2]
|
||||
n_batch = indices.ndim - 2
|
||||
n_batch, n_sparse, _ = _validate_bcoo(data, indices, shape)
|
||||
batch_slices = tuple(slice(s) for s in shape[:n_batch])
|
||||
sparse_ind = tuple(indices[tuple(np.mgrid[batch_slices]) + (i,)] for i in range(n_sparse))
|
||||
batch_ind = tuple(np.mgrid[batch_slices + (slice(1),)])[:-1]
|
||||
@ -907,9 +907,7 @@ def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs
|
||||
@bcoo_dot_general_p.def_abstract_eval
|
||||
def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
n_sparse = lhs_indices.shape[-2]
|
||||
n_batch = lhs_indices.ndim - 2
|
||||
_validate_bcoo(lhs_data, lhs_indices, lhs_shape)
|
||||
n_batch, n_sparse, _ = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
|
||||
|
||||
# Check for proper dimension_numbers
|
||||
for dims in [lhs_contracting, rhs_contracting, lhs_batch, rhs_batch]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user