mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 05:46:06 +00:00
improve Parallel daxpr construct
This commit is contained in:
parent
2f171192b9
commit
0fb88ee934
@ -1,7 +1,8 @@
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from itertools import count
|
||||
from itertools import count, islice
|
||||
from threading import Thread
|
||||
from time import sleep, time
|
||||
from Queue import Queue
|
||||
|
||||
import numpy as onp
|
||||
@ -11,7 +12,7 @@ from jax.util import unzip2
|
||||
|
||||
# utils
|
||||
|
||||
new_queue = partial(Queue, maxsize=10)
|
||||
new_queue = partial(Queue, maxsize=3)
|
||||
|
||||
def spawn(fun):
|
||||
thread = Thread(target=fun)
|
||||
@ -133,8 +134,6 @@ def type_daxpr(daxpr):
|
||||
daxpr_types = map(type_daxpr, daxpr.daxprs)
|
||||
ins, outs = unzip2((dt.arity_in, dt.arity_out) for dt in daxpr_types)
|
||||
if s is Parallel:
|
||||
if not all(ain == aout == 1 for ain, aout in zip(ins, outs)):
|
||||
raise DaxprTypeError("parallel combination of multi-arity daxprs")
|
||||
return DaxprType(sum(ins), sum(outs))
|
||||
elif s is Pipeline:
|
||||
if ins[1:] != outs[:-1]:
|
||||
@ -188,8 +187,9 @@ def _build_threaded(daxpr, qs_in, qs_out):
|
||||
daxpr_types = map(type_daxpr, daxpr.daxprs)
|
||||
ins, outs = unzip2((dt.arity_in, dt.arity_out) for dt in daxpr_types)
|
||||
if s is Parallel:
|
||||
for d, q_in, q_out in zip(daxpr.daxprs, qs_in, qs_out):
|
||||
_build_threaded(d, [q_in], [q_out])
|
||||
qs_in, qs_out = iter(qs_in), iter(qs_out)
|
||||
for d, ain, aout in zip(daxpr.daxprs, ins, outs):
|
||||
_build_threaded(d, islice(qs_in, ain), islice(qs_out, aout))
|
||||
elif s is Pipeline:
|
||||
qs = [[new_queue() for _ in range(arity)] for arity in ins[1:]]
|
||||
map(_build_threaded, daxpr.daxprs, [qs_in] + qs, qs + [qs_out])
|
||||
@ -234,3 +234,22 @@ if __name__ == '__main__':
|
||||
for _ in range(10):
|
||||
print q.get()
|
||||
print
|
||||
|
||||
|
||||
# an example that models the situation we care about
|
||||
|
||||
num_parallel = 8
|
||||
|
||||
daxpr = Pipeline([
|
||||
Parallel([Source(partial(sleep, 1e-2)) for _ in range(num_parallel)]),
|
||||
FanInConcatenate(num_parallel),
|
||||
ProducerConsumer(lambda x: sleep(1e-2)),
|
||||
])
|
||||
|
||||
q = build_threaded(daxpr)
|
||||
|
||||
tic = time()
|
||||
for i in range(100):
|
||||
q.get()
|
||||
toc = time()
|
||||
print 'finished in {} sec'.format(toc - tic)
|
||||
|
Loading…
x
Reference in New Issue
Block a user