Remove runtime tuple use from JAX. (#2441)

Change in preparation for upcoming runtime changes related to buffer aliasing.
This commit is contained in:
Peter Hawkins 2020-03-17 17:02:22 -04:00 committed by GitHub
parent 985d5f7327
commit e46a002ead
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 29 additions and 56 deletions

View File

@ -45,7 +45,7 @@ _map = safe_map
def identity(x): return x
def shard_args(backend, devices, assignments, axis_size, tuple_args, args):
def shard_args(backend, devices, assignments, axis_size, args):
"""Shard each argument data array along its leading axis.
Args:
@ -96,10 +96,6 @@ def shard_args(backend, devices, assignments, axis_size, tuple_args, args):
for r, buf in enumerate(bufs):
buffers[r][a] = buf
if tuple_args:
buffers = [[xla.make_tuple(bufs, devices[r], backend)]
for r, bufs in enumerate(buffers)]
return buffers
@ -546,11 +542,12 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
handle_args = partial(shard_args, backend, compiled.local_devices(),
assign_shards_to_replicas(num_local_replicas, axis_size),
axis_size, tuple_args)
axis_size)
handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas,
out_pvals, compiled.local_devices(),
backend)
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
return partial(execute_replicated, compiled, backend, handle_args, handle_outs,
tuple_args)
multi_host_supported_collectives = set()
@ -564,7 +561,7 @@ def _pvals_to_results_handler(size, nrep, out_pvals, devices, backend):
def handler(out_bufs):
buffers = [[result_to_populate] * nrep for _ in range(nouts)]
for r, tuple_buf in enumerate(out_bufs):
for i, buf in enumerate(tuple_buf.destructure()):
for i, buf in enumerate(tuple_buf):
buffers[i][r] = buf
assert not any(buf is result_to_populate for bufs in buffers
for buf in bufs)
@ -631,9 +628,11 @@ def _pval_to_result_handler(axis_size, nrep, pval, devices, backend):
else:
return aval_to_result_handler(axis_size, nrep, pv)
def execute_replicated(compiled, backend, in_handler, out_handler, *args):
def execute_replicated(compiled, backend, in_handler, out_handler, tuple_args,
*args):
input_bufs = in_handler(args)
out_bufs = compiled.ExecuteOnLocalDevices(list(input_bufs))
out_bufs = compiled.ExecuteOnLocalDevices(
list(input_bufs), tuple_arguments=tuple_args)
return out_handler(out_bufs)

View File

@ -86,7 +86,7 @@ def _pvals_to_results_handler(nrep, npar, partitions, out_pvals):
buffers = [[[None] * npar for _ in range(nrep)] for _ in range(nouts)]
for raw_idx, tuple_buf in enumerate(out_bufs):
r, p = onp.unravel_index(raw_idx, (nrep, npar))
for i, buf in enumerate(tuple_buf.destructure()):
for i, buf in enumerate(tuple_buf):
buffers[i][r][p] = buf
return [h(bufs) for h, bufs in zip(handlers, buffers)]
@ -202,7 +202,8 @@ def _sharded_jit_translation_rule(c, axis_env, freevar_nodes,
def _execute_spatially_partitioned(compiled, in_handler, out_handler, *args):
input_bufs = in_handler(args)
out_bufs = compiled.ExecuteOnLocalDevices(list(input_bufs))
out_bufs = compiled.ExecuteOnLocalDevices(
list(input_bufs), tuple_arguments=False)
return out_handler(out_bufs)

View File

@ -175,7 +175,7 @@ def xla_primitive_callable(prim, *arg_specs, **params):
handle_result = aval_to_result_handler(device, aval_out)
else:
handlers = tuple(map(partial(aval_to_result_handler, device), aval_out))
handle_result = lambda xs: tuple(h(x) for h, x in zip(handlers, xs.destructure()))
handle_result = lambda xs: tuple(h(x) for h, x in zip(handlers, xs))
tuple_args = len(avals) > 100
if prim in initial_style_translations:
nreps = initial_style_primitive_replicas(params)
@ -249,30 +249,23 @@ def _execute_compiled_primitive(prim, compiled, backend, tuple_args,
result_handler, *args):
device, = compiled.local_devices()
input_bufs = [device_put(x, device) for x in args if x is not token]
if tuple_args:
input_bufs = [make_tuple(input_bufs, device, backend)]
out_buf = compiled.Execute(input_bufs)
out_bufs = compiled.Execute(input_bufs, tuple_arguments=tuple_args)
if FLAGS.jax_debug_nans:
check_nans(prim, out_buf.destructure() if prim.multiple_results else out_buf)
return result_handler(out_buf)
check_nans(prim, out_bufs)
return result_handler(out_bufs if prim.multiple_results else out_bufs[0])
def _execute_replicated_primitive(prim, compiled, backend, tuple_args,
result_handler, *args):
input_bufs = [
[device_put(x, device) for x in args if x is not token]
for device in compiled.local_devices()]
if tuple_args:
input_bufs = [[make_tuple(bufs, device, backend)] for bufs, device in
zip(input_bufs, compiled.local_devices())]
out_buf = compiled.ExecuteOnLocalDevices(input_bufs)[0]
out_buf = compiled.ExecuteOnLocalDevices(
input_bufs, tuple_arguments=tuple_args)[0][0]
return result_handler(out_buf)
def check_nans(prim, bufs):
if prim.multiple_results:
for buf in bufs:
_check_nans(prim.name, buf.shape(), buf)
else:
_check_nans(prim.name, bufs.shape(), bufs)
for buf in bufs:
_check_nans(prim.name, buf.shape(), buf)
def _check_nans(name, xla_shape, buf):
assert not xla_shape.is_tuple()
@ -564,9 +557,7 @@ def _pval_to_result_handler(device, pval):
def _execute_compiled(compiled, backend, handlers, tuple_args, *args):
device, = compiled.local_devices()
input_bufs = [device_put(x, device) for x in args if x is not token]
if tuple_args:
input_bufs = [make_tuple(input_bufs, device, backend)]
out_bufs = compiled.Execute(input_bufs).destructure()
out_bufs = compiled.Execute(input_bufs, tuple_arguments=tuple_args)
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
@ -574,10 +565,8 @@ def _execute_replicated(compiled, backend, handlers, tuple_args, *args):
input_bufs = [
[device_put(x, device) for x in args if x is not token]
for device in compiled.local_devices()]
if tuple_args:
input_bufs = [[make_tuple(bufs, device, backend)] for bufs, device in
zip(input_bufs, compiled.local_devices())]
out_bufs = compiled.ExecuteOnLocalDevices(input_bufs)[0].destructure()
out_bufs = compiled.ExecuteOnLocalDevices(
input_bufs, tuple_arguments=tuple_args)[0]
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
@ -590,9 +579,6 @@ def _execute_trivial(jaxpr, device, consts, handlers, *args):
return [_copy_device_array_to_device(x, device) if type(x) is DeviceArray
else h(device_put(x, device)) for h, x in zip(handlers, outs)]
def make_tuple(bufs, device, backend):
return xb.get_backend(backend).make_tuple(bufs, device)
@memoize
def _get_device(device, backend):
# TODO(mattjj): after jaxlib update, avoid compile here, just to get device
@ -960,9 +946,11 @@ def _lazy_force_computation(sticky, aval, device, lexpr):
result_device = device if sticky else None
handler = partial(DeviceArray, aval, result_device, lazy.array(aval.shape))
if lazy.is_constant(lexpr):
force_fun = lambda _: handler(compiled.Execute([]))
def force_fun(_):
return handler(compiled.Execute([], tuple_arguments=False)[0])
else:
force_fun = lambda x: handler(compiled.Execute([x.device_buffer]))
def force_fun(x):
return handler(compiled.Execute([x.device_buffer], tuple_arguments=False)[0])
return force_fun

View File

@ -117,27 +117,12 @@ class APITest(jtu.JaxTestCase):
f(1, 2, z=onp.zeros(3)) # doesn't crash
def test_jit_many_args_tuples(self):
def test_jit_with_many_args_works(self):
@jit
def f(args_list):
return sum(args_list)
make_tuple = xla.make_tuple
counts = [0]
def make_tuple_and_count(*args, **kwargs):
counts[0] += 1
return make_tuple(*args, **kwargs)
try:
xla.make_tuple = make_tuple_and_count
ans = f(list(range(500)))
finally:
xla.make_tuple = make_tuple
expected = sum(range(500))
self.assertEqual(counts[0], 1) # formed a tuple on dispatch
self.assertEqual(ans, expected) # computed the correct result
self.assertEqual(f(list(range(500))), sum(range(500)))
def test_grad_of_jit(self):
side = []