mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #9052 from jpuigcerver:main
PiperOrigin-RevId: 430680329
This commit is contained in:
commit
3948fde842
@ -32,6 +32,12 @@ from jax import dtypes
|
||||
def zeros(key, shape, dtype=jnp.float_): return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
|
||||
def ones(key, shape, dtype=jnp.float_): return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
def constant(value, dtype=jnp.float_):
|
||||
def init(key, shape, dtype=dtype):
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
return jnp.full(shape, value, dtype=dtype)
|
||||
return init
|
||||
|
||||
def uniform(scale=1e-2, dtype=jnp.float_):
|
||||
def init(key, shape, dtype=dtype):
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
|
@ -19,6 +19,7 @@ used in Keras and Sonnet.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.nn.initializers import (
|
||||
constant as constant,
|
||||
delta_orthogonal as delta_orthogonal,
|
||||
glorot_normal as glorot_normal,
|
||||
glorot_uniform as glorot_uniform,
|
||||
|
Loading…
x
Reference in New Issue
Block a user