mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 04:26:07 +00:00
Reimplemented the passing of tokens with a Jaxpr transform
This commit is contained in:
parent
8fc96910c2
commit
d8b75e1913
@ -325,16 +325,14 @@ def xla_computation(fun: Callable,
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals,
|
||||
instantiate=instantiate_const_outputs,
|
||||
stage_out=True)
|
||||
jaxpr, _ = xla.apply_outfeed_rewriter(jaxpr)
|
||||
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
|
||||
c = xb.make_computation_builder('xla_computation_{}'.format(fun_name))
|
||||
uses_outfeed = xla.jaxpr_uses_outfeed(jaxpr)
|
||||
xla.state_carry.start_nested_comp_without_input(c, uses_outfeed)
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
xla_args = xla._xla_callable_args(c, avals, tuple_args)
|
||||
outs = xla.jaxpr_subcomp(
|
||||
c, jaxpr, backend, axis_env_, xla_consts,
|
||||
extend_name_stack(wrap_name(fun_name, 'xla_computation')), *xla_args)
|
||||
xla.state_carry.end_nested_comp_without_output(c)
|
||||
return c.Build(xc.ops.Tuple(c, outs))
|
||||
return computation_maker
|
||||
|
||||
|
@ -114,7 +114,7 @@ from jax import core
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax.lib import pytree, xla_bridge
|
||||
from jax.interpreters import ad, xla, batching, masking
|
||||
from jax.interpreters import ad, xla, batching, masking, pxla
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax import pprint_util as ppu
|
||||
from jax import util
|
||||
@ -126,7 +126,7 @@ import msgpack # type: ignore
|
||||
import numpy as onp
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Callable, Dict, List, Optional, NamedTuple, Sequence, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, NamedTuple, Sequence, Tuple
|
||||
|
||||
xops = xla_client._xla.ops
|
||||
|
||||
@ -281,7 +281,6 @@ positional arguments and parameters:
|
||||
# TODO: handle multiple vmap and mask
|
||||
id_tap_p = core.Primitive("id_tap")
|
||||
id_tap_p.multiple_results = True
|
||||
xla.outfeed_primitives.add(id_tap_p)
|
||||
|
||||
|
||||
def _add_transform_name(params: Dict, transform: str) -> Dict:
|
||||
@ -313,25 +312,27 @@ def _id_tap_abstract_eval(*args_a: pe.AbstractValue, **params) \
|
||||
|
||||
id_tap_p.def_abstract_eval(_id_tap_abstract_eval)
|
||||
|
||||
def _instantiate_zeros(tan, arg):
|
||||
# TODO: there must be a better way to do this.
|
||||
# The AttributeError is for regular values, the KeyError is for ConcreteArray
|
||||
def _instantiate_zeros(arg, tan):
|
||||
"""Turn special ad.zero tangents into arrays of 0s."""
|
||||
if tan is not ad.zero:
|
||||
return tan
|
||||
elif isinstance(arg, core.Tracer):
|
||||
# TODO: why do I have to do this to get a zero?
|
||||
else:
|
||||
try:
|
||||
aval = arg.aval
|
||||
return ad.instantiate_zeros_aval(aval, tan)
|
||||
except:
|
||||
# It seems that we get here for ConcreteArray
|
||||
return ad.instantiate_zeros(arg, tan)
|
||||
except (AttributeError, KeyError):
|
||||
# We get here for regular Python values
|
||||
return ad.zeros_like_jaxval(arg)
|
||||
|
||||
|
||||
def _id_tap_jvp_rule(primals, tangents, *, func, nr_untapped=0, **params):
|
||||
# Put primals through id_tap separately, so that partial evaluation
|
||||
# can do its job for grad
|
||||
out_primals = id_tap_p.bind(*primals, func=func, nr_untapped=nr_untapped, **params)
|
||||
# Add one primal output as untapped, to create dependency.
|
||||
tangent_zeros = tuple(map(_instantiate_zeros, tangents, primals))
|
||||
tangent_zeros = tuple(map(_instantiate_zeros, primals, tangents))
|
||||
out_tangents_extra = id_tap_p.bind(*tangent_zeros, out_primals[0],
|
||||
func=func, nr_untapped=nr_untapped + 1,
|
||||
**_add_transform_name(params, "jvp"))
|
||||
@ -343,7 +344,7 @@ ad.primitive_jvps[id_tap_p] = _id_tap_jvp_rule
|
||||
|
||||
def _id_tap_transpose_rule(cts, *args, func=None, nr_untapped=0, **params):
|
||||
assert len(cts) == len(args)
|
||||
cts_zeros = tuple(map(_instantiate_zeros, cts, args))
|
||||
cts_zeros = tuple(map(_instantiate_zeros, args, cts))
|
||||
ct_args = id_tap_p.bind(*cts_zeros, func=func, nr_untapped=nr_untapped,
|
||||
**_add_transform_name(params, "transpose"))
|
||||
return ct_args
|
||||
@ -376,6 +377,7 @@ def _id_tap_masking_rule(operands, operands_logical_shapes, **params):
|
||||
masking.masking_rules[id_tap_p] = _id_tap_masking_rule
|
||||
|
||||
#### XLA compilation ####
|
||||
####
|
||||
# Special consumer to mark the end of outfeed stream for a device
|
||||
_end_consumer = 0
|
||||
_unknown_consumer = 1 # for testing error cases
|
||||
@ -394,12 +396,24 @@ def _id_print_translation_rule_outfeed(comp: XlaComputationBuilder,
|
||||
params["consumer_id"] = _register_consumer(
|
||||
_ConsumerCallable(func, tuple(params.items()), arg_treedef))
|
||||
|
||||
prev_token = xla.state_carry.current_token(comp)
|
||||
nr_args_to_emit = len(args_op) - nr_untapped
|
||||
next_token = _emit_outfeed(comp, prev_token,
|
||||
|
||||
# We expect the current token at the end
|
||||
current_token = args_op[-1]
|
||||
current_token_shape = comp.GetShape(current_token)
|
||||
if current_token_shape.is_array():
|
||||
# TODO: we get here because wehn we partially eval some primitives
|
||||
# we impl themn with JIT, but we did not rewrite them
|
||||
has_token = False
|
||||
current_token = xops.CreateToken(comp)
|
||||
else:
|
||||
has_token = True
|
||||
|
||||
nr_args_to_emit = len(args_op) - nr_untapped - (1 if has_token else 0)
|
||||
next_token = _emit_outfeed(comp, current_token,
|
||||
args_op[0:nr_args_to_emit], params["consumer_id"])
|
||||
xla.state_carry.set_current_token(comp, next_token)
|
||||
return xops.Tuple(comp, args_op)
|
||||
results = (args_op[:-1] + (next_token,)) if has_token else args_op
|
||||
return xops.Tuple(comp, results)
|
||||
|
||||
|
||||
xla.translations[id_tap_p] = _id_print_translation_rule_outfeed
|
||||
|
||||
@ -517,6 +531,7 @@ class TapFunctionException(Exception):
|
||||
"""
|
||||
pass
|
||||
|
||||
_outfeed_receiver_started = False
|
||||
@contextmanager
|
||||
def outfeed_receiver(*,
|
||||
timeout_sec=10,
|
||||
@ -550,6 +565,10 @@ def outfeed_receiver(*,
|
||||
The ``outfeed_receiver`` must be started outside any jitted computation.
|
||||
|
||||
"""
|
||||
global _outfeed_receiver_started
|
||||
if _outfeed_receiver_started:
|
||||
raise ValueError("At most one outfeed_receiver can be running at once.")
|
||||
|
||||
if not devices:
|
||||
backends = backends or xla_client._get_local_backends().keys()
|
||||
devices = tuple(itertools.chain(*[api.devices(backend)
|
||||
@ -593,13 +612,13 @@ def outfeed_receiver(*,
|
||||
# bugs in our code, not from the tap functions.
|
||||
for rf in receiver_futures:
|
||||
rf.add_done_callback(lambda rf: rf.result())
|
||||
xla.set_outfeed_allowed(True)
|
||||
xla.can_execute_outfeed_computations = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for d in devices: # Signal the end of printing
|
||||
api.jit(lambda x: id_tap(_end_consumer, None, result=x), device=d)(0) # type: ignore[arg-type]
|
||||
xla.set_outfeed_allowed(False)
|
||||
xla.can_execute_outfeed_computations = False
|
||||
for f in futures.as_completed(receiver_futures, timeout=timeout_sec):
|
||||
finished_device = f.result() # Throw exceptions here
|
||||
if _LOGGING:
|
||||
@ -607,3 +626,172 @@ def outfeed_receiver(*,
|
||||
if count_tap_exceptions > 0:
|
||||
raise TapFunctionException
|
||||
|
||||
#### Jaxpr rewriting logic
|
||||
####
|
||||
def _jaxpr_var_defs(jaxpr: core.Jaxpr) -> Iterable[int]:
|
||||
"""Iterates over all the unique vars the top-level of a Jaxpr"""
|
||||
for iv in jaxpr.invars:
|
||||
yield iv.count
|
||||
for cv in jaxpr.constvars:
|
||||
yield cv.count
|
||||
for eqn in jaxpr.eqns:
|
||||
for ov in eqn.outvars:
|
||||
yield ov.count
|
||||
|
||||
|
||||
def _jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool:
|
||||
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive is id_tap_p:
|
||||
return True
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
if _jaxpr_uses_outfeed(subjaxpr):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _rewrite_typed_jaxpr(tjaxpr: core.TypedJaxpr,
|
||||
has_input_token: bool,
|
||||
has_output_token: bool) -> Tuple[core.TypedJaxpr, bool]:
|
||||
"""Rewrites a TypedJaxpr to thread the token, if needed.
|
||||
|
||||
Returns the rewritten Jaxpr, and whether it uses outfeed."""
|
||||
new_jaxpr, uses_outfeed = _rewrite_jaxpr(tjaxpr.jaxpr, has_input_token, has_output_token)
|
||||
return (core.TypedJaxpr(new_jaxpr, tjaxpr.literals,
|
||||
tuple(map(lambda v: v.aval, new_jaxpr.invars)),
|
||||
tuple(map(lambda v: v.aval, new_jaxpr.outvars))),
|
||||
uses_outfeed)
|
||||
|
||||
|
||||
def _rewrite_jaxpr(jaxpr: core.Jaxpr,
|
||||
has_input_token: bool,
|
||||
has_output_token: bool) -> Tuple[core.Jaxpr, bool]:
|
||||
"""Rewrite a Jaxpr to thread the token, if needed."""
|
||||
assert has_input_token or not has_output_token
|
||||
|
||||
uses_outfeed = _jaxpr_uses_outfeed(jaxpr)
|
||||
if not has_input_token and not uses_outfeed:
|
||||
return (jaxpr, False)
|
||||
max_var_count = max(_jaxpr_var_defs(jaxpr))
|
||||
mk_new_id = itertools.count(start=max_var_count + 1)
|
||||
|
||||
def mk_new_var(aval: core.AbstractValue) -> core.Var:
|
||||
return core.Var(next(mk_new_id), '', aval)
|
||||
|
||||
eqns: List[core.JaxprEqn] = []
|
||||
last_token_var = mk_new_var(core.abstract_token)
|
||||
if has_input_token:
|
||||
invars = jaxpr.invars + [last_token_var]
|
||||
else:
|
||||
invars = jaxpr.invars
|
||||
eqns.append(core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var],
|
||||
lax.create_token_p, {}))
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive is id_tap_p:
|
||||
new_token_var = mk_new_var(core.abstract_token)
|
||||
eqns.append(core.new_jaxpr_eqn(eqn.invars + [last_token_var],
|
||||
eqn.outvars + [new_token_var],
|
||||
eqn.primitive, eqn.params))
|
||||
last_token_var = new_token_var
|
||||
elif eqn.primitive is lax.while_p:
|
||||
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
|
||||
eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
|
||||
if _jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
|
||||
raise NotImplementedError("outfeed not supported in the conditional of a while")
|
||||
uses_outfeed = _jaxpr_uses_outfeed(body_jaxpr.jaxpr)
|
||||
if not uses_outfeed:
|
||||
eqns.append(eqn)
|
||||
continue
|
||||
new_token_var = mk_new_var(core.abstract_token)
|
||||
eqns.append(core.new_jaxpr_eqn(
|
||||
eqn.invars + [last_token_var],
|
||||
eqn.outvars + [new_token_var],
|
||||
eqn.primitive,
|
||||
dict(eqn.params,
|
||||
body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True)[0],
|
||||
cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True, False)[0])))
|
||||
last_token_var = new_token_var
|
||||
elif eqn.primitive is lax.cond_p:
|
||||
true_jaxpr, false_jaxpr, linear = util.split_dict(
|
||||
eqn.params, ["true_jaxpr", "false_jaxpr", "linear"])
|
||||
uses_outfeed = _jaxpr_uses_outfeed(true_jaxpr.jaxpr) or _jaxpr_uses_outfeed(false_jaxpr.jaxpr)
|
||||
if not uses_outfeed:
|
||||
eqns.append(eqn)
|
||||
continue
|
||||
nr_true_invars = len(true_jaxpr.jaxpr.invars)
|
||||
pred, true_invars, false_invars = util.split_list(eqn.invars,
|
||||
[1, nr_true_invars])
|
||||
new_token_var = mk_new_var(core.abstract_token)
|
||||
new_invars = pred + true_invars + [last_token_var] + false_invars + [last_token_var]
|
||||
eqns.append(core.new_jaxpr_eqn(
|
||||
new_invars, eqn.outvars + [new_token_var],
|
||||
eqn.primitive,
|
||||
dict(eqn.params,
|
||||
true_jaxpr=_rewrite_typed_jaxpr(true_jaxpr, True, True)[0],
|
||||
false_jaxpr=_rewrite_typed_jaxpr(false_jaxpr, True, True)[0],
|
||||
linear=linear + (False, False))))
|
||||
last_token_var = new_token_var
|
||||
elif eqn.primitive is lax.scan_p:
|
||||
num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict(
|
||||
eqn.params, ["num_consts", "num_carry", "jaxpr", "linear",
|
||||
"reverse", "length"])
|
||||
uses_outfeed = _jaxpr_uses_outfeed(carry_jaxpr.jaxpr)
|
||||
if not uses_outfeed:
|
||||
eqns.append(eqn)
|
||||
continue
|
||||
nr_const_and_carry = num_consts + num_carry
|
||||
new_invars = eqn.invars[0:nr_const_and_carry] + [last_token_var] + eqn.invars[nr_const_and_carry:]
|
||||
new_token_var = mk_new_var(core.abstract_token)
|
||||
new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)[0]
|
||||
# The rewrite put the token carry at end, it has to be at end of carry
|
||||
new_jaxpr_invars = new_jaxpr.jaxpr.invars
|
||||
new_jaxpr_invars = (new_jaxpr_invars[0:nr_const_and_carry] +
|
||||
[new_jaxpr_invars[-1]] +
|
||||
new_jaxpr_invars[nr_const_and_carry:-1])
|
||||
new_jaxpr.jaxpr.invars = new_jaxpr_invars
|
||||
new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars]
|
||||
|
||||
new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
|
||||
new_jaxpr_outvars = (new_jaxpr_outvars[0:num_carry] +
|
||||
[new_jaxpr_outvars[-1]] +
|
||||
new_jaxpr_outvars[num_carry:-1])
|
||||
new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
|
||||
new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars]
|
||||
eqns.append(core.new_jaxpr_eqn(
|
||||
new_invars,
|
||||
# Output token is at the end of carry result
|
||||
eqn.outvars[0:num_carry] + [new_token_var] + eqn.outvars[num_carry:],
|
||||
eqn.primitive,
|
||||
dict(eqn.params,
|
||||
jaxpr=new_jaxpr,
|
||||
num_carry=num_carry + 1,
|
||||
linear=linear + (False,))))
|
||||
last_token_var = new_token_var
|
||||
elif eqn.primitive is xla.xla_call_p:
|
||||
call_jaxpr = eqn.params["call_jaxpr"]
|
||||
uses_outfeed = _jaxpr_uses_outfeed(call_jaxpr)
|
||||
if not uses_outfeed:
|
||||
eqns.append(eqn)
|
||||
continue
|
||||
new_token_var = mk_new_var(core.abstract_token)
|
||||
eqns.append(core.new_jaxpr_eqn(
|
||||
eqn.invars + [last_token_var],
|
||||
eqn.outvars + [new_token_var],
|
||||
eqn.primitive,
|
||||
dict(eqn.params, call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True)[0])))
|
||||
last_token_var = new_token_var
|
||||
elif eqn.primitive is pxla.xla_pmap_p:
|
||||
raise NotImplementedError("rewrite of pmap")
|
||||
else:
|
||||
# Check no more subjaxprs
|
||||
for param in eqn.params.values():
|
||||
if type(param) is core.Jaxpr or type(param) is core.TypedJaxpr:
|
||||
assert False
|
||||
eqns.append(eqn)
|
||||
|
||||
outvars = jaxpr.outvars + ([last_token_var] if has_output_token else [])
|
||||
return (core.Jaxpr(jaxpr.constvars, invars, outvars, eqns), True)
|
||||
|
||||
|
||||
xla.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)
|
||||
|
@ -659,6 +659,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
||||
dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True)
|
||||
jaxpr.invars = jaxpr.invars[1:] # ignore dummy
|
||||
jaxpr, uses_outfeed = xla.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
out_pvs, out_consts = unzip2(out_pvals)
|
||||
|
||||
@ -693,13 +694,10 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
|
||||
tuple_args = len(sharded_avals) > 100 # pass long arg lists as tuple for TPU
|
||||
|
||||
c = xb.make_computation_builder("pmap_{}".format(fun.__name__))
|
||||
uses_outfeed = xla.jaxpr_uses_outfeed(jaxpr)
|
||||
xla.state_carry.start_nested_comp_without_input(c, uses_outfeed)
|
||||
xla_consts = _map(partial(xb.constant, c), consts)
|
||||
xla_args = xla._xla_callable_args(c, sharded_avals, tuple_args)
|
||||
out_nodes = xla.jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts,
|
||||
extend_name_stack(wrap_name(name, 'pmap')), *xla_args)
|
||||
xla.state_carry.end_nested_comp_without_output(c)
|
||||
built = c.Build(xops.Tuple(c, out_nodes))
|
||||
|
||||
if devices is None:
|
||||
@ -861,8 +859,8 @@ def _pmap_sharding_spec(nrep, axis_size, sharded_aval, mapped):
|
||||
replication_factor=replication_factor * axis_size)
|
||||
|
||||
|
||||
def execute_replicated(compiled, uses_outfeed: bool, backend, in_handler, out_handler, *args):
|
||||
xla.check_outfeed_allowed(uses_outfeed)
|
||||
def execute_replicated(compiled, uses_outfeed, backend, in_handler, out_handler, *args):
|
||||
xla.check_before_outfeed_execution(uses_outfeed)
|
||||
input_bufs = in_handler(args)
|
||||
out_bufs = compiled.ExecuteOnLocalDevices(list(input_bufs))
|
||||
return out_handler(out_bufs)
|
||||
|
@ -14,9 +14,9 @@
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import reduce
|
||||
import itertools as it
|
||||
import operator as op
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type, Tuple
|
||||
|
||||
from absl import logging
|
||||
@ -168,167 +168,22 @@ pytype_aval_mappings.update((t, make_shaped_array) for t in array_types)
|
||||
pytype_aval_mappings.update(
|
||||
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
|
||||
|
||||
# Keep track of which primitives are stateful (they read or read/write state).
|
||||
# This helps avoid threading the state through control-flow primitives that
|
||||
# do not need it. This is a worthwhile optimization because it seems that XLA
|
||||
# may not be good at dealing with tokens (b/154992062).
|
||||
outfeed_primitives: Set[core.Primitive] = set()
|
||||
def jaxpr_uses_outfeed(jaxpr):
|
||||
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
|
||||
if type(jaxpr) is core.TypedJaxpr:
|
||||
jaxpr = jaxpr.jaxpr
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive in outfeed_primitives:
|
||||
return True
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
if jaxpr_uses_outfeed(subjaxpr):
|
||||
return True
|
||||
return False
|
||||
# We can optionally set a Jaxpr rewriter that can be applied just before
|
||||
# compilation. This mechanism is used for compiling id_tap, we can
|
||||
# remove it once we bring the id_tap implementation into the core.
|
||||
outfeed_rewriter: Optional[Callable[[core.Jaxpr], Tuple[core.Jaxpr, bool]]] = None
|
||||
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, bool]:
|
||||
if outfeed_rewriter is not None:
|
||||
return outfeed_rewriter(jaxpr)
|
||||
else:
|
||||
return jaxpr, False
|
||||
|
||||
class _ComputationStateCarry(threading.local):
|
||||
"""Carries some state globally as we build the HLO.
|
||||
|
||||
For now the state is only a token, obtained from the last OutFeed. The
|
||||
state is carried through nested loops and conditionals, if they use
|
||||
state.
|
||||
"""
|
||||
# A stack of nested computations and the current token for each. A None token
|
||||
# means that we cannot use state in this computation and any nested
|
||||
# computations. The last one is the most recent.
|
||||
_computations: Tuple[XlaComputationBuilder, ...]
|
||||
_tokens: List[XlaOp]
|
||||
_log_idx: int # For debugging
|
||||
|
||||
# TODO(necula): remove these experiment flags
|
||||
# Flags to force compilation of while, cond, call with in/out state, to
|
||||
# test the XLA compiler's handling of tokens.
|
||||
FORCE_OUTFEED = False
|
||||
_LOG_STATE = False
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._computations = ()
|
||||
self._tokens = []
|
||||
self._log_idx = 0
|
||||
|
||||
def check_and_reset_state(self) -> bool:
|
||||
"""Checks that the state is empty, and resets it if not."""
|
||||
ok = not self._computations and not self._tokens
|
||||
self._computations = ()
|
||||
self._tokens = []
|
||||
self._log_idx = 0
|
||||
return ok
|
||||
|
||||
def _log_state(self, msg: str):
|
||||
if self._LOG_STATE:
|
||||
top_comp = id(self._computations[-1]) if self._computations else 0
|
||||
logging.warning(
|
||||
f"[{self._log_idx} @ 0x{top_comp:x}/{len(self._computations)}]: {msg}.")
|
||||
self._log_idx += 1
|
||||
|
||||
def current_state(self, comp: XlaComputationBuilder, uses_state: bool) -> List[XlaOp]:
|
||||
"""Get the current state for the current computation."""
|
||||
if uses_state:
|
||||
assert self._computations and self._tokens
|
||||
# Ensure we kept track of computations properly
|
||||
assert (comp is self._computations[-1]), f"Reading state from unexpected computation (0x{id(comp):x})"
|
||||
assert self._tokens[-1], "Reading non-initialized state"
|
||||
return [self._tokens[-1]]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def start_nested_comp_without_input(self, comp: XlaComputationBuilder,
|
||||
uses_state: bool):
|
||||
"""Starts a nested computation, with no state passed in.
|
||||
`comp` is the new computation.
|
||||
"""
|
||||
self._computations = self._computations + (comp,)
|
||||
if uses_state:
|
||||
self._log_state(f"start nested computation without input, initialize state")
|
||||
assert all([t is None for t in self._tokens]), "Ignoring upstream state"
|
||||
token = xops.CreateToken(comp)
|
||||
else:
|
||||
self._log_state(f"start nested computation without input state, no state")
|
||||
token = None # Cannot use state, here or in nested computations
|
||||
self._tokens = self._tokens + [token]
|
||||
|
||||
def start_nested_comp_with_input(self, comp: XlaComputationBuilder,
|
||||
tuple_op: XlaOp, nr_regular: int,
|
||||
uses_state: bool
|
||||
) -> Tuple[Sequence[XlaOp], Sequence[XlaOp]]:
|
||||
"""Start a nested computation with inputs.
|
||||
`comp` is the new computation. `tuple_op` is a tuple in the new computation
|
||||
with `nr_regular` elements and perhaps some state elements (if `uses_state`).
|
||||
Stores the new state for the new computation.
|
||||
|
||||
Returns the regular elements of the tuple, and the state.
|
||||
"""
|
||||
self._computations = self._computations + (comp,)
|
||||
self._log_state(f"start nested computation with input")
|
||||
self._tokens = self._tokens + [None] # Will set the token below
|
||||
return self.set_state_from_tuple(comp, tuple_op, nr_regular, uses_state)
|
||||
|
||||
def end_nested_comp_without_output(self, comp: XlaComputationBuilder) -> None:
|
||||
"""Ends a nested computation, with no state returned.
|
||||
`comp` is the ending computation.
|
||||
"""
|
||||
self._log_state(f"end nested computation without output")
|
||||
if self._tokens[-1]:
|
||||
assert self._computations and self._computations[-1] is comp
|
||||
self._computations = self._computations[:-1]
|
||||
self._tokens = self._tokens[:-1]
|
||||
|
||||
def end_nested_comp_with_output(self, comp: XlaComputationBuilder,
|
||||
tuple_op: XlaOp, nr_regular: int,
|
||||
uses_state: bool) -> Sequence[XlaOp]:
|
||||
"""Ends a nested computation with results.
|
||||
`comp` is the parent computation into which we return, `tuple_op` is
|
||||
a tuple in the parent computation with `nr_regular` elements and
|
||||
perhaps some state elements (if `uses_state`).
|
||||
Stores the new state for the parent computation.
|
||||
|
||||
Returns the regular elements of the output.
|
||||
"""
|
||||
self._log_state(f"end nested computation with output")
|
||||
if uses_state:
|
||||
assert len(self._computations) >= 2 and self._computations[-2] is comp
|
||||
self._computations = self._computations[:-1]
|
||||
self._tokens = self._tokens[:-1]
|
||||
regulars, _ = self.set_state_from_tuple(comp, tuple_op,
|
||||
nr_regular, uses_state)
|
||||
return regulars
|
||||
|
||||
|
||||
def set_state_from_tuple(self, comp: XlaComputationBuilder,
|
||||
tuple_op: XlaOp, nr_regular: int,
|
||||
uses_state: bool
|
||||
) -> Tuple[Sequence[XlaOp], Sequence[XlaOp]]:
|
||||
"""Decomposes a tuple, returns the regular elements and the state.
|
||||
|
||||
We assume that the `tuple_op` represents a tuple with `nr_regular` regular
|
||||
elements, followed by some elements encoding the state.
|
||||
Stores the new state.
|
||||
"""
|
||||
regular_ops = [xops.GetTupleElement(tuple_op, i) for i in range(nr_regular)]
|
||||
if uses_state:
|
||||
assert self._computations and self._computations[-1] is comp
|
||||
self._log_state(f"set state from tuple_op")
|
||||
self._tokens[-1] = xops.GetTupleElement(tuple_op, nr_regular)
|
||||
return regular_ops, self.current_state(comp, uses_state)
|
||||
else:
|
||||
return regular_ops, []
|
||||
|
||||
def current_token(self, comp: XlaComputationBuilder) -> XlaOp:
|
||||
self._log_state("reading token")
|
||||
state = self.current_state(comp, True)
|
||||
return state[0]
|
||||
|
||||
def set_current_token(self, comp: XlaComputationBuilder, token: XlaOp):
|
||||
assert comp == self._computations[-1]
|
||||
assert self._tokens[-1]
|
||||
self._tokens[-1] = token
|
||||
|
||||
state_carry = _ComputationStateCarry()
|
||||
# TODO(necula): remove this when we start the outfeed receiver automatically.
|
||||
can_execute_outfeed_computations: bool = False
|
||||
def check_before_outfeed_execution(uses_outfeed: bool):
|
||||
if uses_outfeed and not can_execute_outfeed_computations:
|
||||
raise ValueError("Attempting to execute compiled code using outfeed, "
|
||||
"but outfeed_receiver is not started.")
|
||||
|
||||
### op-by-op execution
|
||||
|
||||
@ -399,7 +254,6 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
|
||||
@cache()
|
||||
def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params):
|
||||
c = xb.make_computation_builder(f"primitive_computation_{prim.name}")
|
||||
state_carry.start_nested_comp_without_input(c, False)
|
||||
c.SetOpMetadata(xc.OpMetadata(
|
||||
op_type=prim.name,
|
||||
op_name=str(pp_eqn_compact(prim.name, params))))
|
||||
@ -420,7 +274,6 @@ def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params)
|
||||
raise NotImplementedError(f"XLA translation rule for {prim} not found")
|
||||
assert isinstance(ans, xe.XlaOp)
|
||||
c.ClearOpMetadata()
|
||||
state_carry.end_nested_comp_without_output(c)
|
||||
try:
|
||||
return c.Build()
|
||||
except RuntimeError as e:
|
||||
@ -653,7 +506,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):
|
||||
pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
|
||||
jaxpr, pvals, consts = pe.trace_to_jaxpr(
|
||||
fun, pvals, instantiate=False, stage_out=True, bottom=True)
|
||||
|
||||
jaxpr, uses_outfeed = apply_outfeed_rewriter(jaxpr)
|
||||
_map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
|
||||
|
||||
nreps = jaxpr_replicas(jaxpr)
|
||||
@ -683,8 +536,6 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):
|
||||
tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU
|
||||
|
||||
c = xb.make_computation_builder("jit_{}".format(fun.__name__))
|
||||
uses_outfeed = jaxpr_uses_outfeed(jaxpr)
|
||||
state_carry.start_nested_comp_without_input(c, uses_outfeed)
|
||||
xla_consts = _map(partial(xb.constant, c), consts)
|
||||
xla_args = _xla_callable_args(c, abstract_args, tuple_args)
|
||||
out_nodes = jaxpr_subcomp(
|
||||
@ -698,9 +549,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):
|
||||
device_assignment=(device.id,) if device else None)
|
||||
options.tuple_arguments = tuple_args
|
||||
backend = xb.get_backend(backend)
|
||||
state_carry.end_nested_comp_without_output(c)
|
||||
compiled = backend.compile(built, compile_options=options)
|
||||
|
||||
if nreps == 1:
|
||||
return partial(_execute_compiled, compiled, uses_outfeed, result_handlers)
|
||||
else:
|
||||
@ -745,19 +594,9 @@ def _pval_to_result_handler(device, pval):
|
||||
else:
|
||||
return aval_to_result_handler(device, pv)
|
||||
|
||||
_outfeed_allowed = False
|
||||
def set_outfeed_allowed(allowed: bool):
|
||||
global _outfeed_allowed
|
||||
_outfeed_allowed = allowed
|
||||
|
||||
def check_outfeed_allowed(uses_outfeed: bool):
|
||||
if uses_outfeed and not _outfeed_allowed:
|
||||
raise ValueError("Attempting to execute compiled code using outfeed, "
|
||||
"but outfeed_consumer is not started.")
|
||||
|
||||
def _execute_compiled(compiled: XlaExecutable, uses_outfeed: bool,
|
||||
handlers, *args):
|
||||
check_outfeed_allowed(uses_outfeed)
|
||||
check_before_outfeed_execution(uses_outfeed)
|
||||
device, = compiled.local_devices()
|
||||
input_bufs = [device_put(x, device) for x in args if x is not token]
|
||||
out_bufs = compiled.Execute(input_bufs)
|
||||
@ -766,7 +605,7 @@ def _execute_compiled(compiled: XlaExecutable, uses_outfeed: bool,
|
||||
|
||||
def _execute_replicated(compiled: XlaExecutable, uses_outfeed: bool,
|
||||
handlers, *args):
|
||||
check_outfeed_allowed(uses_outfeed)
|
||||
check_before_outfeed_execution(uses_outfeed)
|
||||
input_bufs = [
|
||||
[device_put(x, device) for x in args if x is not token]
|
||||
for device in compiled.local_devices()]
|
||||
@ -810,23 +649,11 @@ def _xla_call_translation_rule(c, axis_env,
|
||||
call_jaxpr, device=None):
|
||||
del device # Ignored.
|
||||
subc = xb.make_computation_builder(f"jit_{name}")
|
||||
|
||||
uses_outfeed = jaxpr_uses_outfeed(call_jaxpr)
|
||||
prev_state = state_carry.current_state(c, uses_outfeed)
|
||||
input_op = xops.Tuple(c, list(in_nodes) + list(prev_state))
|
||||
arg = xb.parameter(subc, 0, c.GetShape(input_op))
|
||||
nr_regular_args = len(in_nodes)
|
||||
args, input_state = state_carry.start_nested_comp_with_input(subc, arg, nr_regular_args, uses_outfeed)
|
||||
args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
|
||||
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
|
||||
extend_name_stack(name_stack, wrap_name(name, 'jit')),
|
||||
*(args + input_state))
|
||||
result_state = state_carry.current_state(subc, uses_outfeed)
|
||||
subc = subc.Build(xops.Tuple(subc, list(out_nodes) + result_state))
|
||||
call_op = xops.Call(c, subc, [input_op])
|
||||
nr_outs = len(out_nodes)
|
||||
regular_outs = state_carry.end_nested_comp_with_output(c, call_op, nr_outs, uses_outfeed)
|
||||
return xops.Tuple(c, regular_outs)
|
||||
|
||||
extend_name_stack(name_stack, wrap_name(name, 'jit')), *args)
|
||||
subc = subc.Build(xops.Tuple(subc, out_nodes))
|
||||
return xops.Call(c, subc, list(in_nodes))
|
||||
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
|
||||
|
||||
|
||||
@ -1230,21 +1057,16 @@ def _remat_translation_rule(c, axis_env, in_nodes,
|
||||
xb.constant(c, onp.array(1, dtype=onp.float32)),
|
||||
xc.Shape.array_shape(xc.PrimitiveType.F32, []))
|
||||
pred = xops.Lt(rng, xb.constant(c, onp.array(2, dtype=onp.float32)))
|
||||
uses_outfeed = jaxpr_uses_outfeed(call_jaxpr)
|
||||
prev_state = state_carry.current_state(c, uses_outfeed)
|
||||
|
||||
true_op = xops.Tuple(c, list(in_nodes) + list(prev_state))
|
||||
true_op = xops.Tuple(c, in_nodes)
|
||||
remat_subc = xb.make_computation_builder("remat_call_subcomputation")
|
||||
input_op = xb.parameter(remat_subc, 0, c.GetShape(true_op), replicated=[])
|
||||
args, input_state = state_carry.start_nested_comp_with_input(remat_subc, input_op, len(in_nodes), uses_outfeed)
|
||||
args = [xops.GetTupleElement(input_op, i) for i in range(len(in_nodes))]
|
||||
out_nodes = jaxpr_subcomp(remat_subc, call_jaxpr, backend, axis_env, (),
|
||||
extend_name_stack(name_stack, wrap_name(name, 'remat')),
|
||||
*(args + input_state))
|
||||
result_state = state_carry.current_state(remat_subc, uses_outfeed)
|
||||
result_tuple = xops.Tuple(remat_subc, out_nodes + tuple(result_state))
|
||||
regular_outs = state_carry.end_nested_comp_with_output(c, result_tuple, len(out_nodes), uses_outfeed)
|
||||
out_node_shapes = [remat_subc.GetShape(o) for o in regular_outs]
|
||||
remat_subc = remat_subc.Build(xops.Tuple(remat_subc, regular_outs))
|
||||
*args)
|
||||
out_node_shapes = [remat_subc.GetShape(o) for o in out_nodes]
|
||||
remat_subc = remat_subc.Build(xops.Tuple(remat_subc, out_nodes))
|
||||
|
||||
false_op = true_op
|
||||
dummy_subc = xb.make_computation_builder("remat_call_dummy_subcomputation")
|
||||
|
@ -254,23 +254,15 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
|
||||
cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts])
|
||||
batched = bool(cond_jaxpr.out_avals[0].shape)
|
||||
|
||||
if xla.jaxpr_uses_outfeed(cond_jaxpr) and not xla.state_carry.FORCE_OUTFEED:
|
||||
# TODO: implement the boolean as an extra carry
|
||||
raise NotImplementedError("State not supported in while_loop conditionals")
|
||||
|
||||
uses_outfeed = xla.jaxpr_uses_outfeed(body_jaxpr)
|
||||
prev_state = xla.state_carry.current_state(c, uses_outfeed)
|
||||
|
||||
# Since jaxprs don't have tuples and have multiple return values, but we need
|
||||
# the HLO While loop to take a single tuple input and output a single boolean
|
||||
# (for the cond computation) or a single tuple output (for the body
|
||||
# computation), we build XLA computations that handle the tuple munging before
|
||||
# generating a Call into the computations formed from the jaxprs.
|
||||
|
||||
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals + prev_state)
|
||||
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals)
|
||||
|
||||
cond_c = xb.make_computation_builder("cond_computation")
|
||||
xla.state_carry.start_nested_comp_without_input(cond_c, False)
|
||||
cond_carry = xb.parameter(cond_c, 0, c.GetShape(init_carry))
|
||||
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
|
||||
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
|
||||
@ -283,12 +275,10 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
|
||||
or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
|
||||
pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, onp.array(False))], or_,
|
||||
list(range(cond_jaxpr.out_avals[0].ndim)))
|
||||
cond_comp = cond_c.Build(pred)
|
||||
xla.state_carry.end_nested_comp_without_output(cond_c)
|
||||
|
||||
body_c = xb.make_computation_builder("body_computation")
|
||||
body_carry = xb.parameter(body_c, 0, c.GetShape(init_carry))
|
||||
body_carry_elts, _ = xla.state_carry.start_nested_comp_with_input(
|
||||
body_c, body_carry, len(args), uses_outfeed)
|
||||
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
|
||||
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
|
||||
new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
|
||||
_map(partial(xb.constant, body_c), body_jaxpr.literals),
|
||||
@ -299,11 +289,10 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
|
||||
extend_name_stack(name_stack, 'body_pred'), *(x + z))
|
||||
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z)
|
||||
assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape, z) # no broadcast
|
||||
result_state = xla.state_carry.current_state(body_c, uses_outfeed)
|
||||
new_carry = xops.Tuple(body_c, list(itertools.chain(x, y, new_z, result_state)))
|
||||
new_carry = xops.Tuple(body_c, list(itertools.chain(x, y, new_z)))
|
||||
|
||||
ans = xops.While(cond_comp, body_c.Build(new_carry), init_carry)
|
||||
ans_elts = xla.state_carry.end_nested_comp_with_output(c, ans, len(args), uses_outfeed)
|
||||
ans = xops.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry)
|
||||
ans_elts = [xops.GetTupleElement(ans, i) for i in range(len(args))]
|
||||
_, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
|
||||
return xops.Tuple(c, z)
|
||||
|
||||
@ -567,35 +556,23 @@ def _cond_translation_rule(c, axis_env, name_stack, avals, backend,
|
||||
pred, *args, true_jaxpr, false_jaxpr, linear):
|
||||
del linear # Unused.
|
||||
true_ops, false_ops = split_list(args, [len(true_jaxpr.in_avals)])
|
||||
uses_outfeed = xla.jaxpr_uses_outfeed(true_jaxpr) or xla.jaxpr_uses_outfeed(false_jaxpr)
|
||||
prev_state = xla.state_carry.current_state(c, uses_outfeed)
|
||||
|
||||
def make_computation(name, jaxpr, op_shape):
|
||||
sub_c = xb.make_computation_builder(name + '_comp')
|
||||
op = xb.parameter(sub_c, 0, op_shape)
|
||||
ops, _ = xla.state_carry.start_nested_comp_with_input(
|
||||
sub_c, op, len(jaxpr.in_avals), uses_outfeed)
|
||||
outs = xla.jaxpr_subcomp(sub_c, jaxpr.jaxpr, backend, axis_env,
|
||||
_map(partial(xb.constant, sub_c), jaxpr.literals),
|
||||
c = xb.make_computation_builder(name + '_comp')
|
||||
op = xb.parameter(c, 0, op_shape)
|
||||
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
|
||||
outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
|
||||
_map(partial(xb.constant, c), jaxpr.literals),
|
||||
extend_name_stack(name_stack, name + '_fun'), *ops)
|
||||
result_state = xla.state_carry.current_state(sub_c, uses_outfeed)
|
||||
result_tuple = xops.Tuple(sub_c, list(outs) + result_state)
|
||||
xla.state_carry.end_nested_comp_with_output(c, result_tuple, len(outs), uses_outfeed)
|
||||
return sub_c.Build(result_tuple)
|
||||
return c.Build(xops.Tuple(c, outs))
|
||||
|
||||
true_op = xops.Tuple(c, true_ops + prev_state)
|
||||
true_op = xops.Tuple(c, true_ops)
|
||||
true_c = make_computation('true', true_jaxpr, c.GetShape(true_op))
|
||||
|
||||
false_op = xops.Tuple(c, false_ops + prev_state)
|
||||
false_op = xops.Tuple(c, false_ops)
|
||||
false_c = make_computation('false', false_jaxpr, c.GetShape(false_op))
|
||||
cond_op = xops.Conditional(pred, true_op, true_c, false_op, false_c)
|
||||
if not uses_outfeed:
|
||||
return cond_op
|
||||
else:
|
||||
nr_outs = len(true_jaxpr.out_avals)
|
||||
regular_outs, _ = xla.state_carry.set_state_from_tuple(
|
||||
c, cond_op, nr_outs, uses_outfeed)
|
||||
return xops.Tuple(c, regular_outs)
|
||||
|
||||
return xops.Conditional(pred, true_op, true_c, false_op, false_c)
|
||||
|
||||
def _cond_pred_bcast_select(pred, x, y):
|
||||
if core.get_aval(x) is core.get_aval(y) is core.abstract_unit:
|
||||
|
@ -21,6 +21,7 @@ import logging
|
||||
import numpy as onp
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Callable, List, Sequence, Tuple
|
||||
from unittest import SkipTest
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -78,6 +79,29 @@ def fun1(a):
|
||||
def fun1_equiv(a): # Numerical equivalent of fun`
|
||||
return (a * 2.)**2
|
||||
|
||||
def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase, expected: str, what: str):
|
||||
"""A variant that preprocesses the string to eliminate non-determinism in
|
||||
floating point values, and several uninteresting id_tap primitive params."""
|
||||
# Sometimes we get floating points in the output; we round them
|
||||
def repl_floats(match_group):
|
||||
matched = match_group.group(0)
|
||||
if matched == ".": return matched
|
||||
# TODO: why can't we use here np.around?
|
||||
x = onp.around(float(matched), decimals=2)
|
||||
return f"{x:.2f}"
|
||||
what = re.sub(r"\-?\d*\.[\-\def]*", repl_floats, what)
|
||||
what = re.sub(r"output_stream=[^\]\n]*", "", what)
|
||||
what = re.sub(r"threshold=[^\]\n]*", "", what)
|
||||
# Empty lines
|
||||
what = re.sub(r"^\s*\n", "", what, flags=re.MULTILINE)
|
||||
def repl_func(match_group):
|
||||
matched = match_group.group(0)
|
||||
if "function _print_consumer" in matched:
|
||||
return "func=_print"
|
||||
else:
|
||||
return "..."
|
||||
what = re.sub(r"func=(.*)", repl_func, what)
|
||||
tst.assertMultiLineStrippedEqual(expected, what)
|
||||
|
||||
class HostCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
@ -100,50 +124,24 @@ class HostCallbackTest(jtu.JaxTestCase):
|
||||
# Clear any cached backends so new CPU backend will pick up the env var.
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
def assertMultiLineStrippedEqual(self, expected, what):
|
||||
"""A variant that preprocesses the string to eliminate non-determinism."""
|
||||
# Sometimes we get floating points in the output; we round them
|
||||
def repl_floats(match_group):
|
||||
matched = match_group.group(0)
|
||||
if matched == ".": return matched
|
||||
# TODO: why can't we use here np.around?
|
||||
x = onp.around(float(matched), decimals=2)
|
||||
return f"{x:.2f}"
|
||||
what = re.sub(r"\-?\d*\.[\-\def]*", repl_floats, what)
|
||||
# We rewrite consumer_id because it changes
|
||||
what = re.sub(r"consumer_id=(\d+)", "consumer_id=...", what)
|
||||
what = re.sub(r"output_stream=[^\]\n]*", "output_stream=...", what)
|
||||
def repl_func(match_group):
|
||||
matched = match_group.group(0)
|
||||
if "function _print_consumer" in matched:
|
||||
return "func=_print"
|
||||
else:
|
||||
return "..."
|
||||
what = re.sub(r"func=(.*)", repl_func, what)
|
||||
super().assertMultiLineStrippedEqual(expected, what)
|
||||
|
||||
def test_eval(self):
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda ; a.
|
||||
let b = mul a 2.00
|
||||
c = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
output_stream=...
|
||||
what=a * 2 ] b
|
||||
d = mul c 3.00
|
||||
e f = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
nr_untapped=1
|
||||
output_stream=...
|
||||
what=y * 3 ] d c
|
||||
g = pow f 2.00
|
||||
in (g,) }""", str(api.make_jaxpr(fun1)(5.)))
|
||||
self.assertEqual("", testing_stream.output)
|
||||
|
||||
self.assertEqual((5. * 2.) ** 2, fun1(5.))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: a * 2
|
||||
10.00
|
||||
what: y * 3
|
||||
@ -155,18 +153,17 @@ what: y * 3
|
||||
x1, y1 = hcb.id_print((x * 2., x * 3.), output_stream=testing_stream)
|
||||
return x1 + y1
|
||||
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda ; a.
|
||||
let b = mul a 2.00
|
||||
c = mul a 3.00
|
||||
d e = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
|
||||
func=_print
|
||||
output_stream=...] b c
|
||||
] b c
|
||||
f = add d e
|
||||
in (f,) }""", str(api.make_jaxpr(func2)(3.)))
|
||||
self.assertEqual(3. * (2. + 3.), func2(3.))
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
[ 6.00
|
||||
9.00 ]""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
@ -176,18 +173,17 @@ what: y * 3
|
||||
res = hcb.id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream)
|
||||
return res["a"] + res["b"]
|
||||
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda ; a.
|
||||
let b = mul a 2.00
|
||||
c = mul a 3.00
|
||||
d e = id_tap[ arg_treedef=PyTreeDef(dict[['a', 'b']], [*,*])
|
||||
func=_print
|
||||
output_stream=...] b c
|
||||
] b c
|
||||
f = add d e
|
||||
in (f,) }""", str(api.make_jaxpr(func2)(3.)))
|
||||
self.assertEqual(3. * (2. + 3.), func2(3.))
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ a=6.00
|
||||
b=9.00 }""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
@ -198,8 +194,7 @@ what: y * 3
|
||||
output_stream=testing_stream)
|
||||
return x1
|
||||
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda ; a.
|
||||
let b = mul a 2.00
|
||||
c = mul a 3.00
|
||||
@ -207,10 +202,10 @@ what: y * 3
|
||||
e f g = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
|
||||
func=_print
|
||||
nr_untapped=1
|
||||
output_stream=...] b c d
|
||||
] b c d
|
||||
in (g,) }""", str(api.make_jaxpr(func2)(3.)))
|
||||
self.assertEqual(3. * 4., func2(3.))
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
[ 6.00
|
||||
9.00 ]""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
@ -255,7 +250,7 @@ what: y * 3
|
||||
res = func(0)
|
||||
|
||||
# We should have received everything before the error
|
||||
self.assertMultiLineStrippedEqual(
|
||||
assertMultiLineStrippedEqual(self,
|
||||
"""
|
||||
what: x1
|
||||
1""", testing_stream.output)
|
||||
@ -264,15 +259,13 @@ what: x1
|
||||
def test_jit_simple(self):
|
||||
jit_fun1 = api.jit(lambda x: 3. * hcb.id_print(
|
||||
2. * x, what="here", output_stream=testing_stream))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda ; a.
|
||||
let b = xla_call[ backend=None
|
||||
call_jaxpr={ lambda ; a.
|
||||
let b = mul a 2.00
|
||||
c = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
output_stream=...
|
||||
what=here ] b
|
||||
d = mul c 3.00
|
||||
in (d,) }
|
||||
@ -285,7 +278,7 @@ what: x1
|
||||
res = jit_fun1(5.)
|
||||
|
||||
self.assertAllClose(6. * 5., res, check_dtypes=True)
|
||||
self.assertMultiLineStrippedEqual(
|
||||
assertMultiLineStrippedEqual(self,
|
||||
"""
|
||||
what: here
|
||||
10.00""", testing_stream.output)
|
||||
@ -303,8 +296,7 @@ what: here
|
||||
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertEqual(2, api.jit(func)(1))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
where: 1
|
||||
1
|
||||
where: 2
|
||||
@ -322,8 +314,7 @@ where: 2
|
||||
self.assertEqual(2, api.jit(func)(1))
|
||||
self.assertEqual(11, api.jit(func)(10))
|
||||
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
where: 1
|
||||
1
|
||||
where: 2
|
||||
@ -341,7 +332,7 @@ where: 2
|
||||
x2 = hcb.id_print(x + 1, where="nested", output_stream=testing_stream)
|
||||
return x2
|
||||
x3 = api.jit(func_nested)(x1)
|
||||
return hcb.id_print(x3 + 1, where="2", output_stream=testing_stream)
|
||||
return hcb.id_print(x3 + 1, where="3", output_stream=testing_stream)
|
||||
|
||||
logging.warning("%s: %s", self._testMethodName,
|
||||
api.make_jaxpr(func)(1))
|
||||
@ -349,13 +340,12 @@ where: 2
|
||||
api.xla_computation(func)(1).GetHloText())
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertEqual(3, api.jit(func)(1))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
where: 1
|
||||
1
|
||||
where: nested
|
||||
2
|
||||
where: 2
|
||||
where: 3
|
||||
3""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@ -421,7 +411,7 @@ where: 2
|
||||
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertEqual(4, api.jit(func)(1))
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
where: 1
|
||||
1
|
||||
where: 2
|
||||
@ -452,8 +442,7 @@ where: w_2
|
||||
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertEqual(10, api.jit(func)(1))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
where: 1
|
||||
1
|
||||
where: 2
|
||||
@ -507,7 +496,7 @@ where: 10
|
||||
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertEqual(10, api.jit(func)(1))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
assertMultiLineStrippedEqual(self,
|
||||
"""
|
||||
""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
@ -536,8 +525,7 @@ where: 10
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
res = api.jit(func)(1)
|
||||
self.assertAllClose(np.array([2, 3, 4]), res, check_dtypes=True)
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
where: 1
|
||||
1
|
||||
where: 2
|
||||
@ -641,8 +629,7 @@ where: 10
|
||||
# return 3.
|
||||
self.assertEqual(3, res)
|
||||
# We should have received all others
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x1
|
||||
1
|
||||
what: x3
|
||||
@ -664,8 +651,7 @@ what: x3
|
||||
# return 3.
|
||||
self.assertEqual(3, res)
|
||||
# We should have received all others
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x1
|
||||
1
|
||||
what: x3
|
||||
@ -689,9 +675,9 @@ what: x3
|
||||
|
||||
assert False # It seems that the previous jit blocks above
|
||||
|
||||
def test_jit_without_consumer_error(self):
|
||||
def test_jit_error_no_consumer(self):
|
||||
# Check for errors if starting jit without a consumer active
|
||||
with self.assertRaisesRegex(ValueError, "outfeed_consumer is not started"):
|
||||
with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"):
|
||||
api.jit(lambda x: hcb.id_print(x))(0)
|
||||
|
||||
# On CPU and GPU the device code blocks
|
||||
@ -727,34 +713,29 @@ what: x3
|
||||
|
||||
def test_jvp(self):
|
||||
jvp_fun1 = lambda x, xt: api.jvp(fun1, (x,), (xt,))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda ; a b.
|
||||
let c = mul a 2.00
|
||||
d = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
nr_untapped=0
|
||||
output_stream=...
|
||||
what=a * 2 ] c
|
||||
e = mul d 3.00
|
||||
f g = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
nr_untapped=1
|
||||
output_stream=...
|
||||
what=y * 3 ] e d
|
||||
h = pow g 2.00
|
||||
i = mul b 2.00
|
||||
j k = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
nr_untapped=1
|
||||
output_stream=...
|
||||
transforms=('jvp',)
|
||||
what=a * 2 ] i d
|
||||
l = mul j 3.00
|
||||
m n o = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
nr_untapped=2
|
||||
output_stream=...
|
||||
transforms=('jvp',)
|
||||
what=y * 3 ] l j f
|
||||
p = pow g 1.00
|
||||
@ -764,8 +745,7 @@ what: x3
|
||||
str(api.make_jaxpr(jvp_fun1)(np.float32(5.), np.float32(0.1))))
|
||||
|
||||
res_primals, res_tangents = jvp_fun1(np.float32(5.), np.float32(0.1))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: a * 2
|
||||
10.00
|
||||
transforms: ('jvp',) what: a * 2
|
||||
@ -777,26 +757,24 @@ transforms: ('jvp',) what: y * 3
|
||||
testing_stream.reset()
|
||||
|
||||
def test_grad_primal_unused(self):
|
||||
# The output of id_print is not needed for backwards pass
|
||||
def func(x):
|
||||
return 2. * hcb.id_print(x * 3., what="x * 3", output_stream=testing_stream)
|
||||
|
||||
grad_func = api.grad(func)
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda ; a.
|
||||
let
|
||||
in (6.00,) }""", str(api.make_jaxpr(grad_func)(5.)))
|
||||
# Just making the Jaxpr invokes the id_print once
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: ('jvp', 'transpose') what: x * 3
|
||||
2.00""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
res_grad = grad_func(np.float32(5.))
|
||||
self.assertAllClose(6., res_grad, check_dtypes=False)
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 3
|
||||
15.00
|
||||
transforms: ('jvp', 'transpose') what: x * 3
|
||||
@ -808,22 +786,18 @@ transforms: ('jvp', 'transpose') what: x * 3
|
||||
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
|
||||
return x * hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream)
|
||||
grad_func = api.grad(func)
|
||||
# TODO: why is the order like that?
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda ; a.
|
||||
let b = mul 1.00 a
|
||||
c d = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
nr_untapped=1
|
||||
output_stream=...
|
||||
transforms=('jvp', 'transpose')
|
||||
what=y * 3 ] b 0.00
|
||||
e = mul c 3.00
|
||||
f g = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
nr_untapped=1
|
||||
output_stream=...
|
||||
transforms=('jvp', 'transpose')
|
||||
what=x * 2 ] e 0.00
|
||||
h = mul f 2.00
|
||||
@ -831,13 +805,11 @@ transforms: ('jvp', 'transpose') what: x * 3
|
||||
j = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
nr_untapped=0
|
||||
output_stream=...
|
||||
what=x * 2 ] i
|
||||
k = mul j 3.00
|
||||
l = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
nr_untapped=0
|
||||
output_stream=...
|
||||
what=y * 3 ] k
|
||||
m = mul 1.00 l
|
||||
n = add_any h m
|
||||
@ -845,8 +817,7 @@ transforms: ('jvp', 'transpose') what: x * 3
|
||||
|
||||
res_grad = grad_func(np.float32(5.))
|
||||
self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 2
|
||||
10.00
|
||||
what: y * 3
|
||||
@ -863,20 +834,21 @@ transforms: ('jvp', 'transpose') what: x * 2
|
||||
return x * (y * 3.)
|
||||
|
||||
grad_func = api.grad(api.grad(func))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda ; a.
|
||||
let
|
||||
in (12.00,) }""", str(api.make_jaxpr(grad_func)(5.)))
|
||||
|
||||
res_grad = grad_func(np.float32(5.))
|
||||
self.assertAllClose(12., res_grad, check_dtypes=False)
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
# Just making the Jaxpr invokes the id_print twiceonce
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: ('jvp', 'transpose') what: x * 2
|
||||
3.00
|
||||
transforms: ('jvp', 'transpose', 'jvp', 'transpose') what: x * 2
|
||||
2.00
|
||||
2.00""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
res_grad = grad_func(np.float32(5.))
|
||||
self.assertAllClose(12., res_grad, check_dtypes=False)
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 2
|
||||
10.00
|
||||
transforms: ('jvp', 'transpose') what: x * 2
|
||||
@ -891,14 +863,12 @@ transforms: ('jvp', 'transpose') what: x * 2
|
||||
def test_vmap(self):
|
||||
vmap_fun1 = api.vmap(fun1)
|
||||
vargs = np.array([np.float32(4.), np.float32(5.)])
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda ; a.
|
||||
let b = mul a 2.00
|
||||
c = id_tap[ arg_treedef=*
|
||||
batch_dims=(0,)
|
||||
func=_print
|
||||
output_stream=...
|
||||
transforms=('batch',)
|
||||
what=a * 2 ] b
|
||||
d = mul c 3.00
|
||||
@ -906,15 +876,13 @@ transforms: ('jvp', 'transpose') what: x * 2
|
||||
batch_dims=(0, 0)
|
||||
func=_print
|
||||
nr_untapped=1
|
||||
output_stream=...
|
||||
transforms=('batch',)
|
||||
what=y * 3 ] d c
|
||||
g = pow f 2.00
|
||||
in (g,) }""", str(api.make_jaxpr(vmap_fun1)(vargs)))
|
||||
|
||||
res_vmap = vmap_fun1(vargs)
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
batch_dims: (0,) transforms: ('batch',) what: a * 2
|
||||
[ 8.00 10.00]
|
||||
batch_dims: (0, 0) transforms: ('batch',) what: y * 3
|
||||
@ -930,27 +898,23 @@ batch_dims: (0, 0) transforms: ('batch',) what: y * 3
|
||||
|
||||
vmap_func = api.vmap(func)
|
||||
vargs = np.array([np.float32(4.), np.float32(5.)])
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda ; a.
|
||||
let b c = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
|
||||
batch_dims=(None, 0)
|
||||
func=_print
|
||||
output_stream=...
|
||||
transforms=('batch',) ] 3.00 a
|
||||
d = add c 3.00
|
||||
in (d,) }""", str(api.make_jaxpr(vmap_func)(vargs)))
|
||||
|
||||
res_vmap = vmap_func(vargs)
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
batch_dims: (None, 0) transforms: ('batch',)
|
||||
[ 3.00
|
||||
[4.00 5.00] ]
|
||||
""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
|
||||
def test_pmap(self):
|
||||
vargs = 2. + np.arange(api.local_device_count(), dtype=np.float32)
|
||||
|
||||
@ -960,10 +924,10 @@ batch_dims: (None, 0) transforms: ('batch',)
|
||||
expected_res = np.stack([fun1_equiv(2. + a) for a in range(api.local_device_count())])
|
||||
self.assertAllClose(expected_res, res, check_dtypes=False)
|
||||
|
||||
def test_pmap_without_consumer_error(self):
|
||||
def test_pmap_error_no_receiver(self):
|
||||
# Check for errors if starting jit without a consumer active
|
||||
vargs = 2. + np.arange(api.local_device_count(), dtype=np.float32)
|
||||
with self.assertRaisesRegex(ValueError, "outfeed_consumer is not started"):
|
||||
with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"):
|
||||
api.pmap(lambda x: hcb.id_print(x))(vargs)
|
||||
|
||||
def test_mask(self):
|
||||
@ -972,13 +936,11 @@ batch_dims: (None, 0) transforms: ('batch',)
|
||||
def padded_sum(x):
|
||||
return np.sum(hcb.id_print(x, what="x", output_stream=testing_stream))
|
||||
args = [np.arange(4)], dict(n=onp.int64(2))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda c f ; a b.
|
||||
let d = lt c b
|
||||
e = id_tap[ func=_print
|
||||
logical_shapes=[(Traced<ShapedArray(int32[]):JaxprTrace(level=0/0)>,)]
|
||||
output_stream=...
|
||||
transforms=('mask',)
|
||||
what=x ] a
|
||||
g = select d e f
|
||||
@ -986,12 +948,108 @@ batch_dims: (None, 0) transforms: ('batch',)
|
||||
in (h,) }""", str(api.make_jaxpr(padded_sum)(*args)))
|
||||
|
||||
res = padded_sum(*args)
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
logical_shapes: [(2,)] transforms: ('mask',) what: x
|
||||
[0 1 2 3]
|
||||
""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
def assertRewrite(self, expected: str, func: Callable, args: Sequence,
|
||||
has_input_token=True, has_output_token=True):
|
||||
"""Check that the rewrite of func(*args) matches expected."""
|
||||
jaxpr = api.make_jaxpr(func)(*args)
|
||||
assertMultiLineStrippedEqual(self, expected,
|
||||
str(hcb._rewrite_typed_jaxpr(jaxpr, has_input_token, has_output_token)[0]))
|
||||
|
||||
def test_no_outfeed(self):
|
||||
self.assertRewrite("""
|
||||
{ lambda ; a.
|
||||
let b = mul a a
|
||||
c = add a b
|
||||
in (c,) }""", lambda x: x + x * x, [0], has_input_token=False, has_output_token=False)
|
||||
self.assertRewrite("""
|
||||
{ lambda ; a d.
|
||||
let b = mul a a
|
||||
c = add a b
|
||||
in (c,) }""", lambda x: x + x * x, [0], has_output_token=False)
|
||||
self.assertRewrite("""
|
||||
{ lambda ; a d.
|
||||
let b = mul a a
|
||||
c = add a b
|
||||
in (c, d) }""", lambda x: x + x * x, [0])
|
||||
|
||||
def test_simple_outfeed(self):
|
||||
self.assertRewrite("""
|
||||
{ lambda ; a d.
|
||||
let b = add a a
|
||||
c e = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
] b d
|
||||
in (c, e) }""", lambda x: hcb.id_print(x + x), [0])
|
||||
|
||||
def test_cond(self):
|
||||
y = np.ones(5) # captured const
|
||||
def func(x, z):
|
||||
return lax.cond(z > 0, (1, 2), lambda a: (a[0], np.zeros(5)),
|
||||
z, lambda a: (hcb.id_print(a), y))
|
||||
self.assertRewrite("""
|
||||
{ lambda d e ; a b h.
|
||||
let c = gt b 0
|
||||
f g i = cond[ false_jaxpr={ lambda ; c a d.
|
||||
let b e = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
] a d
|
||||
in (b, c, e) }
|
||||
linear=(False, False, False, False, False, False, False)
|
||||
true_jaxpr={ lambda ; c a b d.
|
||||
let
|
||||
in (a, c, d) } ] c d 1 2 h e b h
|
||||
in (f, g, i) }""", func, [y, 5])
|
||||
|
||||
def test_while(self):
|
||||
y = np.ones(5) # captured const
|
||||
|
||||
def func(x):
|
||||
return lax.while_loop(lambda c: c[1] < 5,
|
||||
lambda c: (y, hcb.id_print(c[1]) + 1), (x, 1))
|
||||
# TODO: we should not need to start a receiver here!!! I believe this is
|
||||
# because of the partial evaluation of while, which calls impl, which
|
||||
# uses JIT.
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertRewrite("""
|
||||
{ lambda b ; a e.
|
||||
let c d f = while[ body_jaxpr={ lambda ; c a b f.
|
||||
let d g = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
] b f
|
||||
e = add d 1
|
||||
in (c, e, g) }
|
||||
body_nconsts=1
|
||||
cond_jaxpr={ lambda ; a b d.
|
||||
let c = lt b 5
|
||||
in (c,) }
|
||||
cond_nconsts=0 ] b a 1 e
|
||||
in (c, 5, f) }""", func, [y])
|
||||
|
||||
def test_scan(self):
|
||||
y = np.ones(5) # captured const
|
||||
def func(x):
|
||||
return lax.scan(lambda c, a: (hcb.id_print(c), y), (1, 2), x)
|
||||
self.assertRewrite("""
|
||||
{ lambda b ; a f.
|
||||
let c d g e = scan[ jaxpr={ lambda ; f a b g c.
|
||||
let d e h = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
|
||||
func=_print
|
||||
] a b g
|
||||
in (d, e, h, f) }
|
||||
length=5
|
||||
linear=(False, False, False, False, False)
|
||||
num_carry=3
|
||||
num_consts=1
|
||||
reverse=False ] b 1 2 f a
|
||||
in (c, d, e, g) }""", func, [y])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user