Merge pull request #9052 from jpuigcerver:main

PiperOrigin-RevId: 430680329
This commit is contained in:
jax authors 2022-02-24 05:37:02 -08:00
commit 3948fde842
2 changed files with 7 additions and 0 deletions

View File

@ -32,6 +32,12 @@ from jax import dtypes
def zeros(key, shape, dtype=jnp.float_): return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
def ones(key, shape, dtype=jnp.float_): return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))
def constant(value, dtype=jnp.float_):
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
return jnp.full(shape, value, dtype=dtype)
return init
def uniform(scale=1e-2, dtype=jnp.float_):
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)

View File

@ -19,6 +19,7 @@ used in Keras and Sonnet.
# flake8: noqa: F401
from jax._src.nn.initializers import (
constant as constant,
delta_orthogonal as delta_orthogonal,
glorot_normal as glorot_normal,
glorot_uniform as glorot_uniform,