From d323ad0f2b44a3178e02867992ba7243d7e0f70c Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 29 Mar 2021 17:21:56 +0300 Subject: [PATCH] [host_callback] Add support for tapping empty arrays We make sure that both the inputs and the outputs of callbacks can contain empty arrays. Most platforms do not support empty infeed, so we ensure we do not send those. --- CHANGELOG.md | 3 +- docs/sphinxext/jax_extensions.py | 7 ++-- jax/experimental/host_callback.py | 52 +++++++++++++++++++------ tests/host_callback_test.py | 64 +++++++++++++++++++++++++++++++ 4 files changed, 111 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c0655fc3..f358e19c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. for more information. * Python integers larger than the maximum `int64` value will now lead to an overflow in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`). - +* Bug fixes: + * `host_callback` now supports empty arrays in arguments and results ({jax-issue}`#6262`). ## jaxlib 0.1.65 (unreleased) diff --git a/docs/sphinxext/jax_extensions.py b/docs/sphinxext/jax_extensions.py index 84cd923bc..b6ba7ecb8 100644 --- a/docs/sphinxext/jax_extensions.py +++ b/docs/sphinxext/jax_extensions.py @@ -15,18 +15,19 @@ from docutils import nodes def jax_issue_role(name, rawtext, text, lineno, inliner, options={}, content=[]): - """Generate links to jax issues in sphinx. + """Generate links to jax issues or PRs in sphinx. Usage:: :jax-issue:`1234` This will output a hyperlink of the form - `#1234 `_. + `#1234 `_. These links work even + for PR numbers. """ text = text.lstrip('#') if not text.isdigit(): - raise RuntimeError(f"Invalid content in {rawtext}: expected an issue number.") + raise RuntimeError(f"Invalid content in {rawtext}: expected an issue or PR number.") url = "https://github.com/google/jax/issues/{}".format(text) node = nodes.reference(rawtext, '#' + text, refuri=url, **options) return [node], [] diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index a84a23b04..272ba7d06 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -272,6 +272,8 @@ program as: RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ... ``` +To debug the underlying cause for these messages, see the Debugging section. + On TPU, there is currently no shape check for infeed, so we take the safer route to not send anything in case of errors, and let the computation hang. @@ -318,9 +320,14 @@ for the C++ outfeed `receiver backend You should also use the ``--verbosity=2`` flag so that you see the logs from Python. -For example: +For example, you can try to enable logging in the ``host_callback`` module: ``` -TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_jit_simple +TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple +``` + +If you want to enable logging in lower-level implementation modules try: +``` +TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple ``` (For bazel tests use --test_arg=--vmodule=... @@ -783,30 +790,47 @@ def _outside_call_translation_rule( "The last two arguments must be tokens") args_to_outfeed = args_op[:-2] - send_infeed = not params["identity"] and len(params["flat_results_aval"]) > 0 + identity = params["identity"] + flat_results_aval = params["flat_results_aval"] if not identity else [] + # Many platforms refuse to infeed empty arrays. We generate constants + # instead. + non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)), + flat_results_aval)) + send_infeed = not identity and len(non_empty_flat_results_aval) > 0 callback_id = _register_callback( functools.partial(_outside_call_run_callback, send_infeed=send_infeed, **params)) next_token = _outfeed_receiver.receiver.add_outfeed(comp, current_token, callback_id, args_to_outfeed) expecting_infeed = False - if params["identity"]: + if identity: results = list(args_to_outfeed) next_itoken = current_itoken else: - flat_results_aval = params["flat_results_aval"] - if flat_results_aval: + empty_results = [ + xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype)) + for aval in flat_results_aval + if _aval_is_empty(aval) + ] + if non_empty_flat_results_aval: after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token]) results_and_token = xla.translations[lax.infeed_p](comp, after_outfeed_itoken, - shapes=flat_results_aval, partitions=None) + shapes=non_empty_flat_results_aval, + partitions=None) expecting_infeed = True - next_itoken = xops.GetTupleElement(results_and_token, len(flat_results_aval)) - results = [xops.GetTupleElement(results_and_token, i) for i in range(len(flat_results_aval))] + next_itoken = xops.GetTupleElement(results_and_token, len(non_empty_flat_results_aval)) + non_empty_results = [xops.GetTupleElement(results_and_token, i) + for i in range(len(non_empty_flat_results_aval))] + results = [ + empty_results.pop(0) if _aval_is_empty(result_aval) else non_empty_results.pop(0) + for result_aval in flat_results_aval] else: - results = [] + results = empty_results next_itoken = current_itoken + assert len(results) == len(flat_results_aval) + assert expecting_infeed == send_infeed return xops.Tuple(comp, results + [next_token, next_itoken]) @@ -877,7 +901,10 @@ def _outside_call_run_callback( raise TypeError(msg) if send_infeed: - device.transfer_to_infeed(tuple(canonical_flat_results)) + # Do not send the 0-sized arrays + non_empty_canonical_flat_results = tuple(filter(lambda r: not _aval_is_empty(r), + canonical_flat_results)) + device.transfer_to_infeed(non_empty_canonical_flat_results) return canonical_flat_results except Exception as e: @@ -910,6 +937,9 @@ def _add_transform(params: Dict, name: str, *transform_params) -> Dict: params, transforms=(params.get("transforms", ()) + (new_transform,))) +def _aval_is_empty(aval) -> bool: + return np.prod(aval.shape) == 0 + # TODO(necula): there must be a better way to do this. # The AttributeError is for regular values, the KeyError is for ConcreteArray def _instantiate_zeros(arg, tan): diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 8d5a657d2..0ea2beffd 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -330,6 +330,18 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): 3""", testing_stream.output) testing_stream.reset() + def test_tap_empty(self): + """Tap empty arrays.""" + hcb.id_print((), output_stream=testing_stream) + hcb.id_print((1., np.ones((2, 0))), what="second", output_stream=testing_stream) + hcb.barrier_wait() + assertMultiLineStrippedEqual(self, """ + ( ) + what: second + ( 1.00 + [] )""", testing_stream.output) + testing_stream.reset() + def test_tap_jit_simple(self): jit_fun1 = api.jit(lambda x: 3. * hcb.id_print( 2. * x, what="here", output_stream=testing_stream)) @@ -1740,6 +1752,58 @@ class HostCallbackCallTest(jtu.JaxTestCase): res_inside = fun(2, use_outside=False) self.assertAllClose(res_inside, fun(2, use_outside=True)) + def test_call_empty_arg(self): + """Call with empty array.""" + result = np.ones((2,), dtype=np.float32) + def f_outside(_): + return result + def fun(x): + return x + hcb.call(f_outside, (), + result_shape=api.ShapeDtypeStruct(result.shape, result.dtype)) + self.assertAllClose(2. + result, fun(2.)) + + def test_call_empty_result(self): + """Call returning empty array.""" + result_shape = (2, 0) + def f_outside(_): + return np.ones(result_shape, dtype=np.float32) + def fun(x): + return x + hcb.call(f_outside, 1., + result_shape=api.ShapeDtypeStruct(result_shape, np.float32)) + self.assertAllClose(f_outside(0.), fun(2.)) + + def test_call_empty_result_inside_pytree(self): + """Call returning a tuple with an empty array and a non-empty one.""" + result_shape_0 = (2, 0) + result_shape_2 = (0,) + def f_outside(_): + return (np.ones(result_shape_0, dtype=np.float32), + np.ones((1,), dtype=np.float32), + np.ones(result_shape_2, dtype=np.float32)) + def fun(x): + res = hcb.call(f_outside, 1., + result_shape=(api.ShapeDtypeStruct(result_shape_0, np.float32), + api.ShapeDtypeStruct((1,), np.float32), + api.ShapeDtypeStruct(result_shape_2, np.float32))) + self.assertEqual(result_shape_0, res[0].shape) + self.assertEqual(result_shape_2, res[2].shape) + return x + res[1] + self.assertAllClose(2 + np.ones((1,), dtype=np.float32), fun(2.)) + + def test_call_empty_result_all_pytree(self): + """Call returning a tuple of empty arrays.""" + result_shape = (2, 0) + def f_outside(_): + return (np.ones(result_shape, dtype=np.float32), + np.ones(result_shape, dtype=np.float32)) + def fun(x): + res = hcb.call(f_outside, 1., + result_shape=(api.ShapeDtypeStruct(result_shape, np.float32), + api.ShapeDtypeStruct(result_shape, np.float32))) + return x + res[0] + res[1] + self.assertAllClose(np.ones(result_shape, dtype=np.float32), + fun(2.)) + def test_call_no_result(self): def f_outside(arg): self.call_log_testing_stream(lambda x: None, arg,