jax.nn.softmax: fix fill value when where is specified

This commit is contained in:
Jake VanderPlas 2023-06-01 10:18:05 -07:00
parent ae9d1498e5
commit c474de424a
2 changed files with 16 additions and 8 deletions

View File

@ -317,7 +317,10 @@ def log_softmax(x: Array,
shifted = x - lax.stop_gradient(x_max)
shifted_logsumexp = jnp.log(
jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True))
return shifted - shifted_logsumexp
result = shifted - shifted_logsumexp
if where is not None:
return jnp.where(where, result, -jnp.inf)
return result
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
@ -357,7 +360,10 @@ def _softmax(
initial: Optional[Array] = None) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - x_max)
return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
if where is not None:
result = jnp.where(where, result, 0)
return result
@_softmax.defjvp
def _softmax_jvp(axis, primals, tangents):
@ -368,7 +374,10 @@ def _softmax_jvp(axis, primals, tangents):
def _softmax_deprecated(x, axis, where, initial):
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
if where is not None:
result = jnp.where(where, result, 0)
return result
@partial(jax.jit, static_argnames=("axis",))

View File

@ -133,13 +133,12 @@ class NNFunctionsTest(jtu.JaxTestCase):
def testSoftmaxWhereMask(self, fn):
x = jnp.array([5.5, 1.3, -4.2, 0.9])
m = jnp.array([True, False, True, True])
x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
out_masked = jnp.take(
fn(x, where=m, initial=-jnp.inf), jnp.array([0, 2, 3]))
out_filtered = fn(x_filtered)
out = fn(x, where=m, initial=-jnp.inf)
self.assertAllClose(out[m], fn(x[m]))
self.assertAllClose(out_masked, out_filtered)
probs = out if fn is nn.softmax else jnp.exp(out)
self.assertAllClose(probs.sum(), 1.0)
# TODO(mattjj): include log_softmax in these extra tests if/when we add a
# custom_jvp rule for it (since otherwise it doesn't pass the numerical