Copybara import of the project:

--
3c400a3e588abf9e2259119c50343cba6f3477f1 by Matthew Johnson <mattjj@google.com>:

add 'inline' option to xla_call for jaxpr inlining

--
fe297e39ca37896b75d7943b9b77c0b53fad13ee by Matthew Johnson <mattjj@google.com>:

add 'inline' to jit docstring

--
ff6866c4b3757cde66fe659c2f27d8aeff024e8f by Matthew Johnson <mattjj@google.com>:

new_sublevel in jax2tf

PiperOrigin-RevId: 371542778
This commit is contained in:
jax authors 2021-05-01 22:18:13 -07:00
parent e8f209c775
commit 75b00a1235
10 changed files with 82 additions and 111 deletions

View File

@ -433,7 +433,6 @@ computation should run. For example
in (g,) }
device=None
donated_invars=(False, False)
inline=False
name=inner ] a b
d = convert_element_type[ new_dtype=float32
weak_type=False ] a

View File

@ -212,7 +212,6 @@ def jit(
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> F:
"""Sets up ``fun`` for just-in-time compilation with XLA.
@ -263,10 +262,7 @@ def jit(
buffers to reduce the amount of memory needed to perform a computation,
for example recycling one of your input buffers to store a result. You
should not reuse buffers that you donate to a computation, JAX will raise
an error if you try to. By default, no arguments are donated.
inline: Specify whether this function should be inlined into enclosing
jaxprs (rather than being represented as an application of the xla_call
primitive with its own subjaxpr). Default False.
an error if you try to.
Returns:
A wrapped version of ``fun``, set up for just-in-time compilation.
@ -292,10 +288,10 @@ def jit(
static_argnames = ()
if FLAGS.experimental_cpp_jit:
return _cpp_jit(fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline)
donate_argnums)
else:
return _python_jit(fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline)
donate_argnums)
def _python_jit(
@ -304,10 +300,9 @@ def _python_jit(
static_argnames: Union[str, Iterable[str], None] = None,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
donate_argnums: Union[int, Iterable[int]] = ()
) -> F:
# The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit.
"""The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit."""
_check_callable(fun)
static_argnums, static_argnames = _infer_argnums_and_argnames(
fun, static_argnums, static_argnames)
@ -344,8 +339,13 @@ def _python_jit(
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
name=flat_fun.__name__, donated_invars=donated_invars, inline=inline)
out = xla.xla_call(
flat_fun,
*args_flat,
device=device,
backend=backend,
name=flat_fun.__name__,
donated_invars=donated_invars)
return tree_unflatten(out_tree(), out)
return f_jitted
@ -366,15 +366,16 @@ def _cpp_jit(
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> F:
# An implementation of `jit` that tries to do as much as possible in C++.
# The goal of this function is to speed up the time it takes to process the
# arguments, find the correct C++ executable, start the transfer of arguments
# and schedule the computation.
# As long as it does not support all features of the Python implementation
# the C++ code will fallback to `_python_jit` when it faces some unsupported
# feature.
"""An implementation of `jit` that tries to do as much as possible in C++.
The goal of this function is to speed up the time it takes to process the
arguments, find the correct C++ executable, start the transfer of arguments
and schedule the computation.
As long as it does not support all features of the Python implementation
the C++ code will fallback to `_python_jit` when it faces some unsupported
feature.
"""
_check_callable(fun)
static_argnums, static_argnames = _infer_argnums_and_argnames(
fun, static_argnums, static_argnames)
@ -416,9 +417,12 @@ def _cpp_jit(
_check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)
out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
flat_fun,
*args_flat,
device=device,
backend=backend,
name=flat_fun.__name__,
donated_invars=donated_invars)
out_pytree_def = out_tree()
out = tree_unflatten(out_pytree_def, out_flat)

View File

@ -270,13 +270,13 @@ class CustomJVPCallPrimitive(core.CallPrimitive):
jvp, env_trace_todo2 = core.process_env_traces(
jvp, self, top_trace and top_trace.level, (), None)
tracers = map(top_trace.full_raise, args) # type: ignore
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
with core.maybe_new_sublevel(top_trace):
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
return _apply_todos(env_trace_todo, map(core.full_lower, outs))
def impl(self, fun, _, *args):
with core.new_sublevel():
return fun.call_wrapped(*args)
return fun.call_wrapped(*args)
def post_process(self, trace, out_tracers, params):
return trace.post_process_custom_jvp_call(out_tracers, params)
@ -563,15 +563,15 @@ class CustomVJPCallPrimitive(core.CallPrimitive):
fwd, env_trace_todo2 = core.process_env_traces(
fwd, self, top_trace and top_trace.level, (), None)
tracers = map(top_trace.full_raise, args) # type: ignore
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
out_trees=out_trees)
with core.maybe_new_sublevel(top_trace):
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
out_trees=out_trees)
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
return _apply_todos(env_trace_todo, map(core.full_lower, outs))
def impl(self, fun, fwd, bwd, *args, out_trees):
del fwd, bwd, out_trees
with core.new_sublevel():
return fun.call_wrapped(*args)
return fun.call_wrapped(*args)
def post_process(self, trace, out_tracers, params):
return trace.post_process_custom_vjp_call(out_tracers, params)

View File

@ -15,7 +15,7 @@
import operator
from operator import attrgetter
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from collections import namedtuple
from functools import total_ordering
import itertools as it
@ -611,13 +611,11 @@ class EvalTrace(Trace):
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
del primitive, jvp # Unused.
with new_sublevel():
return fun.call_wrapped(*tracers)
return fun.call_wrapped(*tracers)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
del primitive, fwd, bwd, out_trees # Unused.
with new_sublevel():
return fun.call_wrapped(*tracers)
return fun.call_wrapped(*tracers)
class MainTrace:
@ -808,6 +806,11 @@ def new_sublevel() -> Generator[None, None, None]:
if t() is not None:
raise Exception(f'Leaked sublevel {t()}.')
def maybe_new_sublevel(trace):
# dynamic traces run the WrappedFun, so we raise the sublevel for them
dynamic = thread_local_state.trace_state.trace_stack.dynamic
return new_sublevel() if trace.main is dynamic else suppress()
def full_lower(val):
if isinstance(val, Tracer):
return val.full_lower()
@ -1549,7 +1552,8 @@ def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
fun, primitive, top_trace and top_trace.level,
params_tuple, out_axes_transforms)
tracers = map(top_trace.full_raise, args)
outs = primitive.process(top_trace, fun, tracers, params)
with maybe_new_sublevel(top_trace):
outs = primitive.process(top_trace, fun, tracers, params)
return map(full_lower, apply_todos(env_trace_todo(), outs))
@ -1568,8 +1572,7 @@ class CallPrimitive(Primitive):
def call_impl(f: lu.WrappedFun, *args, **params):
del params # params parameterize the call primitive, not the function
with new_sublevel():
return f.call_wrapped(*args)
return f.call_wrapped(*args)
call_p = CallPrimitive('call')
call = call_p.bind

View File

@ -1419,8 +1419,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
dict(
call_jaxpr=transformed_cond_jaxpr.jaxpr,
name="cond_before",
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
inline=False),
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
eqn.source_info))
# Make a new cond "lambda pred, carry, token, itoken: pred"
new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
@ -1463,8 +1462,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
dict(
call_jaxpr=transformed_body_jaxpr.jaxpr,
name="body",
donated_invars=(False,) * len(transformed_body_jaxpr.in_avals),
inline=False),
donated_invars=(False,) * len(transformed_body_jaxpr.in_avals)),
eqn.source_info),
core.new_jaxpr_eqn(
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2],
@ -1472,8 +1470,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
dict(
call_jaxpr=transformed_cond_jaxpr.jaxpr,
name="cond_body",
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
inline=False),
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
eqn.source_info)
]
new_body_jaxpr = core.ClosedJaxpr(

View File

@ -310,9 +310,7 @@ def _interpret_fun(fun: lu.WrappedFun,
) -> Sequence[Tuple[TfVal, core.AbstractValue]]:
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
fun = _interpret_subtrace(fun, main, in_avals)
with core.new_sublevel():
out_vals: Sequence[Tuple[TfVal, core.AbstractValue]] = \
fun.call_wrapped(*in_vals)
out_vals: Sequence[Tuple[TfVal, core.AbstractValue]] = fun.call_wrapped(*in_vals)
del main
return tuple(out_vals)
@ -677,15 +675,14 @@ class TensorFlowTrace(core.Trace):
assert call_primitive.multiple_results
vals: Sequence[TfVal] = [t.val for t in tracers]
f = _interpret_subtrace(f, self.main, tuple(t.aval for t in tracers))
with core.new_sublevel():
if call_primitive == core.named_call_p:
with tf.name_scope(_sanitize_scope_name(params["name"])):
vals_out: Sequence[Tuple[TfVal, core.AbstractValue]] = \
f.call_wrapped(*vals)
elif call_primitive == sharded_jit.sharded_call_p:
vals_out = _sharded_call(f, vals, **params)
else:
vals_out = f.call_wrapped(*vals)
if call_primitive == core.named_call_p:
with tf.name_scope(_sanitize_scope_name(params["name"])):
vals_out: Sequence[Tuple[TfVal,
core.AbstractValue]] = f.call_wrapped(*vals)
elif call_primitive == sharded_jit.sharded_call_p:
vals_out = _sharded_call(f, vals, **params)
else:
vals_out = f.call_wrapped(*vals)
return [TensorFlowTracer(self, v, a) for v, a in vals_out]
def post_process_call(self, call_primitive: core.Primitive,
@ -2176,7 +2173,7 @@ def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal],
out_parts_thunk,
**_) -> Sequence[Tuple[TfVal, core.AbstractValue]]:
sharded_vals = util.safe_map(split_to_logical_devices, vals, in_parts)
vals_out = f.call_wrapped(*sharded_vals) # caller handles new_sublevel
vals_out = f.call_wrapped(*sharded_vals)
out_parts_flat = out_parts_thunk()
assert len(out_parts_flat) == len(vals_out), f"expected {len(out_parts_flat)} == {len(vals_out)}"
sharded_vals_out = [

View File

@ -1005,7 +1005,7 @@ def _xmap_translation_rule_replica(c, axis_env,
# NOTE: We don't extend the resource env with the mesh shape, because those
# resources are already in scope! It's the outermost xmap that introduces
# them!
vectorized_jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(f, local_avals)
vectorized_jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, local_avals)
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
assert not consts
@ -1123,7 +1123,7 @@ def _xmap_translation_rule_spmd(c, axis_env,
global_in_avals = [core.ShapedArray(xla_type.dimensions(), xla_type.numpy_dtype())
for in_node in global_in_nodes
for xla_type in (c.get_shape(in_node),)]
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(f, global_in_avals)
assert not consts
global_sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)

View File

@ -242,7 +242,6 @@ class JaxprTrace(Trace):
# We use post_process_call to handle both call and map primitives.
def post_process_call(self, primitive, out_tracers, params):
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
out = out_pv_consts + consts
@ -322,8 +321,7 @@ class JaxprTrace(Trace):
def jvp_jaxpr_thunk():
jvp_ = trace_to_subjaxpr(jvp, self.main, True)
jvp_, aux = partial_eval_wrapper(jvp_, tuple(in_avals) * 2)
with core.new_sublevel():
out_flat = jvp_.call_wrapped(*(in_consts * 2)) # in_consts are units
out_flat = jvp_.call_wrapped(*(in_consts * 2)) # in_consts are units
out_avals, jaxpr, env = aux()
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
@ -362,8 +360,7 @@ class JaxprTrace(Trace):
def fwd_jaxpr_thunk():
fwd_ = trace_to_subjaxpr(fwd, self.main, True)
fwd_, aux = partial_eval_wrapper(fwd_, tuple(in_avals))
with core.new_sublevel():
out_flat = fwd_.call_wrapped(*in_consts) # in_consts are units
out_flat = fwd_.call_wrapped(*in_consts) # in_consts are units
out_avals, jaxpr, env = aux()
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
@ -1061,9 +1058,8 @@ class DynamicJaxprTrace(core.Trace):
def process_call(self, call_primitive, f, tracers, params):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
if params.get('inline', False):
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
if not jaxpr.eqns:
return core.eval_jaxpr(jaxpr, consts, *tracers)
source_info = source_info_util.current()
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
@ -1089,9 +1085,8 @@ class DynamicJaxprTrace(core.Trace):
if in_axis is not None else a
for a, in_axis in zip(in_avals, params['in_axes'])]
with core.extend_axis_env(axis_name, axis_size, None): # type: ignore
with core.new_sublevel():
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, reduced_in_avals)
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, reduced_in_avals)
out_axes = params['out_axes_thunk']()
out_avals = [core.unmapped_aval(params['axis_size'], out_axis, a)
if out_axis is not None else a
@ -1118,8 +1113,7 @@ class DynamicJaxprTrace(core.Trace):
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
jvp_jaxpr_thunk = _memoize(
lambda: trace_to_subjaxpr_dynamic(jvp, self.main, 2 * in_avals)[::2])
@ -1140,8 +1134,7 @@ class DynamicJaxprTrace(core.Trace):
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
fwd_jaxpr_thunk = _memoize(
lambda: trace_to_subjaxpr_dynamic(fwd, self.main, in_avals)[::2])
@ -1213,8 +1206,7 @@ def trace_to_jaxpr_final(fun: lu.WrappedFun,
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.source_info = fun_sourceinfo(fun.f, transform_name) # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
del fun, main
return jaxpr, out_avals, consts

View File

@ -218,10 +218,7 @@ def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool:
### op-by-op execution
ArgSpec = Tuple[core.AbstractValue, Optional[Device]]
def arg_spec(x: Any) -> ArgSpec:
def arg_spec(x):
aval = abstractify(x)
try:
return aval, x._device
@ -243,7 +240,8 @@ def _partition_outputs(avals, outs):
@cache()
def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue,
Optional[Device]], **params):
avals, arg_devices = unzip2(arg_specs)
donated_invars = (False,) * len(arg_specs)
device = _device_from_arg_devices(arg_devices)
@ -575,9 +573,7 @@ def jaxpr_collectives(jaxpr):
### xla_call underlying jit
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
try:
@ -595,8 +591,7 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
# intentional here, to avoid "Store occupied" errors we reset the stores to
# be empty.
for store in fun.stores: store and store.reset()
with core.new_sublevel():
return fun.call_wrapped(*args) # probably won't return
return fun.call_wrapped(*args) # probably won't return
def flatten_shape(s: XlaShape) -> Sequence[Tuple[Sequence[int], XlaShape]]:
"""Expands a given shape tree into a flat list of indices to arrays.
@ -697,8 +692,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
c = xb.make_computation_builder("jit_{}".format(fun.__name__))
xla_consts = _xla_consts(c, consts)
xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args,
donated_invars=donated_invars)
xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args, donated_invars=donated_invars)
out_nodes = jaxpr_subcomp(
c, jaxpr, backend, AxisEnv(nreps, (), ()), xla_consts,
extend_name_stack(wrap_name(name, 'jit')), *xla_args)
@ -795,9 +789,9 @@ def _xla_callable_args(
for (a, r, p) in safe_zip(avals, replicated, parts)
for xla_shape in aval_to_xla_shapes(a)]
if donated_invars is not None:
donated_invars = [
d for (a, _, _, d) in zip(avals, replicated, parts, donated_invars)
for xla_shape in aval_to_xla_shapes(a)]
donated_invars = [d
for (a, r, p, d) in safe_zip(avals, replicated, parts, donated_invars)
for xla_shape in aval_to_xla_shapes(a)]
return xla_args, donated_invars
else:
if replicated is not None:
@ -891,8 +885,8 @@ ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
def _xla_call_translation_rule(c, axis_env,
in_nodes, name_stack, backend, name,
call_jaxpr, donated_invars, inline, device=None):
del device, donated_invars, inline # Ignored.
call_jaxpr, donated_invars, device=None):
del device, donated_invars # Ignored.
subc = xb.make_computation_builder(f"jit_{name}")
args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),

View File

@ -2343,7 +2343,7 @@ class APITest(jtu.JaxTestCase):
lst.append(x)
return x
with self.assertRaisesRegex(Exception, r"Leaked"):
with self.assertRaisesRegex(Exception, r"Leaked trace"):
f(3)
def test_leak_checker_catches_a_pmap_leak(self):
@ -2355,7 +2355,7 @@ class APITest(jtu.JaxTestCase):
lst.append(x)
return x
with self.assertRaisesRegex(Exception, r"Leaked"):
with self.assertRaisesRegex(Exception, r"Leaked trace"):
f(np.ones(1))
def test_leak_checker_catches_a_grad_leak(self):
@ -2678,21 +2678,6 @@ class APITest(jtu.JaxTestCase):
jtu.check_grads(batched_scan_over_mul, (x_batch, coeff), order=2,
modes=['rev'])
def test_jit_inline(self):
@partial(api.jit, inline=False)
def f(x):
return x * 2
jaxpr = api.make_jaxpr(f)(3)
self.assertIn('xla_call', str(jaxpr))
@partial(api.jit, inline=True)
def f(x):
return x * 2
jaxpr = api.make_jaxpr(f)(3)
self.assertNotIn('xla_call', str(jaxpr))
class RematTest(jtu.JaxTestCase):