mirror of
https://github.com/ROCm/jax.git
synced 2025-04-27 00:36:05 +00:00
70 lines
2.2 KiB
Python
70 lines
2.2 KiB
Python
# Copyright 2021 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Sparse utilities."""
|
|
|
|
import numpy as np
|
|
import jax
|
|
from jax import core
|
|
from jax._src import dtypes
|
|
from jax._src import stages
|
|
import jax.numpy as jnp
|
|
|
|
class SparseEfficiencyError(ValueError):
|
|
pass
|
|
|
|
class SparseEfficiencyWarning(UserWarning):
|
|
pass
|
|
|
|
class CuSparseEfficiencyWarning(SparseEfficiencyWarning):
|
|
pass
|
|
|
|
#--------------------------------------------------------------------
|
|
# utilities
|
|
# TODO: possibly make these primitives, targeting cusparse rountines
|
|
# csr2coo/coo2csr/SPDDMM
|
|
@jax.jit
|
|
def _csr_to_coo(indices, indptr):
|
|
"""Given CSR (indices, indptr) return COO (row, col)"""
|
|
return jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1, indices
|
|
|
|
def _csr_extract(indices, indptr, mat):
|
|
"""Extract values of dense matrix mat at given CSR indices."""
|
|
return _coo_extract(*_csr_to_coo(indices, indptr), mat)
|
|
|
|
def _coo_extract(row, col, mat):
|
|
"""Extract values of dense matrix mat at given COO indices."""
|
|
return mat[row, col]
|
|
|
|
def _is_pytree_placeholder(*args):
|
|
# Returns True if the arguments are consistent with being a placeholder within
|
|
# pytree validation.
|
|
return all(type(arg) is object for arg in args) or all(arg is None for arg in args)
|
|
|
|
def _is_aval(*args):
|
|
return all(isinstance(arg, core.AbstractValue) for arg in args)
|
|
|
|
def _is_arginfo(*args):
|
|
return all(isinstance(arg, stages.ArgInfo) for arg in args)
|
|
|
|
def _asarray_or_float0(arg):
|
|
if isinstance(arg, np.ndarray) and arg.dtype == dtypes.float0:
|
|
return arg
|
|
return jnp.asarray(arg)
|
|
|
|
def _safe_asarray(args):
|
|
if _is_pytree_placeholder(*args) or _is_aval(*args) or _is_arginfo(*args):
|
|
return args
|
|
return map(_asarray_or_float0, args)
|