mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
A couple of typo/gap fixes in PRNG design notes
This commit is contained in:
parent
b7ec636cfa
commit
11e6f49c3d
@ -29,7 +29,8 @@ def foo(): return bar() + baz()
|
||||
def bar(): return rand(RNG, (3, 4))
|
||||
def baz(): return rand(RNG, (3, 4))
|
||||
def main():
|
||||
global RNG = RandomState(0)
|
||||
global RNG
|
||||
RNG = RandomState(0)
|
||||
return foo()
|
||||
```
|
||||
|
||||
@ -80,7 +81,6 @@ In short, making the code functional by explicitly threading state isn’t enoug
|
||||
The key problem in both the previous models is that there’s too much sequencing. To reduce the amount of sequential dependence we use **functional [splittable](https://dl.acm.org/citation.cfm?id=2503784) PRNGs**. Splitting is a mechanism to ‘fork’ a new PRNG state into two PRNG states while maintaining the usual desirable PRNG properties (the two new streams are computationally parallelizable and produce independent random values, i.e. they behave like [multistreams](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)).
|
||||
|
||||
```python
|
||||
|
||||
def foo(rng_1):
|
||||
rng_2, rng_3 = split(rng_1, 2)
|
||||
return bar(rng_2) + baz(rng_3)
|
||||
@ -104,8 +104,7 @@ The example doesn’t show it, but as a consequence of the choice (2) the only w
|
||||
|
||||
## Design
|
||||
|
||||
We can use the *counter-based PRNG* design, and in particular the Threefry hash function, as described in [Parallel random numbers: as easy as 1, 2, 3](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf). We use the counter to achieve efficient vectorization: for a given key we can generate an array of values in a vectorized fashion by mapping the hash function over a range of integers [k + 1, …, k + sample_size]. We use the key together with the hash function to implement [
|
||||
](https://dl.acm.org/citation.cfm?id=2503784): that is, splitting is a way to generate two new keys from an existing one.
|
||||
We can use the *counter-based PRNG* design, and in particular the Threefry hash function, as described in [Parallel random numbers: as easy as 1, 2, 3](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf). We use the counter to achieve efficient vectorization: for a given key we can generate an array of values in a vectorized fashion by mapping the hash function over a range of integers [k + 1, …, k + sample_size]. We use the key together with the hash function to implement [splittable PRNGs](https://dl.acm.org/citation.cfm?id=2503784): that is, splitting is a way to generate two new keys from an existing one.
|
||||
|
||||
```haskell
|
||||
type Sample = Int256
|
||||
@ -159,7 +158,7 @@ The rng value here is just the key used for the hash, not a special object. The
|
||||
def serial(*layers):
|
||||
init_funs, apply_funs = zip(*layers)
|
||||
def init_fun(input_shape):
|
||||
# …
|
||||
...
|
||||
def apply_fun(rng, params, inputs):
|
||||
rngs = split(rng, len(layers))
|
||||
for rng, param, apply_fun in zip(rngs, params, apply_funs):
|
||||
@ -170,7 +169,7 @@ def serial(*layers):
|
||||
def parallel(*layers):
|
||||
init_funs, apply_funs = zip(*layers)
|
||||
def init_fun(input_shape):
|
||||
# …
|
||||
...
|
||||
def apply_fun(rng, params, inputs):
|
||||
rngs = split(rng, len(layers))
|
||||
return [f(r, p, x) for f, r, p, x in zip(apply_funs, rngs, params, inputs)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user