Require array-like inputs to sparse_plus

We should not silently convert non-array inputs to arrays, because this can lead to silent performance degredation. This brings the sparse_plus API in line with other APIs in this module.

PiperOrigin-RevId: 617190413
This commit is contained in:
Jake VanderPlas 2024-03-19 09:05:33 -07:00 committed by jax authors
parent 0b28a4b168
commit b48aec57ad

View File

@ -132,6 +132,7 @@ def sparse_plus(x: ArrayLike) -> Array:
Args:
x: input (float)
"""
numpy_util.check_arraylike("sparse_plus", x)
x = jnp.asarray(x)
return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4))