1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 05:46:06 +00:00

improve Parallel daxpr construct

This commit is contained in:
Matthew Johnson 2019-03-29 12:01:48 -07:00
parent 2f171192b9
commit 0fb88ee934

@ -1,7 +1,8 @@
from collections import namedtuple
from functools import partial
from itertools import count
from itertools import count, islice
from threading import Thread
from time import sleep, time
from Queue import Queue
import numpy as onp
@ -11,7 +12,7 @@ from jax.util import unzip2
# utils
new_queue = partial(Queue, maxsize=10)
new_queue = partial(Queue, maxsize=3)
def spawn(fun):
thread = Thread(target=fun)
@ -133,8 +134,6 @@ def type_daxpr(daxpr):
daxpr_types = map(type_daxpr, daxpr.daxprs)
ins, outs = unzip2((dt.arity_in, dt.arity_out) for dt in daxpr_types)
if s is Parallel:
if not all(ain == aout == 1 for ain, aout in zip(ins, outs)):
raise DaxprTypeError("parallel combination of multi-arity daxprs")
return DaxprType(sum(ins), sum(outs))
elif s is Pipeline:
if ins[1:] != outs[:-1]:
@ -188,8 +187,9 @@ def _build_threaded(daxpr, qs_in, qs_out):
daxpr_types = map(type_daxpr, daxpr.daxprs)
ins, outs = unzip2((dt.arity_in, dt.arity_out) for dt in daxpr_types)
if s is Parallel:
for d, q_in, q_out in zip(daxpr.daxprs, qs_in, qs_out):
_build_threaded(d, [q_in], [q_out])
qs_in, qs_out = iter(qs_in), iter(qs_out)
for d, ain, aout in zip(daxpr.daxprs, ins, outs):
_build_threaded(d, islice(qs_in, ain), islice(qs_out, aout))
elif s is Pipeline:
qs = [[new_queue() for _ in range(arity)] for arity in ins[1:]]
map(_build_threaded, daxpr.daxprs, [qs_in] + qs, qs + [qs_out])
@ -234,3 +234,22 @@ if __name__ == '__main__':
for _ in range(10):
print q.get()
print
# an example that models the situation we care about
num_parallel = 8
daxpr = Pipeline([
Parallel([Source(partial(sleep, 1e-2)) for _ in range(num_parallel)]),
FanInConcatenate(num_parallel),
ProducerConsumer(lambda x: sleep(1e-2)),
])
q = build_threaded(daxpr)
tic = time()
for i in range(100):
q.get()
toc = time()
print 'finished in {} sec'.format(toc - tic)