mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11509 from jakevdp:sparse-lower
PiperOrigin-RevId: 461242221
This commit is contained in:
commit
f12b7fb0bb
@ -18,6 +18,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import stages
|
||||
import jax.numpy as jnp
|
||||
|
||||
class SparseEfficiencyError(ValueError):
|
||||
@ -54,12 +55,15 @@ def _is_pytree_placeholder(*args):
|
||||
def _is_aval(*args):
|
||||
return all(isinstance(arg, core.AbstractValue) for arg in args)
|
||||
|
||||
def _is_arginfo(*args):
|
||||
return all(isinstance(arg, stages.ArgInfo) for arg in args)
|
||||
|
||||
def _asarray_or_float0(arg):
|
||||
if isinstance(arg, np.ndarray) and arg.dtype == dtypes.float0:
|
||||
return arg
|
||||
return jnp.asarray(arg)
|
||||
|
||||
def _safe_asarray(args):
|
||||
if _is_pytree_placeholder(*args) or _is_aval(*args):
|
||||
if _is_pytree_placeholder(*args) or _is_aval(*args) or _is_arginfo(*args):
|
||||
return args
|
||||
return map(_asarray_or_float0, args)
|
||||
|
@ -2073,6 +2073,15 @@ class SparseGradTest(jtu.JaxTestCase):
|
||||
|
||||
class SparseObjectTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{cls.__name__}", "cls": cls}
|
||||
for cls in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])
|
||||
def test_jit_lower(self, cls):
|
||||
sparse_format = cls.__name__.lower()
|
||||
M = sparse.empty((2, 4), sparse_format=sparse_format)
|
||||
self.assertIsInstance(M, cls)
|
||||
jax.jit(lambda x: x).lower(M) # doesn't crash
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{cls.__name__}{shape}", "cls": cls, "shape": shape}
|
||||
for cls in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]
|
||||
|
Loading…
x
Reference in New Issue
Block a user