An experiment for id_print implemented with outfeed

* Added print descriptors, support multiple types
* Added a state-passing mechanism to XLA interpreter
This commit is contained in:
George Necula 2020-04-22 12:10:18 +02:00
parent 970e475e0a
commit de685c9d5a
5 changed files with 902 additions and 12 deletions

View File

@ -100,3 +100,10 @@ pytype_library(
srcs_version = "PY3",
deps = [":jax"],
)
pytype_library(
name = "experimental_host_callback",
srcs = ["experimental/host_callback.py"],
srcs_version = "PY3",
deps = [":jax"],
)

View File

@ -0,0 +1,308 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of an experimental primitive for printing, including
from transformed and compiled code.
See documentation for `id_print` below.
For usage example, see tests/host_callback_test.py.
Implementation plan:
* Write the API for the `id_print` primitive, using data-dependence as
explained in `id_print` documentation (DONE).
* Implement the transformations. DONE (except pmap)
* Implement the JIT for CPU using CustomCall in C++. DONE (except unit tests
do not run in OSS; also missing float16 and bfloat16).
* Implement the JIT for GPU using also CustomCall in C++. DONE.
* Explore how to pipe the printed data back to the Colab cell,
when running in Colab. ?
* Explore implementation using outfeed, hoping that it works for all
platforms, and can pipe data more easily. STARTED.
* Explore feeding the data back to the Python program (the `id_tap`
primitive). ?
* Explore a simpler API that uses Python program-order, instead of
data dependency-order. Need to add support to JAX for stateful primitives.
"""
from contextlib import contextmanager
from functools import partial
import io
import itertools
from jax import abstract_arrays
from jax import core
from jax import dtypes
from jax import lax
from jax.lib import pytree, xla_bridge
from jax.interpreters import ad, xla, batching
from jax.interpreters import partial_eval as pe
from jax import util
from jaxlib import xla_client
from jaxlib import xla_extension
import logging
import numpy as onp
import os
import threading
from typing import Any, Dict, List, Optional, Sequence, Tuple
# TODO(necula): fix mypy errors if I define the type aliases below
XlaOp = Any # xla_extension.XlaOp
XlaShape = Any # xla_client.Shape
XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder
id_print_p = core.Primitive("id_print")
id_print_p.multiple_results = True
xops = xla_client._xla.ops
def id_print(*args, result=None, **kwargs):
"""Behaves like the identify function for positional arguments, but prints all
arguments on the host, even from transformed or compiled code.
The return value is a tuple with the value of `args` or the value of the
keyword parameter `result` if present. If there is a single positional
argument, it returns just that argument without packing it in a tuple.
The positional arguments must be JAX values. The keyword arguments are
serialized to a string and printed along with the positional arguments.
There are a few special keywork arguments that are not printed:
* `result`: is the result of `id_print`, must be a JAX value or a
pytree of values.
* `output_stream`: is the output stream where the values should be
printed. (Note: does not yet work from under JIT).
Usage:
>>> y = id_print(x * 2) # prints and returns 2x
>>> y, z = id_print(x * 2, x * 3) # prints and returns 2x and 3x
>>> y = id_print(x * 2, result=y) # prints 2x and returns y
>>> y = id_print(x * 2, what='x') # prints what=x followed by 2x
The order of execution is by data dependency: after all the arguments are
computed and before the result is used. At least one of the returned values
must be used in the rest of the computation, or else this operation has
no effect.
Upon JAX transformations, the transformed values are wrapped with
`id_print`, and a special `transforms` tuple keyword argument is added with
the sequence of transformations applied:
- For `vmap` the arguments are batched, and transforms=('vmap')
- For `jvp` there will be an id_print for the primal values, and a
separate `id_print` for the tangents with `transforms=('jvp')`.
- For `grad` there will be an `id_print` for the primal values (if
needed in the computation of `grad` and an `id_print` with the
adjoints of the results, with transforms=('vjp').
"""
flat_args, args_treedef = pytree.flatten(args)
params = dict(kwargs) # copy
if result is not None:
flat_results, results_treedef = pytree.flatten(result)
params["nr_results"] = len(flat_results)
all_args = flat_args + flat_results
else:
all_args = flat_args
flat_outs = id_print_p.bind(*all_args, **params) # Always returns a tuple of all args
if result is not None:
return results_treedef.unflatten(flat_outs[-params["nr_results"]:])
else:
res = args_treedef.unflatten(flat_outs)
return res if len(args) > 1 else res[0]
def _expand_params_transform(params: Dict, transform: str) -> Dict:
"""Adds the `transform` to the params["transforms"]."""
return dict(params, transforms=params.get("transforms", ()) + (transform,))
def _id_print_impl(*args, **params):
output_stream = params.get("output_stream")
if output_stream is not None:
print_params = dict(params)
del print_params["output_stream"]
else:
import sys
output_stream = sys.stdout
print_params = params
# TODO: use the JITed version to do the actual printing.
to_print = f"{args} {print_params}"
output_stream.write(to_print)
return args
id_print_p.def_impl(_id_print_impl)
def _id_print_abstract_eval(*args_a: pe.AbstractValue, **params) \
-> Sequence[pe.AbstractValue]:
return args_a
id_print_p.def_abstract_eval(_id_print_abstract_eval)
# Each array sent to the outfeed is preceeded by a descriptor, which
# is an array s32[32], with the format:
# [0]: special header value
# [1]: an encoding of the element_type()
# [2]: the number of dimensions
# [3:...]: the size of the dimensions
# padded with 0s
#
_OUTFEED_DESCRIPTOR_PRINT_HEADER = 13579
_OUTFEED_DESCRIPTOR_LENGTH = 32
_CODE_TO_DTYPE = {
0: onp.dtype(onp.int8),
1: onp.dtype(onp.int16),
2: onp.dtype(onp.int32),
3: onp.dtype(onp.int64),
4: onp.dtype(onp.uint8),
5: onp.dtype(onp.uint16),
6: onp.dtype(onp.uint32),
7: onp.dtype(onp.uint64),
8: onp.dtype(onp.float16),
9: onp.dtype(onp.float32),
10: onp.dtype(onp.float64),
11: onp.dtype(dtypes.bfloat16),
}
_DTYPE_STR_TO_CODE = dict([(str(d), c) for c, d in _CODE_TO_DTYPE.items()])
def _id_print_translation_rule_outfeed(
comp: XlaComputationBuilder,
*args_op: XlaOp, **params):
# TODO: outfeed a whole tuple at once?
def _outfeed_one_array(a: XlaOp, token: XlaOp) -> XlaOp:
a_shape = comp.GetShape(a)
dimensions = a_shape.dimensions()
descriptor = [_OUTFEED_DESCRIPTOR_PRINT_HEADER,
_DTYPE_STR_TO_CODE[str(onp.dtype(a_shape.element_type()))],
len(dimensions)] + list(dimensions)
if len(descriptor) > _OUTFEED_DESCRIPTOR_LENGTH:
raise ValueError(f"Too many dimensions in array to print: {a_shape}")
descriptor += [0] * (_OUTFEED_DESCRIPTOR_LENGTH - len(descriptor))
data = xops.ConstantLiteral(comp, onp.array(descriptor, dtype=onp.int32))
token = xops.OutfeedWithToken(data, token, comp.GetShape(data))
token = xops.OutfeedWithToken(a, token, a_shape)
return token
prev_token = xla.computation_state_carry.current_token(comp)
nr_args_to_emit = len(args_op) - params.get("nr_results", 0)
for i, a_op in enumerate(args_op):
if i < nr_args_to_emit:
prev_token = _outfeed_one_array(a_op, prev_token)
xla.computation_state_carry.set_current_token(comp, prev_token)
return xops.Tuple(comp, args_op)
xla.translations[id_print_p] = _id_print_translation_rule_outfeed
# TODO: find a better way to signal the end of printing
_END_PRINTING = onp.int32(12345678)
def end_printing(res):
return id_print(_END_PRINTING, result=res)
@contextmanager
def print_receiver(output_stream=None,
receiver_name="",
timeout_sec=10):
# TODO: better timeout management
"""Starts a receiver for the id_print outfeed.
Usage:
with print_receiver():
jax.jit(func)(args)
"""
# TODO: start receivers for each device
platform = xla_client.get_local_backend(None).platform
def _get_data(data_shape: XlaShape) -> XlaShape:
if platform == "gpu":
return xla_client.transfer_from_outfeed(data_shape)
else:
return xla_client.transfer_from_outfeed(
xla_client.Shape.tuple_shape((data_shape,)))[0]
def _consume_one_array_from_outfeed():
descriptor_shape = xla_client.Shape.array_shape(onp.dtype(onp.int32),
(_OUTFEED_DESCRIPTOR_LENGTH,))
descriptor = _get_data(descriptor_shape)
logging.info(f"[{receiver_name}] Read descriptor: {descriptor}")
if descriptor[0] != _OUTFEED_DESCRIPTOR_PRINT_HEADER:
raise ValueError(f"Read unexpected print descriptor {descriptor} [{receiver_name}]")
data_dimensions = tuple(descriptor[3:3+descriptor[2]])
data_shape = xla_client.Shape.array_shape(_CODE_TO_DTYPE[descriptor[1]],
data_dimensions)
data = _get_data(data_shape)
logging.info(f"[{receiver_name}] Read data of shape {data.dtype}{data.shape}")
return data
def receiver_loop():
i = 0
while (True):
got = _consume_one_array_from_outfeed()
if not got.shape and got == _END_PRINTING:
logging.info(f"[{receiver_name}] Received END_PRINTING")
return
got_str = onp.array2string(got, threshold=1024)
logging.info(f"[{receiver_name}] Received {i} ({got.dtype}{got.shape}): {got_str}")
if output_stream is not None:
output_stream.write(got_str)
i += 1
receiver = threading.Thread(target=receiver_loop)
receiver.start()
try:
yield receiver
finally:
# TODO: proper termination
receiver.join(timeout=timeout_sec)
if receiver.is_alive():
logging.error(f"[{receiver_name}] Receiver still alive")
else:
logging.info(f"[{receiver_name}] Receiver finished")
def _id_print_jvp_rule(primals, tangents, **params):
primals_out = id_print(primals, **params)
tangents_out = id_print(tangents, **_expand_params_transform(params, "jvp"))
return primals_out, tangents_out
ad.primitive_jvps[id_print_p] = _id_print_jvp_rule
def _id_print_transpose_rule(cts, *args, **params):
assert all([ad.is_undefined_primal(x) for x in args])
assert len(cts) == len(args)
ct_args = id_print_p.bind(*cts,
**_expand_params_transform(params, "transpose"))
return ct_args
ad.primitive_transposes[id_print_p] = _id_print_transpose_rule
def _id_print_batching_rule(batched_args, batch_dims, **params):
res = id_print_p.bind(*batched_args,
**_expand_params_transform(params, "batch"))
return res, batch_dims
batching.primitive_batchers[id_print_p] = _id_print_batching_rule

View File

@ -14,9 +14,11 @@
from collections import defaultdict
from contextlib import contextmanager
import itertools as it
import operator as op
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type
import threading
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Tuple
from absl import logging
import numpy as onp
@ -33,13 +35,14 @@ from ..abstract_arrays import (ConcreteArray, ShapedArray, AbstractToken,
from ..core import Literal, pp_eqn_compact
from ..pprint_util import pp
from ..util import (partial, partialmethod, cache, prod, unzip2, memoize,
extend_name_stack, wrap_name)
extend_name_stack, wrap_name, split_list)
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from . import partial_eval as pe
from . import ad
from . import masking
xe = xc._xla
xops = xc._xla.ops
@ -48,6 +51,11 @@ Backend = Any # xc.LocalBackend (why does mypy not like this?)
Device = Any # xc.Device
PyLocalBuffer = Any
XlaOp = Any # xla_extension.XlaOp
XlaShape = Any # xla_client.Shape
XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder
FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_debug_nans',
bool_env('JAX_DEBUG_NANS', False),
@ -162,6 +170,63 @@ pytype_aval_mappings.update((t, make_shaped_array) for t in array_types)
pytype_aval_mappings.update(
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
class _ComputationStateCarry(threading.local):
"""Carries some state in global state.
For now the state is only a token, obtained from the last OutFeed.
The translation rules for primitives can read-write from this class.
This assumes that the primitives are processed in order!
"""
_current_computation: Optional[XlaComputationBuilder]
_current_token: Optional[XlaOp]
def __init__(self) -> None:
self._current_computation = None
self._current_token = None
def state_len(self):
return 1
def current_state(self, comp: XlaComputationBuilder) -> List[XlaOp]:
if comp is not self._current_computation:
if self._current_computation is not None:
# TODO: add more error checking
logging.warning("Overwriting previous computation")
self._current_computation = comp
self._current_token = xops.CreateToken(comp)
return [self._current_token]
def current_token(self, comp: XlaComputationBuilder) -> XlaOp:
state = self.current_state(comp)
return state[0]
def set_current_token(self, comp: XlaComputationBuilder, token: XlaOp):
assert comp == self._current_computation
self._current_token = token
def set_comp_and_current_state(self, comp: XlaComputationBuilder,
state: Sequence[XlaOp]):
self._current_computation = comp
self._current_token = state[0]
def extract_state_from_tuple_op(self, comp: XlaComputationBuilder,
tuple_op: XlaOp, nr_regular: int):
"""Given a tuple extract the state from its tail.
We assume that the `tuple_op` represents a tuple with `nr_outs` regular
elements, followed by some elements encoding the state. Stores the new state,
and returns a Tuple with the regular elemennts.
"""
regular_ops = [xops.GetTupleElement(tuple_op, i) for i in range(nr_regular)]
self._current_computation = comp
current_state = [xops.GetTupleElement(tuple_op, i)
for i in range(nr_regular, nr_regular + self.state_len())]
self._current_token = current_state[0]
return xops.Tuple(comp, regular_ops)
computation_state_carry = _ComputationStateCarry()
### op-by-op execution
def arg_spec(x):
@ -622,12 +687,23 @@ def _xla_call_translation_rule(c, axis_env,
in_nodes, name_stack, backend, name,
call_jaxpr, device=None):
del device # Ignored.
prev_state = computation_state_carry.current_state(c)
all_in_nodes = list(in_nodes) + list(prev_state)
subc = xb.make_computation_builder(f"jit_{name}")
args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
all_args = [xb.parameter(subc, i, c.GetShape(n))
for i, n in enumerate(all_in_nodes)]
args, input_state = split_list(all_args, [len(in_nodes)])
computation_state_carry.set_comp_and_current_state(subc, input_state)
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, 'jit')), *args)
subc = subc.Build(xops.Tuple(subc, out_nodes))
return xops.Call(c, subc, list(in_nodes))
result_state = computation_state_carry.current_state(subc)
subc = subc.Build(xops.Tuple(subc, list(out_nodes) + result_state))
call_op = xops.Call(c, subc, all_in_nodes)
nr_outs = len(out_nodes)
return computation_state_carry.extract_state_from_tuple_op(c, call_op, nr_outs)
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)

View File

@ -254,13 +254,15 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts])
batched = bool(cond_jaxpr.out_avals[0].shape)
prev_state = xla.computation_state_carry.current_state(c)
# Since jaxprs don't have tuples and have multiple return values, but we need
# the HLO While loop to take a single tuple input and output a single boolean
# (for the cond computation) or a single tuple output (for the body
# computation), we build XLA computations that handle the tuple munging before
# generating a Call into the computations formed from the jaxprs.
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals)
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals + prev_state)
cond_c = xb.make_computation_builder("cond_computation")
cond_carry = xb.parameter(cond_c, 0, c.GetShape(init_carry))
@ -279,6 +281,8 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
body_c = xb.make_computation_builder("body_computation")
body_carry = xb.parameter(body_c, 0, c.GetShape(init_carry))
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
xla.computation_state_carry.extract_state_from_tuple_op(
body_c, body_carry, len(args))
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
_map(partial(xb.constant, body_c), body_jaxpr.literals),
@ -289,10 +293,13 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
extend_name_stack(name_stack, 'body_pred'), *(x + z))
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z)
assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape, z) # no broadcast
new_carry = xops.Tuple(body_c, list(itertools.chain(x, y, new_z)))
result_state = xla.computation_state_carry.current_state(body_c)
new_carry = xops.Tuple(body_c, list(itertools.chain(x, y, new_z, result_state)))
ans = xops.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry)
ans_elts = [xops.GetTupleElement(ans, i) for i in range(len(args))]
xla.computation_state_carry.extract_state_from_tuple_op(
c, ans, len(args))
_, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
return xops.Tuple(c, z)
@ -556,23 +563,30 @@ def _cond_translation_rule(c, axis_env, name_stack, avals, backend,
pred, *args, true_jaxpr, false_jaxpr, linear):
del linear # Unused.
true_ops, false_ops = split_list(args, [len(true_jaxpr.in_avals)])
current_state = xla.computation_state_carry.current_state(c)
def make_computation(name, jaxpr, op_shape):
c = xb.make_computation_builder(name + '_comp')
op = xb.parameter(c, 0, op_shape)
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
xla.computation_state_carry.extract_state_from_tuple_op(
c, op, len(jaxpr.in_avals))
outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
_map(partial(xb.constant, c), jaxpr.literals),
extend_name_stack(name_stack, name + '_fun'), *ops)
return c.Build(xops.Tuple(c, outs))
result_state = xla.computation_state_carry.current_state(c)
return c.Build(xops.Tuple(c, list(outs) + result_state))
true_op = xops.Tuple(c, true_ops)
true_op = xops.Tuple(c, true_ops + current_state)
true_c = make_computation('true', true_jaxpr, c.GetShape(true_op))
false_op = xops.Tuple(c, false_ops)
false_op = xops.Tuple(c, false_ops + current_state)
false_c = make_computation('false', false_jaxpr, c.GetShape(false_op))
cond_op = xops.Conditional(pred, true_op, true_c, false_op, false_c)
nr_outs = len(true_jaxpr.out_avals)
return xops.Conditional(pred, true_op, true_c, false_op, false_c)
return xla.computation_state_carry.extract_state_from_tuple_op(
c, cond_op, nr_outs)
def _cond_pred_bcast_select(pred, x, y):
if core.get_aval(x) is core.get_aval(y) is core.abstract_unit:

485
tests/host_callback_test.py Normal file
View File

@ -0,0 +1,485 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as onp
import os
import re
import threading
from unittest import SkipTest
from absl.testing import absltest
from absl.testing import parameterized
from jax import api
from jax import lax
from jax import numpy as np
from jax import test_util as jtu
from jax.config import config
from jax.experimental import host_callback as hcb
from jax.interpreters import xla
from jax.interpreters import partial_eval as pe
from jax.lib import xla_bridge
config.parse_flags_with_absl()
FLAGS = config.FLAGS
def skip_if_jit_not_enabled():
if os.getenv("JAX_ENABLE_JIT_PRINT", "false") == "false":
raise SkipTest("print jit not enabled yet; use JAX_ENABLE_JIT_PRINT env.")
class _TestingOutputStream(object):
"""Use as `output_stream` for tests."""
def __init__(self):
self._output = []
def write(self, what: str) -> None:
# Sometimes we get floating points in the output; we round them
def repl(match_group):
# TODO: why can't we use here np.around?
matched = match_group.group(0)
if matched == ".": return matched
x = onp.around(float(matched), decimals=2)
return f"{x:.2f}"
what = re.sub(r"\-?\d*\.[\-\def]*", repl, what)
print(f"output_stream: {what}")
self._output.append(what)
@property
def output(self):
return "\n".join(self._output)
def __str__(self):
return "TestingOutputStream"
def reset(self):
self._output = []
testing_stream = _TestingOutputStream()
def fun1(a):
y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream)
y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y)
return y**4 # Some computation to make the gradient interesting
def fun1_equiv(a): # Numerical equivalent of fun`
return (a * 2.)**4
class HostCallbackTest(jtu.JaxTestCase):
def setUp(self):
testing_stream.reset()
def helper_set_devices(self, nr_devices):
flags_str = os.getenv("XLA_FLAGS", "")
os.environ["XLA_FLAGS"] = (
flags_str +
" --xla_force_host_platform_device_count={}".format(nr_devices))
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
return api.devices()
def helper_set_hlo_dump(self):
flags_str = os.getenv("XLA_FLAGS", "")
os.environ["XLA_FLAGS"] = f"{flags_str} --xla_dump_to=/tmp/xla_dump"
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
def helper_print_serialization(self, description, np_vals, params):
"""Encode a print_descriptor and print it.
Args:
np_vals: a list of np.ndarray from which to extract the shapes.
"""
encoded = hcb._make_id_print_metadata(
[xla.aval_to_xla_shape(pe.get_aval(np_val)) for np_val in np_vals],
params)
print(f"test_serialization: {description}")
print(", ".join([f"{b}" for b in encoded]))
def test_serialization(self):
"""Prints encodings used in host_callback_test::TestParseDescriptor."""
raise SkipTest("Not implemented")
self.helper_print_serialization("no args, separator=sep, param=0", [],
dict(param=0, separator="sep"))
self.helper_print_serialization("1 scalar int, separator= param=1",
[np.int32(0)], dict(param=1, separator=""))
self.helper_print_serialization("1 array f32[2, 3], separator= param=2",
[np.ones((2, 3), dtype=np.float32)],
dict(param=2, separator=""))
self.helper_print_serialization(
"1 array f32[2, 3] and 1 f64, separator= param=3",
[np.ones((2, 3), dtype=np.float32),
np.float64(0)], dict(param=3, separator=""))
def test_with_tuple_result(self):
def func2(x):
x1, y1 = hcb.id_print(x * 2., x * 3., output_stream=testing_stream)
return x1 + y1
self.assertMultiLineStrippedEqual(
"""
{ lambda ; a.
let b = mul a 2.0
c = mul a 3.0
d e = id_print[ output_stream=TestingOutputStream ] b c
f = add d e
in (f,) }""", str(api.make_jaxpr(func2)(3.)))
self.assertEqual(3. * (2. + 3.), func2(3.))
self.assertMultiLineStrippedEqual("""
(6.00, 9.00) {}""", testing_stream.output)
testing_stream.reset()
def test_eval(self):
self.assertMultiLineStrippedEqual(
"""
{ lambda ; a.
let b = mul a 2.0
c = id_print[ output_stream=TestingOutputStream
what=a * 2 ] b
d = mul c 3.0
e f = id_print[ nr_results=1
output_stream=TestingOutputStream
what=y * 3 ] d c
g = pow f 4.0
in (g,) }""", str(api.make_jaxpr(fun1)(5.)))
self.assertEqual("", testing_stream.output)
self.assertEqual((5. * 2.)**4, fun1(5.))
self.assertMultiLineStrippedEqual(
"""
(10.00,) {'what': 'a * 2'}
(30.00, 10.00) {'what': 'y * 3', 'nr_results': 1}""", testing_stream.output)
testing_stream.reset()
def test_jit_simple(self):
jit_fun1 = api.jit(lambda x: hcb.end_printing(3. * hcb.id_print(
2. * x, what="here")))
self.assertMultiLineStrippedEqual(
"""
{ lambda ; a.
let b = xla_call[ backend=None
call_jaxpr={ lambda ; a.
let b = mul a 2.0
c = id_print[ what=here ] b
d = mul c 3.0
e f = id_print[ nr_results=1 ] 12345678 d
in (f,) }
device=None
name=<lambda> ] a
in (b,) }""", str(api.make_jaxpr(jit_fun1)(5.)))
logging.info("%s: %s",
self._testMethodName, api.xla_computation(jit_fun1)(5.).GetHloText())
with hcb.print_receiver(output_stream=testing_stream,
receiver_name=self._testMethodName):
res = jit_fun1(5.)
self.assertAllClose(6. * 5., res, check_dtypes=True)
self.assertMultiLineStrippedEqual(
"""
10.00""", testing_stream.output)
testing_stream.reset()
def test_simple_jit_sequencing(self):
def func(x):
x1 = hcb.id_print(x, where="1")
x2 = hcb.id_print(x1 + 1, where="2")
return hcb.end_printing(x2)
logging.info("%s: %s", self._testMethodName,
api.make_jaxpr(func)(1))
logging.info("%s: %s", self._testMethodName,
api.xla_computation(func)(1).GetHloText())
with hcb.print_receiver(output_stream=testing_stream,
receiver_name=self._testMethodName):
self.assertEqual(2, api.jit(func)(1))
self.assertMultiLineStrippedEqual(
"""
1
2""", testing_stream.output)
testing_stream.reset()
def test_jit2(self):
"""A sequence of JIT."""
def func(x):
x1 = hcb.id_print(x, where="1")
x2 = hcb.id_print(x1 + 1, where="2")
return x2
with hcb.print_receiver(output_stream=testing_stream,
receiver_name=self._testMethodName):
self.assertEqual(2, api.jit(func)(1))
self.assertEqual(11, api.jit(func)(10))
# Now send the end of printing
api.jit(lambda x: hcb.end_printing(x))(0)
self.assertMultiLineStrippedEqual(
"""
1
2
10
11""", testing_stream.output)
testing_stream.reset()
def test_jit_nested(self):
def func(x):
x1 = hcb.id_print(x, where="1")
def func_nested(x):
x2 = hcb.id_print(x, where="nested")
return x2
x3 = api.jit(func_nested)(x1)
x2 = hcb.id_print(x3 + 1, where="2")
return hcb.end_printing(x2)
logging.warning("%s: %s", self._testMethodName,
api.make_jaxpr(func)(1))
logging.warning("%s: %s", self._testMethodName,
api.xla_computation(func)(1).GetHloText())
return
with hcb.print_receiver(output_stream=testing_stream,
receiver_name=self._testMethodName):
pass # self.assertEqual(2, api.jit(func)(1))
self.assertMultiLineStrippedEqual(
"""
1
2""", testing_stream.output)
testing_stream.reset()
def test_jit_cond1(self):
"""A conditional"""
def func(x):
x1 = hcb.id_print(x, where="1")
x2 = hcb.id_print(x1 + 1, where="2")
x4 = lax.cond(x % 2 == 0,
x2 + 1, lambda x: hcb.id_print(x, where="cond_t"),
x2 + 1, lambda x: hcb.id_print(-1, where="cond_f", result=x))
x5 = hcb.id_print(x4 + 1, where="w.2")
return hcb.end_printing(x5)
logging.warning("%s: %s", self._testMethodName, api.make_jaxpr(func)(1))
logging.warning("%s: %s", self._testMethodName,
api.xla_computation(func)(1).GetHloText())
with hcb.print_receiver(output_stream=testing_stream,
receiver_name=self._testMethodName):
self.assertEqual(4, api.jit(func)(1))
self.assertMultiLineStrippedEqual("""
1
2
-1
4""", testing_stream.output)
testing_stream.reset()
def test_jit_while_cond(self):
def func(x):
x1 = hcb.id_print(x, where="1")
x2 = hcb.id_print(x1 + 1, where="2")
def body(x):
x3 = hcb.id_print(x, where="w.1")
x4 = lax.cond(x % 2 == 0,
x3 + 1, lambda x: hcb.id_print(x, where="w.t"),
x3 + 1, lambda x: hcb.id_print(-1, where="w.f", result=x))
return hcb.id_print(x4 + 1, where="w.2")
x10 = lax.while_loop(lambda x: x < 10, body, x2)
res = hcb.id_print(x10, where="10")
return hcb.end_printing(res)
logging.warning("%s: %s", self._testMethodName, api.make_jaxpr(func)(1))
logging.warning("%s: %s", self._testMethodName,
api.xla_computation(func)(1).GetHloText())
with hcb.print_receiver(output_stream=testing_stream,
receiver_name=self._testMethodName):
self.assertEqual(10, api.jit(func)(1))
self.assertMultiLineStrippedEqual(
"""
1
2
2
3
4
4
5
6
6
7
8
8
9
10
10""", testing_stream.output)
testing_stream.reset()
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
testcase_name=f"_shape_{shape}_dtype_{dtype}_nr_args={nr_args}",
shape=shape,
dtype=dtype,
nr_args=nr_args) for nr_args in [1, 2]
for shape in [(), (2,), (2, 3), (2, 3, 4)]
for dtype in jtu.supported_dtypes()))
def test_jit_types(self, nr_args=2, dtype=np.int16, shape=(2,)):
if dtype in (np.complex64, np.complex128, np.bool_):
raise SkipTest(f"id_print jit not implemented for {dtype}.")
if jtu.device_under_test() == "tpu":
if dtype in (np.int16,):
raise SkipTest(f"transfering {dtype} not supported on TPU")
self.helper_set_hlo_dump()
args = [np.arange(np.prod(shape), dtype=dtype).reshape(shape)]
if nr_args > 1:
args = args * nr_args
jit_fun1 = api.jit(lambda xs: hcb.end_printing(
hcb.id_print(
*xs,
a_new_test="************",
testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}")))
with hcb.print_receiver(receiver_name=self._testMethodName):
res = jit_fun1(args)
# self.assertAllClose(args, res, check_dtypes=True)
def test_jit_large(self):
arg = np.arange(10000, dtype=np.int32).reshape((10, 10, 5, -1))
with hcb.print_receiver(output_stream=testing_stream,
receiver_name=self._testMethodName):
api.jit(lambda x: hcb.end_printing(hcb.id_print(x)))(arg)
def test_jvp(self):
jvp_fun1 = lambda x, xt: api.jvp(fun1, (x,), (xt,))
self.assertMultiLineStrippedEqual(
"""
{ lambda ; a b.
let c = mul a 2.0
d = id_print[ output_stream=TestingOutputStream
what=a * 2 ] c
e = mul d 3.0
f g = id_print[ nr_results=1
output_stream=TestingOutputStream
what=y * 3 ] e d
h = pow g 4.0
i = mul b 2.0
j = id_print[ output_stream=TestingOutputStream
transforms=('jvp',)
what=a * 2 ] i
k = mul j 3.0
l m = id_print[ nr_results=1
output_stream=TestingOutputStream
transforms=('jvp',)
what=y * 3 ] k j
n = pow g 3.0
o = mul 4.0 n
p = mul m o
in (h, p) }""",
str(api.make_jaxpr(jvp_fun1)(np.float32(5.), np.float32(0.1))))
res_primals, res_tangents = jvp_fun1(np.float32(5.), np.float32(0.1))
self.assertMultiLineStrippedEqual(
"""
(DeviceArray(10.00, dtype=float32),) {'what': 'a * 2'}
(DeviceArray(0.20, dtype=float32),) {'what': 'a * 2', 'transforms': ('jvp',)}
(DeviceArray(30.00, dtype=float32), DeviceArray(10.00, dtype=float32)) {'what': 'y * 3', 'nr_results': 1}
(DeviceArray(0.60, dtype=float32), DeviceArray(0.20, dtype=float32)) {'what': 'y * 3', 'nr_results': 1, 'transforms': ('jvp',)}
""", testing_stream.output)
testing_stream.reset()
def test_grad(self):
raise SkipTest("failing with new implementation")
grad_fun1 = api.grad(fun1)
self.assertMultiLineStrippedEqual(
"""
{ lambda ; a.
let b = mul a 2.0
c = id_print[ output_stream=TestingOutputStream
what=a * 2 ] b
d = mul c 3.0
e = id_print[ output_stream=TestingOutputStream
what=y * 3 ] d
f = tie_in e c
g = pow f 3.0
h = mul 4.0 g
i = mul 1.0 h
j = id_print[ output_stream=TestingOutputStream
transforms=('jvp', 'transpose')
what=a * 2 ] i
k = mul j 2.0
in (k,) }""", str(api.make_jaxpr(grad_fun1)(5.)))
# This comes from the actual partial evaluation
self.assertMultiLineStrippedEqual(
"""
(Zero,) {'what': 'y * 3', 'transforms': ('jvp', 'transpose')}
""", testing_stream.output)
testing_stream.reset()
res_grad = grad_fun1(np.float32(5.))
self.assertMultiLineStrippedEqual(
"""
(DeviceArray(10.00, dtype=float32),) {'what': 'a * 2'}
(DeviceArray(30.00, dtype=float32),) {'what': 'y * 3'}
(Zero,) {'what': 'y * 3', 'transforms': ('jvp', 'transpose')}
(DeviceArray(4000.00, dtype=float32),) {'what': 'a * 2', 'transforms': ('jvp', 'transpose')}
""", testing_stream.output)
testing_stream.reset()
def test_vmap(self):
vmap_fun1 = api.vmap(fun1)
vargs = np.array([np.float32(4.), np.float32(5.)])
self.assertMultiLineStrippedEqual(
"""
{ lambda ; a.
let b = mul a 2.0
c = id_print[ output_stream=TestingOutputStream
transforms=('batch',)
what=a * 2 ] b
d = mul c 3.0
e f = id_print[ nr_results=1
output_stream=TestingOutputStream
transforms=('batch',)
what=y * 3 ] d c
g = pow f 4.0
in (g,) }""", str(api.make_jaxpr(vmap_fun1)(vargs)))
res_vmap = vmap_fun1(vargs)
self.assertMultiLineStrippedEqual(
"""
(DeviceArray([ 8.00, 10.00], dtype=float32),) {'what': 'a * 2', 'transforms': ('batch',)}
(DeviceArray([24.00, 30.00], dtype=float32), DeviceArray([ 8.00, 10.00], dtype=float32)) {'what': 'y * 3', 'nr_results': 1, 'transforms': ('batch',)}
""", testing_stream.output)
testing_stream.reset()
def test_pmap(self):
skip_if_jit_not_enabled()
self.helper_set_devices(4)
vargs = np.arange(api.local_device_count(), dtype=np.float32)
pmap_fun1 = api.pmap(fun1, axis_name="i")
res = pmap_fun1(vargs)
if __name__ == "__main__":
absltest.main()