mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
refactor call primitives, simpler param processing (#3491)
This commit is contained in:
parent
d5a5d301f2
commit
75278309aa
13
jax/api.py
13
jax/api.py
@ -1150,16 +1150,11 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
|
||||
for arg in args: _check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
out = pxla.xla_pmap(
|
||||
flat_fun,
|
||||
*args,
|
||||
backend=backend,
|
||||
axis_name=axis_name,
|
||||
axis_size=local_axis_size,
|
||||
global_axis_size=axis_size,
|
||||
devices=tuple(devices) if devices is not None else devices,
|
||||
name=flat_fun.__name__,
|
||||
flat_fun, *args, backend=backend, axis_name=axis_name,
|
||||
axis_size=local_axis_size, global_axis_size=axis_size,
|
||||
devices=None if devices is None else tuple(devices),
|
||||
mapped_invars=tuple(axis is not None for axis in in_axes_flat),
|
||||
donated_invars=tuple(donated_invars))
|
||||
name=flat_fun.__name__, donated_invars=tuple(donated_invars))
|
||||
return tree_unflatten(out_tree(), out)
|
||||
|
||||
return f_pmapped
|
||||
|
64
jax/core.py
64
jax/core.py
@ -1062,7 +1062,7 @@ def canonicalize_shape(shape):
|
||||
raise TypeError(msg.format(shape))
|
||||
|
||||
|
||||
# ------------------- Call and map -------------------
|
||||
# ------------------- Call -------------------
|
||||
|
||||
def apply_todos(todos, outs):
|
||||
todos_list = list(todos)
|
||||
@ -1071,7 +1071,7 @@ def apply_todos(todos, outs):
|
||||
return outs
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def process_env_traces(post_processor: str, primitive: Primitive,
|
||||
def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'],
|
||||
level: int, params_tuple: tuple, *args):
|
||||
outs = yield args, {}
|
||||
params = dict(params_tuple)
|
||||
@ -1084,41 +1084,59 @@ def process_env_traces(post_processor: str, primitive: Primitive,
|
||||
break
|
||||
trace = type(ans._trace)(ans._trace.master, cur_sublevel())
|
||||
outs = map(trace.full_raise, outs)
|
||||
post_process = getattr(trace, post_processor)
|
||||
outs, cur_todo = post_process(primitive, outs, params)
|
||||
outs, cur_todo = primitive.post_process(trace, outs, params)
|
||||
todo.append(cur_todo)
|
||||
yield outs, tuple(todo) # Ensure the aux output is immutable
|
||||
|
||||
def _call_bind(processor: str, post_processor: str, primitive: Primitive,
|
||||
f: lu.WrappedFun, *args, **params):
|
||||
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
|
||||
fun: lu.WrappedFun, *args, **params):
|
||||
params_tuple = tuple(params.items())
|
||||
top_trace = find_top_trace(args)
|
||||
level = trace_state.trace_stack.next_level(True) if top_trace is None else top_trace.level
|
||||
params_tuple = tuple(params.items())
|
||||
f, env_trace_todo = process_env_traces(f, post_processor, primitive, level, params_tuple)
|
||||
fun, env_trace_todo = process_env_traces(fun, primitive, level, params_tuple)
|
||||
if top_trace is None:
|
||||
with new_sublevel():
|
||||
outs = primitive.impl(f, *args, **params)
|
||||
outs = primitive.impl(fun, *args, **params)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
process = getattr(top_trace, processor)
|
||||
outs = map(full_lower, process(primitive, f, tracers, params))
|
||||
outs = primitive.process(top_trace, fun, tracers, params)
|
||||
return apply_todos(env_trace_todo(), outs)
|
||||
|
||||
call_bind = partial(_call_bind, 'process_call', 'post_process_call')
|
||||
map_bind = partial(_call_bind, 'process_map', 'post_process_map')
|
||||
class CallPrimitive(Primitive):
|
||||
multiple_results = True
|
||||
call_primitive = True
|
||||
bind = call_bind
|
||||
|
||||
def process(self, trace, fun, tracers, params):
|
||||
return trace.process_call(self, fun, tracers, params)
|
||||
|
||||
def post_process(self, trace, out_tracers, params):
|
||||
return trace.post_process_call(self, out_tracers, params)
|
||||
|
||||
def call_impl(f: lu.WrappedFun, *args, **params):
|
||||
del params # params parameterize the call primitive, not the function
|
||||
return f.call_wrapped(*args)
|
||||
|
||||
call_p = Primitive('call')
|
||||
call_p.multiple_results = True
|
||||
call_p.call_primitive = True
|
||||
call = partial(call_bind, call_p)
|
||||
call_p.def_custom_bind(call)
|
||||
call_p = CallPrimitive('call')
|
||||
call = call_p.bind
|
||||
call_p.def_impl(call_impl)
|
||||
|
||||
# ------------------- Map -------------------
|
||||
|
||||
class MapPrimitive(Primitive):
|
||||
multiple_results = True
|
||||
map_primitive = True
|
||||
|
||||
def bind(self, fun, *args, **params):
|
||||
assert len(params['mapped_invars']) == len(args)
|
||||
return call_bind(self, fun, *args, **params)
|
||||
|
||||
def process(self, trace, fun, tracers, params):
|
||||
return trace.process_map(self, fun, tracers, params)
|
||||
|
||||
def post_process(self, trace, out_tracers, params):
|
||||
return trace.post_process_map(self, out_tracers, params)
|
||||
|
||||
# ------------------- Jaxpr checking -------------------
|
||||
|
||||
@ -1168,14 +1186,13 @@ def check_jaxpr(jaxpr: Jaxpr):
|
||||
try:
|
||||
_check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])
|
||||
except Exception as e:
|
||||
exception_type = type(e)
|
||||
msg_context = f"while checking jaxpr:\n\n{jaxpr}\n"
|
||||
if len(e.args) == 0:
|
||||
exception_args = [msg_context]
|
||||
else:
|
||||
msg = f"{e.args[0]}\n\n" + msg_context
|
||||
msg = f"{e.args[0]}\n\n{msg_context}"
|
||||
exception_args = [msg, *e.args[1:]]
|
||||
raise exception_type(*exception_args) from e
|
||||
raise type(e)(*exception_args) from e
|
||||
|
||||
def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
|
||||
|
||||
@ -1203,6 +1220,11 @@ def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
|
||||
map(write, jaxpr.invars, in_avals)
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
|
||||
if eqn.primitive in skip_check_primitives:
|
||||
map(write, eqn.outvars, [v.aval for v in eqn.outvars]) # skip checking
|
||||
continue
|
||||
|
||||
in_avals = map(read, eqn.invars)
|
||||
if eqn.primitive.call_primitive:
|
||||
out_avals = check_call(eqn.primitive, in_avals, eqn.params)
|
||||
@ -1218,6 +1240,8 @@ def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
|
||||
|
||||
map(read, jaxpr.outvars)
|
||||
|
||||
skip_check_primitives: Set[Primitive] = set()
|
||||
|
||||
def check_eqn(prim, in_avals, params):
|
||||
for jaxpr in jaxprs_in_params(params):
|
||||
check_jaxpr(jaxpr)
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial, update_wrapper, reduce
|
||||
from functools import update_wrapper, reduce
|
||||
import inspect
|
||||
import operator as op
|
||||
|
||||
@ -262,25 +262,30 @@ def _flatten_jvp(in_tree, *args):
|
||||
raise TypeError(msg.format('\n'.join(disagreements)))
|
||||
yield primals_out + tangents_out, out_tree
|
||||
|
||||
def _custom_jvp_call_bind(prim, fun, jvp, *args):
|
||||
args = map(core.full_lower, args)
|
||||
top_trace = core.find_top_trace(args)
|
||||
if top_trace is None:
|
||||
with core.new_sublevel():
|
||||
outs = prim.impl(fun, jvp, *args)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = top_trace.process_custom_jvp_call(prim, fun, jvp, tracers)
|
||||
return map(core.full_lower, outs)
|
||||
class CustomJVPCallPrimitive(core.CallPrimitive):
|
||||
def bind(self, fun, jvp, *args):
|
||||
args = map(core.full_lower, args)
|
||||
top_trace = core.find_top_trace(args)
|
||||
fun, env_trace_todo1 = core.process_env_traces(
|
||||
fun, self, top_trace and top_trace.level, ())
|
||||
jvp, env_trace_todo2 = core.process_env_traces(
|
||||
jvp, self, top_trace and top_trace.level, ())
|
||||
if top_trace is None:
|
||||
with core.new_sublevel():
|
||||
outs = self.impl(fun, jvp, *args)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)
|
||||
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
|
||||
if env_trace_todo:
|
||||
raise core.UnexpectedTracerError
|
||||
return map(core.full_lower, outs)
|
||||
|
||||
def _custom_jvp_call_impl(fun, _, *args):
|
||||
return fun.call_wrapped(*args)
|
||||
def impl(self, fun, _, *args):
|
||||
return fun.call_wrapped(*args)
|
||||
|
||||
custom_jvp_call_p = core.Primitive('custom_jvp_call')
|
||||
custom_jvp_call_p.multiple_results = True
|
||||
custom_jvp_call = partial(_custom_jvp_call_bind, custom_jvp_call_p)
|
||||
custom_jvp_call_p.def_custom_bind(custom_jvp_call)
|
||||
custom_jvp_call_p.def_impl(_custom_jvp_call_impl)
|
||||
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')
|
||||
custom_jvp_call = custom_jvp_call_p.bind
|
||||
|
||||
|
||||
def custom_jvp_call_jaxpr(fun, jvp, *args):
|
||||
@ -501,28 +506,25 @@ def _flatten_bwd(in_tree, out_trees, *args):
|
||||
raise TypeError(msg.format(in_tree2, in_tree)) from None
|
||||
yield cts_in
|
||||
|
||||
def _custom_vjp_call_bind(prim, fun, fwd, bwd, *args, out_trees):
|
||||
args = map(core.full_lower, args)
|
||||
top_trace = core.find_top_trace(args)
|
||||
if top_trace is None:
|
||||
with core.new_sublevel():
|
||||
outs = prim.impl(fun, fwd, bwd, *args, out_trees=out_trees)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = top_trace.process_custom_vjp_call(prim, fun, fwd, bwd, tracers,
|
||||
out_trees=out_trees)
|
||||
outs = map(core.full_lower, outs)
|
||||
return map(core.full_lower, outs)
|
||||
|
||||
def _custom_vjp_call_impl(fun, fwd, bwd, *args, out_trees):
|
||||
del fwd, bwd, out_trees # Unused.
|
||||
return fun.call_wrapped(*args)
|
||||
class CustomVJPCallPrimitive(core.CallPrimitive):
|
||||
def bind(self, fun, fwd, bwd, *args, out_trees):
|
||||
args = map(core.full_lower, args)
|
||||
top_trace = core.find_top_trace(args)
|
||||
if top_trace is None:
|
||||
outs = fun.call_wrapped(*args)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
|
||||
out_trees=out_trees)
|
||||
return map(core.full_lower, outs)
|
||||
|
||||
custom_vjp_call_p = core.Primitive('custom_vjp_call')
|
||||
custom_vjp_call_p.multiple_results = True
|
||||
custom_vjp_call = partial(_custom_vjp_call_bind, custom_vjp_call_p)
|
||||
custom_vjp_call_p.def_custom_bind(custom_vjp_call)
|
||||
custom_vjp_call_p.def_impl(_custom_vjp_call_impl)
|
||||
def impl(self, fun, fwd, bwd, *args, out_trees):
|
||||
del fwd, bwd, out_trees
|
||||
return fun.call_wrapped(*args)
|
||||
|
||||
custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call')
|
||||
custom_vjp_call = custom_vjp_call_p.bind
|
||||
|
||||
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees):
|
||||
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
|
||||
|
@ -632,7 +632,8 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn,
|
||||
pred1_and_token1,
|
||||
xla.xla_call_p,
|
||||
dict(call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
||||
name="cond_before"),
|
||||
name="cond_before",
|
||||
donated_invars=(False,) * (cond_nconsts + len(carry_invars) + 1)),
|
||||
eqn.source_info))
|
||||
# Make a new cond "lambda pred, carry, token: pred"
|
||||
new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
|
||||
@ -667,14 +668,19 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn,
|
||||
new_body_carry2 + [new_body_token2],
|
||||
xla.xla_call_p,
|
||||
dict(call_jaxpr=transformed_body_jaxpr.jaxpr,
|
||||
name="body"),
|
||||
name="body",
|
||||
donated_invars=(False,) * (len(new_body_invars_body_constvars) +
|
||||
len(new_body_invars_carry) +
|
||||
1 + len(new_body_carry2) + 1)),
|
||||
eqn.source_info),
|
||||
core.new_jaxpr_eqn(
|
||||
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2],
|
||||
[new_body_pred2, new_body_token3],
|
||||
xla.xla_call_p,
|
||||
dict(call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
||||
name="cond_body"),
|
||||
name="cond_body",
|
||||
donated_invars=(False,) * (len(new_body_invars_cond_constvars) +
|
||||
len(new_body_carry2) + 1 + 2)),
|
||||
eqn.source_info)
|
||||
]
|
||||
new_body_jaxpr = _mk_typed_jaxpr(
|
||||
|
@ -243,12 +243,10 @@ class JVPTrace(Trace):
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
|
||||
try:
|
||||
jvp = primitive_jvps[primitive]
|
||||
except KeyError as err:
|
||||
raise NotImplementedError(
|
||||
"Forward-mode differentiation rule for '{}' not implemented"
|
||||
.format(primitive)) from err
|
||||
jvp = primitive_jvps.get(primitive)
|
||||
if not jvp:
|
||||
msg = f"Differentiation rule for '{primitive}' not implemented"
|
||||
raise NotImplementedError(msg)
|
||||
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
|
||||
if primitive.multiple_results:
|
||||
return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
|
||||
@ -258,52 +256,35 @@ class JVPTrace(Trace):
|
||||
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
|
||||
assert call_primitive.multiple_results
|
||||
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
|
||||
nonzero_tangents, in_tree_def = tree_flatten(tangents)
|
||||
nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
|
||||
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
|
||||
len(primals), in_tree_def)
|
||||
name = params.get('name', f.__name__)
|
||||
new_params = dict(params, name=wrap_name(name, 'jvp'))
|
||||
if 'donated_invars' in new_params:
|
||||
new_donated_invars = (*params['donated_invars'],
|
||||
*[m for m, t in zip(params['donated_invars'], tangents)
|
||||
if type(t) is not Zero])
|
||||
new_params['donated_invars'] = tuple(new_donated_invars)
|
||||
len(primals), tangent_tree_def)
|
||||
nz_tangents = [type(t) is not Zero for t in tangents]
|
||||
params = dict(params, name=wrap_name(params['name'], 'jvp'))
|
||||
if isinstance(call_primitive, core.MapPrimitive):
|
||||
mapped_invars = params['mapped_invars']
|
||||
mapped_tangents = [m for m, nz in zip(mapped_invars, nz_tangents) if nz]
|
||||
params = dict(params, mapped_invars=(*mapped_invars, *mapped_tangents))
|
||||
update_params = call_param_updaters.get(call_primitive)
|
||||
new_params = update_params(params, nz_tangents) if update_params else params
|
||||
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
|
||||
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
|
||||
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
|
||||
|
||||
def post_process_call(self, call_primitive, out_tracers, params):
|
||||
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
|
||||
out = primals + tangents
|
||||
out, treedef = tree_flatten((primals, tangents))
|
||||
del primals, tangents
|
||||
master = self.master
|
||||
def todo(x):
|
||||
n = len(x) // 2
|
||||
primals, tangents = x[:n], x[n:]
|
||||
primals, tangents = tree_unflatten(treedef, x)
|
||||
trace = JVPTrace(master, core.cur_sublevel())
|
||||
return map(partial(JVPTracer, trace), primals, tangents)
|
||||
return out, todo
|
||||
|
||||
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
|
||||
# only differs from process_call in that it must update mapped_invars
|
||||
# TODO de-duplicate code
|
||||
assert map_primitive.multiple_results
|
||||
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
|
||||
nonzero_tangents, in_tree_def = tree_flatten(tangents)
|
||||
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
|
||||
len(primals), in_tree_def)
|
||||
new_name = wrap_name(params.get('name', f.__name__), 'jvp')
|
||||
new_mapped_invars = (*params['mapped_invars'],
|
||||
*[m for m, t in zip(params['mapped_invars'], tangents)
|
||||
if type(t) is not Zero])
|
||||
new_donated_invars = (*params['donated_invars'],
|
||||
*[m for m, t in zip(params['donated_invars'], tangents)
|
||||
if type(t) is not Zero])
|
||||
new_params = dict(params, name=new_name, mapped_invars=new_mapped_invars,
|
||||
donated_invars=new_donated_invars)
|
||||
result = map_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
|
||||
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
|
||||
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
|
||||
# The only difference between process_map and process_call is that
|
||||
# the `mapped_invars` param must be updated; that's handled in process_call.
|
||||
process_map = process_call
|
||||
post_process_map = post_process_call
|
||||
|
||||
def process_custom_jvp_call(self, _, __, f_jvp, tracers):
|
||||
@ -363,11 +344,14 @@ def _primal_tangent_shapes_match(primal, tangent):
|
||||
if type(tangent) is not Zero:
|
||||
primal_aval = raise_to_shaped(get_aval(primal))
|
||||
tangent_aval = raise_to_shaped(get_aval(tangent))
|
||||
assert primal_aval == tangent_aval
|
||||
assert primal_aval == tangent_aval, (primal_aval, tangent_aval)
|
||||
|
||||
call_param_updaters: Dict[core.Primitive, Callable] = {}
|
||||
call_transpose_param_updaters: Dict[core.Primitive, Callable] = {}
|
||||
|
||||
|
||||
# -------------------- Primitives --------------------
|
||||
|
||||
|
||||
primitive_jvps : Dict[core.Primitive, Callable] = {}
|
||||
|
||||
primitive_transposes: Dict[core.Primitive, Callable] = {}
|
||||
@ -492,13 +476,12 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
params = dict(params, name=wrap_name(params['name'], 'transpose'))
|
||||
if 'donated_invars' in params:
|
||||
new_donated_invars = (*[d for d, x in zip(params['donated_invars'], args)
|
||||
if not is_undefined_primal(x)],
|
||||
*[False for x in ct if type(x) is not Zero])
|
||||
params['donated_invars'] = tuple(new_donated_invars)
|
||||
out_flat = primitive.bind(fun, *all_args, **params)
|
||||
new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
|
||||
update_params = call_transpose_param_updaters.get(primitive)
|
||||
if update_params:
|
||||
new_params = update_params(new_params, map(is_undefined_primal, args),
|
||||
[type(x) is not Zero for x in ct])
|
||||
out_flat = primitive.bind(fun, *all_args, **new_params)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
|
||||
|
||||
@ -507,14 +490,12 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_
|
||||
# backward_pass can only transpose linear computations, but the call_jaxpr embedded in
|
||||
# remat contains primal (non-linear) equations too. Hence, we have to eliminate those
|
||||
# (in this case via partial_eval) before we call into backward_pass again.
|
||||
typed_call_jaxpr = core.TypedJaxpr(
|
||||
call_jaxpr, [],
|
||||
[raise_to_shaped(p.aval if is_undefined_primal(p) else get_aval(p)) for p in primals_in],
|
||||
cotangent_in_avals)
|
||||
in_avals = [raise_to_shaped(p.aval if is_undefined_primal(p) else get_aval(p))
|
||||
for p in primals_in]
|
||||
typed_call_jaxpr = core.TypedJaxpr(call_jaxpr, [], in_avals, cotangent_in_avals)
|
||||
unknowns = map(is_undefined_primal, primals_in)
|
||||
primal_jaxpr, tangent_jaxpr, out_unknowns = \
|
||||
pe.partial_eval_jaxpr(typed_call_jaxpr,
|
||||
unknowns=map(is_undefined_primal, primals_in),
|
||||
instantiate=True,
|
||||
pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True,
|
||||
trace_type=None)
|
||||
|
||||
def do_transpose(primals_in, cotangents_in):
|
||||
@ -541,12 +522,12 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
new_mapped_invars = (*[m for m, x in zip(params['mapped_invars'], args)
|
||||
if not is_undefined_primal(x)],
|
||||
*[True for x in ct if type(x) is not Zero])
|
||||
new_donated_invars = (*[d for d, x in zip(params['donated_invars'], args)
|
||||
if not is_undefined_primal(x)],
|
||||
*[False for x in ct if type(x) is not Zero])
|
||||
new_params = dict(params, name=wrap_name(params['name'], 'transpose'),
|
||||
mapped_invars=tuple(new_mapped_invars),
|
||||
donated_invars=tuple(new_donated_invars))
|
||||
mapped_invars=new_mapped_invars)
|
||||
update_params = call_transpose_param_updaters.get(primitive)
|
||||
if update_params:
|
||||
new_params = update_params(new_params, map(is_undefined_primal, args),
|
||||
[type(x) is not Zero for x in ct])
|
||||
out_flat = primitive.bind(fun, *all_args, **new_params)
|
||||
arg_cts = tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
|
@ -36,12 +36,9 @@ zip = safe_zip
|
||||
# Reverse call primitive
|
||||
################################################################################
|
||||
|
||||
invertible_call_p = core.Primitive('invertible_call')
|
||||
invertible_call_p.call_primitive = True
|
||||
invertible_call = partial(core.call_bind, invertible_call_p)
|
||||
invertible_call_p.def_custom_bind(invertible_call)
|
||||
invertible_call_p = core.CallPrimitive('invertible_call')
|
||||
invertible_call = invertible_call_p.bind
|
||||
invertible_call_p.def_impl(core.call_impl)
|
||||
invertible_call_p.multiple_results = True
|
||||
|
||||
def _invertible_call_make_output_tracers(trace, in_tracers, out_tracers, params):
|
||||
uks = [not t.pval.is_known() for t in out_tracers]
|
||||
@ -67,6 +64,8 @@ def _invertible_call_make_output_tracers(trace, in_tracers, out_tracers, params)
|
||||
pe.call_partial_eval_rules[invertible_call_p] = partial(
|
||||
pe._remat_partial_eval, _invertible_call_make_output_tracers)
|
||||
|
||||
# TODO(mattjj): remove this when #3370 lands
|
||||
core.skip_check_primitives.add(invertible_call_p)
|
||||
|
||||
@cache()
|
||||
def _append_invars(jaxpr, avals):
|
||||
@ -259,10 +258,8 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
|
||||
|
||||
in_avals = map(abstract, primals_in + primals_out + primals_out)
|
||||
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
|
||||
complete_ivjp_flat,
|
||||
map(PartialVal.unknown, in_avals),
|
||||
instantiate=True,
|
||||
stage_out=False)
|
||||
complete_ivjp_flat, map(PartialVal.unknown, in_avals),
|
||||
instantiate=True, stage_out=False)
|
||||
assert not ivjp_jaxpr.constvars # That might happen some time, but don't bother until then
|
||||
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
|
||||
ivjp_jaxpr = core.TypedJaxpr(ivjp_jaxpr, [], in_avals, out_avals)
|
||||
@ -273,10 +270,8 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
|
||||
unknowns = (map(ad.is_undefined_primal, primals_in) +
|
||||
map(ad.is_undefined_primal, primals_out) +
|
||||
[False] * len(cts_in))
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(ivjp_jaxpr,
|
||||
unknowns,
|
||||
instantiate=False,
|
||||
trace_type=None)
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
|
||||
ivjp_jaxpr, unknowns, instantiate=False, trace_type=None)
|
||||
unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs])
|
||||
# Make sure we're able to compute all cotangents. We don't really care if we
|
||||
# can reconstruct or primals or not, although failure to do so might result in
|
||||
|
@ -27,7 +27,7 @@ from .. import linear_util as lu
|
||||
from ..abstract_arrays import ConcreteArray, raise_to_shaped
|
||||
from ..ad_util import Zero
|
||||
from ..util import (unzip2, safe_zip, safe_map, toposort, partial, split_list,
|
||||
cache, curry)
|
||||
cache)
|
||||
from ..core import (Trace, Tracer, new_master, Jaxpr, Literal, get_aval,
|
||||
AbstractValue, unit, unitvar, abstract_unit,
|
||||
TypedJaxpr, new_jaxpr_eqn)
|
||||
@ -73,7 +73,7 @@ class PartialVal(tuple):
|
||||
return self[1] if self[0] is None else None
|
||||
|
||||
def get_aval(self) -> AbstractValue:
|
||||
"""Get the AbstractValue either directly for unknown values, or from the known constant."""
|
||||
"""Get AbstractValue directly (if unknown) or from the constant (known)."""
|
||||
known = self.get_known()
|
||||
if known is not None:
|
||||
return get_aval(known)
|
||||
@ -170,35 +170,29 @@ class JaxprTrace(Trace):
|
||||
return out_tracer
|
||||
|
||||
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
|
||||
name = params.get('name', f.__name__)
|
||||
if (self.master.trace_type is StagingJaxprTrace
|
||||
and primitive in staged_out_calls):
|
||||
tracers = map(self.instantiate_const_abstracted, tracers)
|
||||
params = dict(params, name=name)
|
||||
|
||||
if primitive in call_partial_eval_rules:
|
||||
return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
|
||||
|
||||
@curry
|
||||
def modify_aval(modify, args):
|
||||
pval, is_mapped = args
|
||||
if pval.is_known() or not is_mapped:
|
||||
return pval
|
||||
return PartialVal((modify(params['axis_size'], pval[0]), pval[1]))
|
||||
|
||||
in_pvals = [t.pval for t in tracers]
|
||||
if primitive.map_primitive:
|
||||
in_pvals = map(modify_aval(core.mapped_aval), zip(in_pvals, params['mapped_invars']))
|
||||
mapped_aval = partial(core.mapped_aval, params['axis_size'])
|
||||
in_pvals = [pval if pval.is_known() or not is_mapped
|
||||
else PartialVal.unknown(mapped_aval(pval[0]))
|
||||
for pval, is_mapped in zip(in_pvals, params['mapped_invars'])]
|
||||
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
|
||||
f, in_pvals, partial(primitive.bind, **params))
|
||||
if primitive.map_primitive:
|
||||
out_pvals = map(modify_aval(core.unmapped_aval),
|
||||
[(pval, True) for pval in out_pvals])
|
||||
unmapped_aval = partial(core.unmapped_aval, params['axis_size'])
|
||||
out_pvals = [pval if pval.is_known()
|
||||
else PartialVal.unknown(unmapped_aval(pval[0]))
|
||||
for pval in out_pvals]
|
||||
|
||||
# Don't bother if the traced jaxpr is trivial. Simply evaluate it in here.
|
||||
# XXX: We don't allow this fast path for map primitives, because this simplification might
|
||||
# e.g. reduce the number of required devices if someone pmaps an identity function.
|
||||
if not primitive.map_primitive and not jaxpr.eqns:
|
||||
# Avoid staging out trivial calls, but maps may involve broadcasting.
|
||||
if not jaxpr.eqns and not primitive.map_primitive:
|
||||
env = {core.unitvar: core.unit}
|
||||
map(env.setdefault, jaxpr.invars, (*env_tracers, *tracers))
|
||||
map(env.setdefault, jaxpr.constvars, consts)
|
||||
@ -212,94 +206,76 @@ class JaxprTrace(Trace):
|
||||
out_unknowns = tuple(not pval.is_known() for pval in out_pvals)
|
||||
jaxpr = _drop_invars(jaxpr, in_knowns)
|
||||
jaxpr = _dce_untyped_jaxpr(jaxpr, out_unknowns, drop_outputs=True)
|
||||
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
|
||||
|
||||
# Known tracers get propagated as if they were constants
|
||||
known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals if pval.is_known()]
|
||||
known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals
|
||||
if pval.is_known()]
|
||||
|
||||
# Unknown tracers need to have the jaxpr set up as their recipe
|
||||
unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals if not pval.is_known()]
|
||||
unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals
|
||||
if not pval.is_known()]
|
||||
unknown_tracers_in = [t for t in tracers if not t.pval.is_known()]
|
||||
const_tracers = map(self.new_instantiated_const, consts)
|
||||
new_params = dict(params, call_jaxpr=lifted_jaxpr)
|
||||
if 'donated_invars' in params:
|
||||
new_donated_invars = ((False,) * len(const_tracers) +
|
||||
(False,) * len(env_tracers) +
|
||||
tuple(v for v, t in zip(params['donated_invars'], tracers) if not t.pval.is_known()))
|
||||
new_params['donated_invars'] = new_donated_invars
|
||||
in_tracers = (*const_tracers, *env_tracers, *unknown_tracers_in)
|
||||
|
||||
# Set up new params
|
||||
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
||||
if primitive.map_primitive:
|
||||
mapped_invars = params['mapped_invars']
|
||||
new_mapped_invars = ((True,) * len(const_tracers) +
|
||||
(False,) * len(env_tracers) +
|
||||
tuple(v for v, t in zip(params['mapped_invars'], tracers) if not t.pval.is_known()))
|
||||
new_params['mapped_invars'] = new_mapped_invars
|
||||
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers, unknown_tracers_in)),
|
||||
unknown_tracers_out, primitive, new_params,
|
||||
source_info_util.current())
|
||||
for t in unknown_tracers_out:
|
||||
t.recipe = eqn
|
||||
tuple(v for v, t in zip(mapped_invars, tracers)
|
||||
if not t.pval.is_known()))
|
||||
new_params = dict(new_params, mapped_invars=new_mapped_invars)
|
||||
update_params = call_param_updaters.get(primitive)
|
||||
if update_params:
|
||||
new_params = update_params(new_params, [not t.pval.is_known() for t in tracers])
|
||||
|
||||
eqn = new_eqn_recipe(in_tracers, unknown_tracers_out, primitive, new_params,
|
||||
source_info_util.current())
|
||||
for t in unknown_tracers_out: t.recipe = eqn
|
||||
return _zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
|
||||
|
||||
def post_process_call(self, call_primitive, out_tracers, params):
|
||||
process_map = process_call
|
||||
|
||||
# We use post_process_call to handle both call and map primitives.
|
||||
def post_process_call(self, primitive, out_tracers, params):
|
||||
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
|
||||
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
|
||||
out = out_pv_consts + consts
|
||||
del consts, out_pv_consts
|
||||
master = self.master
|
||||
|
||||
if primitive.map_primitive:
|
||||
sz = params['axis_size']
|
||||
out_pvs = [None if pv is None else core.unmapped_aval(sz, pv)
|
||||
for pv in out_pvs]
|
||||
|
||||
def todo(x):
|
||||
n = len(jaxpr.outvars)
|
||||
out_pv_consts, consts = x[:n], x[n:]
|
||||
trace = JaxprTrace(master, core.cur_sublevel())
|
||||
const_tracers = map(trace.new_instantiated_const, consts)
|
||||
env_tracers = map(trace.full_raise, env)
|
||||
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
|
||||
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
|
||||
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
|
||||
invars = tuple(it.chain(const_tracers, env_tracers))
|
||||
new_params = dict(params, call_jaxpr=lifted_jaxpr)
|
||||
if 'donated_invars' in params:
|
||||
new_params['donated_invars'] = (False,) * len(invars)
|
||||
# The `jaxpr` already contains the env_vars at start of invars
|
||||
eqn = new_eqn_recipe(invars, out_tracers, call_primitive, new_params,
|
||||
in_tracers = (*const_tracers, *map(trace.full_raise, env))
|
||||
|
||||
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
||||
if primitive.map_primitive:
|
||||
new_mapped_invars = (True,) * len(const_tracers) + (False,) * len(env)
|
||||
new_params = dict(new_params, mapped_invars=new_mapped_invars)
|
||||
update_params = call_param_updaters.get(primitive)
|
||||
if update_params:
|
||||
new_params = update_params(new_params, [])
|
||||
|
||||
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
|
||||
source_info_util.current())
|
||||
for t in out_tracers:
|
||||
t.recipe = eqn
|
||||
return out_tracers
|
||||
return out, todo
|
||||
|
||||
process_map = process_call
|
||||
|
||||
def post_process_map(self, map_primitive, out_tracers, params):
|
||||
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
|
||||
out_pvs_reduced, out_pv_consts = unzip2(t.pval for t in out_tracers)
|
||||
out_pvs = [None if pv is None
|
||||
else core.unmapped_aval(params['axis_size'], pv)
|
||||
for pv in out_pvs_reduced]
|
||||
out = out_pv_consts + consts
|
||||
del consts, out_pv_consts
|
||||
master = self.master
|
||||
def todo(x):
|
||||
n = len(jaxpr.outvars)
|
||||
out_pv_consts, consts = x[:n], x[n:]
|
||||
trace = JaxprTrace(master, core.cur_sublevel())
|
||||
const_tracers = map(trace.new_instantiated_const, consts)
|
||||
# The `jaxpr` already contains the env_vars at start of invars
|
||||
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
|
||||
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
|
||||
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
|
||||
new_donated_invars = (False,) * (len(const_tracers) + len(env))
|
||||
new_mapped_invars = (True,) * len(const_tracers) + (False,) * len(env)
|
||||
new_params = dict(params, donated_invars=tuple(new_donated_invars),
|
||||
mapped_invars=tuple(new_mapped_invars),
|
||||
call_jaxpr=lifted_jaxpr)
|
||||
env_tracers = map(trace.full_raise, env)
|
||||
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers)),
|
||||
out_tracers, map_primitive, new_params,
|
||||
source_info_util.current())
|
||||
for t in out_tracers:
|
||||
t.recipe = eqn
|
||||
return out_tracers
|
||||
return out, todo
|
||||
post_process_map = post_process_call
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
||||
# See comment at top of `JaxprTrace`. This method should be reachable
|
||||
@ -333,23 +309,24 @@ class StagingJaxprTrace(JaxprTrace):
|
||||
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def partial_eval_wrapper(avals: Sequence[Optional[AbstractValue]], *consts):
|
||||
py_args = (map(PartialVal, zip(avals, consts)),)
|
||||
jaxpr, (out_pvals, consts, env) = yield py_args, {}
|
||||
def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts):
|
||||
py_args = map(PartialVal, zip(pvs, consts))
|
||||
jaxpr, (out_pvals, consts, env) = yield (py_args,), {}
|
||||
out_pvs, out_consts = unzip2(out_pvals)
|
||||
out = tuple(out_consts) + tuple(consts) # TODO: can consts be traced?
|
||||
out = tuple(out_consts) + tuple(consts)
|
||||
yield out, (out_pvs, jaxpr, env)
|
||||
|
||||
|
||||
custom_partial_eval_rules: Dict[core.Primitive, Callable] = {}
|
||||
call_partial_eval_rules: Dict[core.Primitive, Callable] = {}
|
||||
staged_out_calls: Set[core.Primitive] = set()
|
||||
call_param_updaters: Dict[core.Primitive, Callable] = {}
|
||||
|
||||
|
||||
def abstract_eval_fun(fun, *avals, **params):
|
||||
pvals_in = [PartialVal.unknown(a) for a in avals]
|
||||
_, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
|
||||
instantiate=True, stage_out=True)
|
||||
instantiate=True, stage_out=True)
|
||||
avals_out, _ = unzip2(pvals_out)
|
||||
for aval_out in avals_out:
|
||||
assert isinstance(aval_out, AbstractValue) # instantiate=True
|
||||
@ -406,15 +383,12 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
-> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
|
||||
"""Traces a function into a Jaxpr, given PartialVals for inputs.
|
||||
|
||||
`trace_type` can be one of `StagingJaxprTrace` or `JaxprTrace` (see
|
||||
comments for that class).
|
||||
|
||||
Returns (`jaxpr`, `out_pvals`, `consts`).
|
||||
The `jaxpr` contains only the computation that depends on unknown inputs.
|
||||
The `out_pvals` are the PartialVal for the outputs. The intermediate
|
||||
values that depend only on known inputs and are needed to compute the output
|
||||
of `jaxpr` are in `consts` and are passed in as the constvars of
|
||||
the `jaxpr`. The handling of the known outputs depends on `instantiate`.
|
||||
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
|
||||
computation that depends on unknown inputs. The `out_pvals` are the PartialVal
|
||||
for the outputs. The intermediate values that depend only on known inputs and
|
||||
are needed to compute the output of `jaxpr` are in `consts` and are passed in
|
||||
as the constvars of the `jaxpr`. The handling of the known outputs depends on
|
||||
`instantiate`.
|
||||
|
||||
For example, given `fun` defined as follows::
|
||||
|
||||
@ -425,11 +399,11 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
|
||||
with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
|
||||
computation that depends on unknown inputs is `ui + ka` and will be the only
|
||||
computation in the body of the `jaxpr`. This computation depends on the
|
||||
known intermediate value `ka`, which will be computed statically. Currently,
|
||||
such constants are either embedded in the Jaxpr if they are scalars, or
|
||||
passed as a constvar to `jaxpr`, and then the value of the actual constant
|
||||
will be in `consts`:
|
||||
computation in the body of the `jaxpr`. This computation depends on the known
|
||||
intermediate value `ka`, which will be computed statically. Currently, such
|
||||
constants are either embedded in the Jaxpr if they are scalars, or passed as a
|
||||
constvar to `jaxpr`, and then the value of the actual constant will be in
|
||||
`consts`:
|
||||
|
||||
When `instantiate=False` we get::
|
||||
|
||||
@ -437,7 +411,7 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
{ lambda ka ; ki ui.
|
||||
let c = add ui ka
|
||||
in (*, c) } # known outputs are `*`
|
||||
out_pvals = [known(6), unknown(ShapedArray)] # the known outputs are known PartialVal
|
||||
out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
|
||||
consts = [3] # the constant for `ka`
|
||||
|
||||
When `instantiate=True` we get::
|
||||
@ -446,7 +420,7 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
{ lambda ka kb ; ki ui.
|
||||
let c = add ui ka
|
||||
in (kb, c) } # known output are explicit
|
||||
out_pvals = [abstract(ConcreteArray(6)), abstract(ShapedArray)] # all are unknown PartialVal
|
||||
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
|
||||
consts = [3, 6] # values for `ka` and `kb` constvars
|
||||
"""
|
||||
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
|
||||
@ -509,19 +483,20 @@ def new_eqn_recipe(invars: Sequence[JaxprTracer],
|
||||
if primitive.call_primitive or primitive.map_primitive:
|
||||
assert "call_jaxpr" in params
|
||||
if primitive.map_primitive:
|
||||
assert "mapped_invars" in params
|
||||
assert "donated_invars" in params
|
||||
assert ("mapped_invars" in params and
|
||||
len(params["mapped_invars"]) == len(params["call_jaxpr"].invars))
|
||||
assert ("donated_invars" in params and
|
||||
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
|
||||
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
|
||||
params, source_info)
|
||||
|
||||
|
||||
def recipe_to_eqn(unused_var: Callable[[], core.Var],
|
||||
getvar: Callable[[JaxprTracer], core.Atom],
|
||||
def recipe_to_eqn(getvar: Callable[[JaxprTracer], core.Atom],
|
||||
recipe: JaxprEqnRecipe) -> core.JaxprEqn:
|
||||
_, in_tracers, out_tracer_refs, primitive, params, source_info = recipe
|
||||
out_tracers = [t_ref() for t_ref in out_tracer_refs]
|
||||
invars = [getvar(t) for t in in_tracers]
|
||||
outvars = [unused_var() if t is None else cast(core.Var, getvar(t))
|
||||
outvars = [core.dropvar if t is None else cast(core.Var, getvar(t))
|
||||
for t in out_tracers]
|
||||
return new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
|
||||
|
||||
@ -564,7 +539,7 @@ def tracers_to_jaxpr(
|
||||
recipe = t.recipe
|
||||
if isinstance(recipe, JaxprEqnRecipe):
|
||||
if recipe.eqn_id not in processed_eqn_ids:
|
||||
eqns.append(recipe_to_eqn(lambda: core.dropvar, getvar, recipe))
|
||||
eqns.append(recipe_to_eqn(getvar, recipe))
|
||||
processed_eqn_ids.add(recipe.eqn_id)
|
||||
elif isinstance(recipe, LambdaBinding):
|
||||
if not any(t is in_tracer for in_tracer in in_tracers):
|
||||
@ -586,7 +561,7 @@ def tracers_to_jaxpr(
|
||||
env_vars, env_vals = unzip2(env.items())
|
||||
const_vars, const_vals = unzip2(consts.items())
|
||||
# The env_vars are pre-pended to the invars
|
||||
jaxpr = Jaxpr(const_vars, list(it.chain(env_vars, invars)), list(map(getvar, out_tracers)), eqns)
|
||||
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns)
|
||||
core.skip_checks or core.check_jaxpr(jaxpr)
|
||||
return jaxpr, const_vals, env_vals
|
||||
|
||||
@ -682,12 +657,9 @@ def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, Abst
|
||||
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
|
||||
|
||||
|
||||
remat_call_p = core.Primitive('remat_call')
|
||||
remat_call_p.call_primitive = True
|
||||
remat_call = partial(core.call_bind, remat_call_p)
|
||||
remat_call_p.def_custom_bind(remat_call)
|
||||
remat_call_p = core.CallPrimitive('remat_call')
|
||||
remat_call = remat_call_p.bind
|
||||
remat_call_p.def_impl(core.call_impl)
|
||||
remat_call_p.multiple_results = True
|
||||
|
||||
# We reuse the _remat_partial_eval function both for remat_call and for
|
||||
# invertible_call, both of which in a sense stage out operations to
|
||||
|
@ -11,24 +11,22 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implementation of pmap and related functionality.
|
||||
"""Implementation of pmap and related functionality."""
|
||||
|
||||
Note on ShardingSpecs and spec_to_indices():
|
||||
A ShardingSpec describes at a high level how a logical array is sharded across
|
||||
devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
|
||||
describe how to shard inputs to a parallel computation). spec_to_indices()
|
||||
encodes exactly how a given ShardingSpec is translated to device buffers,
|
||||
i.e. how the sharded array is "laid out" across devices. Given a sequence of
|
||||
devices, we shard the data across the devices in row-major order, with
|
||||
replication treated as an extra inner dimension.
|
||||
|
||||
For example, given the logical data array [1, 2, 3, 4], if we were to partition
|
||||
this array 4 ways with a replication factor of 2, for a total of 8 devices, the
|
||||
data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
|
||||
|
||||
This encoding is assumed by various parts of the system, e.g. generating
|
||||
replica groups for collective operations.
|
||||
"""
|
||||
# A ShardingSpec describes at a high level how a logical array is sharded across
|
||||
# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
|
||||
# describe how to shard inputs to a parallel computation). spec_to_indices()
|
||||
# encodes exactly how a given ShardingSpec is translated to device buffers, i.e.
|
||||
# how the sharded array is "laid out" across devices. Given a sequence of
|
||||
# devices, we shard the data across the devices in row-major order, with
|
||||
# replication treated as an extra inner dimension.
|
||||
#
|
||||
# For example, given the logical data array [1, 2, 3, 4], if we were to
|
||||
# partition this array 4 ways with a replication factor of 2, for a total of 8
|
||||
# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
|
||||
#
|
||||
# This encoding is assumed by various parts of the system, e.g. generating
|
||||
# replica groups for collective operations.
|
||||
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
@ -64,7 +62,7 @@ xops = xc.ops
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
_map = safe_map
|
||||
unsafe_map, map = map, safe_map
|
||||
|
||||
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||
|
||||
@ -120,7 +118,7 @@ class ShardingSpec:
|
||||
|
||||
def __repr__(self):
|
||||
return ("ShardingSpec(shards_per_axis=%s, is_axis_materialized=%s, "
|
||||
"replication_factor=%s)" %
|
||||
"replication_factors=%s)" %
|
||||
(self.shards_per_axis, self.is_axis_materialized,
|
||||
self.replication_factors))
|
||||
|
||||
@ -186,7 +184,7 @@ def spec_to_indices(shape: Tuple[int, ...],
|
||||
|
||||
def _axis_indices(axis_size, num_shards, is_materialized):
|
||||
if not is_materialized:
|
||||
assert axis_size == num_shards
|
||||
assert axis_size == num_shards, f'{axis_size} != {num_shards}'
|
||||
return list(range(axis_size))
|
||||
if num_shards == 1:
|
||||
return [slice(None)]
|
||||
@ -785,7 +783,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
|
||||
tuple_args = len(sharded_avals) > 100 # pass long arg lists as tuple for TPU
|
||||
|
||||
c = xb.make_computation_builder("pmap_{}".format(fun.__name__))
|
||||
xla_consts = _map(partial(xb.constant, c), consts)
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
replicated = [not m for m in mapped_invars]
|
||||
xla_args = xla._xla_callable_args(c, sharded_avals, tuple_args, replicated,
|
||||
arg_parts)
|
||||
@ -1115,22 +1113,23 @@ def execute_replicated(compiled,
|
||||
return out_handler(out_bufs)
|
||||
|
||||
|
||||
xla_pmap_p = core.Primitive('xla_pmap')
|
||||
xla_pmap_p.map_primitive = True
|
||||
xla_pmap_p.multiple_results = True
|
||||
xla_pmap = partial(core.map_bind, xla_pmap_p)
|
||||
xla_pmap_p.def_custom_bind(xla_pmap)
|
||||
xla_pmap_p = core.MapPrimitive('xla_pmap')
|
||||
xla_pmap = xla_pmap_p.bind
|
||||
xla_pmap_p.def_impl(xla_pmap_impl)
|
||||
pe.staged_out_calls.add(xla_pmap_p)
|
||||
|
||||
# Set param update handlers to update `donated_invars` just like xla_call_p
|
||||
pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p]
|
||||
ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p]
|
||||
ad.call_transpose_param_updaters[xla_pmap_p] = \
|
||||
ad.call_transpose_param_updaters[xla.xla_call_p]
|
||||
|
||||
def _pmap_translation_rule(c, axis_env,
|
||||
in_nodes, name_stack, axis_name, axis_size,
|
||||
global_axis_size, devices, name,
|
||||
call_jaxpr, *, backend=None, mapped_invars,
|
||||
donated_invars):
|
||||
if any(donated_invars):
|
||||
raise ValueError("Donating buffers passed to a a pmap nested inside a jit "
|
||||
"or another pmap is not supported.")
|
||||
del donated_invars # Unused.
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
if axis_env.names and devices is not None:
|
||||
|
@ -187,11 +187,8 @@ def _sharded_call_impl(fun, *args, num_partitions, in_parts, out_parts_thunk,
|
||||
return compiled_fun(*args)
|
||||
|
||||
|
||||
sharded_call_p = core.Primitive("sharded_call")
|
||||
sharded_call_p.call_primitive = True
|
||||
sharded_call_p.multiple_results = True
|
||||
sharded_call = partial(core.call_bind, sharded_call_p)
|
||||
sharded_call_p.def_custom_bind(sharded_call)
|
||||
sharded_call_p = core.CallPrimitive("sharded_call")
|
||||
sharded_call = sharded_call_p.bind
|
||||
sharded_call_p.def_impl(_sharded_call_impl)
|
||||
xla.call_translations[sharded_call_p] = _sharded_jit_translation_rule
|
||||
|
||||
|
@ -35,13 +35,16 @@ from ..abstract_arrays import (ConcreteArray, ShapedArray, AbstractToken,
|
||||
from ..core import Literal, pp_eqn_compact
|
||||
from ..pprint_util import pp
|
||||
from ..util import (partial, partialmethod, cache, prod, unzip2, memoize,
|
||||
extend_name_stack, wrap_name, safe_zip)
|
||||
extend_name_stack, wrap_name, safe_zip, safe_map)
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from . import partial_eval as pe
|
||||
from . import ad
|
||||
from . import masking
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
xe = xc._xla
|
||||
xops = xc._xla.ops
|
||||
|
||||
@ -63,7 +66,6 @@ flags.DEFINE_bool('jax_log_compiles',
|
||||
bool_env('JAX_LOG_COMPILES', False),
|
||||
'Print a message each time a `jit` computation is compiled.')
|
||||
|
||||
def _map(f, *xs): return tuple(map(f, *xs))
|
||||
def identity(x): return x
|
||||
|
||||
_scalar_types = dtypes.python_scalar_dtypes.keys()
|
||||
@ -199,7 +201,7 @@ def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool:
|
||||
return True
|
||||
for param in params.values():
|
||||
if isinstance(param, tuple):
|
||||
if any(_map(_param_uses_outfeed, param)):
|
||||
if any(unsafe_map(_param_uses_outfeed, param)):
|
||||
return True
|
||||
elif _param_uses_outfeed(param):
|
||||
return True
|
||||
@ -223,7 +225,7 @@ def arg_spec(x):
|
||||
|
||||
def apply_primitive(prim, *args, **params):
|
||||
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
|
||||
compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
|
||||
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
|
||||
return compiled_fun(*args)
|
||||
|
||||
@cache()
|
||||
@ -243,8 +245,8 @@ def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue,
|
||||
if not prim.multiple_results:
|
||||
handle_result = aval_to_result_handler(device, aval_out)
|
||||
else:
|
||||
handlers = tuple(map(partial(aval_to_result_handler, device), aval_out))
|
||||
handle_result = lambda xs: tuple(h(x) for h, x in zip(handlers, xs))
|
||||
handlers = map(partial(aval_to_result_handler, device), aval_out)
|
||||
handle_result = lambda xs: tuple(h(x) for h, x in unsafe_zip(handlers, xs))
|
||||
tuple_args = len(avals) > 100
|
||||
if prim in initial_style_translations:
|
||||
nreps = initial_style_primitive_replicas(params)
|
||||
@ -386,8 +388,8 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
|
||||
|
||||
env = {}
|
||||
write(core.unitvar, _make_unit(c))
|
||||
_map(write, jaxpr.constvars, consts)
|
||||
_map(write, jaxpr.invars, args)
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
for eqn in jaxpr.eqns:
|
||||
frame = source_info_util.user_frame(eqn.source_info)
|
||||
c.set_op_metadata(xc.OpMetadata(
|
||||
@ -396,7 +398,7 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
|
||||
eqn.primitive.name, eqn.params)),
|
||||
source_file=frame.file_name if frame else None,
|
||||
source_line=frame.line_num if frame else None))
|
||||
in_nodes = list(map(read, eqn.invars))
|
||||
in_nodes = map(read, eqn.invars)
|
||||
if eqn.primitive in backend_specific_translations[platform]:
|
||||
rule = backend_specific_translations[platform][eqn.primitive]
|
||||
ans = rule(c, *in_nodes, **eqn.params)
|
||||
@ -432,8 +434,8 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
|
||||
c.get_shape(ans) # force xla to do shape error checking
|
||||
out_nodes = xla_destructure(c, ans) if eqn.primitive.multiple_results else [ans]
|
||||
c.clear_op_metadata()
|
||||
_map(write, eqn.outvars, out_nodes)
|
||||
return _map(read, jaxpr.outvars)
|
||||
map(write, eqn.outvars, out_nodes)
|
||||
return map(read, jaxpr.outvars)
|
||||
|
||||
def xla_destructure(c, ans):
|
||||
num_elements = len(c.get_shape(ans).tuple_shapes())
|
||||
@ -470,7 +472,7 @@ def axis_read(axis_env, axis_name):
|
||||
|
||||
def axis_groups(axis_env, name):
|
||||
if isinstance(name, (list, tuple)):
|
||||
mesh_axes = tuple(map(partial(axis_read, axis_env), name))
|
||||
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
|
||||
else:
|
||||
mesh_axes = (axis_read(axis_env, name),)
|
||||
return _axis_groups(axis_env.nreps, axis_env.sizes, mesh_axes)
|
||||
@ -483,7 +485,7 @@ def _axis_groups(nrep, mesh_spec, mesh_axes):
|
||||
groups = onp.reshape(
|
||||
onp.moveaxis(iota, mesh_axes, onp.arange(len(mesh_axes))),
|
||||
(prod(onp.take(full_spec, mesh_axes)), -1))
|
||||
return tuple(map(tuple, groups.T))
|
||||
return tuple(unsafe_map(tuple, groups.T))
|
||||
|
||||
def jaxpr_replicas(jaxpr):
|
||||
"""The number of replicas needed for a jaxpr.
|
||||
@ -536,7 +538,8 @@ def jaxpr_collectives(jaxpr):
|
||||
### xla_call underlying jit
|
||||
|
||||
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
|
||||
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, *map(arg_spec, args))
|
||||
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
|
||||
*unsafe_map(arg_spec, args))
|
||||
try:
|
||||
return compiled_fun(*args)
|
||||
except FloatingPointError:
|
||||
@ -597,8 +600,8 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
|
||||
jaxpr, pvals, consts = pe.trace_to_jaxpr(
|
||||
fun, pvals, instantiate=False, stage_out=True, bottom=True)
|
||||
map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
|
||||
jaxpr, uses_outfeed = apply_outfeed_rewriter(jaxpr)
|
||||
_map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
|
||||
|
||||
nreps = jaxpr_replicas(jaxpr)
|
||||
device = _xla_callable_device(nreps, backend, device, arg_devices)
|
||||
@ -635,7 +638,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU
|
||||
|
||||
c = xb.make_computation_builder("jit_{}".format(fun.__name__))
|
||||
xla_consts = _map(partial(xb.constant, c), consts)
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
xla_args = _xla_callable_args(c, abstract_args, tuple_args)
|
||||
out_nodes = jaxpr_subcomp(
|
||||
c, jaxpr, backend, AxisEnv(nreps, (), ()), xla_consts,
|
||||
@ -779,8 +782,8 @@ def _execute_replicated(compiled: XlaExecutable, uses_outfeed: bool,
|
||||
|
||||
def _execute_trivial(jaxpr, device: Optional[Device], consts, handlers, *args):
|
||||
env = {core.unitvar: core.unit}
|
||||
_map(env.setdefault, jaxpr.invars, args)
|
||||
_map(env.setdefault, jaxpr.constvars, consts)
|
||||
map(env.setdefault, jaxpr.invars, args)
|
||||
map(env.setdefault, jaxpr.constvars, consts)
|
||||
outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v]
|
||||
for v in jaxpr.outvars]
|
||||
return [_copy_device_array_to_device(x, device) if type(x) is DeviceArray
|
||||
@ -800,24 +803,43 @@ def _get_device(device, backend):
|
||||
out, = compiled.local_devices()
|
||||
return out
|
||||
|
||||
xla_call_p = core.Primitive('xla_call')
|
||||
xla_call_p.call_primitive = True
|
||||
xla_call_p.multiple_results = True
|
||||
xla_call = partial(core.call_bind, xla_call_p)
|
||||
xla_call_p.def_custom_bind(xla_call)
|
||||
xla_call_p = core.CallPrimitive('xla_call')
|
||||
xla_call = xla_call_p.bind
|
||||
xla_call_p.def_impl(_xla_call_impl)
|
||||
pe.staged_out_calls.add(xla_call_p)
|
||||
|
||||
def _xla_call_partial_eval_update_params(params, in_unknowns):
|
||||
call_jaxpr = params['call_jaxpr']
|
||||
donated_invars = params['donated_invars']
|
||||
if not in_unknowns and donated_invars:
|
||||
# JaxprTrace.post_process_call creates a call with no input tracers
|
||||
new_donated_invars = (False,) * len(call_jaxpr.invars)
|
||||
else:
|
||||
# JaxprTrace.process_call drops known input tracers
|
||||
donated_invars = [d for d, uk in zip(donated_invars, in_unknowns) if uk]
|
||||
new_donated_invars = ((False,) * (len(call_jaxpr.invars) - len(donated_invars))
|
||||
+ tuple(donated_invars))
|
||||
return dict(params, donated_invars=new_donated_invars)
|
||||
pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params
|
||||
|
||||
def _xla_call_jvp_update_params(params, nz_tangents):
|
||||
donated_invars = params['donated_invars']
|
||||
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
|
||||
new_donated_invars = (*donated_invars, *donated_tangents)
|
||||
return dict(params, donated_invars=new_donated_invars)
|
||||
ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params
|
||||
|
||||
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
|
||||
donated_invars = params['donated_invars']
|
||||
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
|
||||
donated_cotangents = [False for nz in nonzero_cts if nz]
|
||||
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
|
||||
ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
|
||||
|
||||
def _xla_call_translation_rule(c, axis_env,
|
||||
in_nodes, name_stack, backend, name,
|
||||
call_jaxpr, device=None, donated_invars=None):
|
||||
del device # Ignored.
|
||||
if donated_invars is None:
|
||||
donated_invars = (False,) * len(in_nodes)
|
||||
elif any(donated_invars):
|
||||
raise ValueError("Donating buffers passed to a jit nested inside a jit or "
|
||||
"pmap is not supported.")
|
||||
|
||||
call_jaxpr, donated_invars, device=None):
|
||||
del device, donated_invars # Ignored.
|
||||
subc = xb.make_computation_builder(f"jit_{name}")
|
||||
args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
|
||||
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
|
||||
@ -874,7 +896,7 @@ def lower_fun(fun, multiple_results):
|
||||
wrapped_fun = _tuple_output(wrapped_fun)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True,
|
||||
stage_out=True)
|
||||
consts = _map(partial(xb.constant, c), consts)
|
||||
consts = map(partial(xb.constant, c), consts)
|
||||
outs = jaxpr_subcomp(c, jaxpr, None, AxisEnv(1), consts, '', *xla_args)
|
||||
if multiple_results:
|
||||
return xops.Tuple(c, outs)
|
||||
@ -895,7 +917,7 @@ def lower_fun_initial_style(fun):
|
||||
pvals = [pe.PartialVal.unknown(a) for a in avals]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(
|
||||
lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True)
|
||||
consts = _map(partial(xb.constant, c), consts)
|
||||
consts = map(partial(xb.constant, c), consts)
|
||||
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack,
|
||||
*xla_args)
|
||||
return xops.Tuple(c, outs)
|
||||
|
@ -259,7 +259,7 @@ class APITest(jtu.JaxTestCase):
|
||||
"Abstract evaluation for 'foo' not implemented")
|
||||
|
||||
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
|
||||
"Forward-mode differentiation rule for 'foo' not implemented")
|
||||
"Differentiation rule for 'foo' not implemented")
|
||||
|
||||
foo_p.def_abstract_eval(lambda x: x)
|
||||
|
||||
@ -1759,7 +1759,9 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.util.curry
|
||||
def call(f, *args):
|
||||
return jax.core.call(jax.linear_util.wrap_init(lambda *args: [f(*args)]), *args)[0]
|
||||
return jax.core.call(
|
||||
jax.linear_util.wrap_init(lambda *args: [f(*args)]),
|
||||
*args, name='foo')[0]
|
||||
|
||||
f = call(add_one)
|
||||
g = jax.remat(lambda x: add_one(f(x)))
|
||||
@ -3198,8 +3200,12 @@ class BufferDonationTest(jtu.JaxTestCase):
|
||||
def test_jit_nested_donate_ignored(self):
|
||||
jit_fun = jit(lambda x: jit(lambda y: y ** 2, donate_argnums=0)(x))
|
||||
a = jax.device_put(jnp.array(1))
|
||||
with self.assertRaisesRegex(ValueError, "nested.*not supported"):
|
||||
jit_fun(a)
|
||||
|
||||
# NOTE(mattjj): stopped raising error here and instead just ignored
|
||||
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
|
||||
# jit_fun(a)
|
||||
|
||||
jit_fun(a) # doesn't crash
|
||||
|
||||
def test_jnp_array_copy(self):
|
||||
# https://github.com/google/jax/issues/3412
|
||||
@ -3232,8 +3238,12 @@ class BufferDonationTest(jtu.JaxTestCase):
|
||||
def test_pmap_nested_donate_raises(self):
|
||||
pmap_fun = jit(lambda x: api.pmap(lambda y: y ** 2, donate_argnums=0)(x))
|
||||
a = api.pmap(lambda x: x)(jnp.array([1]))
|
||||
with self.assertRaisesRegex(ValueError, "nested.*not supported"):
|
||||
pmap_fun(a)
|
||||
|
||||
# NOTE(mattjj): stopped raising error here and instead just ignored
|
||||
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
|
||||
# pmap_fun(a)
|
||||
|
||||
pmap_fun(a) # doesn't crash
|
||||
|
||||
assertDeleted = lambda self, x: self._assertDeleted(x, True)
|
||||
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user