mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Require array-like inputs to sparse_plus
We should not silently convert non-array inputs to arrays, because this can lead to silent performance degredation. This brings the sparse_plus API in line with other APIs in this module. PiperOrigin-RevId: 617190413
This commit is contained in:
parent
0b28a4b168
commit
b48aec57ad
@ -132,6 +132,7 @@ def sparse_plus(x: ArrayLike) -> Array:
|
||||
Args:
|
||||
x: input (float)
|
||||
"""
|
||||
numpy_util.check_arraylike("sparse_plus", x)
|
||||
x = jnp.asarray(x)
|
||||
return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user