Merge pull request #11509 from jakevdp:sparse-lower

PiperOrigin-RevId: 461242221
This commit is contained in:
jax authors 2022-07-15 14:50:02 -07:00
commit f12b7fb0bb
2 changed files with 14 additions and 1 deletions

View File

@ -18,6 +18,7 @@ import numpy as np
import jax
from jax import core
from jax._src import dtypes
from jax._src import stages
import jax.numpy as jnp
class SparseEfficiencyError(ValueError):
@ -54,12 +55,15 @@ def _is_pytree_placeholder(*args):
def _is_aval(*args):
return all(isinstance(arg, core.AbstractValue) for arg in args)
def _is_arginfo(*args):
return all(isinstance(arg, stages.ArgInfo) for arg in args)
def _asarray_or_float0(arg):
if isinstance(arg, np.ndarray) and arg.dtype == dtypes.float0:
return arg
return jnp.asarray(arg)
def _safe_asarray(args):
if _is_pytree_placeholder(*args) or _is_aval(*args):
if _is_pytree_placeholder(*args) or _is_aval(*args) or _is_arginfo(*args):
return args
return map(_asarray_or_float0, args)

View File

@ -2073,6 +2073,15 @@ class SparseGradTest(jtu.JaxTestCase):
class SparseObjectTest(jtu.JaxTestCase):
@parameterized.named_parameters(
{"testcase_name": f"_{cls.__name__}", "cls": cls}
for cls in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])
def test_jit_lower(self, cls):
sparse_format = cls.__name__.lower()
M = sparse.empty((2, 4), sparse_format=sparse_format)
self.assertIsInstance(M, cls)
jax.jit(lambda x: x).lower(M) # doesn't crash
@parameterized.named_parameters(
{"testcase_name": f"_{cls.__name__}{shape}", "cls": cls, "shape": shape}
for cls in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]