mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Skip doctests for initializer examples.
This commit is contained in:
parent
dc2ca18dc8
commit
ad5144f59e
@ -87,7 +87,7 @@ def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable:
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.uniform(10.0)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[7.298188 , 8.691938 , 8.7230015],
|
||||
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
|
||||
"""
|
||||
@ -109,7 +109,7 @@ def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable:
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.normal(5.0)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ],
|
||||
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
|
||||
"""
|
||||
@ -271,7 +271,7 @@ def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.glorot_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 0.50350785, 0.8088631 , 0.81566876],
|
||||
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
|
||||
|
||||
@ -309,7 +309,7 @@ def glorot_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.glorot_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 0.41770416, 0.75262755, 0.7619329 ],
|
||||
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
|
||||
|
||||
@ -346,7 +346,7 @@ def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.lecun_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 0.56293887, 0.90433645, 0.9119454 ],
|
||||
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
|
||||
|
||||
@ -381,7 +381,7 @@ def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.lecun_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 0.46700746, 0.8414632 , 0.8518669 ],
|
||||
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
|
||||
|
||||
@ -417,7 +417,7 @@ def he_uniform(in_axis: Union[int, Sequence[int]] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.kaiming_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 0.79611576, 1.2789248 , 1.2896855 ],
|
||||
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
|
||||
|
||||
@ -455,7 +455,7 @@ def he_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.kaiming_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 0.6604483 , 1.1900088 , 1.2047218 ],
|
||||
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user