mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
jax.nn.softmax: fix fill value when where is specified
This commit is contained in:
parent
ae9d1498e5
commit
c474de424a
@ -317,7 +317,10 @@ def log_softmax(x: Array,
|
||||
shifted = x - lax.stop_gradient(x_max)
|
||||
shifted_logsumexp = jnp.log(
|
||||
jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True))
|
||||
return shifted - shifted_logsumexp
|
||||
result = shifted - shifted_logsumexp
|
||||
if where is not None:
|
||||
return jnp.where(where, result, -jnp.inf)
|
||||
return result
|
||||
|
||||
|
||||
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
|
||||
@ -357,7 +360,10 @@ def _softmax(
|
||||
initial: Optional[Array] = None) -> Array:
|
||||
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
|
||||
unnormalized = jnp.exp(x - x_max)
|
||||
return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
|
||||
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
|
||||
if where is not None:
|
||||
result = jnp.where(where, result, 0)
|
||||
return result
|
||||
|
||||
@_softmax.defjvp
|
||||
def _softmax_jvp(axis, primals, tangents):
|
||||
@ -368,7 +374,10 @@ def _softmax_jvp(axis, primals, tangents):
|
||||
def _softmax_deprecated(x, axis, where, initial):
|
||||
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
|
||||
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
|
||||
return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
|
||||
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
|
||||
if where is not None:
|
||||
result = jnp.where(where, result, 0)
|
||||
return result
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnames=("axis",))
|
||||
|
@ -133,13 +133,12 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
def testSoftmaxWhereMask(self, fn):
|
||||
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
||||
m = jnp.array([True, False, True, True])
|
||||
x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
|
||||
|
||||
out_masked = jnp.take(
|
||||
fn(x, where=m, initial=-jnp.inf), jnp.array([0, 2, 3]))
|
||||
out_filtered = fn(x_filtered)
|
||||
out = fn(x, where=m, initial=-jnp.inf)
|
||||
self.assertAllClose(out[m], fn(x[m]))
|
||||
|
||||
self.assertAllClose(out_masked, out_filtered)
|
||||
probs = out if fn is nn.softmax else jnp.exp(out)
|
||||
self.assertAllClose(probs.sum(), 1.0)
|
||||
|
||||
# TODO(mattjj): include log_softmax in these extra tests if/when we add a
|
||||
# custom_jvp rule for it (since otherwise it doesn't pass the numerical
|
||||
|
Loading…
x
Reference in New Issue
Block a user