mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Copybara import of the project:
-- 3c400a3e588abf9e2259119c50343cba6f3477f1 by Matthew Johnson <mattjj@google.com>: add 'inline' option to xla_call for jaxpr inlining -- fe297e39ca37896b75d7943b9b77c0b53fad13ee by Matthew Johnson <mattjj@google.com>: add 'inline' to jit docstring -- ff6866c4b3757cde66fe659c2f27d8aeff024e8f by Matthew Johnson <mattjj@google.com>: new_sublevel in jax2tf PiperOrigin-RevId: 371542778
This commit is contained in:
parent
e8f209c775
commit
75b00a1235
@ -433,7 +433,6 @@ computation should run. For example
|
||||
in (g,) }
|
||||
device=None
|
||||
donated_invars=(False, False)
|
||||
inline=False
|
||||
name=inner ] a b
|
||||
d = convert_element_type[ new_dtype=float32
|
||||
weak_type=False ] a
|
||||
|
@ -212,7 +212,6 @@ def jit(
|
||||
device: Optional[xc.Device] = None,
|
||||
backend: Optional[str] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
inline: bool = False,
|
||||
) -> F:
|
||||
"""Sets up ``fun`` for just-in-time compilation with XLA.
|
||||
|
||||
@ -263,10 +262,7 @@ def jit(
|
||||
buffers to reduce the amount of memory needed to perform a computation,
|
||||
for example recycling one of your input buffers to store a result. You
|
||||
should not reuse buffers that you donate to a computation, JAX will raise
|
||||
an error if you try to. By default, no arguments are donated.
|
||||
inline: Specify whether this function should be inlined into enclosing
|
||||
jaxprs (rather than being represented as an application of the xla_call
|
||||
primitive with its own subjaxpr). Default False.
|
||||
an error if you try to.
|
||||
|
||||
Returns:
|
||||
A wrapped version of ``fun``, set up for just-in-time compilation.
|
||||
@ -292,10 +288,10 @@ def jit(
|
||||
static_argnames = ()
|
||||
if FLAGS.experimental_cpp_jit:
|
||||
return _cpp_jit(fun, static_argnums, static_argnames, device, backend,
|
||||
donate_argnums, inline)
|
||||
donate_argnums)
|
||||
else:
|
||||
return _python_jit(fun, static_argnums, static_argnames, device, backend,
|
||||
donate_argnums, inline)
|
||||
donate_argnums)
|
||||
|
||||
|
||||
def _python_jit(
|
||||
@ -304,10 +300,9 @@ def _python_jit(
|
||||
static_argnames: Union[str, Iterable[str], None] = None,
|
||||
device: Optional[xc.Device] = None,
|
||||
backend: Optional[str] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
inline: bool = False,
|
||||
donate_argnums: Union[int, Iterable[int]] = ()
|
||||
) -> F:
|
||||
# The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit.
|
||||
"""The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit."""
|
||||
_check_callable(fun)
|
||||
static_argnums, static_argnames = _infer_argnums_and_argnames(
|
||||
fun, static_argnums, static_argnames)
|
||||
@ -344,8 +339,13 @@ def _python_jit(
|
||||
for arg in args_flat:
|
||||
_check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
|
||||
name=flat_fun.__name__, donated_invars=donated_invars, inline=inline)
|
||||
out = xla.xla_call(
|
||||
flat_fun,
|
||||
*args_flat,
|
||||
device=device,
|
||||
backend=backend,
|
||||
name=flat_fun.__name__,
|
||||
donated_invars=donated_invars)
|
||||
return tree_unflatten(out_tree(), out)
|
||||
|
||||
return f_jitted
|
||||
@ -366,15 +366,16 @@ def _cpp_jit(
|
||||
device: Optional[xc.Device] = None,
|
||||
backend: Optional[str] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
inline: bool = False,
|
||||
) -> F:
|
||||
# An implementation of `jit` that tries to do as much as possible in C++.
|
||||
# The goal of this function is to speed up the time it takes to process the
|
||||
# arguments, find the correct C++ executable, start the transfer of arguments
|
||||
# and schedule the computation.
|
||||
# As long as it does not support all features of the Python implementation
|
||||
# the C++ code will fallback to `_python_jit` when it faces some unsupported
|
||||
# feature.
|
||||
"""An implementation of `jit` that tries to do as much as possible in C++.
|
||||
|
||||
The goal of this function is to speed up the time it takes to process the
|
||||
arguments, find the correct C++ executable, start the transfer of arguments
|
||||
and schedule the computation.
|
||||
As long as it does not support all features of the Python implementation
|
||||
the C++ code will fallback to `_python_jit` when it faces some unsupported
|
||||
feature.
|
||||
"""
|
||||
_check_callable(fun)
|
||||
static_argnums, static_argnames = _infer_argnums_and_argnames(
|
||||
fun, static_argnums, static_argnames)
|
||||
@ -416,9 +417,12 @@ def _cpp_jit(
|
||||
_check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
out_flat = xla.xla_call(
|
||||
flat_fun, *args_flat,
|
||||
device=device, backend=backend, name=flat_fun.__name__,
|
||||
donated_invars=donated_invars, inline=inline)
|
||||
flat_fun,
|
||||
*args_flat,
|
||||
device=device,
|
||||
backend=backend,
|
||||
name=flat_fun.__name__,
|
||||
donated_invars=donated_invars)
|
||||
out_pytree_def = out_tree()
|
||||
out = tree_unflatten(out_pytree_def, out_flat)
|
||||
|
||||
|
@ -270,13 +270,13 @@ class CustomJVPCallPrimitive(core.CallPrimitive):
|
||||
jvp, env_trace_todo2 = core.process_env_traces(
|
||||
jvp, self, top_trace and top_trace.level, (), None)
|
||||
tracers = map(top_trace.full_raise, args) # type: ignore
|
||||
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
|
||||
with core.maybe_new_sublevel(top_trace):
|
||||
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
|
||||
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
|
||||
return _apply_todos(env_trace_todo, map(core.full_lower, outs))
|
||||
|
||||
def impl(self, fun, _, *args):
|
||||
with core.new_sublevel():
|
||||
return fun.call_wrapped(*args)
|
||||
return fun.call_wrapped(*args)
|
||||
|
||||
def post_process(self, trace, out_tracers, params):
|
||||
return trace.post_process_custom_jvp_call(out_tracers, params)
|
||||
@ -563,15 +563,15 @@ class CustomVJPCallPrimitive(core.CallPrimitive):
|
||||
fwd, env_trace_todo2 = core.process_env_traces(
|
||||
fwd, self, top_trace and top_trace.level, (), None)
|
||||
tracers = map(top_trace.full_raise, args) # type: ignore
|
||||
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
|
||||
out_trees=out_trees)
|
||||
with core.maybe_new_sublevel(top_trace):
|
||||
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
|
||||
out_trees=out_trees)
|
||||
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
|
||||
return _apply_todos(env_trace_todo, map(core.full_lower, outs))
|
||||
|
||||
def impl(self, fun, fwd, bwd, *args, out_trees):
|
||||
del fwd, bwd, out_trees
|
||||
with core.new_sublevel():
|
||||
return fun.call_wrapped(*args)
|
||||
return fun.call_wrapped(*args)
|
||||
|
||||
def post_process(self, trace, out_tracers, params):
|
||||
return trace.post_process_custom_vjp_call(out_tracers, params)
|
||||
|
19
jax/core.py
19
jax/core.py
@ -15,7 +15,7 @@
|
||||
|
||||
import operator
|
||||
from operator import attrgetter
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, suppress
|
||||
from collections import namedtuple
|
||||
from functools import total_ordering
|
||||
import itertools as it
|
||||
@ -611,13 +611,11 @@ class EvalTrace(Trace):
|
||||
|
||||
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
|
||||
del primitive, jvp # Unused.
|
||||
with new_sublevel():
|
||||
return fun.call_wrapped(*tracers)
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
||||
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
|
||||
del primitive, fwd, bwd, out_trees # Unused.
|
||||
with new_sublevel():
|
||||
return fun.call_wrapped(*tracers)
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
||||
|
||||
class MainTrace:
|
||||
@ -808,6 +806,11 @@ def new_sublevel() -> Generator[None, None, None]:
|
||||
if t() is not None:
|
||||
raise Exception(f'Leaked sublevel {t()}.')
|
||||
|
||||
def maybe_new_sublevel(trace):
|
||||
# dynamic traces run the WrappedFun, so we raise the sublevel for them
|
||||
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
||||
return new_sublevel() if trace.main is dynamic else suppress()
|
||||
|
||||
def full_lower(val):
|
||||
if isinstance(val, Tracer):
|
||||
return val.full_lower()
|
||||
@ -1549,7 +1552,8 @@ def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
|
||||
fun, primitive, top_trace and top_trace.level,
|
||||
params_tuple, out_axes_transforms)
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = primitive.process(top_trace, fun, tracers, params)
|
||||
with maybe_new_sublevel(top_trace):
|
||||
outs = primitive.process(top_trace, fun, tracers, params)
|
||||
return map(full_lower, apply_todos(env_trace_todo(), outs))
|
||||
|
||||
|
||||
@ -1568,8 +1572,7 @@ class CallPrimitive(Primitive):
|
||||
|
||||
def call_impl(f: lu.WrappedFun, *args, **params):
|
||||
del params # params parameterize the call primitive, not the function
|
||||
with new_sublevel():
|
||||
return f.call_wrapped(*args)
|
||||
return f.call_wrapped(*args)
|
||||
|
||||
call_p = CallPrimitive('call')
|
||||
call = call_p.bind
|
||||
|
@ -1419,8 +1419,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
dict(
|
||||
call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
||||
name="cond_before",
|
||||
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
|
||||
inline=False),
|
||||
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
|
||||
eqn.source_info))
|
||||
# Make a new cond "lambda pred, carry, token, itoken: pred"
|
||||
new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
|
||||
@ -1463,8 +1462,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
dict(
|
||||
call_jaxpr=transformed_body_jaxpr.jaxpr,
|
||||
name="body",
|
||||
donated_invars=(False,) * len(transformed_body_jaxpr.in_avals),
|
||||
inline=False),
|
||||
donated_invars=(False,) * len(transformed_body_jaxpr.in_avals)),
|
||||
eqn.source_info),
|
||||
core.new_jaxpr_eqn(
|
||||
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2],
|
||||
@ -1472,8 +1470,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
dict(
|
||||
call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
||||
name="cond_body",
|
||||
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
|
||||
inline=False),
|
||||
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
|
||||
eqn.source_info)
|
||||
]
|
||||
new_body_jaxpr = core.ClosedJaxpr(
|
||||
|
@ -310,9 +310,7 @@ def _interpret_fun(fun: lu.WrappedFun,
|
||||
) -> Sequence[Tuple[TfVal, core.AbstractValue]]:
|
||||
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
|
||||
fun = _interpret_subtrace(fun, main, in_avals)
|
||||
with core.new_sublevel():
|
||||
out_vals: Sequence[Tuple[TfVal, core.AbstractValue]] = \
|
||||
fun.call_wrapped(*in_vals)
|
||||
out_vals: Sequence[Tuple[TfVal, core.AbstractValue]] = fun.call_wrapped(*in_vals)
|
||||
del main
|
||||
return tuple(out_vals)
|
||||
|
||||
@ -677,15 +675,14 @@ class TensorFlowTrace(core.Trace):
|
||||
assert call_primitive.multiple_results
|
||||
vals: Sequence[TfVal] = [t.val for t in tracers]
|
||||
f = _interpret_subtrace(f, self.main, tuple(t.aval for t in tracers))
|
||||
with core.new_sublevel():
|
||||
if call_primitive == core.named_call_p:
|
||||
with tf.name_scope(_sanitize_scope_name(params["name"])):
|
||||
vals_out: Sequence[Tuple[TfVal, core.AbstractValue]] = \
|
||||
f.call_wrapped(*vals)
|
||||
elif call_primitive == sharded_jit.sharded_call_p:
|
||||
vals_out = _sharded_call(f, vals, **params)
|
||||
else:
|
||||
vals_out = f.call_wrapped(*vals)
|
||||
if call_primitive == core.named_call_p:
|
||||
with tf.name_scope(_sanitize_scope_name(params["name"])):
|
||||
vals_out: Sequence[Tuple[TfVal,
|
||||
core.AbstractValue]] = f.call_wrapped(*vals)
|
||||
elif call_primitive == sharded_jit.sharded_call_p:
|
||||
vals_out = _sharded_call(f, vals, **params)
|
||||
else:
|
||||
vals_out = f.call_wrapped(*vals)
|
||||
return [TensorFlowTracer(self, v, a) for v, a in vals_out]
|
||||
|
||||
def post_process_call(self, call_primitive: core.Primitive,
|
||||
@ -2176,7 +2173,7 @@ def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal],
|
||||
out_parts_thunk,
|
||||
**_) -> Sequence[Tuple[TfVal, core.AbstractValue]]:
|
||||
sharded_vals = util.safe_map(split_to_logical_devices, vals, in_parts)
|
||||
vals_out = f.call_wrapped(*sharded_vals) # caller handles new_sublevel
|
||||
vals_out = f.call_wrapped(*sharded_vals)
|
||||
out_parts_flat = out_parts_thunk()
|
||||
assert len(out_parts_flat) == len(vals_out), f"expected {len(out_parts_flat)} == {len(vals_out)}"
|
||||
sharded_vals_out = [
|
||||
|
@ -1005,7 +1005,7 @@ def _xmap_translation_rule_replica(c, axis_env,
|
||||
# NOTE: We don't extend the resource env with the mesh shape, because those
|
||||
# resources are already in scope! It's the outermost xmap that introduces
|
||||
# them!
|
||||
vectorized_jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(f, local_avals)
|
||||
vectorized_jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, local_avals)
|
||||
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
|
||||
assert not consts
|
||||
|
||||
@ -1123,7 +1123,7 @@ def _xmap_translation_rule_spmd(c, axis_env,
|
||||
global_in_avals = [core.ShapedArray(xla_type.dimensions(), xla_type.numpy_dtype())
|
||||
for in_node in global_in_nodes
|
||||
for xla_type in (c.get_shape(in_node),)]
|
||||
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
|
||||
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(f, global_in_avals)
|
||||
assert not consts
|
||||
|
||||
global_sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)
|
||||
|
@ -242,7 +242,6 @@ class JaxprTrace(Trace):
|
||||
|
||||
# We use post_process_call to handle both call and map primitives.
|
||||
def post_process_call(self, primitive, out_tracers, params):
|
||||
|
||||
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
|
||||
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
|
||||
out = out_pv_consts + consts
|
||||
@ -322,8 +321,7 @@ class JaxprTrace(Trace):
|
||||
def jvp_jaxpr_thunk():
|
||||
jvp_ = trace_to_subjaxpr(jvp, self.main, True)
|
||||
jvp_, aux = partial_eval_wrapper(jvp_, tuple(in_avals) * 2)
|
||||
with core.new_sublevel():
|
||||
out_flat = jvp_.call_wrapped(*(in_consts * 2)) # in_consts are units
|
||||
out_flat = jvp_.call_wrapped(*(in_consts * 2)) # in_consts are units
|
||||
out_avals, jaxpr, env = aux()
|
||||
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
||||
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
|
||||
@ -362,8 +360,7 @@ class JaxprTrace(Trace):
|
||||
def fwd_jaxpr_thunk():
|
||||
fwd_ = trace_to_subjaxpr(fwd, self.main, True)
|
||||
fwd_, aux = partial_eval_wrapper(fwd_, tuple(in_avals))
|
||||
with core.new_sublevel():
|
||||
out_flat = fwd_.call_wrapped(*in_consts) # in_consts are units
|
||||
out_flat = fwd_.call_wrapped(*in_consts) # in_consts are units
|
||||
out_avals, jaxpr, env = aux()
|
||||
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
||||
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
|
||||
@ -1061,9 +1058,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
in_avals = [t.aval for t in tracers]
|
||||
with core.new_sublevel():
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
|
||||
if params.get('inline', False):
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
|
||||
if not jaxpr.eqns:
|
||||
return core.eval_jaxpr(jaxpr, consts, *tracers)
|
||||
source_info = source_info_util.current()
|
||||
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
||||
@ -1089,9 +1085,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
if in_axis is not None else a
|
||||
for a, in_axis in zip(in_avals, params['in_axes'])]
|
||||
with core.extend_axis_env(axis_name, axis_size, None): # type: ignore
|
||||
with core.new_sublevel():
|
||||
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, reduced_in_avals)
|
||||
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, reduced_in_avals)
|
||||
out_axes = params['out_axes_thunk']()
|
||||
out_avals = [core.unmapped_aval(params['axis_size'], out_axis, a)
|
||||
if out_axis is not None else a
|
||||
@ -1118,8 +1113,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
||||
in_avals = [t.aval for t in tracers]
|
||||
with core.new_sublevel():
|
||||
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
||||
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
||||
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
|
||||
jvp_jaxpr_thunk = _memoize(
|
||||
lambda: trace_to_subjaxpr_dynamic(jvp, self.main, 2 * in_avals)[::2])
|
||||
@ -1140,8 +1134,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
||||
in_avals = [t.aval for t in tracers]
|
||||
with core.new_sublevel():
|
||||
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
||||
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
||||
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
|
||||
fwd_jaxpr_thunk = _memoize(
|
||||
lambda: trace_to_subjaxpr_dynamic(fwd, self.main, in_avals)[::2])
|
||||
@ -1213,8 +1206,7 @@ def trace_to_jaxpr_final(fun: lu.WrappedFun,
|
||||
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
|
||||
main.source_info = fun_sourceinfo(fun.f, transform_name) # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
with core.new_sublevel():
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
||||
del fun, main
|
||||
return jaxpr, out_avals, consts
|
||||
|
||||
|
@ -218,10 +218,7 @@ def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool:
|
||||
|
||||
### op-by-op execution
|
||||
|
||||
|
||||
ArgSpec = Tuple[core.AbstractValue, Optional[Device]]
|
||||
|
||||
def arg_spec(x: Any) -> ArgSpec:
|
||||
def arg_spec(x):
|
||||
aval = abstractify(x)
|
||||
try:
|
||||
return aval, x._device
|
||||
@ -243,7 +240,8 @@ def _partition_outputs(avals, outs):
|
||||
|
||||
|
||||
@cache()
|
||||
def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
|
||||
def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue,
|
||||
Optional[Device]], **params):
|
||||
avals, arg_devices = unzip2(arg_specs)
|
||||
donated_invars = (False,) * len(arg_specs)
|
||||
device = _device_from_arg_devices(arg_devices)
|
||||
@ -575,9 +573,7 @@ def jaxpr_collectives(jaxpr):
|
||||
|
||||
### xla_call underlying jit
|
||||
|
||||
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
|
||||
donated_invars, inline):
|
||||
del inline # Only used at tracing time
|
||||
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
|
||||
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
|
||||
*unsafe_map(arg_spec, args))
|
||||
try:
|
||||
@ -595,8 +591,7 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
|
||||
# intentional here, to avoid "Store occupied" errors we reset the stores to
|
||||
# be empty.
|
||||
for store in fun.stores: store and store.reset()
|
||||
with core.new_sublevel():
|
||||
return fun.call_wrapped(*args) # probably won't return
|
||||
return fun.call_wrapped(*args) # probably won't return
|
||||
|
||||
def flatten_shape(s: XlaShape) -> Sequence[Tuple[Sequence[int], XlaShape]]:
|
||||
"""Expands a given shape tree into a flat list of indices to arrays.
|
||||
@ -697,8 +692,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
|
||||
c = xb.make_computation_builder("jit_{}".format(fun.__name__))
|
||||
xla_consts = _xla_consts(c, consts)
|
||||
xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args,
|
||||
donated_invars=donated_invars)
|
||||
xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args, donated_invars=donated_invars)
|
||||
out_nodes = jaxpr_subcomp(
|
||||
c, jaxpr, backend, AxisEnv(nreps, (), ()), xla_consts,
|
||||
extend_name_stack(wrap_name(name, 'jit')), *xla_args)
|
||||
@ -795,9 +789,9 @@ def _xla_callable_args(
|
||||
for (a, r, p) in safe_zip(avals, replicated, parts)
|
||||
for xla_shape in aval_to_xla_shapes(a)]
|
||||
if donated_invars is not None:
|
||||
donated_invars = [
|
||||
d for (a, _, _, d) in zip(avals, replicated, parts, donated_invars)
|
||||
for xla_shape in aval_to_xla_shapes(a)]
|
||||
donated_invars = [d
|
||||
for (a, r, p, d) in safe_zip(avals, replicated, parts, donated_invars)
|
||||
for xla_shape in aval_to_xla_shapes(a)]
|
||||
return xla_args, donated_invars
|
||||
else:
|
||||
if replicated is not None:
|
||||
@ -891,8 +885,8 @@ ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
|
||||
|
||||
def _xla_call_translation_rule(c, axis_env,
|
||||
in_nodes, name_stack, backend, name,
|
||||
call_jaxpr, donated_invars, inline, device=None):
|
||||
del device, donated_invars, inline # Ignored.
|
||||
call_jaxpr, donated_invars, device=None):
|
||||
del device, donated_invars # Ignored.
|
||||
subc = xb.make_computation_builder(f"jit_{name}")
|
||||
args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
|
||||
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
|
||||
|
@ -2343,7 +2343,7 @@ class APITest(jtu.JaxTestCase):
|
||||
lst.append(x)
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(Exception, r"Leaked"):
|
||||
with self.assertRaisesRegex(Exception, r"Leaked trace"):
|
||||
f(3)
|
||||
|
||||
def test_leak_checker_catches_a_pmap_leak(self):
|
||||
@ -2355,7 +2355,7 @@ class APITest(jtu.JaxTestCase):
|
||||
lst.append(x)
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(Exception, r"Leaked"):
|
||||
with self.assertRaisesRegex(Exception, r"Leaked trace"):
|
||||
f(np.ones(1))
|
||||
|
||||
def test_leak_checker_catches_a_grad_leak(self):
|
||||
@ -2678,21 +2678,6 @@ class APITest(jtu.JaxTestCase):
|
||||
jtu.check_grads(batched_scan_over_mul, (x_batch, coeff), order=2,
|
||||
modes=['rev'])
|
||||
|
||||
def test_jit_inline(self):
|
||||
@partial(api.jit, inline=False)
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
jaxpr = api.make_jaxpr(f)(3)
|
||||
self.assertIn('xla_call', str(jaxpr))
|
||||
|
||||
@partial(api.jit, inline=True)
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
jaxpr = api.make_jaxpr(f)(3)
|
||||
self.assertNotIn('xla_call', str(jaxpr))
|
||||
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user