mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Prune unused inputs in jit.
- Python part based on: https://github.com/google/jax/pull/6567 - Added cpp_jit path to handle pruned args PiperOrigin-RevId: 371743277
This commit is contained in:
parent
e6bdcbb674
commit
850bd66242
@ -134,13 +134,37 @@ def jit_simple_many_args(n, state):
|
||||
while state:
|
||||
f(args).block_until_ready()
|
||||
|
||||
def jit_simple_pruned_args_dispatch(n, state):
|
||||
args = [jax.device_put(i) for i in range(n)]
|
||||
f = jax.jit(lambda *xs: xs[0] + 1)
|
||||
x = f(*args)
|
||||
x.block_until_ready()
|
||||
|
||||
while state:
|
||||
x = f(*args)
|
||||
x.block_until_ready()
|
||||
|
||||
|
||||
def jit_simple_pruned_args(n, state):
|
||||
args = [jax.device_put(i) for i in range(n)]
|
||||
f = jax.jit(lambda *xs: xs[0] + 1)
|
||||
x = f(*args)
|
||||
x.block_until_ready()
|
||||
|
||||
while state:
|
||||
f(*args).block_until_ready()
|
||||
|
||||
benchmarks = []
|
||||
for n in [10, 100, 1000, 2000]:
|
||||
benchmarks += [
|
||||
google_benchmark.register(partial(jit_simple_many_args_dispatch, n),
|
||||
name=f"jit_simple_many_args_dispatch_{n}"),
|
||||
google_benchmark.register(partial(jit_simple_many_args, n),
|
||||
name=f"jit_simple_many_args_{n}")
|
||||
name=f"jit_simple_many_args_{n}"),
|
||||
google_benchmark.register(partial(jit_simple_pruned_args_dispatch, n),
|
||||
name=f"jit_simple_pruned_args_dispatch_{n}"),
|
||||
google_benchmark.register(partial(jit_simple_pruned_args, n),
|
||||
name=f"jit_simple_pruned_args_{n}")
|
||||
]
|
||||
|
||||
|
||||
|
@ -355,6 +355,13 @@ class _BackendAndDeviceInfo(NamedTuple):
|
||||
default_device: xc.Device
|
||||
committed_to_device: bool
|
||||
|
||||
class _FastpathData(NamedTuple):
|
||||
xla_executable: xla.XlaExecutable
|
||||
out_pytree_def: Any
|
||||
sticky_device: xc.Device
|
||||
avals: Iterable[Any]
|
||||
lazy_exprs: Iterable[Any]
|
||||
kept_var_bitvec: Iterable[bool]
|
||||
|
||||
if lib._xla_extension_version >= 16:
|
||||
_cpp_jit_cache = jax_jit.CompiledFunctionCache()
|
||||
@ -442,7 +449,7 @@ def _cpp_jit(
|
||||
all(xla.type_is_device_array(x) for x in out_flat))
|
||||
### If we can use the fastpath, we return required info to the caller.
|
||||
if use_fastpath:
|
||||
xla_executable, _, result_handlers = execute.args
|
||||
xla_executable, _, result_handlers, kept_var_idx = execute.args
|
||||
sticky_device = None
|
||||
avals = []
|
||||
lazy_exprs = [None] * len(result_handlers)
|
||||
@ -450,7 +457,14 @@ def _cpp_jit(
|
||||
aval, sticky_device = result_handler.args
|
||||
avals.append(aval)
|
||||
assert len(avals) == len(out_flat)
|
||||
fastpath_data = (xla_executable, out_pytree_def, sticky_device, avals, lazy_exprs)
|
||||
if xla._ALLOW_ARG_PRUNING:
|
||||
kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))]
|
||||
fastpath_data = _FastpathData(xla_executable, out_pytree_def,
|
||||
sticky_device, avals, lazy_exprs,
|
||||
kept_var_bitvec)
|
||||
else:
|
||||
fastpath_data = (xla_executable, out_pytree_def, sticky_device, avals,
|
||||
lazy_exprs)
|
||||
else:
|
||||
fastpath_data = None
|
||||
|
||||
|
@ -38,6 +38,7 @@ from .._src.util import (partial, partialmethod, cache, prod, unzip2,
|
||||
extend_name_stack, wrap_name, safe_zip, safe_map)
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from ..lib import _xla_extension_version
|
||||
from . import partial_eval as pe
|
||||
from . import ad
|
||||
from . import masking
|
||||
@ -647,10 +648,17 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
"got device={} and backend={}".format(device, backend))
|
||||
|
||||
abstract_args, arg_devices = unzip2(arg_specs)
|
||||
abstract_args, _ = unzip2(arg_specs)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
|
||||
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
|
||||
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
||||
pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx)
|
||||
abstract_args, arg_devices = unzip2(pruned_arg_specs)
|
||||
donated_invars = [
|
||||
x for i, x in enumerate(donated_invars) if i in kept_var_idx
|
||||
]
|
||||
map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
|
||||
jaxpr = apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
@ -663,7 +671,8 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
# which are often produced from partial evaluation, don't need compilation,
|
||||
# and don't need to evaluate their arguments.
|
||||
if not jaxpr.eqns:
|
||||
return partial(_execute_trivial, jaxpr, device, consts, out_avals, result_handlers)
|
||||
return partial(_execute_trivial, jaxpr, device, consts, out_avals,
|
||||
result_handlers, kept_var_idx)
|
||||
|
||||
if not _on_exit:
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
@ -714,9 +723,12 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
options.parameter_is_tupled_arguments = tuple_args
|
||||
compiled = backend_compile(backend, built, options)
|
||||
if nreps == 1:
|
||||
return partial(_execute_compiled, compiled, out_avals, result_handlers)
|
||||
return partial(_execute_compiled, compiled, out_avals, result_handlers,
|
||||
kept_var_idx)
|
||||
else:
|
||||
return partial(_execute_replicated, compiled, out_avals, result_handlers)
|
||||
return partial(_execute_replicated, compiled, out_avals, result_handlers,
|
||||
kept_var_idx)
|
||||
|
||||
|
||||
def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):
|
||||
"""Configures input/output "must" aliasing based on `donated_args`."""
|
||||
@ -746,6 +758,33 @@ def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):
|
||||
|
||||
return tuple(out_donated_args)
|
||||
|
||||
|
||||
# Pruning unused JIT arguments require jaxlib 0.1.66 or newer.
|
||||
# TODO(zhangqiaorjc): remove when jaxlib 0.1.66 is the minimum.
|
||||
_ALLOW_ARG_PRUNING = _xla_extension_version >= 18
|
||||
|
||||
|
||||
def _prune_unused_inputs(
|
||||
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
|
||||
if not _ALLOW_ARG_PRUNING:
|
||||
kept_const_idx = range(len(jaxpr.constvars))
|
||||
kept_var_idx = range(len(jaxpr.invars))
|
||||
return jaxpr, set(kept_const_idx), set(kept_var_idx)
|
||||
|
||||
used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
|
||||
# TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive
|
||||
# applications that do not produce used outputs. Must handle side-effecting
|
||||
# primitives and nested jaxpr.
|
||||
used.update(
|
||||
v for eqn in jaxpr.eqns for v in eqn.invars if isinstance(v, core.Var))
|
||||
kept_const_idx, new_constvars = unzip2(
|
||||
(i, v) for i, v in enumerate(jaxpr.constvars) if v in used)
|
||||
kept_var_idx, new_invars = unzip2(
|
||||
(i, v) for i, v in enumerate(jaxpr.invars) if v in used)
|
||||
new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns)
|
||||
return new_jaxpr, set(kept_const_idx), set(kept_var_idx)
|
||||
|
||||
|
||||
def _xla_callable_device(nreps, backend, device, arg_devices):
|
||||
if nreps > 1:
|
||||
if device is not None or backend is not None:
|
||||
@ -823,17 +862,30 @@ def _xla_param(builder, param_num, xla_shape, replicated, partitions, parts_prot
|
||||
else:
|
||||
return with_sharding(builder, partitions, make_param)
|
||||
|
||||
def _execute_compiled(compiled: XlaExecutable, avals, handlers, *args):
|
||||
|
||||
def _execute_compiled(compiled: XlaExecutable, avals, handlers, kept_var_idx,
|
||||
*args):
|
||||
device, = compiled.local_devices()
|
||||
input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
|
||||
input_bufs = list(
|
||||
it.chain.from_iterable(
|
||||
device_put(x, device)
|
||||
for i, x in enumerate(args)
|
||||
if x is not token and i in kept_var_idx))
|
||||
out_bufs = compiled.execute(input_bufs)
|
||||
check_special(xla_call_p.name, out_bufs)
|
||||
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]
|
||||
|
||||
def _execute_replicated(compiled: XlaExecutable, avals, handlers, *args):
|
||||
|
||||
def _execute_replicated(compiled: XlaExecutable, avals, handlers, kept_var_idx,
|
||||
*args):
|
||||
input_bufs = [
|
||||
list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
|
||||
for device in compiled.local_devices()]
|
||||
list(
|
||||
it.chain.from_iterable(
|
||||
device_put(x, device)
|
||||
for i, x in enumerate(args)
|
||||
if x is not token and i in kept_var_idx))
|
||||
for device in compiled.local_devices()
|
||||
]
|
||||
out_bufs = [
|
||||
buf[0] for buf in compiled.execute_sharded_on_local_devices(
|
||||
list(zip(*input_bufs)))
|
||||
@ -841,9 +893,12 @@ def _execute_replicated(compiled: XlaExecutable, avals, handlers, *args):
|
||||
check_special(xla_call_p.name, out_bufs)
|
||||
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]
|
||||
|
||||
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, *args):
|
||||
|
||||
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
|
||||
kept_var_idx, *args):
|
||||
env = {core.unitvar: core.unit}
|
||||
map(env.setdefault, jaxpr.invars, args)
|
||||
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
|
||||
map(env.setdefault, jaxpr.invars, pruned_args)
|
||||
map(env.setdefault, jaxpr.constvars, consts)
|
||||
outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v]
|
||||
for v in jaxpr.outvars]
|
||||
|
@ -618,6 +618,23 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
x_is_tracer, y_is_tracer = False, True
|
||||
assert f_mixed(x='foo', y=3) == 1
|
||||
|
||||
# TODO(zhangqiaorjc): Test pruning constants after DCE pass prunes primitive
|
||||
# applications.
|
||||
@unittest.skipIf(not xla._ALLOW_ARG_PRUNING, "Test requires jaxlib 0.1.66")
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_num_args={}".format(num_args),
|
||||
"num_args": num_args}
|
||||
for num_args in [2, 3, 4]))
|
||||
def test_jit_with_pruned_args(self, num_args):
|
||||
def f(*args):
|
||||
used = np.array(2)
|
||||
return args[1] + used
|
||||
f_pruned = self.jit(f)
|
||||
args = range(num_args)
|
||||
with jtu.count_device_put() as count:
|
||||
np.testing.assert_allclose(f_pruned(*args), 3)
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
|
||||
class PythonJitTest(CPPJitTest):
|
||||
|
||||
|
40
tests/xla_interpreter_test.py
Normal file
40
tests/xla_interpreter_test.py
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax import test_util as jtu
|
||||
from jax._src import api
|
||||
from jax.interpreters import xla
|
||||
|
||||
|
||||
class XlaInterpreterTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(not xla._ALLOW_ARG_PRUNING, "Test requires jaxlib 0.1.66")
|
||||
def test_prune_jit_args(self):
|
||||
def f(*args):
|
||||
return args[0]
|
||||
|
||||
closed_jaxpr = api.make_jaxpr(f)(*range(10))
|
||||
pruned_jaxpr, kept_const_idx, kept_var_idx = xla._prune_unused_inputs(
|
||||
closed_jaxpr.jaxpr)
|
||||
assert len(pruned_jaxpr.invars) == 1
|
||||
assert kept_const_idx == set()
|
||||
assert kept_var_idx == {0}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user