mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Rename jax.nn.normalize to standardize. Add normalize alias with DeprecationWarning.
This commit is contained in:
parent
c3581a2218
commit
7915c6ce27
@ -13,8 +13,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
commits](https://github.com/google/jax/compare/jax-v0.3.4...main).
|
||||
* Changes:
|
||||
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
|
||||
and {func}`jax.random.dirichlet` for small parameter values `({jax-issue}`9906`).
|
||||
|
||||
and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`).
|
||||
* Deprecations:
|
||||
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`).
|
||||
|
||||
## jaxlib 0.3.3 (Unreleased)
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
@ -333,7 +334,7 @@ def softmax(x: Array,
|
||||
return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
|
||||
|
||||
@partial(jax.jit, static_argnames=("axis",))
|
||||
def normalize(x: Array,
|
||||
def standardize(x: Array,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
|
||||
mean: Optional[Array] = None,
|
||||
variance: Optional[Array] = None,
|
||||
@ -351,6 +352,15 @@ def normalize(x: Array,
|
||||
jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean)
|
||||
return (x - mean) * lax.rsqrt(variance + epsilon)
|
||||
|
||||
def normalize(x: Array,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
|
||||
mean: Optional[Array] = None,
|
||||
variance: Optional[Array] = None,
|
||||
epsilon: Array = 1e-5,
|
||||
where: Optional[Array] = None) -> Array:
|
||||
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
|
||||
warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
|
||||
return standardize(x, axis, mean, variance, epsilon, where)
|
||||
|
||||
@partial(jax.jit, static_argnames=("num_classes", "dtype", "axis"))
|
||||
def _one_hot(x: Array, num_classes: int, *,
|
||||
|
@ -34,7 +34,7 @@ from jax import random
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
|
||||
leaky_relu, selu, gelu, normalize)
|
||||
leaky_relu, selu, gelu, standardize)
|
||||
from jax.nn.initializers import glorot_normal, normal, ones, zeros
|
||||
|
||||
# aliases for backwards compatibility
|
||||
@ -137,7 +137,7 @@ def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
|
||||
# TODO(phawkins): jnp.expand_dims should accept an axis tuple.
|
||||
# (https://github.com/numpy/numpy/issues/12290)
|
||||
ed = tuple(None if i in axis else slice(None) for i in range(jnp.ndim(x)))
|
||||
z = normalize(x, axis, epsilon=epsilon)
|
||||
z = standardize(x, axis, epsilon=epsilon)
|
||||
if center and scale: return gamma[ed] * z + beta[ed]
|
||||
if center: return z + beta[ed]
|
||||
if scale: return gamma[ed] * z
|
||||
|
@ -32,6 +32,7 @@ from jax._src.nn.functions import (
|
||||
log_softmax as log_softmax,
|
||||
logsumexp as logsumexp,
|
||||
normalize as normalize,
|
||||
standardize as standardize,
|
||||
one_hot as one_hot,
|
||||
relu as relu,
|
||||
relu6 as relu6,
|
||||
|
@ -126,13 +126,13 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(out_masked, out_filtered)
|
||||
|
||||
def testNormalizeWhereMask(self):
|
||||
def testStandardizeWhereMask(self):
|
||||
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
||||
m = jnp.array([True, False, True, True])
|
||||
x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
|
||||
|
||||
out_masked = jnp.take(nn.normalize(x, where=m), jnp.array([0, 2, 3]))
|
||||
out_filtered = nn.normalize(x_filtered)
|
||||
out_masked = jnp.take(nn.standardize(x, where=m), jnp.array([0, 2, 3]))
|
||||
out_filtered = nn.standardize(x_filtered)
|
||||
|
||||
self.assertAllClose(out_masked, out_filtered)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user