mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove runtime tuple use from JAX. (#2441)
Change in preparation for upcoming runtime changes related to buffer aliasing.
This commit is contained in:
parent
985d5f7327
commit
e46a002ead
@ -45,7 +45,7 @@ _map = safe_map
|
||||
|
||||
def identity(x): return x
|
||||
|
||||
def shard_args(backend, devices, assignments, axis_size, tuple_args, args):
|
||||
def shard_args(backend, devices, assignments, axis_size, args):
|
||||
"""Shard each argument data array along its leading axis.
|
||||
|
||||
Args:
|
||||
@ -96,10 +96,6 @@ def shard_args(backend, devices, assignments, axis_size, tuple_args, args):
|
||||
for r, buf in enumerate(bufs):
|
||||
buffers[r][a] = buf
|
||||
|
||||
if tuple_args:
|
||||
buffers = [[xla.make_tuple(bufs, devices[r], backend)]
|
||||
for r, bufs in enumerate(buffers)]
|
||||
|
||||
return buffers
|
||||
|
||||
|
||||
@ -546,11 +542,12 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
|
||||
|
||||
handle_args = partial(shard_args, backend, compiled.local_devices(),
|
||||
assign_shards_to_replicas(num_local_replicas, axis_size),
|
||||
axis_size, tuple_args)
|
||||
axis_size)
|
||||
handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas,
|
||||
out_pvals, compiled.local_devices(),
|
||||
backend)
|
||||
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
|
||||
return partial(execute_replicated, compiled, backend, handle_args, handle_outs,
|
||||
tuple_args)
|
||||
|
||||
multi_host_supported_collectives = set()
|
||||
|
||||
@ -564,7 +561,7 @@ def _pvals_to_results_handler(size, nrep, out_pvals, devices, backend):
|
||||
def handler(out_bufs):
|
||||
buffers = [[result_to_populate] * nrep for _ in range(nouts)]
|
||||
for r, tuple_buf in enumerate(out_bufs):
|
||||
for i, buf in enumerate(tuple_buf.destructure()):
|
||||
for i, buf in enumerate(tuple_buf):
|
||||
buffers[i][r] = buf
|
||||
assert not any(buf is result_to_populate for bufs in buffers
|
||||
for buf in bufs)
|
||||
@ -631,9 +628,11 @@ def _pval_to_result_handler(axis_size, nrep, pval, devices, backend):
|
||||
else:
|
||||
return aval_to_result_handler(axis_size, nrep, pv)
|
||||
|
||||
def execute_replicated(compiled, backend, in_handler, out_handler, *args):
|
||||
def execute_replicated(compiled, backend, in_handler, out_handler, tuple_args,
|
||||
*args):
|
||||
input_bufs = in_handler(args)
|
||||
out_bufs = compiled.ExecuteOnLocalDevices(list(input_bufs))
|
||||
out_bufs = compiled.ExecuteOnLocalDevices(
|
||||
list(input_bufs), tuple_arguments=tuple_args)
|
||||
return out_handler(out_bufs)
|
||||
|
||||
|
||||
|
@ -86,7 +86,7 @@ def _pvals_to_results_handler(nrep, npar, partitions, out_pvals):
|
||||
buffers = [[[None] * npar for _ in range(nrep)] for _ in range(nouts)]
|
||||
for raw_idx, tuple_buf in enumerate(out_bufs):
|
||||
r, p = onp.unravel_index(raw_idx, (nrep, npar))
|
||||
for i, buf in enumerate(tuple_buf.destructure()):
|
||||
for i, buf in enumerate(tuple_buf):
|
||||
buffers[i][r][p] = buf
|
||||
return [h(bufs) for h, bufs in zip(handlers, buffers)]
|
||||
|
||||
@ -202,7 +202,8 @@ def _sharded_jit_translation_rule(c, axis_env, freevar_nodes,
|
||||
|
||||
def _execute_spatially_partitioned(compiled, in_handler, out_handler, *args):
|
||||
input_bufs = in_handler(args)
|
||||
out_bufs = compiled.ExecuteOnLocalDevices(list(input_bufs))
|
||||
out_bufs = compiled.ExecuteOnLocalDevices(
|
||||
list(input_bufs), tuple_arguments=False)
|
||||
return out_handler(out_bufs)
|
||||
|
||||
|
||||
|
@ -175,7 +175,7 @@ def xla_primitive_callable(prim, *arg_specs, **params):
|
||||
handle_result = aval_to_result_handler(device, aval_out)
|
||||
else:
|
||||
handlers = tuple(map(partial(aval_to_result_handler, device), aval_out))
|
||||
handle_result = lambda xs: tuple(h(x) for h, x in zip(handlers, xs.destructure()))
|
||||
handle_result = lambda xs: tuple(h(x) for h, x in zip(handlers, xs))
|
||||
tuple_args = len(avals) > 100
|
||||
if prim in initial_style_translations:
|
||||
nreps = initial_style_primitive_replicas(params)
|
||||
@ -249,30 +249,23 @@ def _execute_compiled_primitive(prim, compiled, backend, tuple_args,
|
||||
result_handler, *args):
|
||||
device, = compiled.local_devices()
|
||||
input_bufs = [device_put(x, device) for x in args if x is not token]
|
||||
if tuple_args:
|
||||
input_bufs = [make_tuple(input_bufs, device, backend)]
|
||||
out_buf = compiled.Execute(input_bufs)
|
||||
out_bufs = compiled.Execute(input_bufs, tuple_arguments=tuple_args)
|
||||
if FLAGS.jax_debug_nans:
|
||||
check_nans(prim, out_buf.destructure() if prim.multiple_results else out_buf)
|
||||
return result_handler(out_buf)
|
||||
check_nans(prim, out_bufs)
|
||||
return result_handler(out_bufs if prim.multiple_results else out_bufs[0])
|
||||
|
||||
def _execute_replicated_primitive(prim, compiled, backend, tuple_args,
|
||||
result_handler, *args):
|
||||
input_bufs = [
|
||||
[device_put(x, device) for x in args if x is not token]
|
||||
for device in compiled.local_devices()]
|
||||
if tuple_args:
|
||||
input_bufs = [[make_tuple(bufs, device, backend)] for bufs, device in
|
||||
zip(input_bufs, compiled.local_devices())]
|
||||
out_buf = compiled.ExecuteOnLocalDevices(input_bufs)[0]
|
||||
out_buf = compiled.ExecuteOnLocalDevices(
|
||||
input_bufs, tuple_arguments=tuple_args)[0][0]
|
||||
return result_handler(out_buf)
|
||||
|
||||
def check_nans(prim, bufs):
|
||||
if prim.multiple_results:
|
||||
for buf in bufs:
|
||||
_check_nans(prim.name, buf.shape(), buf)
|
||||
else:
|
||||
_check_nans(prim.name, bufs.shape(), bufs)
|
||||
for buf in bufs:
|
||||
_check_nans(prim.name, buf.shape(), buf)
|
||||
|
||||
def _check_nans(name, xla_shape, buf):
|
||||
assert not xla_shape.is_tuple()
|
||||
@ -564,9 +557,7 @@ def _pval_to_result_handler(device, pval):
|
||||
def _execute_compiled(compiled, backend, handlers, tuple_args, *args):
|
||||
device, = compiled.local_devices()
|
||||
input_bufs = [device_put(x, device) for x in args if x is not token]
|
||||
if tuple_args:
|
||||
input_bufs = [make_tuple(input_bufs, device, backend)]
|
||||
out_bufs = compiled.Execute(input_bufs).destructure()
|
||||
out_bufs = compiled.Execute(input_bufs, tuple_arguments=tuple_args)
|
||||
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
|
||||
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
|
||||
|
||||
@ -574,10 +565,8 @@ def _execute_replicated(compiled, backend, handlers, tuple_args, *args):
|
||||
input_bufs = [
|
||||
[device_put(x, device) for x in args if x is not token]
|
||||
for device in compiled.local_devices()]
|
||||
if tuple_args:
|
||||
input_bufs = [[make_tuple(bufs, device, backend)] for bufs, device in
|
||||
zip(input_bufs, compiled.local_devices())]
|
||||
out_bufs = compiled.ExecuteOnLocalDevices(input_bufs)[0].destructure()
|
||||
out_bufs = compiled.ExecuteOnLocalDevices(
|
||||
input_bufs, tuple_arguments=tuple_args)[0]
|
||||
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
|
||||
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
|
||||
|
||||
@ -590,9 +579,6 @@ def _execute_trivial(jaxpr, device, consts, handlers, *args):
|
||||
return [_copy_device_array_to_device(x, device) if type(x) is DeviceArray
|
||||
else h(device_put(x, device)) for h, x in zip(handlers, outs)]
|
||||
|
||||
def make_tuple(bufs, device, backend):
|
||||
return xb.get_backend(backend).make_tuple(bufs, device)
|
||||
|
||||
@memoize
|
||||
def _get_device(device, backend):
|
||||
# TODO(mattjj): after jaxlib update, avoid compile here, just to get device
|
||||
@ -960,9 +946,11 @@ def _lazy_force_computation(sticky, aval, device, lexpr):
|
||||
result_device = device if sticky else None
|
||||
handler = partial(DeviceArray, aval, result_device, lazy.array(aval.shape))
|
||||
if lazy.is_constant(lexpr):
|
||||
force_fun = lambda _: handler(compiled.Execute([]))
|
||||
def force_fun(_):
|
||||
return handler(compiled.Execute([], tuple_arguments=False)[0])
|
||||
else:
|
||||
force_fun = lambda x: handler(compiled.Execute([x.device_buffer]))
|
||||
def force_fun(x):
|
||||
return handler(compiled.Execute([x.device_buffer], tuple_arguments=False)[0])
|
||||
return force_fun
|
||||
|
||||
|
||||
|
@ -117,27 +117,12 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
f(1, 2, z=onp.zeros(3)) # doesn't crash
|
||||
|
||||
def test_jit_many_args_tuples(self):
|
||||
def test_jit_with_many_args_works(self):
|
||||
@jit
|
||||
def f(args_list):
|
||||
return sum(args_list)
|
||||
|
||||
make_tuple = xla.make_tuple
|
||||
|
||||
counts = [0]
|
||||
def make_tuple_and_count(*args, **kwargs):
|
||||
counts[0] += 1
|
||||
return make_tuple(*args, **kwargs)
|
||||
|
||||
try:
|
||||
xla.make_tuple = make_tuple_and_count
|
||||
ans = f(list(range(500)))
|
||||
finally:
|
||||
xla.make_tuple = make_tuple
|
||||
|
||||
expected = sum(range(500))
|
||||
self.assertEqual(counts[0], 1) # formed a tuple on dispatch
|
||||
self.assertEqual(ans, expected) # computed the correct result
|
||||
self.assertEqual(f(list(range(500))), sum(range(500)))
|
||||
|
||||
def test_grad_of_jit(self):
|
||||
side = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user