Inline sigmoid, isfinite, and isnan in jaxprs.

In the common case (real values) these are all single-expression jaxprs themselves, so putting them out of line just makes things more verbose.

There's no reason to include stuff like this in a jaxpr:
```
          cxd:bool[8,16] = pjit[
            jaxpr={ lambda ; cxe:f32[8,16]. let
                cxf:bool[8,16] = is_finite cxe
              in (cxf,) }
            name=isfinite
          ] cxc
```

PiperOrigin-RevId: 587047955
This commit is contained in:
Peter Hawkins 2023-12-01 10:23:25 -08:00 committed by jax authors
parent ada5fe5dc9
commit 95bc2ba1b9
2 changed files with 3 additions and 3 deletions

View File

@ -124,7 +124,7 @@ def soft_sign(x: ArrayLike) -> Array:
x_arr = jnp.asarray(x)
return x_arr / (jnp.abs(x_arr) + 1)
@jax.jit
@partial(jax.jit, inline=True)
def sigmoid(x: ArrayLike) -> Array:
r"""Sigmoid activation function.

View File

@ -651,7 +651,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
@_wraps(np.isfinite, module='numpy')
@jit
@partial(jit, inline=True)
def isfinite(x: ArrayLike, /) -> Array:
check_arraylike("isfinite", x)
dtype = dtypes.dtype(x)
@ -702,7 +702,7 @@ isneginf: UnOp = _wraps(np.isneginf, skip_params=['out'])(
@_wraps(np.isnan, module='numpy')
@jit
@partial(jit, inline=True)
def isnan(x: ArrayLike, /) -> Array:
check_arraylike("isnan", x)
return lax.ne(x, x)