mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Inline sigmoid, isfinite, and isnan in jaxprs.
In the common case (real values) these are all single-expression jaxprs themselves, so putting them out of line just makes things more verbose. There's no reason to include stuff like this in a jaxpr: ``` cxd:bool[8,16] = pjit[ jaxpr={ lambda ; cxe:f32[8,16]. let cxf:bool[8,16] = is_finite cxe in (cxf,) } name=isfinite ] cxc ``` PiperOrigin-RevId: 587047955
This commit is contained in:
parent
ada5fe5dc9
commit
95bc2ba1b9
@ -124,7 +124,7 @@ def soft_sign(x: ArrayLike) -> Array:
|
||||
x_arr = jnp.asarray(x)
|
||||
return x_arr / (jnp.abs(x_arr) + 1)
|
||||
|
||||
@jax.jit
|
||||
@partial(jax.jit, inline=True)
|
||||
def sigmoid(x: ArrayLike) -> Array:
|
||||
r"""Sigmoid activation function.
|
||||
|
||||
|
@ -651,7 +651,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
|
||||
|
||||
|
||||
@_wraps(np.isfinite, module='numpy')
|
||||
@jit
|
||||
@partial(jit, inline=True)
|
||||
def isfinite(x: ArrayLike, /) -> Array:
|
||||
check_arraylike("isfinite", x)
|
||||
dtype = dtypes.dtype(x)
|
||||
@ -702,7 +702,7 @@ isneginf: UnOp = _wraps(np.isneginf, skip_params=['out'])(
|
||||
|
||||
|
||||
@_wraps(np.isnan, module='numpy')
|
||||
@jit
|
||||
@partial(jit, inline=True)
|
||||
def isnan(x: ArrayLike, /) -> Array:
|
||||
check_arraylike("isnan", x)
|
||||
return lax.ne(x, x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user