[host_callback] Add support for tapping empty arrays

We make sure that both the inputs and the outputs of
callbacks can contain empty arrays.
Most platforms do not support empty infeed, so we ensure
we do not send those.
This commit is contained in:
George Necula 2021-03-29 17:21:56 +03:00
parent 4fc8fb57ee
commit d323ad0f2b
4 changed files with 111 additions and 15 deletions

View File

@ -25,7 +25,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
for more information.
* Python integers larger than the maximum `int64` value will now lead to an overflow
in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`).
* Bug fixes:
* `host_callback` now supports empty arrays in arguments and results ({jax-issue}`#6262`).
## jaxlib 0.1.65 (unreleased)

View File

@ -15,18 +15,19 @@
from docutils import nodes
def jax_issue_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
"""Generate links to jax issues in sphinx.
"""Generate links to jax issues or PRs in sphinx.
Usage::
:jax-issue:`1234`
This will output a hyperlink of the form
`#1234 <http://github.com/google.jax/issues/1234>`_.
`#1234 <http://github.com/google/jax/issues/1234>`_. These links work even
for PR numbers.
"""
text = text.lstrip('#')
if not text.isdigit():
raise RuntimeError(f"Invalid content in {rawtext}: expected an issue number.")
raise RuntimeError(f"Invalid content in {rawtext}: expected an issue or PR number.")
url = "https://github.com/google/jax/issues/{}".format(text)
node = nodes.reference(rawtext, '#' + text, refuri=url, **options)
return [node], []

View File

@ -272,6 +272,8 @@ program as:
RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ...
```
To debug the underlying cause for these messages, see the Debugging section.
On TPU, there is currently no shape check for infeed, so we take the safer
route to not send anything in case of errors, and let the computation hang.
@ -318,9 +320,14 @@ for the C++ outfeed `receiver backend
You should also use the ``--verbosity=2`` flag so that you see the logs from Python.
For example:
For example, you can try to enable logging in the ``host_callback`` module:
```
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_jit_simple
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
```
If you want to enable logging in lower-level implementation modules try:
```
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
```
(For bazel tests use --test_arg=--vmodule=...
@ -783,30 +790,47 @@ def _outside_call_translation_rule(
"The last two arguments must be tokens")
args_to_outfeed = args_op[:-2]
send_infeed = not params["identity"] and len(params["flat_results_aval"]) > 0
identity = params["identity"]
flat_results_aval = params["flat_results_aval"] if not identity else []
# Many platforms refuse to infeed empty arrays. We generate constants
# instead.
non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)),
flat_results_aval))
send_infeed = not identity and len(non_empty_flat_results_aval) > 0
callback_id = _register_callback(
functools.partial(_outside_call_run_callback, send_infeed=send_infeed, **params))
next_token = _outfeed_receiver.receiver.add_outfeed(comp, current_token,
callback_id,
args_to_outfeed)
expecting_infeed = False
if params["identity"]:
if identity:
results = list(args_to_outfeed)
next_itoken = current_itoken
else:
flat_results_aval = params["flat_results_aval"]
if flat_results_aval:
empty_results = [
xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype))
for aval in flat_results_aval
if _aval_is_empty(aval)
]
if non_empty_flat_results_aval:
after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token])
results_and_token = xla.translations[lax.infeed_p](comp, after_outfeed_itoken,
shapes=flat_results_aval, partitions=None)
shapes=non_empty_flat_results_aval,
partitions=None)
expecting_infeed = True
next_itoken = xops.GetTupleElement(results_and_token, len(flat_results_aval))
results = [xops.GetTupleElement(results_and_token, i) for i in range(len(flat_results_aval))]
next_itoken = xops.GetTupleElement(results_and_token, len(non_empty_flat_results_aval))
non_empty_results = [xops.GetTupleElement(results_and_token, i)
for i in range(len(non_empty_flat_results_aval))]
results = [
empty_results.pop(0) if _aval_is_empty(result_aval) else non_empty_results.pop(0)
for result_aval in flat_results_aval]
else:
results = []
results = empty_results
next_itoken = current_itoken
assert len(results) == len(flat_results_aval)
assert expecting_infeed == send_infeed
return xops.Tuple(comp, results + [next_token, next_itoken])
@ -877,7 +901,10 @@ def _outside_call_run_callback(
raise TypeError(msg)
if send_infeed:
device.transfer_to_infeed(tuple(canonical_flat_results))
# Do not send the 0-sized arrays
non_empty_canonical_flat_results = tuple(filter(lambda r: not _aval_is_empty(r),
canonical_flat_results))
device.transfer_to_infeed(non_empty_canonical_flat_results)
return canonical_flat_results
except Exception as e:
@ -910,6 +937,9 @@ def _add_transform(params: Dict, name: str, *transform_params) -> Dict:
params, transforms=(params.get("transforms", ()) + (new_transform,)))
def _aval_is_empty(aval) -> bool:
return np.prod(aval.shape) == 0
# TODO(necula): there must be a better way to do this.
# The AttributeError is for regular values, the KeyError is for ConcreteArray
def _instantiate_zeros(arg, tan):

View File

@ -330,6 +330,18 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
3""", testing_stream.output)
testing_stream.reset()
def test_tap_empty(self):
"""Tap empty arrays."""
hcb.id_print((), output_stream=testing_stream)
hcb.id_print((1., np.ones((2, 0))), what="second", output_stream=testing_stream)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( )
what: second
( 1.00
[] )""", testing_stream.output)
testing_stream.reset()
def test_tap_jit_simple(self):
jit_fun1 = api.jit(lambda x: 3. * hcb.id_print(
2. * x, what="here", output_stream=testing_stream))
@ -1740,6 +1752,58 @@ class HostCallbackCallTest(jtu.JaxTestCase):
res_inside = fun(2, use_outside=False)
self.assertAllClose(res_inside, fun(2, use_outside=True))
def test_call_empty_arg(self):
"""Call with empty array."""
result = np.ones((2,), dtype=np.float32)
def f_outside(_):
return result
def fun(x):
return x + hcb.call(f_outside, (),
result_shape=api.ShapeDtypeStruct(result.shape, result.dtype))
self.assertAllClose(2. + result, fun(2.))
def test_call_empty_result(self):
"""Call returning empty array."""
result_shape = (2, 0)
def f_outside(_):
return np.ones(result_shape, dtype=np.float32)
def fun(x):
return x + hcb.call(f_outside, 1.,
result_shape=api.ShapeDtypeStruct(result_shape, np.float32))
self.assertAllClose(f_outside(0.), fun(2.))
def test_call_empty_result_inside_pytree(self):
"""Call returning a tuple with an empty array and a non-empty one."""
result_shape_0 = (2, 0)
result_shape_2 = (0,)
def f_outside(_):
return (np.ones(result_shape_0, dtype=np.float32),
np.ones((1,), dtype=np.float32),
np.ones(result_shape_2, dtype=np.float32))
def fun(x):
res = hcb.call(f_outside, 1.,
result_shape=(api.ShapeDtypeStruct(result_shape_0, np.float32),
api.ShapeDtypeStruct((1,), np.float32),
api.ShapeDtypeStruct(result_shape_2, np.float32)))
self.assertEqual(result_shape_0, res[0].shape)
self.assertEqual(result_shape_2, res[2].shape)
return x + res[1]
self.assertAllClose(2 + np.ones((1,), dtype=np.float32), fun(2.))
def test_call_empty_result_all_pytree(self):
"""Call returning a tuple of empty arrays."""
result_shape = (2, 0)
def f_outside(_):
return (np.ones(result_shape, dtype=np.float32),
np.ones(result_shape, dtype=np.float32))
def fun(x):
res = hcb.call(f_outside, 1.,
result_shape=(api.ShapeDtypeStruct(result_shape, np.float32),
api.ShapeDtypeStruct(result_shape, np.float32)))
return x + res[0] + res[1]
self.assertAllClose(np.ones(result_shape, dtype=np.float32),
fun(2.))
def test_call_no_result(self):
def f_outside(arg):
self.call_log_testing_stream(lambda x: None, arg,