mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[hcb] Simplifications to the host_calback API
* dropping support for special AD handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS environment variale, or the --flax_host_callback_ad_transforms flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs. This allows us to make some significant cleanup in the internals.
This commit is contained in:
parent
53318a2a7a
commit
f08156ab7c
@ -12,6 +12,15 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.2.26...main).
|
||||
|
||||
* Breaking changes:
|
||||
* The host_callback primitives have been simplified to drop the
|
||||
special autodiff handling for hcb.id_tap and id_print.
|
||||
From now on, only the primals are tapped. The old behavior can be
|
||||
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
|
||||
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
|
||||
Additionally, added documentation for how to implement the old behavior
|
||||
using JAX custom AD APIs ({jax-issue}`#7839`).
|
||||
|
||||
## jaxlib 0.1.76 (Unreleased)
|
||||
|
||||
## jaxlib 0.1.75 (Dec 8, 2021)
|
||||
|
@ -463,6 +463,15 @@ flags.DEFINE_bool(
|
||||
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
|
||||
)
|
||||
)
|
||||
flags.DEFINE_bool(
|
||||
'jax_host_callback_ad_transforms',
|
||||
bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False),
|
||||
help=(
|
||||
'Enable support for jvp/vjp for the host_callback primitives. Default is '
|
||||
'False, which means that host_callback operates only on primals. '
|
||||
'The flag exists only temporarily, for backward compatibility.'
|
||||
)
|
||||
)
|
||||
|
||||
enable_checks = config.define_bool_state(
|
||||
name='jax_enable_checks',
|
||||
|
@ -1009,6 +1009,9 @@ class JaxTestCase(parameterized.TestCase):
|
||||
ignore_space_re = re.compile(r'\s*\n\s*')
|
||||
expected_clean = re.sub(ignore_space_re, '\n', expected.strip())
|
||||
what_clean = re.sub(ignore_space_re, '\n', what.strip())
|
||||
if what_clean != expected_clean:
|
||||
# Print it so we can copy-and-paste it into the test
|
||||
print(f"Found\n{what}\n")
|
||||
self.assertMultiLineEqual(expected_clean, what_clean,
|
||||
msg="Found\n{}\nExpecting\n{}".format(what, expected))
|
||||
|
||||
|
@ -26,10 +26,8 @@ on the CPU, e.g., to use NumPy CPU custom kernels. Then we
|
||||
show uses of :func:`id_tap` and :func:`id_print`, which have the restriction
|
||||
that they cannot return values from the host to the device.
|
||||
These primitives are generally faster
|
||||
because they are executed asynchronously with the device code and they also
|
||||
support the whole spectrum of JAX transformations. In particular, they can be
|
||||
used to tap into and to debug JAX-transformed code.
|
||||
|
||||
because they are executed asynchronously with the device code.
|
||||
In particular, they can be used to tap into and to debug JAX code.
|
||||
|
||||
Using :func:`call` to call a host function and return results to device
|
||||
-----------------------------------------------------------------------
|
||||
@ -61,19 +59,21 @@ using a host computation::
|
||||
The :func:`call` function and the Python host function both take a single argument
|
||||
and return a single result, but those can be pytrees. Note that we must tell
|
||||
the :func:`call` what shape and dtype to expect from the host invocation, using
|
||||
the ``result_shape`` kwarg.
|
||||
the ``result_shape`` keyword argument.
|
||||
This is important because the device code is compiled with that expectation.
|
||||
There will be an error raised at runtime if the actual invocation produces a
|
||||
different result shape. In general, **such errors and also exceptions raised
|
||||
by the host computation may be difficult to debug**. See the Debugging section
|
||||
below.
|
||||
This is a problem for :func:`call` but not for :func:`id_tap`.
|
||||
This is a problem for :func:`call` but not for :func:`id_tap` because for the
|
||||
latter the decice code does not expect a returned value.
|
||||
|
||||
The :func:`call` API can be used inside a jit or pmap computation or inside
|
||||
cond/scan/while control flow. When used inside :func:`jax.pmap`, there will be
|
||||
separate calls to the host from each of the participating devices::
|
||||
|
||||
def host_sin(x, *, device):
|
||||
# The ``device`` argument is passed due to ``call_with_device=True`` below.
|
||||
print(f"Invoking host_sin with {x.shape} on {device}")
|
||||
return np.sin(x)
|
||||
|
||||
@ -88,12 +88,12 @@ separate calls to the host from each of the participating devices::
|
||||
# Invoking host_sin with (4,) on cpu:0
|
||||
# Invoking host_sin with (4,) on cpu:1
|
||||
|
||||
Note that :func:`call` does not (yet) support any JAX transformations, but as we
|
||||
Note that :func:`call` does not support any JAX transformations, but as we
|
||||
show below one can make use of the
|
||||
existing support for `Custom differentiation in JAX <https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html>`_.
|
||||
|
||||
Using :func:`id_tap` to call a Python function on the host, with no returned values, but full JAX transformation support
|
||||
---------------------------------------------------------------------------------------------------------------------------
|
||||
Using :func:`id_tap` to call a Python function on the host, with no returned values
|
||||
-----------------------------------------------------------------------------------
|
||||
|
||||
The :func:`id_tap` and :func:`id_print` are special cases of :func:`call`, when
|
||||
you just want the side effects of your Python callback. These functions have
|
||||
@ -104,7 +104,7 @@ For :func:`id_tap` you can specify your Python callback to be called, while
|
||||
`stdout` on the host.
|
||||
The Python function passed
|
||||
to :func:`id_tap` takes two positional arguments (the value tapped
|
||||
from the device computation along with ``transforms`` sequence,
|
||||
from the device computation along with a ``transforms`` tuple,
|
||||
described below). Optionally, the function may be passed a keyword argument
|
||||
``device`` with the Device from which the value was tapped.
|
||||
|
||||
@ -129,9 +129,8 @@ A few examples::
|
||||
id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y))
|
||||
|
||||
The above examples can all be adapted to use :func:`id_print` instead, with
|
||||
the difference that :func:`id_print` takes one positional argument (to print
|
||||
on the host), and possibly additional kwargs
|
||||
that are also printed along with the automatic kwarg ``transforms``.
|
||||
the difference that :func:`id_print` prints on the host the positional argument,
|
||||
along with any additional kwargs and the automatic kwarg ``transforms``.
|
||||
|
||||
Using :func:`barrier_wait` to wait until all callbacks have executed
|
||||
--------------------------------------------------------------------
|
||||
@ -184,76 +183,8 @@ of the computation, if all the callbacks are :func:`call`::
|
||||
res1.block_until_ready()
|
||||
res2.block_until_ready()
|
||||
|
||||
|
||||
Behavior under JAX transformations
|
||||
----------------------------------
|
||||
|
||||
The :func:`call` does not support any JAX transformations. However, the
|
||||
:func:`id_tap` and :func:`id_print` support all transformations. In this
|
||||
context, it is important that both these functions behave like the identity
|
||||
function::
|
||||
|
||||
# calls func((2x, 3x), []) and returns (2x, 3x)
|
||||
id_tap(func, (2 * x, 3 * x))
|
||||
|
||||
# calls func(2x, []) and returns y
|
||||
y = id_tap(func, 2 * x, result=y) # override the result of id_tap
|
||||
|
||||
We describe the behaviour under transformations for :func:`id_tap` and
|
||||
:func:`id_print` in the context of the
|
||||
following function definition::
|
||||
|
||||
def power3(x):
|
||||
y = x * x
|
||||
# Print both 'x' and 'x^2'
|
||||
_, y = id_print((x, y), what="x,x^2") # Must pack multiple arguments
|
||||
return y * x
|
||||
|
||||
power3(3.)
|
||||
# what: x,x^2 : (3., 9.)
|
||||
|
||||
(You can see these examples tested in `host_callback_test.HostCallbackIdTapTest.test_tap_transforms`.)
|
||||
|
||||
During JAX transformations the special parameter ``transforms`` is added to
|
||||
contain a list of transformation descriptors in the form
|
||||
``(transform_name, transform_params)``.
|
||||
|
||||
For :func:`jax.vmap` the arguments are batched, and ``transforms`` is extended
|
||||
with transformation name ``batch`` and ``batch_dims`` set to the tuple of
|
||||
batched dimensions (one entry per argument, ``None`` denotes an argument that
|
||||
was broadcast)::
|
||||
|
||||
jax.vmap(power3)(np.arange(3.))
|
||||
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : ([0, 1, 2], [0, 1,
|
||||
4])
|
||||
|
||||
For :func:`jax.jvp` there will be one callback with a pair, consisting of
|
||||
the values of the primals and those of the tangents::
|
||||
|
||||
jax.jvp(power3, (3.,), (0.1,))
|
||||
# transforms: ['jvp'] what: x,x^2 : ( (3., 9.), (0.1, 0.6) )
|
||||
|
||||
For :func:`jax.vjp` or :func:`jax.grad` there will be one callback with the
|
||||
values of the adjoints for the arguments. You may also see a callback with
|
||||
the values of the primals from the forward pass, if those values are needed for
|
||||
the backward pass::
|
||||
|
||||
jax.grad(power3)(3.)
|
||||
# what=x,x^2: (3., 9.) # from forward pass, since y is used in backward pass
|
||||
# transforms: ['jvp', 'transpose'] what: x,x^2 : (0., 3.) # from backward pass, adjoints of _, y
|
||||
|
||||
And here is an example of composed transforms. For vmap of grad, we see first
|
||||
a callback with the vmap of the forward pass (with just the 'batch' transform),
|
||||
and another callback with the vmap of the adjoints of the arguments. Note that
|
||||
the first argument is replicated (`batch_dims` is None)::
|
||||
|
||||
jax.vmap(jax.grad(power3))(np.array([2., 3.]))
|
||||
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
|
||||
# ( [2. 3.]
|
||||
# [4. 9.] )
|
||||
# transforms: ['jvp', 'transpose', ('batch', {'batch_dims': (None, 0)})] what: x,x^2
|
||||
# ( 0.
|
||||
# [2. 3.] )
|
||||
Behavior under parallelization transformations
|
||||
----------------------------------------------
|
||||
|
||||
In presence of :func:`jax.pmap` the code will run on multiple devices and
|
||||
each device will tap its values independently.
|
||||
@ -294,7 +225,112 @@ the operand collected
|
||||
from all participating devices on all hosts. For a :func:`call`, the callback
|
||||
must return the entire array for all devices on all hosts.
|
||||
|
||||
Behavior under JAX autodiff transformations
|
||||
-------------------------------------------
|
||||
|
||||
When used under a JAX autodiff transformation, the host callback functions
|
||||
operate on the primal values only. Consider the following example:
|
||||
|
||||
def power3(x):
|
||||
y = x * x
|
||||
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
||||
hcb.id_print((x, y), what="x,x^2")
|
||||
return y * x
|
||||
|
||||
power3(3.)
|
||||
# what: x,x^2 : (3., 9.)
|
||||
|
||||
(You can see these examples tested in `host_callback_test.HostCallbackTapTest.test_tap_transforms`.)
|
||||
|
||||
When used under :func:`jax.jvp` there will be one callback with the primal
|
||||
values only::
|
||||
|
||||
jax.jvp(power3, (3.,), (0.1,))
|
||||
# what: x,x^2 : (3., 9.)
|
||||
|
||||
Similarly for :func:`jax.grad`, we get a callback from the forward computation
|
||||
only::
|
||||
|
||||
jax.grad(power3)(3.)
|
||||
# what: x,x^2 : (3., 9.)
|
||||
|
||||
If you want to invoke the callback on the tangents during a :func:`jax.jvp`,
|
||||
you can use a custom_jvp. For example, you can define a function that does
|
||||
nothing interesting except that its custom_jvp will print the tangents::
|
||||
|
||||
@jax.custom_jvp
|
||||
def print_tangents(arg):
|
||||
return None
|
||||
|
||||
@print_tangents.defjvp
|
||||
def print_tangents_jvp(primals, tangents):
|
||||
arg_dot, = tangents
|
||||
hcb.id_print(arg_dot, what="tangents")
|
||||
return primals, tangents
|
||||
|
||||
Then you use this function in the places where you want to tap the tangents::
|
||||
|
||||
def power3_with_tangents(x):
|
||||
y = x * x
|
||||
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
||||
hcb.id_print((x, y), what="x,x^2")
|
||||
print_tangents((x, y))
|
||||
return y * x
|
||||
|
||||
jax.jvp(power3_with_tangents, (3.,), (0.1,))
|
||||
# what: x,x^2 : (3., 9.)
|
||||
# what: tangents : (0.1, 0.6)
|
||||
|
||||
You can do a similar thing for the cotangents during :func:`jax.grad`. This
|
||||
time you must be careful to use in the rest of the computation the values whose
|
||||
cotangents you want to tap. Hence we make the ``print_cotangents`` return
|
||||
its argument::
|
||||
|
||||
@jax.custom_vjp
|
||||
def print_cotangents(arg):
|
||||
# Must return the argument for which we want the cotangent.
|
||||
return arg
|
||||
|
||||
# f_fwd: a -> (b, residual)
|
||||
def print_cotangents_fwd(arg):
|
||||
return print_cotangents(arg), None
|
||||
# f_bwd: (residual, CT b) -> [CT a]
|
||||
def print_cotangents_bwd(residual, ct_b):
|
||||
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
|
||||
return ct_b,
|
||||
|
||||
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
|
||||
|
||||
def power3_with_cotangents(x):
|
||||
y = x * x
|
||||
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
||||
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
|
||||
(x1, y1) = print_cotangents((x, y))
|
||||
# Must use the output of print_cotangents
|
||||
return y1 * x1
|
||||
|
||||
jax.grad(power3_with_cotangents)(3.)
|
||||
# what: x,x^2 : (3., 9.)
|
||||
# what: cotangents : (9., 3.)
|
||||
|
||||
Behavior under jax.vmap
|
||||
-----------------------
|
||||
|
||||
The host callback functions :func:`id_print` and :func:`id_tap` support the
|
||||
vectorization transformation :func:`jax.vmap`.
|
||||
|
||||
For :func:`jax.vmap` the arguments to the callback are batched,
|
||||
and the callback function is
|
||||
passed an additional special ``transforms`` containing a list of transformation descriptors
|
||||
in the form ``("batch", {"batch_dims": ...})``, where ``...``` denotes the
|
||||
batched dimensions for the tapped values (one entry per argument, `
|
||||
`None`` denotes an argument that was broadcast).
|
||||
|
||||
jax.vmap(power3)(np.array([2., 3.]))
|
||||
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : ([2., 3.], [4., 9.])
|
||||
|
||||
See documentation for :func:`id_tap`, :func:`id_print`, and :func:`call`.
|
||||
|
||||
For more usage example, see tests/host_callback_test.py.
|
||||
|
||||
Using :func:`call` to call a TensorFlow function, with reverse-mode autodiff support
|
||||
@ -444,7 +480,10 @@ import functools
|
||||
import itertools
|
||||
import threading
|
||||
import traceback
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence, Tuple, cast)
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence,
|
||||
Tuple, cast)
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
|
||||
from jax._src import api
|
||||
@ -533,11 +572,16 @@ def id_tap(tap_func, arg, *, result=None, tap_with_device=False, **kwargs):
|
||||
"pre-apply keyword arguments, either by using a closure or by passing "
|
||||
"``functools.partial(tap_func, **kwargs)``.")
|
||||
raise TypeError(msg)
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
warnings.warn('The flag jax_host_callback_ad_transforms is for temporary '
|
||||
'backwards compatibility mode. This flag, and the behavior '
|
||||
'it enabled will be removed soon.',
|
||||
FutureWarning)
|
||||
|
||||
if result is not None:
|
||||
flat_results, result_treedef = pytree.flatten(result)
|
||||
for result in flat_results:
|
||||
api._check_arg(result)
|
||||
for r in flat_results:
|
||||
api._check_arg(r)
|
||||
|
||||
call_res = _call(tap_func, arg, call_with_device=tap_with_device,
|
||||
result_shape=None, identity=True)
|
||||
@ -545,13 +589,16 @@ def id_tap(tap_func, arg, *, result=None, tap_with_device=False, **kwargs):
|
||||
if result is not None:
|
||||
# Return the results, but add a dependency on the call, to ensure it
|
||||
# is kept in the graph.
|
||||
call_flat_results, _ = pytree.flatten(call_res)
|
||||
if call_flat_results:
|
||||
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
|
||||
for r in flat_results]
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
call_flat_results, _ = pytree.flatten(call_res)
|
||||
if call_flat_results:
|
||||
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
|
||||
for r in flat_results]
|
||||
else:
|
||||
call_flat_results = flat_results
|
||||
return result_treedef.unflatten(call_flat_results)
|
||||
else:
|
||||
call_flat_results = flat_results
|
||||
return result_treedef.unflatten(call_flat_results)
|
||||
return result
|
||||
else:
|
||||
return call_res
|
||||
|
||||
@ -768,6 +815,8 @@ xla.register_translation(id_tap_dep_p,
|
||||
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)
|
||||
|
||||
def _id_tap_dep_jvp_rule(primals, tangents):
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assert False
|
||||
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
|
||||
return (id_tap_dep_p.bind(primals[0], primals[1]),
|
||||
id_tap_dep_p.bind(tangents_instantiated[0], tangents_instantiated[1]))
|
||||
@ -775,6 +824,8 @@ def _id_tap_dep_jvp_rule(primals, tangents):
|
||||
ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule
|
||||
|
||||
def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap):
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assert False
|
||||
if ad.is_undefined_primal(arg_res):
|
||||
ct_res = _instantiate_zeros(cts, arg_res)
|
||||
else:
|
||||
@ -789,6 +840,8 @@ ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule
|
||||
|
||||
|
||||
def _id_tap_dep_batching_rule(batched_args, batch_dims):
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assert False
|
||||
arg_res, arg_tap = batched_args
|
||||
return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0]
|
||||
|
||||
@ -797,6 +850,8 @@ batching.primitive_batchers[id_tap_dep_p] = _id_tap_dep_batching_rule
|
||||
|
||||
|
||||
def _id_tap_dep_masking_rule(operands, operands_logical_shapes):
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assert False
|
||||
arg_res, arg_tap = operands
|
||||
return id_tap_dep_p.bind(arg_res, arg_tap)
|
||||
|
||||
@ -1167,20 +1222,24 @@ def _outside_call_jvp_rule(primals, tangents, **params):
|
||||
assert "has_token" not in params
|
||||
if not params["identity"]:
|
||||
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
|
||||
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
|
||||
|
||||
arg_treedef = params["arg_treedef"]
|
||||
# The argument to the jvp tap is a pair of the tapped primals and tangents
|
||||
jvp_flat_args, jvp_arg_treedef = api.tree_flatten(
|
||||
(arg_treedef.unflatten(primals),
|
||||
arg_treedef.unflatten(tangents_instantiated)))
|
||||
out_all = outside_call_p.bind(
|
||||
*jvp_flat_args,
|
||||
**dict(_add_transform(params, "jvp"),
|
||||
arg_treedef=jvp_arg_treedef,
|
||||
))
|
||||
out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)])
|
||||
return tuple(out_primals_tapped), tuple(out_tangents_tapped)
|
||||
arg_treedef = params["arg_treedef"]
|
||||
# The argument to the jvp tap is a pair of the tapped primals and tangents
|
||||
jvp_flat_args, jvp_arg_treedef = api.tree_flatten(
|
||||
(arg_treedef.unflatten(primals),
|
||||
arg_treedef.unflatten(tangents_instantiated)))
|
||||
out_all = outside_call_p.bind(
|
||||
*jvp_flat_args,
|
||||
**dict(_add_transform(params, "jvp"),
|
||||
arg_treedef=jvp_arg_treedef,
|
||||
))
|
||||
out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)])
|
||||
return tuple(out_primals_tapped), tuple(out_tangents_tapped)
|
||||
else:
|
||||
out_primals_tapped = outside_call_p.bind(*primals, **params)
|
||||
return tuple(out_primals_tapped), tangents
|
||||
|
||||
|
||||
ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule
|
||||
@ -1188,6 +1247,9 @@ ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule
|
||||
|
||||
def _outside_call_partial_eval_rule(trace, *args, **params):
|
||||
# partial eval is used after jvp and before transpose.
|
||||
if not FLAGS.jax_host_callback_ad_transforms:
|
||||
# TODO: just remote the partial eval rule
|
||||
return trace.default_process_primitive(outside_call_p, args, params)
|
||||
transforms = params.get("transforms", ())
|
||||
if not transforms or transforms[-1] != ("jvp",):
|
||||
# We are not in the process of computing VJP
|
||||
@ -1252,6 +1314,9 @@ def _outside_call_transpose_rule(cts, *args, **params):
|
||||
*cts_instantiated,
|
||||
**_add_transform(params, "transpose"))
|
||||
|
||||
if not FLAGS.jax_host_callback_ad_transforms:
|
||||
assert False
|
||||
|
||||
assert len(args) % 2 == 0
|
||||
nr_primals = len(args) // 2
|
||||
|
||||
|
@ -904,11 +904,18 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(100., res_primals, check_dtypes=False)
|
||||
self.assertAllClose(4., res_tangents, check_dtypes=False)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: ['jvp'] what: a * 2
|
||||
( 10.00 0.20 )
|
||||
transforms: ['jvp'] what: y * 3
|
||||
( 30.00 0.60 )""", testing_stream.output)
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: ['jvp'] what: a * 2
|
||||
( 10.00 0.20 )
|
||||
transforms: ['jvp'] what: y * 3
|
||||
( 30.00 0.60 )""", testing_stream.output)
|
||||
else:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: a * 2
|
||||
10.00
|
||||
what: y * 3
|
||||
30.00""", testing_stream.output)
|
||||
|
||||
def test_tap_grad_primal_unused(self):
|
||||
# The output of id_print is not needed for backwards pass
|
||||
@ -923,25 +930,39 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
|
||||
treedef = tree_util.tree_structure(arg)
|
||||
assertMultiLineStrippedEqual(self, f"""
|
||||
{{ lambda ; a:f32[]. let
|
||||
b:f32[] = mul a 3.00
|
||||
c:f32[] = outside_call[
|
||||
arg_treedef={treedef}
|
||||
callback=...
|
||||
identity=True
|
||||
transforms=()
|
||||
] b
|
||||
_:f32[] = mul c 2.00
|
||||
d:f32[] = mul 1.00 2.00
|
||||
e:f32[] = outside_call[
|
||||
arg_treedef={treedef}
|
||||
callback=...
|
||||
identity=True
|
||||
transforms=(('jvp',), ('transpose',))
|
||||
] d
|
||||
f:f32[] = mul e 3.00
|
||||
in (f,) }}""", jaxpr)
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assertMultiLineStrippedEqual(self, f"""
|
||||
{{ lambda ; a:f32[]. let
|
||||
b:f32[] = mul a 3.00
|
||||
c:f32[] = outside_call[
|
||||
arg_treedef={treedef}
|
||||
callback=...
|
||||
identity=True
|
||||
transforms=()
|
||||
] b
|
||||
_:f32[] = mul c 2.00
|
||||
d:f32[] = mul 1.00 2.00
|
||||
e:f32[] = outside_call[
|
||||
arg_treedef={treedef}
|
||||
callback=...
|
||||
identity=True
|
||||
transforms=(('jvp',), ('transpose',))
|
||||
] d
|
||||
f:f32[] = mul e 3.00
|
||||
in (f,) }}""", jaxpr)
|
||||
else:
|
||||
assertMultiLineStrippedEqual(self, f"""
|
||||
{{ lambda ; a:f32[]. let
|
||||
b:f32[] = mul a 3.00
|
||||
c:f32[] = outside_call[
|
||||
arg_treedef={treedef}
|
||||
callback=...
|
||||
identity=True
|
||||
] b
|
||||
_:f32[] = mul c 2.00
|
||||
d:f32[] = mul 1.00 2.00
|
||||
e:f32[] = mul d 3.00
|
||||
in (e,) }}""", jaxpr)
|
||||
assertMultiLineStrippedEqual(self, "", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@ -949,11 +970,16 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb.barrier_wait()
|
||||
|
||||
self.assertAllClose(6., res_grad, check_dtypes=False)
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 3
|
||||
15.00
|
||||
transforms: ['jvp', 'transpose'] what: x * 3
|
||||
2.00""", testing_stream.output)
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 3
|
||||
15.00
|
||||
transforms: ['jvp', 'transpose'] what: x * 3
|
||||
2.00""", testing_stream.output)
|
||||
else:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 3
|
||||
15.00""", testing_stream.output)
|
||||
|
||||
def test_tap_grad_simple(self):
|
||||
def func(x):
|
||||
@ -966,15 +992,22 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
res_grad = grad_func(jnp.float32(5.))
|
||||
self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 2
|
||||
10.00
|
||||
what: y * 3
|
||||
30.00
|
||||
transforms: ['jvp', 'transpose'] what: y * 3
|
||||
5.00
|
||||
transforms: ['jvp', 'transpose'] what: x * 2
|
||||
15.00""", testing_stream.output)
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 2
|
||||
10.00
|
||||
what: y * 3
|
||||
30.00
|
||||
transforms: ['jvp', 'transpose'] what: y * 3
|
||||
5.00
|
||||
transforms: ['jvp', 'transpose'] what: x * 2
|
||||
15.00""", testing_stream.output)
|
||||
else:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 2
|
||||
10.00
|
||||
what: y * 3
|
||||
30.00""", testing_stream.output)
|
||||
|
||||
def test_tap_grad_grad(self):
|
||||
def func(x):
|
||||
@ -991,15 +1024,20 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(12., res_grad, check_dtypes=False)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 2
|
||||
10.00
|
||||
transforms: ['jvp', 'transpose'] what: x * 2
|
||||
15.00
|
||||
transforms: ['jvp', 'transpose', 'jvp', 'transpose'] what: x * 2
|
||||
2.00
|
||||
transforms: ['jvp', 'transpose'] what: x * 2
|
||||
3.00""", testing_stream.output)
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 2
|
||||
10.00
|
||||
transforms: ['jvp', 'transpose'] what: x * 2
|
||||
15.00
|
||||
transforms: ['jvp', 'transpose', 'jvp', 'transpose'] what: x * 2
|
||||
2.00
|
||||
transforms: ['jvp', 'transpose'] what: x * 2
|
||||
3.00""", testing_stream.output)
|
||||
else:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 2
|
||||
10.00""", testing_stream.output)
|
||||
|
||||
def test_tap_grad_pytree(self):
|
||||
def func(x):
|
||||
@ -1014,11 +1052,16 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
res_grad = grad_func(x)
|
||||
self.assertAllClose(14., res_grad, check_dtypes=False)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: pair
|
||||
( 10.00 15.00 )
|
||||
transforms: ['jvp', 'transpose'] what: pair
|
||||
( 0.00 0.00 )""", testing_stream.output)
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: pair
|
||||
( 10.00 15.00 )
|
||||
transforms: ['jvp', 'transpose'] what: pair
|
||||
( 0.00 0.00 )""", testing_stream.output)
|
||||
else:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: pair
|
||||
( 10.00 15.00 )""", testing_stream.output)
|
||||
|
||||
def test_tap_jvp_float0(self):
|
||||
def f(x, yint):
|
||||
@ -1038,11 +1081,16 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
res_grad = grad_func(jnp.float32(5.), jnp.int32(2))
|
||||
self.assertAllClose(2., res_grad, check_dtypes=False)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: pair
|
||||
( 5.00 2 )
|
||||
transforms: ['jvp', 'transpose'] what: pair
|
||||
( 2.00 False )""", testing_stream.output)
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: pair
|
||||
( 5.00 2 )
|
||||
transforms: ['jvp', 'transpose'] what: pair
|
||||
( 2.00 False )""", testing_stream.output)
|
||||
else:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: pair
|
||||
( 5.00 2 )""", testing_stream.output)
|
||||
|
||||
def test_tap_grad_float0_result(self):
|
||||
# https://github.com/google/jax/issues/7340
|
||||
@ -1063,10 +1111,14 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(np.array([3., 3.], dtype=np.float32), g[0])
|
||||
self.assertEqual(dtypes.float0, g[1].dtype)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( [0.70 0.80] [11 12 13] )
|
||||
transforms: ['jvp', 'transpose']
|
||||
( [0.00 0.00] [False False False] )""", testing_stream.output)
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( [0.70 0.80] [11 12 13] )
|
||||
transforms: ['jvp', 'transpose']
|
||||
( [0.00 0.00] [False False False] )""", testing_stream.output)
|
||||
else:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( [0.70 0.80] [11 12 13] )""", testing_stream.output)
|
||||
|
||||
def test_tap_higher_order_grad_float0_result(self):
|
||||
# https://github.com/google/jax/issues/7340
|
||||
@ -1101,10 +1153,14 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
f_jax_vjp1, args_vjp1 = wrap_vjp(f_jax, (x,), res)
|
||||
res_vjp1 = f_jax_vjp1(*args_vjp1)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( [0.70 0.80] [11 12 13] )
|
||||
transforms: ['jvp', 'transpose']
|
||||
( [0.00 0.00] [False False False] )""", testing_stream.output)
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( [0.70 0.80] [11 12 13] )
|
||||
transforms: ['jvp', 'transpose']
|
||||
( [0.00 0.00] [False False False] )""", testing_stream.output)
|
||||
else:
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( [0.70 0.80] [11 12 13] )""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
# 2nd order
|
||||
@ -1227,8 +1283,11 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
transforms: [('batch', {'batch_dims': (0,)})] where: 3
|
||||
[2 2 2 3 4]""", testing_stream.output)
|
||||
|
||||
def test_tap_transforms(self):
|
||||
def test_tap_transforms_old_doc(self):
|
||||
if not FLAGS.jax_host_callback_ad_transforms:
|
||||
raise unittest.SkipTest("disabled for new behavior")
|
||||
|
||||
# Examples from the documentation
|
||||
def power3(x):
|
||||
y = x * x
|
||||
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
||||
@ -1278,6 +1337,127 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
( 0. [2. 3.] )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
|
||||
def test_tap_transforms_doc(self):
|
||||
# Examples from the documentation
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
raise unittest.SkipTest("disabled for old behavior")
|
||||
def power3(x):
|
||||
y = x * x
|
||||
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
||||
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
|
||||
return y * x
|
||||
|
||||
print(f"impl = {power3(3.)}")
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3. 9. )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}")
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3. 9. )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@jax.custom_jvp
|
||||
def print_tangents(arg):
|
||||
return None
|
||||
|
||||
@print_tangents.defjvp
|
||||
def print_tangents_jvp(primals, tangents):
|
||||
arg_dot, = tangents
|
||||
hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream)
|
||||
return primals, tangents
|
||||
|
||||
def power3_with_tangents(x):
|
||||
y = x * x
|
||||
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
||||
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
|
||||
print_tangents((x, y))
|
||||
return y * x
|
||||
|
||||
print(f"jvp = {jax.jvp(power3_with_tangents, (3.,), (0.1,))}")
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3. 9. )
|
||||
what: tangents
|
||||
( 0.1 0.6 )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
print(f"grad = {jax.grad(power3)(3.)}")
|
||||
hcb.barrier_wait()
|
||||
# Only the primals by default
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3. 9. )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@jax.custom_vjp
|
||||
def print_cotangents(arg):
|
||||
# Must return the argument for which we want the cotangent.
|
||||
return arg
|
||||
|
||||
# f_fwd: a -> (b, residual)
|
||||
def print_cotangents_fwd(arg):
|
||||
return print_cotangents(arg), None
|
||||
# f_bwd: (residual, CT b) -> [CT a]
|
||||
def print_cotangents_bwd(residual, ct_b):
|
||||
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
|
||||
return ct_b,
|
||||
|
||||
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
|
||||
|
||||
def power3_with_cotangents(x):
|
||||
y = x * x
|
||||
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
||||
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
|
||||
# Must use the output of print_cotangents
|
||||
(x1, y1) = print_cotangents((x, y))
|
||||
return y1 * x1
|
||||
|
||||
print(f"grad = {jax.grad(power3_with_cotangents)(3.)}")
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
what: x,x^2
|
||||
( 3. 9. )
|
||||
what: cotangents
|
||||
( 9. 3. )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
# TODO: grad of grad
|
||||
|
||||
print(f"vmap = {jax.vmap(power3)(np.array([2., 3.]))}")
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
|
||||
( [2. 3.] [4. 9.] )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}")
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
|
||||
( [2. 3.] [4. 9.] )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
print(f"vmap o grad {jax.vmap(jax.grad(power3_with_cotangents))(np.array([2., 3.]))}")
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
|
||||
( [2. 3.] [4. 9.] )
|
||||
transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents
|
||||
( [4. 9.] [2. 3.] )"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
|
||||
def test_tap_pmap(self):
|
||||
if len(local_devices()) < 2:
|
||||
@ -1416,15 +1596,24 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(expected_res, res, check_dtypes=False)
|
||||
# Assertion text is for 2 devices (also works for 1 device)
|
||||
# Device 0 will get to execute jax.jvp(jax.vmap(...)) for matrix[0, :, :]
|
||||
assertMultiDeviceOutputEqual(self, """
|
||||
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
|
||||
( [[ 0.00 2.00 4.00]
|
||||
[20.00 22.00 24.00]] [[0.20 0.20 0.20]
|
||||
[0.20 0.20 0.20]] )
|
||||
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
|
||||
( [[200.00 202.00 204.00]
|
||||
[220.00 222.00 224.00]] [[0.20 0.20 0.20]
|
||||
[0.20 0.20 0.20]] )""")
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
assertMultiDeviceOutputEqual(self, """
|
||||
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
|
||||
( [[ 0.00 2.00 4.00]
|
||||
[20.00 22.00 24.00]] [[0.20 0.20 0.20]
|
||||
[0.20 0.20 0.20]] )
|
||||
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
|
||||
( [[200.00 202.00 204.00]
|
||||
[220.00 222.00 224.00]] [[0.20 0.20 0.20]
|
||||
[0.20 0.20 0.20]] )""")
|
||||
else:
|
||||
assertMultiDeviceOutputEqual(self, """
|
||||
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
||||
[[ 0.00 2.00 4.00]
|
||||
[20.00 22.00 24.00]]
|
||||
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
||||
[[200.00 202.00 204.00]
|
||||
[220.00 222.00 224.00]]""")
|
||||
|
||||
def test_tap_vmap_pmap(self):
|
||||
# A matrix M[ijk] = i * 100 + j * 10 * k
|
||||
@ -1565,7 +1754,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
[[ 3 3 3 3]
|
||||
[33 33 33 33]]""")
|
||||
|
||||
def test_tap_tap_scan_custom_jvp(self):
|
||||
def test_tap_scan_custom_jvp(self):
|
||||
"""custom JVP, inside scan.
|
||||
This exercises the custom_jvp_call_jaxpr primitives."""
|
||||
|
||||
@ -1696,11 +1885,17 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
jax.jvp(lambda arg: padded_sum([arg], dict(n=3)),
|
||||
(x,), (x * 0.1,)))
|
||||
hcb.barrier_wait()
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
transforms: [('mask', {'logical_shapes': 5}), 'jvp'] what: x
|
||||
( ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )
|
||||
( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) ( ( False ) ( False ) ) ) )""",
|
||||
testing_stream.output)
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
transforms: [('mask', {'logical_shapes': 5}), 'jvp'] what: x
|
||||
( ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )
|
||||
( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) ( ( False ) ( False ) ) ) )""",
|
||||
testing_stream.output)
|
||||
else:
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
transforms: [('mask', {'logical_shapes': 5})] what: x
|
||||
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
|
||||
testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
# Now with JIT
|
||||
@ -1829,40 +2024,53 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
use_result=use_result, use_remat=use_remat, grad_func=grad_func)
|
||||
for use_result in [True, False]
|
||||
for grad_func in ["grad", "value_and_grad"]
|
||||
for use_remat in [True, False]))
|
||||
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat=False):
|
||||
for use_remat in ["old", "none"]))
|
||||
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="old"):
|
||||
def f(x):
|
||||
id_print_result = hcb.id_print(x, output_stream=testing_stream)
|
||||
if use_result:
|
||||
x = id_print_result
|
||||
return 3. * x
|
||||
grad_f = jax.grad if grad_func == "grad" else jax.value_and_grad
|
||||
trans_f = jax.remat(f) if use_remat else f
|
||||
if use_remat == "old":
|
||||
trans_f = jax.remat(f)
|
||||
else:
|
||||
assert use_remat == "none"
|
||||
trans_f = f
|
||||
print(jax.make_jaxpr(grad_f(trans_f))(2.))
|
||||
grad_f(trans_f)(2.)
|
||||
|
||||
hcb.barrier_wait()
|
||||
|
||||
if not use_result:
|
||||
if use_remat:
|
||||
expected = ""
|
||||
if use_remat == "none":
|
||||
if use_result:
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
expected = """
|
||||
2.
|
||||
transforms: ['jvp', 'transpose']
|
||||
3."""
|
||||
else:
|
||||
# GOOD: whether or not we use_result, in absence of
|
||||
# jax_host_callback_ad_transforms we get the same callback.
|
||||
expected = "2."
|
||||
else:
|
||||
# TODO: if not use_result then we should only see the primal when
|
||||
# computing value_and_grad.
|
||||
expected = "2."
|
||||
elif use_remat:
|
||||
expected = """
|
||||
2.
|
||||
2.
|
||||
transforms: ['jvp', 'transpose']
|
||||
3."""
|
||||
else:
|
||||
expected = """
|
||||
2.
|
||||
transforms: ['jvp', 'transpose']
|
||||
3."""
|
||||
else: # use_remat
|
||||
if use_result:
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
expected = """
|
||||
2.
|
||||
2.
|
||||
transforms: ['jvp', 'transpose']
|
||||
3."""
|
||||
else:
|
||||
expected = """
|
||||
2.
|
||||
2."""
|
||||
else:
|
||||
# TODO: we should see two callbacks
|
||||
expected = ""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_tap_named_call(self):
|
||||
def tap_scalar(init, do_print=False):
|
||||
@ -1884,8 +2092,6 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
|
||||
|
||||
|
||||
|
||||
class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
"""Tests for hcb.call"""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user