mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 13:56:07 +00:00
update documentation notes to new-style typed keys
This commit is contained in:
parent
98f790f5d5
commit
0bdbe763aa
@ -304,7 +304,7 @@ per_core_batch_size=4
|
||||
seq_len=512
|
||||
emb_dim=512
|
||||
x = jax.random.normal(
|
||||
jax.random.PRNGKey(0),
|
||||
jax.random.key(0),
|
||||
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
|
||||
dtype=jnp.bfloat16,
|
||||
)
|
||||
@ -1049,7 +1049,7 @@ per_core_batch_size=4
|
||||
seq_len=512
|
||||
emb_dim=512
|
||||
x = jax.random.normal(
|
||||
jax.random.PRNGKey(0),
|
||||
jax.random.key(0),
|
||||
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
|
||||
dtype=jnp.bfloat16,
|
||||
)
|
||||
|
@ -9,7 +9,7 @@ program:
|
||||
>>> import numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from jax import random
|
||||
>>> x = random.uniform(random.PRNGKey(0), (1000, 1000))
|
||||
>>> x = random.uniform(random.key(0), (1000, 1000))
|
||||
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
|
||||
>>> # will block until the value is ready.
|
||||
>>> jnp.dot(x, x) + 3. # doctest: +SKIP
|
||||
|
@ -59,7 +59,7 @@ def func2(x):
|
||||
y = func1(x)
|
||||
return y, jnp.tile(x, 10) + 1
|
||||
|
||||
x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
|
||||
x = jax.random.normal(jax.random.key(42), (1000, 1000))
|
||||
y, z = func2(x)
|
||||
|
||||
z.block_until_ready()
|
||||
@ -107,14 +107,14 @@ import jax.numpy as jnp
|
||||
import jax.profiler
|
||||
|
||||
def afunction():
|
||||
return jax.random.normal(jax.random.PRNGKey(77), (1000000,))
|
||||
return jax.random.normal(jax.random.key(77), (1000000,))
|
||||
|
||||
z = afunction()
|
||||
|
||||
def anotherfunc():
|
||||
arrays = []
|
||||
for i in range(1, 10):
|
||||
x = jax.random.normal(jax.random.PRNGKey(42), (i, 10000))
|
||||
x = jax.random.normal(jax.random.key(42), (i, 10000))
|
||||
arrays.append(x)
|
||||
x.block_until_ready()
|
||||
jax.profiler.save_device_memory_profile(f"memory{i}.prof")
|
||||
|
@ -11,7 +11,7 @@ check out the Tensorboard profiler below.
|
||||
```python
|
||||
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
|
||||
# Run the operations to be profiled
|
||||
key = jax.random.PRNGKey(0)
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (5000, 5000))
|
||||
y = x @ x
|
||||
y.block_until_ready()
|
||||
@ -107,7 +107,7 @@ import jax
|
||||
jax.profiler.start_trace("/tmp/tensorboard")
|
||||
|
||||
# Run the operations to be profiled
|
||||
key = jax.random.PRNGKey(0)
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (5000, 5000))
|
||||
y = x @ x
|
||||
y.block_until_ready()
|
||||
@ -126,7 +126,7 @@ alternative to `start_trace` and `stop_trace`:
|
||||
import jax
|
||||
|
||||
with jax.profiler.trace("/tmp/tensorboard"):
|
||||
key = jax.random.PRNGKey(0)
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (5000, 5000))
|
||||
y = x @ x
|
||||
y.block_until_ready()
|
||||
|
Loading…
x
Reference in New Issue
Block a user