Rename jax.nn.normalize to standardize. Add normalize alias with DeprecationWarning.

This commit is contained in:
dogeplusplus 2022-03-23 20:39:39 +00:00
parent c3581a2218
commit 7915c6ce27
5 changed files with 20 additions and 8 deletions

View File

@ -13,8 +13,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
commits](https://github.com/google/jax/compare/jax-v0.3.4...main).
* Changes:
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
and {func}`jax.random.dirichlet` for small parameter values `({jax-issue}`9906`).
and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`).
* Deprecations:
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`).
## jaxlib 0.3.3 (Unreleased)

View File

@ -16,6 +16,7 @@
from functools import partial
import operator
import warnings
import numpy as np
from typing import Any, Optional, Tuple, Union
@ -333,7 +334,7 @@ def softmax(x: Array,
return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
@partial(jax.jit, static_argnames=("axis",))
def normalize(x: Array,
def standardize(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
@ -351,6 +352,15 @@ def normalize(x: Array,
jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean)
return (x - mean) * lax.rsqrt(variance + epsilon)
def normalize(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
epsilon: Array = 1e-5,
where: Optional[Array] = None) -> Array:
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
return standardize(x, axis, mean, variance, epsilon, where)
@partial(jax.jit, static_argnames=("num_classes", "dtype", "axis"))
def _one_hot(x: Array, num_classes: int, *,

View File

@ -34,7 +34,7 @@ from jax import random
import jax.numpy as jnp
from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
leaky_relu, selu, gelu, normalize)
leaky_relu, selu, gelu, standardize)
from jax.nn.initializers import glorot_normal, normal, ones, zeros
# aliases for backwards compatibility
@ -137,7 +137,7 @@ def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
# TODO(phawkins): jnp.expand_dims should accept an axis tuple.
# (https://github.com/numpy/numpy/issues/12290)
ed = tuple(None if i in axis else slice(None) for i in range(jnp.ndim(x)))
z = normalize(x, axis, epsilon=epsilon)
z = standardize(x, axis, epsilon=epsilon)
if center and scale: return gamma[ed] * z + beta[ed]
if center: return z + beta[ed]
if scale: return gamma[ed] * z

View File

@ -32,6 +32,7 @@ from jax._src.nn.functions import (
log_softmax as log_softmax,
logsumexp as logsumexp,
normalize as normalize,
standardize as standardize,
one_hot as one_hot,
relu as relu,
relu6 as relu6,

View File

@ -126,13 +126,13 @@ class NNFunctionsTest(jtu.JaxTestCase):
self.assertAllClose(out_masked, out_filtered)
def testNormalizeWhereMask(self):
def testStandardizeWhereMask(self):
x = jnp.array([5.5, 1.3, -4.2, 0.9])
m = jnp.array([True, False, True, True])
x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
out_masked = jnp.take(nn.normalize(x, where=m), jnp.array([0, 2, 3]))
out_filtered = nn.normalize(x_filtered)
out_masked = jnp.take(nn.standardize(x, where=m), jnp.array([0, 2, 3]))
out_filtered = nn.standardize(x_filtered)
self.assertAllClose(out_masked, out_filtered)