diff --git a/CHANGELOG.md b/CHANGELOG.md index e1cb45a18..b88a48224 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,15 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.26...main). +* Breaking changes: + * The host_callback primitives have been simplified to drop the + special autodiff handling for hcb.id_tap and id_print. + From now on, only the primals are tapped. The old behavior can be + obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS`` + environment variable, or the ```--flax_host_callback_ad_transforms``` flag. + Additionally, added documentation for how to implement the old behavior + using JAX custom AD APIs ({jax-issue}`#7839`). + ## jaxlib 0.1.76 (Unreleased) ## jaxlib 0.1.75 (Dec 8, 2021) diff --git a/jax/_src/config.py b/jax/_src/config.py index 62f7353d2..db2452ada 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -463,6 +463,15 @@ flags.DEFINE_bool( 'Has no effect on TPU, since only the outfeed mechanism is implemented.' ) ) +flags.DEFINE_bool( + 'jax_host_callback_ad_transforms', + bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False), + help=( + 'Enable support for jvp/vjp for the host_callback primitives. Default is ' + 'False, which means that host_callback operates only on primals. ' + 'The flag exists only temporarily, for backward compatibility.' + ) +) enable_checks = config.define_bool_state( name='jax_enable_checks', diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index a975a5329..30c096ba4 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1009,6 +1009,9 @@ class JaxTestCase(parameterized.TestCase): ignore_space_re = re.compile(r'\s*\n\s*') expected_clean = re.sub(ignore_space_re, '\n', expected.strip()) what_clean = re.sub(ignore_space_re, '\n', what.strip()) + if what_clean != expected_clean: + # Print it so we can copy-and-paste it into the test + print(f"Found\n{what}\n") self.assertMultiLineEqual(expected_clean, what_clean, msg="Found\n{}\nExpecting\n{}".format(what, expected)) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index f5d8c92db..f3499863c 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -26,10 +26,8 @@ on the CPU, e.g., to use NumPy CPU custom kernels. Then we show uses of :func:`id_tap` and :func:`id_print`, which have the restriction that they cannot return values from the host to the device. These primitives are generally faster -because they are executed asynchronously with the device code and they also -support the whole spectrum of JAX transformations. In particular, they can be -used to tap into and to debug JAX-transformed code. - +because they are executed asynchronously with the device code. +In particular, they can be used to tap into and to debug JAX code. Using :func:`call` to call a host function and return results to device ----------------------------------------------------------------------- @@ -61,19 +59,21 @@ using a host computation:: The :func:`call` function and the Python host function both take a single argument and return a single result, but those can be pytrees. Note that we must tell the :func:`call` what shape and dtype to expect from the host invocation, using -the ``result_shape`` kwarg. +the ``result_shape`` keyword argument. This is important because the device code is compiled with that expectation. There will be an error raised at runtime if the actual invocation produces a different result shape. In general, **such errors and also exceptions raised by the host computation may be difficult to debug**. See the Debugging section below. -This is a problem for :func:`call` but not for :func:`id_tap`. +This is a problem for :func:`call` but not for :func:`id_tap` because for the +latter the decice code does not expect a returned value. The :func:`call` API can be used inside a jit or pmap computation or inside cond/scan/while control flow. When used inside :func:`jax.pmap`, there will be separate calls to the host from each of the participating devices:: def host_sin(x, *, device): + # The ``device`` argument is passed due to ``call_with_device=True`` below. print(f"Invoking host_sin with {x.shape} on {device}") return np.sin(x) @@ -88,12 +88,12 @@ separate calls to the host from each of the participating devices:: # Invoking host_sin with (4,) on cpu:0 # Invoking host_sin with (4,) on cpu:1 -Note that :func:`call` does not (yet) support any JAX transformations, but as we +Note that :func:`call` does not support any JAX transformations, but as we show below one can make use of the existing support for `Custom differentiation in JAX `_. -Using :func:`id_tap` to call a Python function on the host, with no returned values, but full JAX transformation support ---------------------------------------------------------------------------------------------------------------------------- +Using :func:`id_tap` to call a Python function on the host, with no returned values +----------------------------------------------------------------------------------- The :func:`id_tap` and :func:`id_print` are special cases of :func:`call`, when you just want the side effects of your Python callback. These functions have @@ -104,7 +104,7 @@ For :func:`id_tap` you can specify your Python callback to be called, while `stdout` on the host. The Python function passed to :func:`id_tap` takes two positional arguments (the value tapped -from the device computation along with ``transforms`` sequence, +from the device computation along with a ``transforms`` tuple, described below). Optionally, the function may be passed a keyword argument ``device`` with the Device from which the value was tapped. @@ -129,9 +129,8 @@ A few examples:: id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y)) The above examples can all be adapted to use :func:`id_print` instead, with -the difference that :func:`id_print` takes one positional argument (to print -on the host), and possibly additional kwargs -that are also printed along with the automatic kwarg ``transforms``. +the difference that :func:`id_print` prints on the host the positional argument, +along with any additional kwargs and the automatic kwarg ``transforms``. Using :func:`barrier_wait` to wait until all callbacks have executed -------------------------------------------------------------------- @@ -184,76 +183,8 @@ of the computation, if all the callbacks are :func:`call`:: res1.block_until_ready() res2.block_until_ready() - -Behavior under JAX transformations ----------------------------------- - -The :func:`call` does not support any JAX transformations. However, the -:func:`id_tap` and :func:`id_print` support all transformations. In this -context, it is important that both these functions behave like the identity -function:: - - # calls func((2x, 3x), []) and returns (2x, 3x) - id_tap(func, (2 * x, 3 * x)) - - # calls func(2x, []) and returns y - y = id_tap(func, 2 * x, result=y) # override the result of id_tap - -We describe the behaviour under transformations for :func:`id_tap` and -:func:`id_print` in the context of the -following function definition:: - - def power3(x): - y = x * x - # Print both 'x' and 'x^2' - _, y = id_print((x, y), what="x,x^2") # Must pack multiple arguments - return y * x - - power3(3.) - # what: x,x^2 : (3., 9.) - -(You can see these examples tested in `host_callback_test.HostCallbackIdTapTest.test_tap_transforms`.) - -During JAX transformations the special parameter ``transforms`` is added to -contain a list of transformation descriptors in the form -``(transform_name, transform_params)``. - -For :func:`jax.vmap` the arguments are batched, and ``transforms`` is extended -with transformation name ``batch`` and ``batch_dims`` set to the tuple of -batched dimensions (one entry per argument, ``None`` denotes an argument that -was broadcast):: - - jax.vmap(power3)(np.arange(3.)) - # transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : ([0, 1, 2], [0, 1, - 4]) - -For :func:`jax.jvp` there will be one callback with a pair, consisting of -the values of the primals and those of the tangents:: - - jax.jvp(power3, (3.,), (0.1,)) - # transforms: ['jvp'] what: x,x^2 : ( (3., 9.), (0.1, 0.6) ) - -For :func:`jax.vjp` or :func:`jax.grad` there will be one callback with the -values of the adjoints for the arguments. You may also see a callback with -the values of the primals from the forward pass, if those values are needed for -the backward pass:: - - jax.grad(power3)(3.) - # what=x,x^2: (3., 9.) # from forward pass, since y is used in backward pass - # transforms: ['jvp', 'transpose'] what: x,x^2 : (0., 3.) # from backward pass, adjoints of _, y - -And here is an example of composed transforms. For vmap of grad, we see first -a callback with the vmap of the forward pass (with just the 'batch' transform), -and another callback with the vmap of the adjoints of the arguments. Note that -the first argument is replicated (`batch_dims` is None):: - - jax.vmap(jax.grad(power3))(np.array([2., 3.])) - # transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - # ( [2. 3.] - # [4. 9.] ) - # transforms: ['jvp', 'transpose', ('batch', {'batch_dims': (None, 0)})] what: x,x^2 - # ( 0. - # [2. 3.] ) +Behavior under parallelization transformations +---------------------------------------------- In presence of :func:`jax.pmap` the code will run on multiple devices and each device will tap its values independently. @@ -294,7 +225,112 @@ the operand collected from all participating devices on all hosts. For a :func:`call`, the callback must return the entire array for all devices on all hosts. +Behavior under JAX autodiff transformations +------------------------------------------- + +When used under a JAX autodiff transformation, the host callback functions +operate on the primal values only. Consider the following example: + + def power3(x): + y = x * x + # Print both 'x' and 'x^2'. Must pack as a tuple. + hcb.id_print((x, y), what="x,x^2") + return y * x + + power3(3.) + # what: x,x^2 : (3., 9.) + +(You can see these examples tested in `host_callback_test.HostCallbackTapTest.test_tap_transforms`.) + +When used under :func:`jax.jvp` there will be one callback with the primal +values only:: + + jax.jvp(power3, (3.,), (0.1,)) + # what: x,x^2 : (3., 9.) + +Similarly for :func:`jax.grad`, we get a callback from the forward computation +only:: + + jax.grad(power3)(3.) + # what: x,x^2 : (3., 9.) + +If you want to invoke the callback on the tangents during a :func:`jax.jvp`, +you can use a custom_jvp. For example, you can define a function that does +nothing interesting except that its custom_jvp will print the tangents:: + + @jax.custom_jvp + def print_tangents(arg): + return None + + @print_tangents.defjvp + def print_tangents_jvp(primals, tangents): + arg_dot, = tangents + hcb.id_print(arg_dot, what="tangents") + return primals, tangents + +Then you use this function in the places where you want to tap the tangents:: + + def power3_with_tangents(x): + y = x * x + # Print both 'x' and 'x^2'. Must pack as a tuple. + hcb.id_print((x, y), what="x,x^2") + print_tangents((x, y)) + return y * x + + jax.jvp(power3_with_tangents, (3.,), (0.1,)) + # what: x,x^2 : (3., 9.) + # what: tangents : (0.1, 0.6) + +You can do a similar thing for the cotangents during :func:`jax.grad`. This +time you must be careful to use in the rest of the computation the values whose +cotangents you want to tap. Hence we make the ``print_cotangents`` return +its argument:: + + @jax.custom_vjp + def print_cotangents(arg): + # Must return the argument for which we want the cotangent. + return arg + + # f_fwd: a -> (b, residual) + def print_cotangents_fwd(arg): + return print_cotangents(arg), None + # f_bwd: (residual, CT b) -> [CT a] + def print_cotangents_bwd(residual, ct_b): + hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream) + return ct_b, + + print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd) + + def power3_with_cotangents(x): + y = x * x + # Print both 'x' and 'x^2'. Must pack as a tuple. + hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) + (x1, y1) = print_cotangents((x, y)) + # Must use the output of print_cotangents + return y1 * x1 + + jax.grad(power3_with_cotangents)(3.) + # what: x,x^2 : (3., 9.) + # what: cotangents : (9., 3.) + +Behavior under jax.vmap +----------------------- + +The host callback functions :func:`id_print` and :func:`id_tap` support the +vectorization transformation :func:`jax.vmap`. + +For :func:`jax.vmap` the arguments to the callback are batched, +and the callback function is +passed an additional special ``transforms`` containing a list of transformation descriptors +in the form ``("batch", {"batch_dims": ...})``, where ``...``` denotes the +batched dimensions for the tapped values (one entry per argument, ` +`None`` denotes an argument that was broadcast). + + jax.vmap(power3)(np.array([2., 3.])) + # transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : ([2., 3.], [4., 9.]) + See documentation for :func:`id_tap`, :func:`id_print`, and :func:`call`. + For more usage example, see tests/host_callback_test.py. Using :func:`call` to call a TensorFlow function, with reverse-mode autodiff support @@ -444,7 +480,10 @@ import functools import itertools import threading import traceback -from typing import (Any, Callable, Dict, List, Optional, Sequence, Tuple, cast) +from typing import (Any, Callable, Dict, List, Optional, Sequence, + Tuple, cast) +import warnings + from absl import logging from jax._src import api @@ -533,11 +572,16 @@ def id_tap(tap_func, arg, *, result=None, tap_with_device=False, **kwargs): "pre-apply keyword arguments, either by using a closure or by passing " "``functools.partial(tap_func, **kwargs)``.") raise TypeError(msg) + if FLAGS.jax_host_callback_ad_transforms: + warnings.warn('The flag jax_host_callback_ad_transforms is for temporary ' + 'backwards compatibility mode. This flag, and the behavior ' + 'it enabled will be removed soon.', + FutureWarning) if result is not None: flat_results, result_treedef = pytree.flatten(result) - for result in flat_results: - api._check_arg(result) + for r in flat_results: + api._check_arg(r) call_res = _call(tap_func, arg, call_with_device=tap_with_device, result_shape=None, identity=True) @@ -545,13 +589,16 @@ def id_tap(tap_func, arg, *, result=None, tap_with_device=False, **kwargs): if result is not None: # Return the results, but add a dependency on the call, to ensure it # is kept in the graph. - call_flat_results, _ = pytree.flatten(call_res) - if call_flat_results: - call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0]) - for r in flat_results] + if FLAGS.jax_host_callback_ad_transforms: + call_flat_results, _ = pytree.flatten(call_res) + if call_flat_results: + call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0]) + for r in flat_results] + else: + call_flat_results = flat_results + return result_treedef.unflatten(call_flat_results) else: - call_flat_results = flat_results - return result_treedef.unflatten(call_flat_results) + return result else: return call_res @@ -768,6 +815,8 @@ xla.register_translation(id_tap_dep_p, id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a) def _id_tap_dep_jvp_rule(primals, tangents): + if FLAGS.jax_host_callback_ad_transforms: + assert False tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals)) return (id_tap_dep_p.bind(primals[0], primals[1]), id_tap_dep_p.bind(tangents_instantiated[0], tangents_instantiated[1])) @@ -775,6 +824,8 @@ def _id_tap_dep_jvp_rule(primals, tangents): ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap): + if FLAGS.jax_host_callback_ad_transforms: + assert False if ad.is_undefined_primal(arg_res): ct_res = _instantiate_zeros(cts, arg_res) else: @@ -789,6 +840,8 @@ ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule def _id_tap_dep_batching_rule(batched_args, batch_dims): + if FLAGS.jax_host_callback_ad_transforms: + assert False arg_res, arg_tap = batched_args return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0] @@ -797,6 +850,8 @@ batching.primitive_batchers[id_tap_dep_p] = _id_tap_dep_batching_rule def _id_tap_dep_masking_rule(operands, operands_logical_shapes): + if FLAGS.jax_host_callback_ad_transforms: + assert False arg_res, arg_tap = operands return id_tap_dep_p.bind(arg_res, arg_tap) @@ -1167,20 +1222,24 @@ def _outside_call_jvp_rule(primals, tangents, **params): assert "has_token" not in params if not params["identity"]: raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.") - tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals)) + if FLAGS.jax_host_callback_ad_transforms: + tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals)) - arg_treedef = params["arg_treedef"] - # The argument to the jvp tap is a pair of the tapped primals and tangents - jvp_flat_args, jvp_arg_treedef = api.tree_flatten( - (arg_treedef.unflatten(primals), - arg_treedef.unflatten(tangents_instantiated))) - out_all = outside_call_p.bind( - *jvp_flat_args, - **dict(_add_transform(params, "jvp"), - arg_treedef=jvp_arg_treedef, - )) - out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)]) - return tuple(out_primals_tapped), tuple(out_tangents_tapped) + arg_treedef = params["arg_treedef"] + # The argument to the jvp tap is a pair of the tapped primals and tangents + jvp_flat_args, jvp_arg_treedef = api.tree_flatten( + (arg_treedef.unflatten(primals), + arg_treedef.unflatten(tangents_instantiated))) + out_all = outside_call_p.bind( + *jvp_flat_args, + **dict(_add_transform(params, "jvp"), + arg_treedef=jvp_arg_treedef, + )) + out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)]) + return tuple(out_primals_tapped), tuple(out_tangents_tapped) + else: + out_primals_tapped = outside_call_p.bind(*primals, **params) + return tuple(out_primals_tapped), tangents ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule @@ -1188,6 +1247,9 @@ ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule def _outside_call_partial_eval_rule(trace, *args, **params): # partial eval is used after jvp and before transpose. + if not FLAGS.jax_host_callback_ad_transforms: + # TODO: just remote the partial eval rule + return trace.default_process_primitive(outside_call_p, args, params) transforms = params.get("transforms", ()) if not transforms or transforms[-1] != ("jvp",): # We are not in the process of computing VJP @@ -1252,6 +1314,9 @@ def _outside_call_transpose_rule(cts, *args, **params): *cts_instantiated, **_add_transform(params, "transpose")) + if not FLAGS.jax_host_callback_ad_transforms: + assert False + assert len(args) % 2 == 0 nr_primals = len(args) // 2 diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index ca31a7101..d8dac83b5 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -904,11 +904,18 @@ class HostCallbackTapTest(jtu.JaxTestCase): self.assertAllClose(100., res_primals, check_dtypes=False) self.assertAllClose(4., res_tangents, check_dtypes=False) hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - transforms: ['jvp'] what: a * 2 - ( 10.00 0.20 ) - transforms: ['jvp'] what: y * 3 - ( 30.00 0.60 )""", testing_stream.output) + if FLAGS.jax_host_callback_ad_transforms: + assertMultiLineStrippedEqual(self, """ + transforms: ['jvp'] what: a * 2 + ( 10.00 0.20 ) + transforms: ['jvp'] what: y * 3 + ( 30.00 0.60 )""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + what: a * 2 + 10.00 + what: y * 3 + 30.00""", testing_stream.output) def test_tap_grad_primal_unused(self): # The output of id_print is not needed for backwards pass @@ -923,25 +930,39 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() treedef = tree_util.tree_structure(arg) - assertMultiLineStrippedEqual(self, f""" - {{ lambda ; a:f32[]. let - b:f32[] = mul a 3.00 - c:f32[] = outside_call[ - arg_treedef={treedef} - callback=... - identity=True - transforms=() - ] b - _:f32[] = mul c 2.00 - d:f32[] = mul 1.00 2.00 - e:f32[] = outside_call[ - arg_treedef={treedef} - callback=... - identity=True - transforms=(('jvp',), ('transpose',)) - ] d - f:f32[] = mul e 3.00 - in (f,) }}""", jaxpr) + if FLAGS.jax_host_callback_ad_transforms: + assertMultiLineStrippedEqual(self, f""" + {{ lambda ; a:f32[]. let + b:f32[] = mul a 3.00 + c:f32[] = outside_call[ + arg_treedef={treedef} + callback=... + identity=True + transforms=() + ] b + _:f32[] = mul c 2.00 + d:f32[] = mul 1.00 2.00 + e:f32[] = outside_call[ + arg_treedef={treedef} + callback=... + identity=True + transforms=(('jvp',), ('transpose',)) + ] d + f:f32[] = mul e 3.00 + in (f,) }}""", jaxpr) + else: + assertMultiLineStrippedEqual(self, f""" + {{ lambda ; a:f32[]. let + b:f32[] = mul a 3.00 + c:f32[] = outside_call[ + arg_treedef={treedef} + callback=... + identity=True + ] b + _:f32[] = mul c 2.00 + d:f32[] = mul 1.00 2.00 + e:f32[] = mul d 3.00 + in (e,) }}""", jaxpr) assertMultiLineStrippedEqual(self, "", testing_stream.output) testing_stream.reset() @@ -949,11 +970,16 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.barrier_wait() self.assertAllClose(6., res_grad, check_dtypes=False) - assertMultiLineStrippedEqual(self, """ - what: x * 3 - 15.00 - transforms: ['jvp', 'transpose'] what: x * 3 - 2.00""", testing_stream.output) + if FLAGS.jax_host_callback_ad_transforms: + assertMultiLineStrippedEqual(self, """ + what: x * 3 + 15.00 + transforms: ['jvp', 'transpose'] what: x * 3 + 2.00""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + what: x * 3 + 15.00""", testing_stream.output) def test_tap_grad_simple(self): def func(x): @@ -966,15 +992,22 @@ class HostCallbackTapTest(jtu.JaxTestCase): res_grad = grad_func(jnp.float32(5.)) self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False) hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - what: x * 2 - 10.00 - what: y * 3 - 30.00 - transforms: ['jvp', 'transpose'] what: y * 3 - 5.00 - transforms: ['jvp', 'transpose'] what: x * 2 - 15.00""", testing_stream.output) + if FLAGS.jax_host_callback_ad_transforms: + assertMultiLineStrippedEqual(self, """ + what: x * 2 + 10.00 + what: y * 3 + 30.00 + transforms: ['jvp', 'transpose'] what: y * 3 + 5.00 + transforms: ['jvp', 'transpose'] what: x * 2 + 15.00""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + what: x * 2 + 10.00 + what: y * 3 + 30.00""", testing_stream.output) def test_tap_grad_grad(self): def func(x): @@ -991,15 +1024,20 @@ class HostCallbackTapTest(jtu.JaxTestCase): self.assertAllClose(12., res_grad, check_dtypes=False) hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - what: x * 2 - 10.00 - transforms: ['jvp', 'transpose'] what: x * 2 - 15.00 - transforms: ['jvp', 'transpose', 'jvp', 'transpose'] what: x * 2 - 2.00 - transforms: ['jvp', 'transpose'] what: x * 2 - 3.00""", testing_stream.output) + if FLAGS.jax_host_callback_ad_transforms: + assertMultiLineStrippedEqual(self, """ + what: x * 2 + 10.00 + transforms: ['jvp', 'transpose'] what: x * 2 + 15.00 + transforms: ['jvp', 'transpose', 'jvp', 'transpose'] what: x * 2 + 2.00 + transforms: ['jvp', 'transpose'] what: x * 2 + 3.00""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + what: x * 2 + 10.00""", testing_stream.output) def test_tap_grad_pytree(self): def func(x): @@ -1014,11 +1052,16 @@ class HostCallbackTapTest(jtu.JaxTestCase): res_grad = grad_func(x) self.assertAllClose(14., res_grad, check_dtypes=False) hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - what: pair - ( 10.00 15.00 ) - transforms: ['jvp', 'transpose'] what: pair - ( 0.00 0.00 )""", testing_stream.output) + if FLAGS.jax_host_callback_ad_transforms: + assertMultiLineStrippedEqual(self, """ + what: pair + ( 10.00 15.00 ) + transforms: ['jvp', 'transpose'] what: pair + ( 0.00 0.00 )""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + what: pair + ( 10.00 15.00 )""", testing_stream.output) def test_tap_jvp_float0(self): def f(x, yint): @@ -1038,11 +1081,16 @@ class HostCallbackTapTest(jtu.JaxTestCase): res_grad = grad_func(jnp.float32(5.), jnp.int32(2)) self.assertAllClose(2., res_grad, check_dtypes=False) hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - what: pair - ( 5.00 2 ) - transforms: ['jvp', 'transpose'] what: pair - ( 2.00 False )""", testing_stream.output) + if FLAGS.jax_host_callback_ad_transforms: + assertMultiLineStrippedEqual(self, """ + what: pair + ( 5.00 2 ) + transforms: ['jvp', 'transpose'] what: pair + ( 2.00 False )""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + what: pair + ( 5.00 2 )""", testing_stream.output) def test_tap_grad_float0_result(self): # https://github.com/google/jax/issues/7340 @@ -1063,10 +1111,14 @@ class HostCallbackTapTest(jtu.JaxTestCase): self.assertAllClose(np.array([3., 3.], dtype=np.float32), g[0]) self.assertEqual(dtypes.float0, g[1].dtype) hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - ( [0.70 0.80] [11 12 13] ) - transforms: ['jvp', 'transpose'] - ( [0.00 0.00] [False False False] )""", testing_stream.output) + if FLAGS.jax_host_callback_ad_transforms: + assertMultiLineStrippedEqual(self, """ + ( [0.70 0.80] [11 12 13] ) + transforms: ['jvp', 'transpose'] + ( [0.00 0.00] [False False False] )""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + ( [0.70 0.80] [11 12 13] )""", testing_stream.output) def test_tap_higher_order_grad_float0_result(self): # https://github.com/google/jax/issues/7340 @@ -1101,10 +1153,14 @@ class HostCallbackTapTest(jtu.JaxTestCase): f_jax_vjp1, args_vjp1 = wrap_vjp(f_jax, (x,), res) res_vjp1 = f_jax_vjp1(*args_vjp1) hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - ( [0.70 0.80] [11 12 13] ) - transforms: ['jvp', 'transpose'] - ( [0.00 0.00] [False False False] )""", testing_stream.output) + if FLAGS.jax_host_callback_ad_transforms: + assertMultiLineStrippedEqual(self, """ + ( [0.70 0.80] [11 12 13] ) + transforms: ['jvp', 'transpose'] + ( [0.00 0.00] [False False False] )""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + ( [0.70 0.80] [11 12 13] )""", testing_stream.output) testing_stream.reset() # 2nd order @@ -1227,8 +1283,11 @@ class HostCallbackTapTest(jtu.JaxTestCase): transforms: [('batch', {'batch_dims': (0,)})] where: 3 [2 2 2 3 4]""", testing_stream.output) - def test_tap_transforms(self): + def test_tap_transforms_old_doc(self): + if not FLAGS.jax_host_callback_ad_transforms: + raise unittest.SkipTest("disabled for new behavior") + # Examples from the documentation def power3(x): y = x * x # Print both 'x' and 'x^2'. Must pack as a tuple. @@ -1278,6 +1337,127 @@ class HostCallbackTapTest(jtu.JaxTestCase): ( 0. [2. 3.] )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) + def test_tap_transforms_doc(self): + # Examples from the documentation + if FLAGS.jax_host_callback_ad_transforms: + raise unittest.SkipTest("disabled for old behavior") + def power3(x): + y = x * x + # Print both 'x' and 'x^2'. Must pack as a tuple. + hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) + return y * x + + print(f"impl = {power3(3.)}") + hcb.barrier_wait() + expected = """ + what: x,x^2 + ( 3. 9. )""" + self.assertMultiLineStrippedEqual(expected, testing_stream.output) + testing_stream.reset() + + print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}") + hcb.barrier_wait() + expected = """ + what: x,x^2 + ( 3. 9. )""" + self.assertMultiLineStrippedEqual(expected, testing_stream.output) + testing_stream.reset() + + @jax.custom_jvp + def print_tangents(arg): + return None + + @print_tangents.defjvp + def print_tangents_jvp(primals, tangents): + arg_dot, = tangents + hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream) + return primals, tangents + + def power3_with_tangents(x): + y = x * x + # Print both 'x' and 'x^2'. Must pack as a tuple. + hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) + print_tangents((x, y)) + return y * x + + print(f"jvp = {jax.jvp(power3_with_tangents, (3.,), (0.1,))}") + hcb.barrier_wait() + expected = """ + what: x,x^2 + ( 3. 9. ) + what: tangents + ( 0.1 0.6 )""" + self.assertMultiLineStrippedEqual(expected, testing_stream.output) + testing_stream.reset() + + print(f"grad = {jax.grad(power3)(3.)}") + hcb.barrier_wait() + # Only the primals by default + expected = """ + what: x,x^2 + ( 3. 9. )""" + self.assertMultiLineStrippedEqual(expected, testing_stream.output) + testing_stream.reset() + + @jax.custom_vjp + def print_cotangents(arg): + # Must return the argument for which we want the cotangent. + return arg + + # f_fwd: a -> (b, residual) + def print_cotangents_fwd(arg): + return print_cotangents(arg), None + # f_bwd: (residual, CT b) -> [CT a] + def print_cotangents_bwd(residual, ct_b): + hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream) + return ct_b, + + print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd) + + def power3_with_cotangents(x): + y = x * x + # Print both 'x' and 'x^2'. Must pack as a tuple. + hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) + # Must use the output of print_cotangents + (x1, y1) = print_cotangents((x, y)) + return y1 * x1 + + print(f"grad = {jax.grad(power3_with_cotangents)(3.)}") + hcb.barrier_wait() + expected = """ + what: x,x^2 + ( 3. 9. ) + what: cotangents + ( 9. 3. )""" + self.assertMultiLineStrippedEqual(expected, testing_stream.output) + testing_stream.reset() + + # TODO: grad of grad + + print(f"vmap = {jax.vmap(power3)(np.array([2., 3.]))}") + hcb.barrier_wait() + expected = """ + transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 + ( [2. 3.] [4. 9.] )""" + self.assertMultiLineStrippedEqual(expected, testing_stream.output) + testing_stream.reset() + + print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}") + hcb.barrier_wait() + expected = """ + transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 + ( [2. 3.] [4. 9.] )""" + self.assertMultiLineStrippedEqual(expected, testing_stream.output) + testing_stream.reset() + + print(f"vmap o grad {jax.vmap(jax.grad(power3_with_cotangents))(np.array([2., 3.]))}") + hcb.barrier_wait() + expected = """ + transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 + ( [2. 3.] [4. 9.] ) + transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents + ( [4. 9.] [2. 3.] )""" + self.assertMultiLineStrippedEqual(expected, testing_stream.output) def test_tap_pmap(self): if len(local_devices()) < 2: @@ -1416,15 +1596,24 @@ class HostCallbackTapTest(jtu.JaxTestCase): self.assertAllClose(expected_res, res, check_dtypes=False) # Assertion text is for 2 devices (also works for 1 device) # Device 0 will get to execute jax.jvp(jax.vmap(...)) for matrix[0, :, :] - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2 - ( [[ 0.00 2.00 4.00] - [20.00 22.00 24.00]] [[0.20 0.20 0.20] - [0.20 0.20 0.20]] ) - device: cpu:1 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2 - ( [[200.00 202.00 204.00] - [220.00 222.00 224.00]] [[0.20 0.20 0.20] - [0.20 0.20 0.20]] )""") + if FLAGS.jax_host_callback_ad_transforms: + assertMultiDeviceOutputEqual(self, """ + device: cpu:0 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2 + ( [[ 0.00 2.00 4.00] + [20.00 22.00 24.00]] [[0.20 0.20 0.20] + [0.20 0.20 0.20]] ) + device: cpu:1 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2 + ( [[200.00 202.00 204.00] + [220.00 222.00 224.00]] [[0.20 0.20 0.20] + [0.20 0.20 0.20]] )""") + else: + assertMultiDeviceOutputEqual(self, """ + device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 + [[ 0.00 2.00 4.00] + [20.00 22.00 24.00]] + device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 + [[200.00 202.00 204.00] + [220.00 222.00 224.00]]""") def test_tap_vmap_pmap(self): # A matrix M[ijk] = i * 100 + j * 10 * k @@ -1565,7 +1754,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): [[ 3 3 3 3] [33 33 33 33]]""") - def test_tap_tap_scan_custom_jvp(self): + def test_tap_scan_custom_jvp(self): """custom JVP, inside scan. This exercises the custom_jvp_call_jaxpr primitives.""" @@ -1696,11 +1885,17 @@ class HostCallbackTapTest(jtu.JaxTestCase): jax.jvp(lambda arg: padded_sum([arg], dict(n=3)), (x,), (x * 0.1,))) hcb.barrier_wait() - self.assertMultiLineStrippedEqual(""" - transforms: [('mask', {'logical_shapes': 5}), 'jvp'] what: x - ( ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) ) - ( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) ( ( False ) ( False ) ) ) )""", - testing_stream.output) + if FLAGS.jax_host_callback_ad_transforms: + self.assertMultiLineStrippedEqual(""" + transforms: [('mask', {'logical_shapes': 5}), 'jvp'] what: x + ( ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) ) + ( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) ( ( False ) ( False ) ) ) )""", + testing_stream.output) + else: + self.assertMultiLineStrippedEqual(""" + transforms: [('mask', {'logical_shapes': 5})] what: x + ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""", + testing_stream.output) testing_stream.reset() # Now with JIT @@ -1829,40 +2024,53 @@ class HostCallbackTapTest(jtu.JaxTestCase): use_result=use_result, use_remat=use_remat, grad_func=grad_func) for use_result in [True, False] for grad_func in ["grad", "value_and_grad"] - for use_remat in [True, False])) - def test_tap_remat(self, use_result=False, grad_func="grad", use_remat=False): + for use_remat in ["old", "none"])) + def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="old"): def f(x): id_print_result = hcb.id_print(x, output_stream=testing_stream) if use_result: x = id_print_result return 3. * x grad_f = jax.grad if grad_func == "grad" else jax.value_and_grad - trans_f = jax.remat(f) if use_remat else f + if use_remat == "old": + trans_f = jax.remat(f) + else: + assert use_remat == "none" + trans_f = f print(jax.make_jaxpr(grad_f(trans_f))(2.)) grad_f(trans_f)(2.) hcb.barrier_wait() - if not use_result: - if use_remat: - expected = "" + if use_remat == "none": + if use_result: + if FLAGS.jax_host_callback_ad_transforms: + expected = """ + 2. + transforms: ['jvp', 'transpose'] + 3.""" + else: + # GOOD: whether or not we use_result, in absence of + # jax_host_callback_ad_transforms we get the same callback. + expected = "2." else: - # TODO: if not use_result then we should only see the primal when - # computing value_and_grad. expected = "2." - elif use_remat: - expected = """ - 2. - 2. - transforms: ['jvp', 'transpose'] - 3.""" - else: - expected = """ - 2. - transforms: ['jvp', 'transpose'] - 3.""" + else: # use_remat + if use_result: + if FLAGS.jax_host_callback_ad_transforms: + expected = """ + 2. + 2. + transforms: ['jvp', 'transpose'] + 3.""" + else: + expected = """ + 2. + 2.""" + else: + # TODO: we should see two callbacks + expected = "" self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() def test_tap_named_call(self): def tap_scalar(init, do_print=False): @@ -1884,8 +2092,6 @@ class HostCallbackTapTest(jtu.JaxTestCase): self.assertMultiLineStrippedEqual(expected, testing_stream.output) - - class HostCallbackCallTest(jtu.JaxTestCase): """Tests for hcb.call"""