[hcb] Simplifications to the host_calback API

* dropping support for special AD handling for hcb.id_tap and id_print.
  From now on, only the primals are tapped. The old behavior can be
  obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS
  environment variale, or the --flax_host_callback_ad_transforms flag.
  Additionally, added documentation for how to implement the old behavior
  using JAX custom AD APIs.

This allows us to make some significant cleanup in the internals.
This commit is contained in:
George Necula 2021-11-24 12:58:16 +02:00
parent 53318a2a7a
commit f08156ab7c
5 changed files with 500 additions and 208 deletions

View File

@ -12,6 +12,15 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.26...main).
* Breaking changes:
* The host_callback primitives have been simplified to drop the
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs ({jax-issue}`#7839`).
## jaxlib 0.1.76 (Unreleased)
## jaxlib 0.1.75 (Dec 8, 2021)

View File

@ -463,6 +463,15 @@ flags.DEFINE_bool(
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
)
)
flags.DEFINE_bool(
'jax_host_callback_ad_transforms',
bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False),
help=(
'Enable support for jvp/vjp for the host_callback primitives. Default is '
'False, which means that host_callback operates only on primals. '
'The flag exists only temporarily, for backward compatibility.'
)
)
enable_checks = config.define_bool_state(
name='jax_enable_checks',

View File

@ -1009,6 +1009,9 @@ class JaxTestCase(parameterized.TestCase):
ignore_space_re = re.compile(r'\s*\n\s*')
expected_clean = re.sub(ignore_space_re, '\n', expected.strip())
what_clean = re.sub(ignore_space_re, '\n', what.strip())
if what_clean != expected_clean:
# Print it so we can copy-and-paste it into the test
print(f"Found\n{what}\n")
self.assertMultiLineEqual(expected_clean, what_clean,
msg="Found\n{}\nExpecting\n{}".format(what, expected))

View File

@ -26,10 +26,8 @@ on the CPU, e.g., to use NumPy CPU custom kernels. Then we
show uses of :func:`id_tap` and :func:`id_print`, which have the restriction
that they cannot return values from the host to the device.
These primitives are generally faster
because they are executed asynchronously with the device code and they also
support the whole spectrum of JAX transformations. In particular, they can be
used to tap into and to debug JAX-transformed code.
because they are executed asynchronously with the device code.
In particular, they can be used to tap into and to debug JAX code.
Using :func:`call` to call a host function and return results to device
-----------------------------------------------------------------------
@ -61,19 +59,21 @@ using a host computation::
The :func:`call` function and the Python host function both take a single argument
and return a single result, but those can be pytrees. Note that we must tell
the :func:`call` what shape and dtype to expect from the host invocation, using
the ``result_shape`` kwarg.
the ``result_shape`` keyword argument.
This is important because the device code is compiled with that expectation.
There will be an error raised at runtime if the actual invocation produces a
different result shape. In general, **such errors and also exceptions raised
by the host computation may be difficult to debug**. See the Debugging section
below.
This is a problem for :func:`call` but not for :func:`id_tap`.
This is a problem for :func:`call` but not for :func:`id_tap` because for the
latter the decice code does not expect a returned value.
The :func:`call` API can be used inside a jit or pmap computation or inside
cond/scan/while control flow. When used inside :func:`jax.pmap`, there will be
separate calls to the host from each of the participating devices::
def host_sin(x, *, device):
# The ``device`` argument is passed due to ``call_with_device=True`` below.
print(f"Invoking host_sin with {x.shape} on {device}")
return np.sin(x)
@ -88,12 +88,12 @@ separate calls to the host from each of the participating devices::
# Invoking host_sin with (4,) on cpu:0
# Invoking host_sin with (4,) on cpu:1
Note that :func:`call` does not (yet) support any JAX transformations, but as we
Note that :func:`call` does not support any JAX transformations, but as we
show below one can make use of the
existing support for `Custom differentiation in JAX <https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html>`_.
Using :func:`id_tap` to call a Python function on the host, with no returned values, but full JAX transformation support
---------------------------------------------------------------------------------------------------------------------------
Using :func:`id_tap` to call a Python function on the host, with no returned values
-----------------------------------------------------------------------------------
The :func:`id_tap` and :func:`id_print` are special cases of :func:`call`, when
you just want the side effects of your Python callback. These functions have
@ -104,7 +104,7 @@ For :func:`id_tap` you can specify your Python callback to be called, while
`stdout` on the host.
The Python function passed
to :func:`id_tap` takes two positional arguments (the value tapped
from the device computation along with ``transforms`` sequence,
from the device computation along with a ``transforms`` tuple,
described below). Optionally, the function may be passed a keyword argument
``device`` with the Device from which the value was tapped.
@ -129,9 +129,8 @@ A few examples::
id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y))
The above examples can all be adapted to use :func:`id_print` instead, with
the difference that :func:`id_print` takes one positional argument (to print
on the host), and possibly additional kwargs
that are also printed along with the automatic kwarg ``transforms``.
the difference that :func:`id_print` prints on the host the positional argument,
along with any additional kwargs and the automatic kwarg ``transforms``.
Using :func:`barrier_wait` to wait until all callbacks have executed
--------------------------------------------------------------------
@ -184,76 +183,8 @@ of the computation, if all the callbacks are :func:`call`::
res1.block_until_ready()
res2.block_until_ready()
Behavior under JAX transformations
----------------------------------
The :func:`call` does not support any JAX transformations. However, the
:func:`id_tap` and :func:`id_print` support all transformations. In this
context, it is important that both these functions behave like the identity
function::
# calls func((2x, 3x), []) and returns (2x, 3x)
id_tap(func, (2 * x, 3 * x))
# calls func(2x, []) and returns y
y = id_tap(func, 2 * x, result=y) # override the result of id_tap
We describe the behaviour under transformations for :func:`id_tap` and
:func:`id_print` in the context of the
following function definition::
def power3(x):
y = x * x
# Print both 'x' and 'x^2'
_, y = id_print((x, y), what="x,x^2") # Must pack multiple arguments
return y * x
power3(3.)
# what: x,x^2 : (3., 9.)
(You can see these examples tested in `host_callback_test.HostCallbackIdTapTest.test_tap_transforms`.)
During JAX transformations the special parameter ``transforms`` is added to
contain a list of transformation descriptors in the form
``(transform_name, transform_params)``.
For :func:`jax.vmap` the arguments are batched, and ``transforms`` is extended
with transformation name ``batch`` and ``batch_dims`` set to the tuple of
batched dimensions (one entry per argument, ``None`` denotes an argument that
was broadcast)::
jax.vmap(power3)(np.arange(3.))
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : ([0, 1, 2], [0, 1,
4])
For :func:`jax.jvp` there will be one callback with a pair, consisting of
the values of the primals and those of the tangents::
jax.jvp(power3, (3.,), (0.1,))
# transforms: ['jvp'] what: x,x^2 : ( (3., 9.), (0.1, 0.6) )
For :func:`jax.vjp` or :func:`jax.grad` there will be one callback with the
values of the adjoints for the arguments. You may also see a callback with
the values of the primals from the forward pass, if those values are needed for
the backward pass::
jax.grad(power3)(3.)
# what=x,x^2: (3., 9.) # from forward pass, since y is used in backward pass
# transforms: ['jvp', 'transpose'] what: x,x^2 : (0., 3.) # from backward pass, adjoints of _, y
And here is an example of composed transforms. For vmap of grad, we see first
a callback with the vmap of the forward pass (with just the 'batch' transform),
and another callback with the vmap of the adjoints of the arguments. Note that
the first argument is replicated (`batch_dims` is None)::
jax.vmap(jax.grad(power3))(np.array([2., 3.]))
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
# ( [2. 3.]
# [4. 9.] )
# transforms: ['jvp', 'transpose', ('batch', {'batch_dims': (None, 0)})] what: x,x^2
# ( 0.
# [2. 3.] )
Behavior under parallelization transformations
----------------------------------------------
In presence of :func:`jax.pmap` the code will run on multiple devices and
each device will tap its values independently.
@ -294,7 +225,112 @@ the operand collected
from all participating devices on all hosts. For a :func:`call`, the callback
must return the entire array for all devices on all hosts.
Behavior under JAX autodiff transformations
-------------------------------------------
When used under a JAX autodiff transformation, the host callback functions
operate on the primal values only. Consider the following example:
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2")
return y * x
power3(3.)
# what: x,x^2 : (3., 9.)
(You can see these examples tested in `host_callback_test.HostCallbackTapTest.test_tap_transforms`.)
When used under :func:`jax.jvp` there will be one callback with the primal
values only::
jax.jvp(power3, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
Similarly for :func:`jax.grad`, we get a callback from the forward computation
only::
jax.grad(power3)(3.)
# what: x,x^2 : (3., 9.)
If you want to invoke the callback on the tangents during a :func:`jax.jvp`,
you can use a custom_jvp. For example, you can define a function that does
nothing interesting except that its custom_jvp will print the tangents::
@jax.custom_jvp
def print_tangents(arg):
return None
@print_tangents.defjvp
def print_tangents_jvp(primals, tangents):
arg_dot, = tangents
hcb.id_print(arg_dot, what="tangents")
return primals, tangents
Then you use this function in the places where you want to tap the tangents::
def power3_with_tangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2")
print_tangents((x, y))
return y * x
jax.jvp(power3_with_tangents, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
# what: tangents : (0.1, 0.6)
You can do a similar thing for the cotangents during :func:`jax.grad`. This
time you must be careful to use in the rest of the computation the values whose
cotangents you want to tap. Hence we make the ``print_cotangents`` return
its argument::
@jax.custom_vjp
def print_cotangents(arg):
# Must return the argument for which we want the cotangent.
return arg
# f_fwd: a -> (b, residual)
def print_cotangents_fwd(arg):
return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b):
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
return ct_b,
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
def power3_with_cotangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
(x1, y1) = print_cotangents((x, y))
# Must use the output of print_cotangents
return y1 * x1
jax.grad(power3_with_cotangents)(3.)
# what: x,x^2 : (3., 9.)
# what: cotangents : (9., 3.)
Behavior under jax.vmap
-----------------------
The host callback functions :func:`id_print` and :func:`id_tap` support the
vectorization transformation :func:`jax.vmap`.
For :func:`jax.vmap` the arguments to the callback are batched,
and the callback function is
passed an additional special ``transforms`` containing a list of transformation descriptors
in the form ``("batch", {"batch_dims": ...})``, where ``...``` denotes the
batched dimensions for the tapped values (one entry per argument, `
`None`` denotes an argument that was broadcast).
jax.vmap(power3)(np.array([2., 3.]))
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : ([2., 3.], [4., 9.])
See documentation for :func:`id_tap`, :func:`id_print`, and :func:`call`.
For more usage example, see tests/host_callback_test.py.
Using :func:`call` to call a TensorFlow function, with reverse-mode autodiff support
@ -444,7 +480,10 @@ import functools
import itertools
import threading
import traceback
from typing import (Any, Callable, Dict, List, Optional, Sequence, Tuple, cast)
from typing import (Any, Callable, Dict, List, Optional, Sequence,
Tuple, cast)
import warnings
from absl import logging
from jax._src import api
@ -533,11 +572,16 @@ def id_tap(tap_func, arg, *, result=None, tap_with_device=False, **kwargs):
"pre-apply keyword arguments, either by using a closure or by passing "
"``functools.partial(tap_func, **kwargs)``.")
raise TypeError(msg)
if FLAGS.jax_host_callback_ad_transforms:
warnings.warn('The flag jax_host_callback_ad_transforms is for temporary '
'backwards compatibility mode. This flag, and the behavior '
'it enabled will be removed soon.',
FutureWarning)
if result is not None:
flat_results, result_treedef = pytree.flatten(result)
for result in flat_results:
api._check_arg(result)
for r in flat_results:
api._check_arg(r)
call_res = _call(tap_func, arg, call_with_device=tap_with_device,
result_shape=None, identity=True)
@ -545,13 +589,16 @@ def id_tap(tap_func, arg, *, result=None, tap_with_device=False, **kwargs):
if result is not None:
# Return the results, but add a dependency on the call, to ensure it
# is kept in the graph.
call_flat_results, _ = pytree.flatten(call_res)
if call_flat_results:
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
for r in flat_results]
if FLAGS.jax_host_callback_ad_transforms:
call_flat_results, _ = pytree.flatten(call_res)
if call_flat_results:
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
for r in flat_results]
else:
call_flat_results = flat_results
return result_treedef.unflatten(call_flat_results)
else:
call_flat_results = flat_results
return result_treedef.unflatten(call_flat_results)
return result
else:
return call_res
@ -768,6 +815,8 @@ xla.register_translation(id_tap_dep_p,
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)
def _id_tap_dep_jvp_rule(primals, tangents):
if FLAGS.jax_host_callback_ad_transforms:
assert False
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
return (id_tap_dep_p.bind(primals[0], primals[1]),
id_tap_dep_p.bind(tangents_instantiated[0], tangents_instantiated[1]))
@ -775,6 +824,8 @@ def _id_tap_dep_jvp_rule(primals, tangents):
ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule
def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap):
if FLAGS.jax_host_callback_ad_transforms:
assert False
if ad.is_undefined_primal(arg_res):
ct_res = _instantiate_zeros(cts, arg_res)
else:
@ -789,6 +840,8 @@ ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule
def _id_tap_dep_batching_rule(batched_args, batch_dims):
if FLAGS.jax_host_callback_ad_transforms:
assert False
arg_res, arg_tap = batched_args
return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0]
@ -797,6 +850,8 @@ batching.primitive_batchers[id_tap_dep_p] = _id_tap_dep_batching_rule
def _id_tap_dep_masking_rule(operands, operands_logical_shapes):
if FLAGS.jax_host_callback_ad_transforms:
assert False
arg_res, arg_tap = operands
return id_tap_dep_p.bind(arg_res, arg_tap)
@ -1167,20 +1222,24 @@ def _outside_call_jvp_rule(primals, tangents, **params):
assert "has_token" not in params
if not params["identity"]:
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
if FLAGS.jax_host_callback_ad_transforms:
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
arg_treedef = params["arg_treedef"]
# The argument to the jvp tap is a pair of the tapped primals and tangents
jvp_flat_args, jvp_arg_treedef = api.tree_flatten(
(arg_treedef.unflatten(primals),
arg_treedef.unflatten(tangents_instantiated)))
out_all = outside_call_p.bind(
*jvp_flat_args,
**dict(_add_transform(params, "jvp"),
arg_treedef=jvp_arg_treedef,
))
out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)])
return tuple(out_primals_tapped), tuple(out_tangents_tapped)
arg_treedef = params["arg_treedef"]
# The argument to the jvp tap is a pair of the tapped primals and tangents
jvp_flat_args, jvp_arg_treedef = api.tree_flatten(
(arg_treedef.unflatten(primals),
arg_treedef.unflatten(tangents_instantiated)))
out_all = outside_call_p.bind(
*jvp_flat_args,
**dict(_add_transform(params, "jvp"),
arg_treedef=jvp_arg_treedef,
))
out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)])
return tuple(out_primals_tapped), tuple(out_tangents_tapped)
else:
out_primals_tapped = outside_call_p.bind(*primals, **params)
return tuple(out_primals_tapped), tangents
ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule
@ -1188,6 +1247,9 @@ ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule
def _outside_call_partial_eval_rule(trace, *args, **params):
# partial eval is used after jvp and before transpose.
if not FLAGS.jax_host_callback_ad_transforms:
# TODO: just remote the partial eval rule
return trace.default_process_primitive(outside_call_p, args, params)
transforms = params.get("transforms", ())
if not transforms or transforms[-1] != ("jvp",):
# We are not in the process of computing VJP
@ -1252,6 +1314,9 @@ def _outside_call_transpose_rule(cts, *args, **params):
*cts_instantiated,
**_add_transform(params, "transpose"))
if not FLAGS.jax_host_callback_ad_transforms:
assert False
assert len(args) % 2 == 0
nr_primals = len(args) // 2

View File

@ -904,11 +904,18 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertAllClose(100., res_primals, check_dtypes=False)
self.assertAllClose(4., res_tangents, check_dtypes=False)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: ['jvp'] what: a * 2
( 10.00 0.20 )
transforms: ['jvp'] what: y * 3
( 30.00 0.60 )""", testing_stream.output)
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
transforms: ['jvp'] what: a * 2
( 10.00 0.20 )
transforms: ['jvp'] what: y * 3
( 30.00 0.60 )""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: a * 2
10.00
what: y * 3
30.00""", testing_stream.output)
def test_tap_grad_primal_unused(self):
# The output of id_print is not needed for backwards pass
@ -923,25 +930,39 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
treedef = tree_util.tree_structure(arg)
assertMultiLineStrippedEqual(self, f"""
{{ lambda ; a:f32[]. let
b:f32[] = mul a 3.00
c:f32[] = outside_call[
arg_treedef={treedef}
callback=...
identity=True
transforms=()
] b
_:f32[] = mul c 2.00
d:f32[] = mul 1.00 2.00
e:f32[] = outside_call[
arg_treedef={treedef}
callback=...
identity=True
transforms=(('jvp',), ('transpose',))
] d
f:f32[] = mul e 3.00
in (f,) }}""", jaxpr)
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, f"""
{{ lambda ; a:f32[]. let
b:f32[] = mul a 3.00
c:f32[] = outside_call[
arg_treedef={treedef}
callback=...
identity=True
transforms=()
] b
_:f32[] = mul c 2.00
d:f32[] = mul 1.00 2.00
e:f32[] = outside_call[
arg_treedef={treedef}
callback=...
identity=True
transforms=(('jvp',), ('transpose',))
] d
f:f32[] = mul e 3.00
in (f,) }}""", jaxpr)
else:
assertMultiLineStrippedEqual(self, f"""
{{ lambda ; a:f32[]. let
b:f32[] = mul a 3.00
c:f32[] = outside_call[
arg_treedef={treedef}
callback=...
identity=True
] b
_:f32[] = mul c 2.00
d:f32[] = mul 1.00 2.00
e:f32[] = mul d 3.00
in (e,) }}""", jaxpr)
assertMultiLineStrippedEqual(self, "", testing_stream.output)
testing_stream.reset()
@ -949,11 +970,16 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.barrier_wait()
self.assertAllClose(6., res_grad, check_dtypes=False)
assertMultiLineStrippedEqual(self, """
what: x * 3
15.00
transforms: ['jvp', 'transpose'] what: x * 3
2.00""", testing_stream.output)
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
what: x * 3
15.00
transforms: ['jvp', 'transpose'] what: x * 3
2.00""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: x * 3
15.00""", testing_stream.output)
def test_tap_grad_simple(self):
def func(x):
@ -966,15 +992,22 @@ class HostCallbackTapTest(jtu.JaxTestCase):
res_grad = grad_func(jnp.float32(5.))
self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
what: y * 3
30.00
transforms: ['jvp', 'transpose'] what: y * 3
5.00
transforms: ['jvp', 'transpose'] what: x * 2
15.00""", testing_stream.output)
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
what: y * 3
30.00
transforms: ['jvp', 'transpose'] what: y * 3
5.00
transforms: ['jvp', 'transpose'] what: x * 2
15.00""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
what: y * 3
30.00""", testing_stream.output)
def test_tap_grad_grad(self):
def func(x):
@ -991,15 +1024,20 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertAllClose(12., res_grad, check_dtypes=False)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
transforms: ['jvp', 'transpose'] what: x * 2
15.00
transforms: ['jvp', 'transpose', 'jvp', 'transpose'] what: x * 2
2.00
transforms: ['jvp', 'transpose'] what: x * 2
3.00""", testing_stream.output)
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
transforms: ['jvp', 'transpose'] what: x * 2
15.00
transforms: ['jvp', 'transpose', 'jvp', 'transpose'] what: x * 2
2.00
transforms: ['jvp', 'transpose'] what: x * 2
3.00""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00""", testing_stream.output)
def test_tap_grad_pytree(self):
def func(x):
@ -1014,11 +1052,16 @@ class HostCallbackTapTest(jtu.JaxTestCase):
res_grad = grad_func(x)
self.assertAllClose(14., res_grad, check_dtypes=False)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
what: pair
( 10.00 15.00 )
transforms: ['jvp', 'transpose'] what: pair
( 0.00 0.00 )""", testing_stream.output)
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
what: pair
( 10.00 15.00 )
transforms: ['jvp', 'transpose'] what: pair
( 0.00 0.00 )""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: pair
( 10.00 15.00 )""", testing_stream.output)
def test_tap_jvp_float0(self):
def f(x, yint):
@ -1038,11 +1081,16 @@ class HostCallbackTapTest(jtu.JaxTestCase):
res_grad = grad_func(jnp.float32(5.), jnp.int32(2))
self.assertAllClose(2., res_grad, check_dtypes=False)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
what: pair
( 5.00 2 )
transforms: ['jvp', 'transpose'] what: pair
( 2.00 False )""", testing_stream.output)
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
what: pair
( 5.00 2 )
transforms: ['jvp', 'transpose'] what: pair
( 2.00 False )""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: pair
( 5.00 2 )""", testing_stream.output)
def test_tap_grad_float0_result(self):
# https://github.com/google/jax/issues/7340
@ -1063,10 +1111,14 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertAllClose(np.array([3., 3.], dtype=np.float32), g[0])
self.assertEqual(dtypes.float0, g[1].dtype)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( [0.70 0.80] [11 12 13] )
transforms: ['jvp', 'transpose']
( [0.00 0.00] [False False False] )""", testing_stream.output)
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
( [0.70 0.80] [11 12 13] )
transforms: ['jvp', 'transpose']
( [0.00 0.00] [False False False] )""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
( [0.70 0.80] [11 12 13] )""", testing_stream.output)
def test_tap_higher_order_grad_float0_result(self):
# https://github.com/google/jax/issues/7340
@ -1101,10 +1153,14 @@ class HostCallbackTapTest(jtu.JaxTestCase):
f_jax_vjp1, args_vjp1 = wrap_vjp(f_jax, (x,), res)
res_vjp1 = f_jax_vjp1(*args_vjp1)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( [0.70 0.80] [11 12 13] )
transforms: ['jvp', 'transpose']
( [0.00 0.00] [False False False] )""", testing_stream.output)
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
( [0.70 0.80] [11 12 13] )
transforms: ['jvp', 'transpose']
( [0.00 0.00] [False False False] )""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
( [0.70 0.80] [11 12 13] )""", testing_stream.output)
testing_stream.reset()
# 2nd order
@ -1227,8 +1283,11 @@ class HostCallbackTapTest(jtu.JaxTestCase):
transforms: [('batch', {'batch_dims': (0,)})] where: 3
[2 2 2 3 4]""", testing_stream.output)
def test_tap_transforms(self):
def test_tap_transforms_old_doc(self):
if not FLAGS.jax_host_callback_ad_transforms:
raise unittest.SkipTest("disabled for new behavior")
# Examples from the documentation
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
@ -1278,6 +1337,127 @@ class HostCallbackTapTest(jtu.JaxTestCase):
( 0. [2. 3.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
def test_tap_transforms_doc(self):
# Examples from the documentation
if FLAGS.jax_host_callback_ad_transforms:
raise unittest.SkipTest("disabled for old behavior")
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
return y * x
print(f"impl = {power3(3.)}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
@jax.custom_jvp
def print_tangents(arg):
return None
@print_tangents.defjvp
def print_tangents_jvp(primals, tangents):
arg_dot, = tangents
hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream)
return primals, tangents
def power3_with_tangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
print_tangents((x, y))
return y * x
print(f"jvp = {jax.jvp(power3_with_tangents, (3.,), (0.1,))}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )
what: tangents
( 0.1 0.6 )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"grad = {jax.grad(power3)(3.)}")
hcb.barrier_wait()
# Only the primals by default
expected = """
what: x,x^2
( 3. 9. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
@jax.custom_vjp
def print_cotangents(arg):
# Must return the argument for which we want the cotangent.
return arg
# f_fwd: a -> (b, residual)
def print_cotangents_fwd(arg):
return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b):
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
return ct_b,
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
def power3_with_cotangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
# Must use the output of print_cotangents
(x1, y1) = print_cotangents((x, y))
return y1 * x1
print(f"grad = {jax.grad(power3_with_cotangents)(3.)}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )
what: cotangents
( 9. 3. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
# TODO: grad of grad
print(f"vmap = {jax.vmap(power3)(np.array([2., 3.]))}")
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}")
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"vmap o grad {jax.vmap(jax.grad(power3_with_cotangents))(np.array([2., 3.]))}")
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )
transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents
( [4. 9.] [2. 3.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
def test_tap_pmap(self):
if len(local_devices()) < 2:
@ -1416,15 +1596,24 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertAllClose(expected_res, res, check_dtypes=False)
# Assertion text is for 2 devices (also works for 1 device)
# Device 0 will get to execute jax.jvp(jax.vmap(...)) for matrix[0, :, :]
assertMultiDeviceOutputEqual(self, """
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
( [[ 0.00 2.00 4.00]
[20.00 22.00 24.00]] [[0.20 0.20 0.20]
[0.20 0.20 0.20]] )
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
( [[200.00 202.00 204.00]
[220.00 222.00 224.00]] [[0.20 0.20 0.20]
[0.20 0.20 0.20]] )""")
if FLAGS.jax_host_callback_ad_transforms:
assertMultiDeviceOutputEqual(self, """
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
( [[ 0.00 2.00 4.00]
[20.00 22.00 24.00]] [[0.20 0.20 0.20]
[0.20 0.20 0.20]] )
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
( [[200.00 202.00 204.00]
[220.00 222.00 224.00]] [[0.20 0.20 0.20]
[0.20 0.20 0.20]] )""")
else:
assertMultiDeviceOutputEqual(self, """
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
[[ 0.00 2.00 4.00]
[20.00 22.00 24.00]]
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
[[200.00 202.00 204.00]
[220.00 222.00 224.00]]""")
def test_tap_vmap_pmap(self):
# A matrix M[ijk] = i * 100 + j * 10 * k
@ -1565,7 +1754,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
[[ 3 3 3 3]
[33 33 33 33]]""")
def test_tap_tap_scan_custom_jvp(self):
def test_tap_scan_custom_jvp(self):
"""custom JVP, inside scan.
This exercises the custom_jvp_call_jaxpr primitives."""
@ -1696,11 +1885,17 @@ class HostCallbackTapTest(jtu.JaxTestCase):
jax.jvp(lambda arg: padded_sum([arg], dict(n=3)),
(x,), (x * 0.1,)))
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5}), 'jvp'] what: x
( ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )
( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) ( ( False ) ( False ) ) ) )""",
testing_stream.output)
if FLAGS.jax_host_callback_ad_transforms:
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5}), 'jvp'] what: x
( ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )
( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) ( ( False ) ( False ) ) ) )""",
testing_stream.output)
else:
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5})] what: x
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
testing_stream.output)
testing_stream.reset()
# Now with JIT
@ -1829,40 +2024,53 @@ class HostCallbackTapTest(jtu.JaxTestCase):
use_result=use_result, use_remat=use_remat, grad_func=grad_func)
for use_result in [True, False]
for grad_func in ["grad", "value_and_grad"]
for use_remat in [True, False]))
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat=False):
for use_remat in ["old", "none"]))
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="old"):
def f(x):
id_print_result = hcb.id_print(x, output_stream=testing_stream)
if use_result:
x = id_print_result
return 3. * x
grad_f = jax.grad if grad_func == "grad" else jax.value_and_grad
trans_f = jax.remat(f) if use_remat else f
if use_remat == "old":
trans_f = jax.remat(f)
else:
assert use_remat == "none"
trans_f = f
print(jax.make_jaxpr(grad_f(trans_f))(2.))
grad_f(trans_f)(2.)
hcb.barrier_wait()
if not use_result:
if use_remat:
expected = ""
if use_remat == "none":
if use_result:
if FLAGS.jax_host_callback_ad_transforms:
expected = """
2.
transforms: ['jvp', 'transpose']
3."""
else:
# GOOD: whether or not we use_result, in absence of
# jax_host_callback_ad_transforms we get the same callback.
expected = "2."
else:
# TODO: if not use_result then we should only see the primal when
# computing value_and_grad.
expected = "2."
elif use_remat:
expected = """
2.
2.
transforms: ['jvp', 'transpose']
3."""
else:
expected = """
2.
transforms: ['jvp', 'transpose']
3."""
else: # use_remat
if use_result:
if FLAGS.jax_host_callback_ad_transforms:
expected = """
2.
2.
transforms: ['jvp', 'transpose']
3."""
else:
expected = """
2.
2."""
else:
# TODO: we should see two callbacks
expected = ""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
def test_tap_named_call(self):
def tap_scalar(init, do_print=False):
@ -1884,8 +2092,6 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
class HostCallbackCallTest(jtu.JaxTestCase):
"""Tests for hcb.call"""