mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] fix concretization error for nse
This commit is contained in:
parent
5f4d4797b2
commit
68e6bc0556
@ -20,7 +20,7 @@ import functools
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from typing import Any, NamedTuple, Optional, Protocol, Union
|
||||
from typing import Any, NamedTuple, Protocol
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -262,9 +262,10 @@ def bcoo_fromdense(mat: Array, *, nse: int | None = None, n_batch: int = 0,
|
||||
mat_bcoo: BCOO representation of the matrix.
|
||||
"""
|
||||
mat = jnp.asarray(mat)
|
||||
if nse is None:
|
||||
nse = _count_stored_elements(mat, n_batch, n_dense)
|
||||
nse_int = core.concrete_or_error(operator.index, nse, _TRACED_NSE_ERROR)
|
||||
nse_arr: int | Array | None = nse
|
||||
if nse_arr is None:
|
||||
nse_arr = _count_stored_elements(mat, n_batch, n_dense)
|
||||
nse_int = core.concrete_or_error(operator.index, nse_arr, _TRACED_NSE_ERROR)
|
||||
return BCOO(_bcoo_fromdense(mat, nse=nse_int, n_batch=n_batch, n_dense=n_dense,
|
||||
index_dtype=index_dtype),
|
||||
shape=mat.shape, indices_sorted=True, unique_indices=True)
|
||||
|
@ -194,9 +194,10 @@ def bcsr_fromdense(mat: ArrayLike, *, nse: int | None = None, n_batch: int = 0,
|
||||
mat_bcsr: BCSR representation of the matrix.
|
||||
"""
|
||||
mat_array = jnp.asarray(mat)
|
||||
if nse is None:
|
||||
nse = _count_stored_elements(mat_array, n_batch, n_dense)
|
||||
nse_int: int = core.concrete_or_error(operator.index, nse, _TRACED_NSE_ERROR)
|
||||
nse_arr: int | Array | None = nse
|
||||
if nse_arr is None:
|
||||
nse_arr = _count_stored_elements(mat_array, n_batch, n_dense)
|
||||
nse_int: int = core.concrete_or_error(operator.index, nse_arr, _TRACED_NSE_ERROR)
|
||||
return BCSR(_bcsr_fromdense(mat_array, nse=nse_int, n_batch=n_batch,
|
||||
n_dense=n_dense, index_dtype=index_dtype),
|
||||
shape=mat_array.shape)
|
||||
|
@ -101,9 +101,9 @@ def _count_stored_elements_per_batch(mat: Array, n_batch: int = 0, n_dense: int
|
||||
mask = mask.sum(tuple(range(n_batch, mask.ndim)))
|
||||
return mask
|
||||
|
||||
def _count_stored_elements(mat: Array, n_batch: int = 0, n_dense: int = 0) -> int:
|
||||
def _count_stored_elements(mat: Array, n_batch: int = 0, n_dense: int = 0) -> Array:
|
||||
"""Return the number of stored elements (nse) of the given dense matrix."""
|
||||
return int(_count_stored_elements_per_batch(mat, n_batch, n_dense).max(initial=0))
|
||||
return _count_stored_elements_per_batch(mat, n_batch, n_dense).max(initial=0)
|
||||
|
||||
def _dot_general_validated_shape(
|
||||
lhs_shape: tuple[int, ...], rhs_shape: tuple[int, ...],
|
||||
|
Loading…
x
Reference in New Issue
Block a user