1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 13:56:07 +00:00

update documentation notes to new-style typed keys

This commit is contained in:
Roy Frostig 2023-08-14 16:31:17 -07:00
parent 98f790f5d5
commit 0bdbe763aa
4 changed files with 9 additions and 9 deletions

@ -304,7 +304,7 @@ per_core_batch_size=4
seq_len=512
emb_dim=512
x = jax.random.normal(
jax.random.PRNGKey(0),
jax.random.key(0),
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
dtype=jnp.bfloat16,
)
@ -1049,7 +1049,7 @@ per_core_batch_size=4
seq_len=512
emb_dim=512
x = jax.random.normal(
jax.random.PRNGKey(0),
jax.random.key(0),
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
dtype=jnp.bfloat16,
)

@ -9,7 +9,7 @@ program:
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from jax import random
>>> x = random.uniform(random.PRNGKey(0), (1000, 1000))
>>> x = random.uniform(random.key(0), (1000, 1000))
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
>>> # will block until the value is ready.
>>> jnp.dot(x, x) + 3. # doctest: +SKIP

@ -59,7 +59,7 @@ def func2(x):
y = func1(x)
return y, jnp.tile(x, 10) + 1
x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)
z.block_until_ready()
@ -107,14 +107,14 @@ import jax.numpy as jnp
import jax.profiler
def afunction():
return jax.random.normal(jax.random.PRNGKey(77), (1000000,))
return jax.random.normal(jax.random.key(77), (1000000,))
z = afunction()
def anotherfunc():
arrays = []
for i in range(1, 10):
x = jax.random.normal(jax.random.PRNGKey(42), (i, 10000))
x = jax.random.normal(jax.random.key(42), (i, 10000))
arrays.append(x)
x.block_until_ready()
jax.profiler.save_device_memory_profile(f"memory{i}.prof")

@ -11,7 +11,7 @@ check out the Tensorboard profiler below.
```python
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
# Run the operations to be profiled
key = jax.random.PRNGKey(0)
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
@ -107,7 +107,7 @@ import jax
jax.profiler.start_trace("/tmp/tensorboard")
# Run the operations to be profiled
key = jax.random.PRNGKey(0)
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
@ -126,7 +126,7 @@ alternative to `start_trace` and `stop_trace`:
import jax
with jax.profiler.trace("/tmp/tensorboard"):
key = jax.random.PRNGKey(0)
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()