A couple of typo/gap fixes in PRNG design notes

This commit is contained in:
Louis Maddox 2020-10-19 21:35:18 +01:00
parent b7ec636cfa
commit 11e6f49c3d

View File

@ -29,7 +29,8 @@ def foo(): return bar() + baz()
def bar(): return rand(RNG, (3, 4))
def baz(): return rand(RNG, (3, 4))
def main():
global RNG = RandomState(0)
global RNG
RNG = RandomState(0)
return foo()
```
@ -80,7 +81,6 @@ In short, making the code functional by explicitly threading state isnt enoug
The key problem in both the previous models is that theres too much sequencing. To reduce the amount of sequential dependence we use **functional [splittable](https://dl.acm.org/citation.cfm?id=2503784) PRNGs**. Splitting is a mechanism to fork a new PRNG state into two PRNG states while maintaining the usual desirable PRNG properties (the two new streams are computationally parallelizable and produce independent random values, i.e. they behave like [multistreams](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)).
```python
def foo(rng_1):
rng_2, rng_3 = split(rng_1, 2)
return bar(rng_2) + baz(rng_3)
@ -104,8 +104,7 @@ The example doesnt show it, but as a consequence of the choice (2) the only w
## Design
We can use the *counter-based PRNG* design, and in particular the Threefry hash function, as described in [Parallel random numbers: as easy as 1, 2, 3](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf). We use the counter to achieve efficient vectorization: for a given key we can generate an array of values in a vectorized fashion by mapping the hash function over a range of integers [k + 1, …, k + sample_size]. We use the key together with the hash function to implement [
](https://dl.acm.org/citation.cfm?id=2503784): that is, splitting is a way to generate two new keys from an existing one.
We can use the *counter-based PRNG* design, and in particular the Threefry hash function, as described in [Parallel random numbers: as easy as 1, 2, 3](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf). We use the counter to achieve efficient vectorization: for a given key we can generate an array of values in a vectorized fashion by mapping the hash function over a range of integers [k + 1, …, k + sample_size]. We use the key together with the hash function to implement [splittable PRNGs](https://dl.acm.org/citation.cfm?id=2503784): that is, splitting is a way to generate two new keys from an existing one.
```haskell
type Sample = Int256
@ -159,7 +158,7 @@ The rng value here is just the key used for the hash, not a special object. The
def serial(*layers):
init_funs, apply_funs = zip(*layers)
def init_fun(input_shape):
# …
...
def apply_fun(rng, params, inputs):
rngs = split(rng, len(layers))
for rng, param, apply_fun in zip(rngs, params, apply_funs):
@ -170,7 +169,7 @@ def serial(*layers):
def parallel(*layers):
init_funs, apply_funs = zip(*layers)
def init_fun(input_shape):
# …
...
def apply_fun(rng, params, inputs):
rngs = split(rng, len(layers))
return [f(r, p, x) for f, r, p, x in zip(apply_funs, rngs, params, inputs)]