Reimplemented the passing of tokens with a Jaxpr transform

This commit is contained in:
George Necula 2020-05-06 11:02:10 +03:00
parent 8fc96910c2
commit d8b75e1913
6 changed files with 427 additions and 386 deletions

View File

@ -325,16 +325,14 @@ def xla_computation(fun: Callable,
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals,
instantiate=instantiate_const_outputs,
stage_out=True)
jaxpr, _ = xla.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
c = xb.make_computation_builder('xla_computation_{}'.format(fun_name))
uses_outfeed = xla.jaxpr_uses_outfeed(jaxpr)
xla.state_carry.start_nested_comp_without_input(c, uses_outfeed)
xla_consts = map(partial(xb.constant, c), consts)
xla_args = xla._xla_callable_args(c, avals, tuple_args)
outs = xla.jaxpr_subcomp(
c, jaxpr, backend, axis_env_, xla_consts,
extend_name_stack(wrap_name(fun_name, 'xla_computation')), *xla_args)
xla.state_carry.end_nested_comp_without_output(c)
return c.Build(xc.ops.Tuple(c, outs))
return computation_maker

View File

@ -114,7 +114,7 @@ from jax import core
from jax import dtypes
from jax import lax
from jax.lib import pytree, xla_bridge
from jax.interpreters import ad, xla, batching, masking
from jax.interpreters import ad, xla, batching, masking, pxla
from jax.interpreters import partial_eval as pe
from jax import pprint_util as ppu
from jax import util
@ -126,7 +126,7 @@ import msgpack # type: ignore
import numpy as onp
import sys
import traceback
from typing import Any, Callable, Dict, List, Optional, NamedTuple, Sequence, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, NamedTuple, Sequence, Tuple
xops = xla_client._xla.ops
@ -281,7 +281,6 @@ positional arguments and parameters:
# TODO: handle multiple vmap and mask
id_tap_p = core.Primitive("id_tap")
id_tap_p.multiple_results = True
xla.outfeed_primitives.add(id_tap_p)
def _add_transform_name(params: Dict, transform: str) -> Dict:
@ -313,25 +312,27 @@ def _id_tap_abstract_eval(*args_a: pe.AbstractValue, **params) \
id_tap_p.def_abstract_eval(_id_tap_abstract_eval)
def _instantiate_zeros(tan, arg):
# TODO: there must be a better way to do this.
# The AttributeError is for regular values, the KeyError is for ConcreteArray
def _instantiate_zeros(arg, tan):
"""Turn special ad.zero tangents into arrays of 0s."""
if tan is not ad.zero:
return tan
elif isinstance(arg, core.Tracer):
# TODO: why do I have to do this to get a zero?
else:
try:
aval = arg.aval
return ad.instantiate_zeros_aval(aval, tan)
except:
# It seems that we get here for ConcreteArray
return ad.instantiate_zeros(arg, tan)
except (AttributeError, KeyError):
# We get here for regular Python values
return ad.zeros_like_jaxval(arg)
def _id_tap_jvp_rule(primals, tangents, *, func, nr_untapped=0, **params):
# Put primals through id_tap separately, so that partial evaluation
# can do its job for grad
out_primals = id_tap_p.bind(*primals, func=func, nr_untapped=nr_untapped, **params)
# Add one primal output as untapped, to create dependency.
tangent_zeros = tuple(map(_instantiate_zeros, tangents, primals))
tangent_zeros = tuple(map(_instantiate_zeros, primals, tangents))
out_tangents_extra = id_tap_p.bind(*tangent_zeros, out_primals[0],
func=func, nr_untapped=nr_untapped + 1,
**_add_transform_name(params, "jvp"))
@ -343,7 +344,7 @@ ad.primitive_jvps[id_tap_p] = _id_tap_jvp_rule
def _id_tap_transpose_rule(cts, *args, func=None, nr_untapped=0, **params):
assert len(cts) == len(args)
cts_zeros = tuple(map(_instantiate_zeros, cts, args))
cts_zeros = tuple(map(_instantiate_zeros, args, cts))
ct_args = id_tap_p.bind(*cts_zeros, func=func, nr_untapped=nr_untapped,
**_add_transform_name(params, "transpose"))
return ct_args
@ -376,6 +377,7 @@ def _id_tap_masking_rule(operands, operands_logical_shapes, **params):
masking.masking_rules[id_tap_p] = _id_tap_masking_rule
#### XLA compilation ####
####
# Special consumer to mark the end of outfeed stream for a device
_end_consumer = 0
_unknown_consumer = 1 # for testing error cases
@ -394,12 +396,24 @@ def _id_print_translation_rule_outfeed(comp: XlaComputationBuilder,
params["consumer_id"] = _register_consumer(
_ConsumerCallable(func, tuple(params.items()), arg_treedef))
prev_token = xla.state_carry.current_token(comp)
nr_args_to_emit = len(args_op) - nr_untapped
next_token = _emit_outfeed(comp, prev_token,
# We expect the current token at the end
current_token = args_op[-1]
current_token_shape = comp.GetShape(current_token)
if current_token_shape.is_array():
# TODO: we get here because wehn we partially eval some primitives
# we impl themn with JIT, but we did not rewrite them
has_token = False
current_token = xops.CreateToken(comp)
else:
has_token = True
nr_args_to_emit = len(args_op) - nr_untapped - (1 if has_token else 0)
next_token = _emit_outfeed(comp, current_token,
args_op[0:nr_args_to_emit], params["consumer_id"])
xla.state_carry.set_current_token(comp, next_token)
return xops.Tuple(comp, args_op)
results = (args_op[:-1] + (next_token,)) if has_token else args_op
return xops.Tuple(comp, results)
xla.translations[id_tap_p] = _id_print_translation_rule_outfeed
@ -517,6 +531,7 @@ class TapFunctionException(Exception):
"""
pass
_outfeed_receiver_started = False
@contextmanager
def outfeed_receiver(*,
timeout_sec=10,
@ -550,6 +565,10 @@ def outfeed_receiver(*,
The ``outfeed_receiver`` must be started outside any jitted computation.
"""
global _outfeed_receiver_started
if _outfeed_receiver_started:
raise ValueError("At most one outfeed_receiver can be running at once.")
if not devices:
backends = backends or xla_client._get_local_backends().keys()
devices = tuple(itertools.chain(*[api.devices(backend)
@ -593,13 +612,13 @@ def outfeed_receiver(*,
# bugs in our code, not from the tap functions.
for rf in receiver_futures:
rf.add_done_callback(lambda rf: rf.result())
xla.set_outfeed_allowed(True)
xla.can_execute_outfeed_computations = True
try:
yield
finally:
for d in devices: # Signal the end of printing
api.jit(lambda x: id_tap(_end_consumer, None, result=x), device=d)(0) # type: ignore[arg-type]
xla.set_outfeed_allowed(False)
xla.can_execute_outfeed_computations = False
for f in futures.as_completed(receiver_futures, timeout=timeout_sec):
finished_device = f.result() # Throw exceptions here
if _LOGGING:
@ -607,3 +626,172 @@ def outfeed_receiver(*,
if count_tap_exceptions > 0:
raise TapFunctionException
#### Jaxpr rewriting logic
####
def _jaxpr_var_defs(jaxpr: core.Jaxpr) -> Iterable[int]:
"""Iterates over all the unique vars the top-level of a Jaxpr"""
for iv in jaxpr.invars:
yield iv.count
for cv in jaxpr.constvars:
yield cv.count
for eqn in jaxpr.eqns:
for ov in eqn.outvars:
yield ov.count
def _jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool:
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
if eqn.primitive is id_tap_p:
return True
for subjaxpr in core.subjaxprs(jaxpr):
if _jaxpr_uses_outfeed(subjaxpr):
return True
return False
def _rewrite_typed_jaxpr(tjaxpr: core.TypedJaxpr,
has_input_token: bool,
has_output_token: bool) -> Tuple[core.TypedJaxpr, bool]:
"""Rewrites a TypedJaxpr to thread the token, if needed.
Returns the rewritten Jaxpr, and whether it uses outfeed."""
new_jaxpr, uses_outfeed = _rewrite_jaxpr(tjaxpr.jaxpr, has_input_token, has_output_token)
return (core.TypedJaxpr(new_jaxpr, tjaxpr.literals,
tuple(map(lambda v: v.aval, new_jaxpr.invars)),
tuple(map(lambda v: v.aval, new_jaxpr.outvars))),
uses_outfeed)
def _rewrite_jaxpr(jaxpr: core.Jaxpr,
has_input_token: bool,
has_output_token: bool) -> Tuple[core.Jaxpr, bool]:
"""Rewrite a Jaxpr to thread the token, if needed."""
assert has_input_token or not has_output_token
uses_outfeed = _jaxpr_uses_outfeed(jaxpr)
if not has_input_token and not uses_outfeed:
return (jaxpr, False)
max_var_count = max(_jaxpr_var_defs(jaxpr))
mk_new_id = itertools.count(start=max_var_count + 1)
def mk_new_var(aval: core.AbstractValue) -> core.Var:
return core.Var(next(mk_new_id), '', aval)
eqns: List[core.JaxprEqn] = []
last_token_var = mk_new_var(core.abstract_token)
if has_input_token:
invars = jaxpr.invars + [last_token_var]
else:
invars = jaxpr.invars
eqns.append(core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var],
lax.create_token_p, {}))
for eqn in jaxpr.eqns:
if eqn.primitive is id_tap_p:
new_token_var = mk_new_var(core.abstract_token)
eqns.append(core.new_jaxpr_eqn(eqn.invars + [last_token_var],
eqn.outvars + [new_token_var],
eqn.primitive, eqn.params))
last_token_var = new_token_var
elif eqn.primitive is lax.while_p:
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
if _jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
raise NotImplementedError("outfeed not supported in the conditional of a while")
uses_outfeed = _jaxpr_uses_outfeed(body_jaxpr.jaxpr)
if not uses_outfeed:
eqns.append(eqn)
continue
new_token_var = mk_new_var(core.abstract_token)
eqns.append(core.new_jaxpr_eqn(
eqn.invars + [last_token_var],
eqn.outvars + [new_token_var],
eqn.primitive,
dict(eqn.params,
body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True)[0],
cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True, False)[0])))
last_token_var = new_token_var
elif eqn.primitive is lax.cond_p:
true_jaxpr, false_jaxpr, linear = util.split_dict(
eqn.params, ["true_jaxpr", "false_jaxpr", "linear"])
uses_outfeed = _jaxpr_uses_outfeed(true_jaxpr.jaxpr) or _jaxpr_uses_outfeed(false_jaxpr.jaxpr)
if not uses_outfeed:
eqns.append(eqn)
continue
nr_true_invars = len(true_jaxpr.jaxpr.invars)
pred, true_invars, false_invars = util.split_list(eqn.invars,
[1, nr_true_invars])
new_token_var = mk_new_var(core.abstract_token)
new_invars = pred + true_invars + [last_token_var] + false_invars + [last_token_var]
eqns.append(core.new_jaxpr_eqn(
new_invars, eqn.outvars + [new_token_var],
eqn.primitive,
dict(eqn.params,
true_jaxpr=_rewrite_typed_jaxpr(true_jaxpr, True, True)[0],
false_jaxpr=_rewrite_typed_jaxpr(false_jaxpr, True, True)[0],
linear=linear + (False, False))))
last_token_var = new_token_var
elif eqn.primitive is lax.scan_p:
num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict(
eqn.params, ["num_consts", "num_carry", "jaxpr", "linear",
"reverse", "length"])
uses_outfeed = _jaxpr_uses_outfeed(carry_jaxpr.jaxpr)
if not uses_outfeed:
eqns.append(eqn)
continue
nr_const_and_carry = num_consts + num_carry
new_invars = eqn.invars[0:nr_const_and_carry] + [last_token_var] + eqn.invars[nr_const_and_carry:]
new_token_var = mk_new_var(core.abstract_token)
new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)[0]
# The rewrite put the token carry at end, it has to be at end of carry
new_jaxpr_invars = new_jaxpr.jaxpr.invars
new_jaxpr_invars = (new_jaxpr_invars[0:nr_const_and_carry] +
[new_jaxpr_invars[-1]] +
new_jaxpr_invars[nr_const_and_carry:-1])
new_jaxpr.jaxpr.invars = new_jaxpr_invars
new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars]
new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
new_jaxpr_outvars = (new_jaxpr_outvars[0:num_carry] +
[new_jaxpr_outvars[-1]] +
new_jaxpr_outvars[num_carry:-1])
new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars]
eqns.append(core.new_jaxpr_eqn(
new_invars,
# Output token is at the end of carry result
eqn.outvars[0:num_carry] + [new_token_var] + eqn.outvars[num_carry:],
eqn.primitive,
dict(eqn.params,
jaxpr=new_jaxpr,
num_carry=num_carry + 1,
linear=linear + (False,))))
last_token_var = new_token_var
elif eqn.primitive is xla.xla_call_p:
call_jaxpr = eqn.params["call_jaxpr"]
uses_outfeed = _jaxpr_uses_outfeed(call_jaxpr)
if not uses_outfeed:
eqns.append(eqn)
continue
new_token_var = mk_new_var(core.abstract_token)
eqns.append(core.new_jaxpr_eqn(
eqn.invars + [last_token_var],
eqn.outvars + [new_token_var],
eqn.primitive,
dict(eqn.params, call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True)[0])))
last_token_var = new_token_var
elif eqn.primitive is pxla.xla_pmap_p:
raise NotImplementedError("rewrite of pmap")
else:
# Check no more subjaxprs
for param in eqn.params.values():
if type(param) is core.Jaxpr or type(param) is core.TypedJaxpr:
assert False
eqns.append(eqn)
outvars = jaxpr.outvars + ([last_token_var] if has_output_token else [])
return (core.Jaxpr(jaxpr.constvars, invars, outvars, eqns), True)
xla.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)

View File

@ -659,6 +659,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True)
jaxpr.invars = jaxpr.invars[1:] # ignore dummy
jaxpr, uses_outfeed = xla.apply_outfeed_rewriter(jaxpr)
out_pvs, out_consts = unzip2(out_pvals)
@ -693,13 +694,10 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
tuple_args = len(sharded_avals) > 100 # pass long arg lists as tuple for TPU
c = xb.make_computation_builder("pmap_{}".format(fun.__name__))
uses_outfeed = xla.jaxpr_uses_outfeed(jaxpr)
xla.state_carry.start_nested_comp_without_input(c, uses_outfeed)
xla_consts = _map(partial(xb.constant, c), consts)
xla_args = xla._xla_callable_args(c, sharded_avals, tuple_args)
out_nodes = xla.jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts,
extend_name_stack(wrap_name(name, 'pmap')), *xla_args)
xla.state_carry.end_nested_comp_without_output(c)
built = c.Build(xops.Tuple(c, out_nodes))
if devices is None:
@ -861,8 +859,8 @@ def _pmap_sharding_spec(nrep, axis_size, sharded_aval, mapped):
replication_factor=replication_factor * axis_size)
def execute_replicated(compiled, uses_outfeed: bool, backend, in_handler, out_handler, *args):
xla.check_outfeed_allowed(uses_outfeed)
def execute_replicated(compiled, uses_outfeed, backend, in_handler, out_handler, *args):
xla.check_before_outfeed_execution(uses_outfeed)
input_bufs = in_handler(args)
out_bufs = compiled.ExecuteOnLocalDevices(list(input_bufs))
return out_handler(out_bufs)

View File

@ -14,9 +14,9 @@
from collections import defaultdict
from functools import reduce
import itertools as it
import operator as op
import threading
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type, Tuple
from absl import logging
@ -168,167 +168,22 @@ pytype_aval_mappings.update((t, make_shaped_array) for t in array_types)
pytype_aval_mappings.update(
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
# Keep track of which primitives are stateful (they read or read/write state).
# This helps avoid threading the state through control-flow primitives that
# do not need it. This is a worthwhile optimization because it seems that XLA
# may not be good at dealing with tokens (b/154992062).
outfeed_primitives: Set[core.Primitive] = set()
def jaxpr_uses_outfeed(jaxpr):
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
if type(jaxpr) is core.TypedJaxpr:
jaxpr = jaxpr.jaxpr
for eqn in jaxpr.eqns:
if eqn.primitive in outfeed_primitives:
return True
for subjaxpr in core.subjaxprs(jaxpr):
if jaxpr_uses_outfeed(subjaxpr):
return True
return False
# We can optionally set a Jaxpr rewriter that can be applied just before
# compilation. This mechanism is used for compiling id_tap, we can
# remove it once we bring the id_tap implementation into the core.
outfeed_rewriter: Optional[Callable[[core.Jaxpr], Tuple[core.Jaxpr, bool]]] = None
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, bool]:
if outfeed_rewriter is not None:
return outfeed_rewriter(jaxpr)
else:
return jaxpr, False
class _ComputationStateCarry(threading.local):
"""Carries some state globally as we build the HLO.
For now the state is only a token, obtained from the last OutFeed. The
state is carried through nested loops and conditionals, if they use
state.
"""
# A stack of nested computations and the current token for each. A None token
# means that we cannot use state in this computation and any nested
# computations. The last one is the most recent.
_computations: Tuple[XlaComputationBuilder, ...]
_tokens: List[XlaOp]
_log_idx: int # For debugging
# TODO(necula): remove these experiment flags
# Flags to force compilation of while, cond, call with in/out state, to
# test the XLA compiler's handling of tokens.
FORCE_OUTFEED = False
_LOG_STATE = False
def __init__(self) -> None:
self._computations = ()
self._tokens = []
self._log_idx = 0
def check_and_reset_state(self) -> bool:
"""Checks that the state is empty, and resets it if not."""
ok = not self._computations and not self._tokens
self._computations = ()
self._tokens = []
self._log_idx = 0
return ok
def _log_state(self, msg: str):
if self._LOG_STATE:
top_comp = id(self._computations[-1]) if self._computations else 0
logging.warning(
f"[{self._log_idx} @ 0x{top_comp:x}/{len(self._computations)}]: {msg}.")
self._log_idx += 1
def current_state(self, comp: XlaComputationBuilder, uses_state: bool) -> List[XlaOp]:
"""Get the current state for the current computation."""
if uses_state:
assert self._computations and self._tokens
# Ensure we kept track of computations properly
assert (comp is self._computations[-1]), f"Reading state from unexpected computation (0x{id(comp):x})"
assert self._tokens[-1], "Reading non-initialized state"
return [self._tokens[-1]]
else:
return []
def start_nested_comp_without_input(self, comp: XlaComputationBuilder,
uses_state: bool):
"""Starts a nested computation, with no state passed in.
`comp` is the new computation.
"""
self._computations = self._computations + (comp,)
if uses_state:
self._log_state(f"start nested computation without input, initialize state")
assert all([t is None for t in self._tokens]), "Ignoring upstream state"
token = xops.CreateToken(comp)
else:
self._log_state(f"start nested computation without input state, no state")
token = None # Cannot use state, here or in nested computations
self._tokens = self._tokens + [token]
def start_nested_comp_with_input(self, comp: XlaComputationBuilder,
tuple_op: XlaOp, nr_regular: int,
uses_state: bool
) -> Tuple[Sequence[XlaOp], Sequence[XlaOp]]:
"""Start a nested computation with inputs.
`comp` is the new computation. `tuple_op` is a tuple in the new computation
with `nr_regular` elements and perhaps some state elements (if `uses_state`).
Stores the new state for the new computation.
Returns the regular elements of the tuple, and the state.
"""
self._computations = self._computations + (comp,)
self._log_state(f"start nested computation with input")
self._tokens = self._tokens + [None] # Will set the token below
return self.set_state_from_tuple(comp, tuple_op, nr_regular, uses_state)
def end_nested_comp_without_output(self, comp: XlaComputationBuilder) -> None:
"""Ends a nested computation, with no state returned.
`comp` is the ending computation.
"""
self._log_state(f"end nested computation without output")
if self._tokens[-1]:
assert self._computations and self._computations[-1] is comp
self._computations = self._computations[:-1]
self._tokens = self._tokens[:-1]
def end_nested_comp_with_output(self, comp: XlaComputationBuilder,
tuple_op: XlaOp, nr_regular: int,
uses_state: bool) -> Sequence[XlaOp]:
"""Ends a nested computation with results.
`comp` is the parent computation into which we return, `tuple_op` is
a tuple in the parent computation with `nr_regular` elements and
perhaps some state elements (if `uses_state`).
Stores the new state for the parent computation.
Returns the regular elements of the output.
"""
self._log_state(f"end nested computation with output")
if uses_state:
assert len(self._computations) >= 2 and self._computations[-2] is comp
self._computations = self._computations[:-1]
self._tokens = self._tokens[:-1]
regulars, _ = self.set_state_from_tuple(comp, tuple_op,
nr_regular, uses_state)
return regulars
def set_state_from_tuple(self, comp: XlaComputationBuilder,
tuple_op: XlaOp, nr_regular: int,
uses_state: bool
) -> Tuple[Sequence[XlaOp], Sequence[XlaOp]]:
"""Decomposes a tuple, returns the regular elements and the state.
We assume that the `tuple_op` represents a tuple with `nr_regular` regular
elements, followed by some elements encoding the state.
Stores the new state.
"""
regular_ops = [xops.GetTupleElement(tuple_op, i) for i in range(nr_regular)]
if uses_state:
assert self._computations and self._computations[-1] is comp
self._log_state(f"set state from tuple_op")
self._tokens[-1] = xops.GetTupleElement(tuple_op, nr_regular)
return regular_ops, self.current_state(comp, uses_state)
else:
return regular_ops, []
def current_token(self, comp: XlaComputationBuilder) -> XlaOp:
self._log_state("reading token")
state = self.current_state(comp, True)
return state[0]
def set_current_token(self, comp: XlaComputationBuilder, token: XlaOp):
assert comp == self._computations[-1]
assert self._tokens[-1]
self._tokens[-1] = token
state_carry = _ComputationStateCarry()
# TODO(necula): remove this when we start the outfeed receiver automatically.
can_execute_outfeed_computations: bool = False
def check_before_outfeed_execution(uses_outfeed: bool):
if uses_outfeed and not can_execute_outfeed_computations:
raise ValueError("Attempting to execute compiled code using outfeed, "
"but outfeed_receiver is not started.")
### op-by-op execution
@ -399,7 +254,6 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
@cache()
def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params):
c = xb.make_computation_builder(f"primitive_computation_{prim.name}")
state_carry.start_nested_comp_without_input(c, False)
c.SetOpMetadata(xc.OpMetadata(
op_type=prim.name,
op_name=str(pp_eqn_compact(prim.name, params))))
@ -420,7 +274,6 @@ def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params)
raise NotImplementedError(f"XLA translation rule for {prim} not found")
assert isinstance(ans, xe.XlaOp)
c.ClearOpMetadata()
state_carry.end_nested_comp_without_output(c)
try:
return c.Build()
except RuntimeError as e:
@ -653,7 +506,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):
pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
jaxpr, pvals, consts = pe.trace_to_jaxpr(
fun, pvals, instantiate=False, stage_out=True, bottom=True)
jaxpr, uses_outfeed = apply_outfeed_rewriter(jaxpr)
_map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
nreps = jaxpr_replicas(jaxpr)
@ -683,8 +536,6 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):
tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU
c = xb.make_computation_builder("jit_{}".format(fun.__name__))
uses_outfeed = jaxpr_uses_outfeed(jaxpr)
state_carry.start_nested_comp_without_input(c, uses_outfeed)
xla_consts = _map(partial(xb.constant, c), consts)
xla_args = _xla_callable_args(c, abstract_args, tuple_args)
out_nodes = jaxpr_subcomp(
@ -698,9 +549,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):
device_assignment=(device.id,) if device else None)
options.tuple_arguments = tuple_args
backend = xb.get_backend(backend)
state_carry.end_nested_comp_without_output(c)
compiled = backend.compile(built, compile_options=options)
if nreps == 1:
return partial(_execute_compiled, compiled, uses_outfeed, result_handlers)
else:
@ -745,19 +594,9 @@ def _pval_to_result_handler(device, pval):
else:
return aval_to_result_handler(device, pv)
_outfeed_allowed = False
def set_outfeed_allowed(allowed: bool):
global _outfeed_allowed
_outfeed_allowed = allowed
def check_outfeed_allowed(uses_outfeed: bool):
if uses_outfeed and not _outfeed_allowed:
raise ValueError("Attempting to execute compiled code using outfeed, "
"but outfeed_consumer is not started.")
def _execute_compiled(compiled: XlaExecutable, uses_outfeed: bool,
handlers, *args):
check_outfeed_allowed(uses_outfeed)
check_before_outfeed_execution(uses_outfeed)
device, = compiled.local_devices()
input_bufs = [device_put(x, device) for x in args if x is not token]
out_bufs = compiled.Execute(input_bufs)
@ -766,7 +605,7 @@ def _execute_compiled(compiled: XlaExecutable, uses_outfeed: bool,
def _execute_replicated(compiled: XlaExecutable, uses_outfeed: bool,
handlers, *args):
check_outfeed_allowed(uses_outfeed)
check_before_outfeed_execution(uses_outfeed)
input_bufs = [
[device_put(x, device) for x in args if x is not token]
for device in compiled.local_devices()]
@ -810,23 +649,11 @@ def _xla_call_translation_rule(c, axis_env,
call_jaxpr, device=None):
del device # Ignored.
subc = xb.make_computation_builder(f"jit_{name}")
uses_outfeed = jaxpr_uses_outfeed(call_jaxpr)
prev_state = state_carry.current_state(c, uses_outfeed)
input_op = xops.Tuple(c, list(in_nodes) + list(prev_state))
arg = xb.parameter(subc, 0, c.GetShape(input_op))
nr_regular_args = len(in_nodes)
args, input_state = state_carry.start_nested_comp_with_input(subc, arg, nr_regular_args, uses_outfeed)
args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, 'jit')),
*(args + input_state))
result_state = state_carry.current_state(subc, uses_outfeed)
subc = subc.Build(xops.Tuple(subc, list(out_nodes) + result_state))
call_op = xops.Call(c, subc, [input_op])
nr_outs = len(out_nodes)
regular_outs = state_carry.end_nested_comp_with_output(c, call_op, nr_outs, uses_outfeed)
return xops.Tuple(c, regular_outs)
extend_name_stack(name_stack, wrap_name(name, 'jit')), *args)
subc = subc.Build(xops.Tuple(subc, out_nodes))
return xops.Call(c, subc, list(in_nodes))
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
@ -1230,21 +1057,16 @@ def _remat_translation_rule(c, axis_env, in_nodes,
xb.constant(c, onp.array(1, dtype=onp.float32)),
xc.Shape.array_shape(xc.PrimitiveType.F32, []))
pred = xops.Lt(rng, xb.constant(c, onp.array(2, dtype=onp.float32)))
uses_outfeed = jaxpr_uses_outfeed(call_jaxpr)
prev_state = state_carry.current_state(c, uses_outfeed)
true_op = xops.Tuple(c, list(in_nodes) + list(prev_state))
true_op = xops.Tuple(c, in_nodes)
remat_subc = xb.make_computation_builder("remat_call_subcomputation")
input_op = xb.parameter(remat_subc, 0, c.GetShape(true_op), replicated=[])
args, input_state = state_carry.start_nested_comp_with_input(remat_subc, input_op, len(in_nodes), uses_outfeed)
args = [xops.GetTupleElement(input_op, i) for i in range(len(in_nodes))]
out_nodes = jaxpr_subcomp(remat_subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, 'remat')),
*(args + input_state))
result_state = state_carry.current_state(remat_subc, uses_outfeed)
result_tuple = xops.Tuple(remat_subc, out_nodes + tuple(result_state))
regular_outs = state_carry.end_nested_comp_with_output(c, result_tuple, len(out_nodes), uses_outfeed)
out_node_shapes = [remat_subc.GetShape(o) for o in regular_outs]
remat_subc = remat_subc.Build(xops.Tuple(remat_subc, regular_outs))
*args)
out_node_shapes = [remat_subc.GetShape(o) for o in out_nodes]
remat_subc = remat_subc.Build(xops.Tuple(remat_subc, out_nodes))
false_op = true_op
dummy_subc = xb.make_computation_builder("remat_call_dummy_subcomputation")

View File

@ -254,23 +254,15 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts])
batched = bool(cond_jaxpr.out_avals[0].shape)
if xla.jaxpr_uses_outfeed(cond_jaxpr) and not xla.state_carry.FORCE_OUTFEED:
# TODO: implement the boolean as an extra carry
raise NotImplementedError("State not supported in while_loop conditionals")
uses_outfeed = xla.jaxpr_uses_outfeed(body_jaxpr)
prev_state = xla.state_carry.current_state(c, uses_outfeed)
# Since jaxprs don't have tuples and have multiple return values, but we need
# the HLO While loop to take a single tuple input and output a single boolean
# (for the cond computation) or a single tuple output (for the body
# computation), we build XLA computations that handle the tuple munging before
# generating a Call into the computations formed from the jaxprs.
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals + prev_state)
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals)
cond_c = xb.make_computation_builder("cond_computation")
xla.state_carry.start_nested_comp_without_input(cond_c, False)
cond_carry = xb.parameter(cond_c, 0, c.GetShape(init_carry))
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
@ -283,12 +275,10 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, onp.array(False))], or_,
list(range(cond_jaxpr.out_avals[0].ndim)))
cond_comp = cond_c.Build(pred)
xla.state_carry.end_nested_comp_without_output(cond_c)
body_c = xb.make_computation_builder("body_computation")
body_carry = xb.parameter(body_c, 0, c.GetShape(init_carry))
body_carry_elts, _ = xla.state_carry.start_nested_comp_with_input(
body_c, body_carry, len(args), uses_outfeed)
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
_map(partial(xb.constant, body_c), body_jaxpr.literals),
@ -299,11 +289,10 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
extend_name_stack(name_stack, 'body_pred'), *(x + z))
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z)
assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape, z) # no broadcast
result_state = xla.state_carry.current_state(body_c, uses_outfeed)
new_carry = xops.Tuple(body_c, list(itertools.chain(x, y, new_z, result_state)))
new_carry = xops.Tuple(body_c, list(itertools.chain(x, y, new_z)))
ans = xops.While(cond_comp, body_c.Build(new_carry), init_carry)
ans_elts = xla.state_carry.end_nested_comp_with_output(c, ans, len(args), uses_outfeed)
ans = xops.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry)
ans_elts = [xops.GetTupleElement(ans, i) for i in range(len(args))]
_, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
return xops.Tuple(c, z)
@ -567,35 +556,23 @@ def _cond_translation_rule(c, axis_env, name_stack, avals, backend,
pred, *args, true_jaxpr, false_jaxpr, linear):
del linear # Unused.
true_ops, false_ops = split_list(args, [len(true_jaxpr.in_avals)])
uses_outfeed = xla.jaxpr_uses_outfeed(true_jaxpr) or xla.jaxpr_uses_outfeed(false_jaxpr)
prev_state = xla.state_carry.current_state(c, uses_outfeed)
def make_computation(name, jaxpr, op_shape):
sub_c = xb.make_computation_builder(name + '_comp')
op = xb.parameter(sub_c, 0, op_shape)
ops, _ = xla.state_carry.start_nested_comp_with_input(
sub_c, op, len(jaxpr.in_avals), uses_outfeed)
outs = xla.jaxpr_subcomp(sub_c, jaxpr.jaxpr, backend, axis_env,
_map(partial(xb.constant, sub_c), jaxpr.literals),
c = xb.make_computation_builder(name + '_comp')
op = xb.parameter(c, 0, op_shape)
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
_map(partial(xb.constant, c), jaxpr.literals),
extend_name_stack(name_stack, name + '_fun'), *ops)
result_state = xla.state_carry.current_state(sub_c, uses_outfeed)
result_tuple = xops.Tuple(sub_c, list(outs) + result_state)
xla.state_carry.end_nested_comp_with_output(c, result_tuple, len(outs), uses_outfeed)
return sub_c.Build(result_tuple)
return c.Build(xops.Tuple(c, outs))
true_op = xops.Tuple(c, true_ops + prev_state)
true_op = xops.Tuple(c, true_ops)
true_c = make_computation('true', true_jaxpr, c.GetShape(true_op))
false_op = xops.Tuple(c, false_ops + prev_state)
false_op = xops.Tuple(c, false_ops)
false_c = make_computation('false', false_jaxpr, c.GetShape(false_op))
cond_op = xops.Conditional(pred, true_op, true_c, false_op, false_c)
if not uses_outfeed:
return cond_op
else:
nr_outs = len(true_jaxpr.out_avals)
regular_outs, _ = xla.state_carry.set_state_from_tuple(
c, cond_op, nr_outs, uses_outfeed)
return xops.Tuple(c, regular_outs)
return xops.Conditional(pred, true_op, true_c, false_op, false_c)
def _cond_pred_bcast_select(pred, x, y):
if core.get_aval(x) is core.get_aval(y) is core.abstract_unit:

View File

@ -21,6 +21,7 @@ import logging
import numpy as onp
import os
import re
from typing import Any, Callable, List, Sequence, Tuple
from unittest import SkipTest
from absl.testing import absltest
@ -78,6 +79,29 @@ def fun1(a):
def fun1_equiv(a): # Numerical equivalent of fun`
return (a * 2.)**2
def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase, expected: str, what: str):
"""A variant that preprocesses the string to eliminate non-determinism in
floating point values, and several uninteresting id_tap primitive params."""
# Sometimes we get floating points in the output; we round them
def repl_floats(match_group):
matched = match_group.group(0)
if matched == ".": return matched
# TODO: why can't we use here np.around?
x = onp.around(float(matched), decimals=2)
return f"{x:.2f}"
what = re.sub(r"\-?\d*\.[\-\def]*", repl_floats, what)
what = re.sub(r"output_stream=[^\]\n]*", "", what)
what = re.sub(r"threshold=[^\]\n]*", "", what)
# Empty lines
what = re.sub(r"^\s*\n", "", what, flags=re.MULTILINE)
def repl_func(match_group):
matched = match_group.group(0)
if "function _print_consumer" in matched:
return "func=_print"
else:
return "..."
what = re.sub(r"func=(.*)", repl_func, what)
tst.assertMultiLineStrippedEqual(expected, what)
class HostCallbackTest(jtu.JaxTestCase):
@ -100,50 +124,24 @@ class HostCallbackTest(jtu.JaxTestCase):
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
def assertMultiLineStrippedEqual(self, expected, what):
"""A variant that preprocesses the string to eliminate non-determinism."""
# Sometimes we get floating points in the output; we round them
def repl_floats(match_group):
matched = match_group.group(0)
if matched == ".": return matched
# TODO: why can't we use here np.around?
x = onp.around(float(matched), decimals=2)
return f"{x:.2f}"
what = re.sub(r"\-?\d*\.[\-\def]*", repl_floats, what)
# We rewrite consumer_id because it changes
what = re.sub(r"consumer_id=(\d+)", "consumer_id=...", what)
what = re.sub(r"output_stream=[^\]\n]*", "output_stream=...", what)
def repl_func(match_group):
matched = match_group.group(0)
if "function _print_consumer" in matched:
return "func=_print"
else:
return "..."
what = re.sub(r"func=(.*)", repl_func, what)
super().assertMultiLineStrippedEqual(expected, what)
def test_eval(self):
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
{ lambda ; a.
let b = mul a 2.00
c = id_tap[ arg_treedef=*
func=_print
output_stream=...
what=a * 2 ] b
d = mul c 3.00
e f = id_tap[ arg_treedef=*
func=_print
nr_untapped=1
output_stream=...
what=y * 3 ] d c
g = pow f 2.00
in (g,) }""", str(api.make_jaxpr(fun1)(5.)))
self.assertEqual("", testing_stream.output)
self.assertEqual((5. * 2.) ** 2, fun1(5.))
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
what: a * 2
10.00
what: y * 3
@ -155,18 +153,17 @@ what: y * 3
x1, y1 = hcb.id_print((x * 2., x * 3.), output_stream=testing_stream)
return x1 + y1
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
{ lambda ; a.
let b = mul a 2.00
c = mul a 3.00
d e = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
func=_print
output_stream=...] b c
] b c
f = add d e
in (f,) }""", str(api.make_jaxpr(func2)(3.)))
self.assertEqual(3. * (2. + 3.), func2(3.))
self.assertMultiLineStrippedEqual("""
assertMultiLineStrippedEqual(self, """
[ 6.00
9.00 ]""", testing_stream.output)
testing_stream.reset()
@ -176,18 +173,17 @@ what: y * 3
res = hcb.id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream)
return res["a"] + res["b"]
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
{ lambda ; a.
let b = mul a 2.00
c = mul a 3.00
d e = id_tap[ arg_treedef=PyTreeDef(dict[['a', 'b']], [*,*])
func=_print
output_stream=...] b c
] b c
f = add d e
in (f,) }""", str(api.make_jaxpr(func2)(3.)))
self.assertEqual(3. * (2. + 3.), func2(3.))
self.assertMultiLineStrippedEqual("""
assertMultiLineStrippedEqual(self, """
{ a=6.00
b=9.00 }""", testing_stream.output)
testing_stream.reset()
@ -198,8 +194,7 @@ what: y * 3
output_stream=testing_stream)
return x1
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
{ lambda ; a.
let b = mul a 2.00
c = mul a 3.00
@ -207,10 +202,10 @@ what: y * 3
e f g = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
func=_print
nr_untapped=1
output_stream=...] b c d
] b c d
in (g,) }""", str(api.make_jaxpr(func2)(3.)))
self.assertEqual(3. * 4., func2(3.))
self.assertMultiLineStrippedEqual("""
assertMultiLineStrippedEqual(self, """
[ 6.00
9.00 ]""", testing_stream.output)
testing_stream.reset()
@ -255,7 +250,7 @@ what: y * 3
res = func(0)
# We should have received everything before the error
self.assertMultiLineStrippedEqual(
assertMultiLineStrippedEqual(self,
"""
what: x1
1""", testing_stream.output)
@ -264,15 +259,13 @@ what: x1
def test_jit_simple(self):
jit_fun1 = api.jit(lambda x: 3. * hcb.id_print(
2. * x, what="here", output_stream=testing_stream))
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
{ lambda ; a.
let b = xla_call[ backend=None
call_jaxpr={ lambda ; a.
let b = mul a 2.00
c = id_tap[ arg_treedef=*
func=_print
output_stream=...
what=here ] b
d = mul c 3.00
in (d,) }
@ -285,7 +278,7 @@ what: x1
res = jit_fun1(5.)
self.assertAllClose(6. * 5., res, check_dtypes=True)
self.assertMultiLineStrippedEqual(
assertMultiLineStrippedEqual(self,
"""
what: here
10.00""", testing_stream.output)
@ -303,8 +296,7 @@ what: here
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertEqual(2, api.jit(func)(1))
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
@ -322,8 +314,7 @@ where: 2
self.assertEqual(2, api.jit(func)(1))
self.assertEqual(11, api.jit(func)(10))
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
@ -341,7 +332,7 @@ where: 2
x2 = hcb.id_print(x + 1, where="nested", output_stream=testing_stream)
return x2
x3 = api.jit(func_nested)(x1)
return hcb.id_print(x3 + 1, where="2", output_stream=testing_stream)
return hcb.id_print(x3 + 1, where="3", output_stream=testing_stream)
logging.warning("%s: %s", self._testMethodName,
api.make_jaxpr(func)(1))
@ -349,13 +340,12 @@ where: 2
api.xla_computation(func)(1).GetHloText())
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertEqual(3, api.jit(func)(1))
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
where: 1
1
where: nested
2
where: 2
where: 3
3""", testing_stream.output)
testing_stream.reset()
@ -421,7 +411,7 @@ where: 2
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertEqual(4, api.jit(func)(1))
self.assertMultiLineStrippedEqual("""
assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
@ -452,8 +442,7 @@ where: w_2
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertEqual(10, api.jit(func)(1))
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
@ -507,7 +496,7 @@ where: 10
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertEqual(10, api.jit(func)(1))
self.assertMultiLineStrippedEqual(
assertMultiLineStrippedEqual(self,
"""
""", testing_stream.output)
testing_stream.reset()
@ -536,8 +525,7 @@ where: 10
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
res = api.jit(func)(1)
self.assertAllClose(np.array([2, 3, 4]), res, check_dtypes=True)
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
@ -641,8 +629,7 @@ where: 10
# return 3.
self.assertEqual(3, res)
# We should have received all others
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
@ -664,8 +651,7 @@ what: x3
# return 3.
self.assertEqual(3, res)
# We should have received all others
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
@ -689,9 +675,9 @@ what: x3
assert False # It seems that the previous jit blocks above
def test_jit_without_consumer_error(self):
def test_jit_error_no_consumer(self):
# Check for errors if starting jit without a consumer active
with self.assertRaisesRegex(ValueError, "outfeed_consumer is not started"):
with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"):
api.jit(lambda x: hcb.id_print(x))(0)
# On CPU and GPU the device code blocks
@ -727,34 +713,29 @@ what: x3
def test_jvp(self):
jvp_fun1 = lambda x, xt: api.jvp(fun1, (x,), (xt,))
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
{ lambda ; a b.
let c = mul a 2.00
d = id_tap[ arg_treedef=*
func=_print
nr_untapped=0
output_stream=...
what=a * 2 ] c
e = mul d 3.00
f g = id_tap[ arg_treedef=*
func=_print
nr_untapped=1
output_stream=...
what=y * 3 ] e d
h = pow g 2.00
i = mul b 2.00
j k = id_tap[ arg_treedef=*
func=_print
nr_untapped=1
output_stream=...
transforms=('jvp',)
what=a * 2 ] i d
l = mul j 3.00
m n o = id_tap[ arg_treedef=*
func=_print
nr_untapped=2
output_stream=...
transforms=('jvp',)
what=y * 3 ] l j f
p = pow g 1.00
@ -764,8 +745,7 @@ what: x3
str(api.make_jaxpr(jvp_fun1)(np.float32(5.), np.float32(0.1))))
res_primals, res_tangents = jvp_fun1(np.float32(5.), np.float32(0.1))
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
what: a * 2
10.00
transforms: ('jvp',) what: a * 2
@ -777,26 +757,24 @@ transforms: ('jvp',) what: y * 3
testing_stream.reset()
def test_grad_primal_unused(self):
# The output of id_print is not needed for backwards pass
def func(x):
return 2. * hcb.id_print(x * 3., what="x * 3", output_stream=testing_stream)
grad_func = api.grad(func)
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
{ lambda ; a.
let
in (6.00,) }""", str(api.make_jaxpr(grad_func)(5.)))
# Just making the Jaxpr invokes the id_print once
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
transforms: ('jvp', 'transpose') what: x * 3
2.00""", testing_stream.output)
testing_stream.reset()
res_grad = grad_func(np.float32(5.))
self.assertAllClose(6., res_grad, check_dtypes=False)
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
what: x * 3
15.00
transforms: ('jvp', 'transpose') what: x * 3
@ -808,22 +786,18 @@ transforms: ('jvp', 'transpose') what: x * 3
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
return x * hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream)
grad_func = api.grad(func)
# TODO: why is the order like that?
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
{ lambda ; a.
let b = mul 1.00 a
c d = id_tap[ arg_treedef=*
func=_print
nr_untapped=1
output_stream=...
transforms=('jvp', 'transpose')
what=y * 3 ] b 0.00
e = mul c 3.00
f g = id_tap[ arg_treedef=*
func=_print
nr_untapped=1
output_stream=...
transforms=('jvp', 'transpose')
what=x * 2 ] e 0.00
h = mul f 2.00
@ -831,13 +805,11 @@ transforms: ('jvp', 'transpose') what: x * 3
j = id_tap[ arg_treedef=*
func=_print
nr_untapped=0
output_stream=...
what=x * 2 ] i
k = mul j 3.00
l = id_tap[ arg_treedef=*
func=_print
nr_untapped=0
output_stream=...
what=y * 3 ] k
m = mul 1.00 l
n = add_any h m
@ -845,8 +817,7 @@ transforms: ('jvp', 'transpose') what: x * 3
res_grad = grad_func(np.float32(5.))
self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
what: y * 3
@ -863,20 +834,21 @@ transforms: ('jvp', 'transpose') what: x * 2
return x * (y * 3.)
grad_func = api.grad(api.grad(func))
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
{ lambda ; a.
let
in (12.00,) }""", str(api.make_jaxpr(grad_func)(5.)))
res_grad = grad_func(np.float32(5.))
self.assertAllClose(12., res_grad, check_dtypes=False)
self.assertMultiLineStrippedEqual(
"""
# Just making the Jaxpr invokes the id_print twiceonce
assertMultiLineStrippedEqual(self, """
transforms: ('jvp', 'transpose') what: x * 2
3.00
transforms: ('jvp', 'transpose', 'jvp', 'transpose') what: x * 2
2.00
2.00""", testing_stream.output)
testing_stream.reset()
res_grad = grad_func(np.float32(5.))
self.assertAllClose(12., res_grad, check_dtypes=False)
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
transforms: ('jvp', 'transpose') what: x * 2
@ -891,14 +863,12 @@ transforms: ('jvp', 'transpose') what: x * 2
def test_vmap(self):
vmap_fun1 = api.vmap(fun1)
vargs = np.array([np.float32(4.), np.float32(5.)])
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
{ lambda ; a.
let b = mul a 2.00
c = id_tap[ arg_treedef=*
batch_dims=(0,)
func=_print
output_stream=...
transforms=('batch',)
what=a * 2 ] b
d = mul c 3.00
@ -906,15 +876,13 @@ transforms: ('jvp', 'transpose') what: x * 2
batch_dims=(0, 0)
func=_print
nr_untapped=1
output_stream=...
transforms=('batch',)
what=y * 3 ] d c
g = pow f 2.00
in (g,) }""", str(api.make_jaxpr(vmap_fun1)(vargs)))
res_vmap = vmap_fun1(vargs)
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
batch_dims: (0,) transforms: ('batch',) what: a * 2
[ 8.00 10.00]
batch_dims: (0, 0) transforms: ('batch',) what: y * 3
@ -930,27 +898,23 @@ batch_dims: (0, 0) transforms: ('batch',) what: y * 3
vmap_func = api.vmap(func)
vargs = np.array([np.float32(4.), np.float32(5.)])
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
{ lambda ; a.
let b c = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
batch_dims=(None, 0)
func=_print
output_stream=...
transforms=('batch',) ] 3.00 a
d = add c 3.00
in (d,) }""", str(api.make_jaxpr(vmap_func)(vargs)))
res_vmap = vmap_func(vargs)
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
batch_dims: (None, 0) transforms: ('batch',)
[ 3.00
[4.00 5.00] ]
""", testing_stream.output)
testing_stream.reset()
def test_pmap(self):
vargs = 2. + np.arange(api.local_device_count(), dtype=np.float32)
@ -960,10 +924,10 @@ batch_dims: (None, 0) transforms: ('batch',)
expected_res = np.stack([fun1_equiv(2. + a) for a in range(api.local_device_count())])
self.assertAllClose(expected_res, res, check_dtypes=False)
def test_pmap_without_consumer_error(self):
def test_pmap_error_no_receiver(self):
# Check for errors if starting jit without a consumer active
vargs = 2. + np.arange(api.local_device_count(), dtype=np.float32)
with self.assertRaisesRegex(ValueError, "outfeed_consumer is not started"):
with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"):
api.pmap(lambda x: hcb.id_print(x))(vargs)
def test_mask(self):
@ -972,13 +936,11 @@ batch_dims: (None, 0) transforms: ('batch',)
def padded_sum(x):
return np.sum(hcb.id_print(x, what="x", output_stream=testing_stream))
args = [np.arange(4)], dict(n=onp.int64(2))
self.assertMultiLineStrippedEqual(
"""
assertMultiLineStrippedEqual(self, """
{ lambda c f ; a b.
let d = lt c b
e = id_tap[ func=_print
logical_shapes=[(Traced<ShapedArray(int32[]):JaxprTrace(level=0/0)>,)]
output_stream=...
transforms=('mask',)
what=x ] a
g = select d e f
@ -986,12 +948,108 @@ batch_dims: (None, 0) transforms: ('batch',)
in (h,) }""", str(api.make_jaxpr(padded_sum)(*args)))
res = padded_sum(*args)
self.assertMultiLineStrippedEqual(
"""
self.assertMultiLineStrippedEqual("""
logical_shapes: [(2,)] transforms: ('mask',) what: x
[0 1 2 3]
""", testing_stream.output)
testing_stream.reset()
class OutfeedRewriterTest(jtu.JaxTestCase):
def assertRewrite(self, expected: str, func: Callable, args: Sequence,
has_input_token=True, has_output_token=True):
"""Check that the rewrite of func(*args) matches expected."""
jaxpr = api.make_jaxpr(func)(*args)
assertMultiLineStrippedEqual(self, expected,
str(hcb._rewrite_typed_jaxpr(jaxpr, has_input_token, has_output_token)[0]))
def test_no_outfeed(self):
self.assertRewrite("""
{ lambda ; a.
let b = mul a a
c = add a b
in (c,) }""", lambda x: x + x * x, [0], has_input_token=False, has_output_token=False)
self.assertRewrite("""
{ lambda ; a d.
let b = mul a a
c = add a b
in (c,) }""", lambda x: x + x * x, [0], has_output_token=False)
self.assertRewrite("""
{ lambda ; a d.
let b = mul a a
c = add a b
in (c, d) }""", lambda x: x + x * x, [0])
def test_simple_outfeed(self):
self.assertRewrite("""
{ lambda ; a d.
let b = add a a
c e = id_tap[ arg_treedef=*
func=_print
] b d
in (c, e) }""", lambda x: hcb.id_print(x + x), [0])
def test_cond(self):
y = np.ones(5) # captured const
def func(x, z):
return lax.cond(z > 0, (1, 2), lambda a: (a[0], np.zeros(5)),
z, lambda a: (hcb.id_print(a), y))
self.assertRewrite("""
{ lambda d e ; a b h.
let c = gt b 0
f g i = cond[ false_jaxpr={ lambda ; c a d.
let b e = id_tap[ arg_treedef=*
func=_print
] a d
in (b, c, e) }
linear=(False, False, False, False, False, False, False)
true_jaxpr={ lambda ; c a b d.
let
in (a, c, d) } ] c d 1 2 h e b h
in (f, g, i) }""", func, [y, 5])
def test_while(self):
y = np.ones(5) # captured const
def func(x):
return lax.while_loop(lambda c: c[1] < 5,
lambda c: (y, hcb.id_print(c[1]) + 1), (x, 1))
# TODO: we should not need to start a receiver here!!! I believe this is
# because of the partial evaluation of while, which calls impl, which
# uses JIT.
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertRewrite("""
{ lambda b ; a e.
let c d f = while[ body_jaxpr={ lambda ; c a b f.
let d g = id_tap[ arg_treedef=*
func=_print
] b f
e = add d 1
in (c, e, g) }
body_nconsts=1
cond_jaxpr={ lambda ; a b d.
let c = lt b 5
in (c,) }
cond_nconsts=0 ] b a 1 e
in (c, 5, f) }""", func, [y])
def test_scan(self):
y = np.ones(5) # captured const
def func(x):
return lax.scan(lambda c, a: (hcb.id_print(c), y), (1, 2), x)
self.assertRewrite("""
{ lambda b ; a f.
let c d g e = scan[ jaxpr={ lambda ; f a b g c.
let d e h = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
func=_print
] a b g
in (d, e, h, f) }
length=5
linear=(False, False, False, False, False)
num_carry=3
num_consts=1
reverse=False ] b 1 2 f a
in (c, d, e, g) }""", func, [y])
if __name__ == "__main__":
absltest.main()