diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c0655fc3..f358e19c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. for more information. * Python integers larger than the maximum `int64` value will now lead to an overflow in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`). - +* Bug fixes: + * `host_callback` now supports empty arrays in arguments and results ({jax-issue}`#6262`). ## jaxlib 0.1.65 (unreleased) diff --git a/docs/sphinxext/jax_extensions.py b/docs/sphinxext/jax_extensions.py index 84cd923bc..b6ba7ecb8 100644 --- a/docs/sphinxext/jax_extensions.py +++ b/docs/sphinxext/jax_extensions.py @@ -15,18 +15,19 @@ from docutils import nodes def jax_issue_role(name, rawtext, text, lineno, inliner, options={}, content=[]): - """Generate links to jax issues in sphinx. + """Generate links to jax issues or PRs in sphinx. Usage:: :jax-issue:`1234` This will output a hyperlink of the form - `#1234 `_. + `#1234 `_. These links work even + for PR numbers. """ text = text.lstrip('#') if not text.isdigit(): - raise RuntimeError(f"Invalid content in {rawtext}: expected an issue number.") + raise RuntimeError(f"Invalid content in {rawtext}: expected an issue or PR number.") url = "https://github.com/google/jax/issues/{}".format(text) node = nodes.reference(rawtext, '#' + text, refuri=url, **options) return [node], [] diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index a84a23b04..272ba7d06 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -272,6 +272,8 @@ program as: RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ... ``` +To debug the underlying cause for these messages, see the Debugging section. + On TPU, there is currently no shape check for infeed, so we take the safer route to not send anything in case of errors, and let the computation hang. @@ -318,9 +320,14 @@ for the C++ outfeed `receiver backend You should also use the ``--verbosity=2`` flag so that you see the logs from Python. -For example: +For example, you can try to enable logging in the ``host_callback`` module: ``` -TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_jit_simple +TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple +``` + +If you want to enable logging in lower-level implementation modules try: +``` +TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple ``` (For bazel tests use --test_arg=--vmodule=... @@ -783,30 +790,47 @@ def _outside_call_translation_rule( "The last two arguments must be tokens") args_to_outfeed = args_op[:-2] - send_infeed = not params["identity"] and len(params["flat_results_aval"]) > 0 + identity = params["identity"] + flat_results_aval = params["flat_results_aval"] if not identity else [] + # Many platforms refuse to infeed empty arrays. We generate constants + # instead. + non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)), + flat_results_aval)) + send_infeed = not identity and len(non_empty_flat_results_aval) > 0 callback_id = _register_callback( functools.partial(_outside_call_run_callback, send_infeed=send_infeed, **params)) next_token = _outfeed_receiver.receiver.add_outfeed(comp, current_token, callback_id, args_to_outfeed) expecting_infeed = False - if params["identity"]: + if identity: results = list(args_to_outfeed) next_itoken = current_itoken else: - flat_results_aval = params["flat_results_aval"] - if flat_results_aval: + empty_results = [ + xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype)) + for aval in flat_results_aval + if _aval_is_empty(aval) + ] + if non_empty_flat_results_aval: after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token]) results_and_token = xla.translations[lax.infeed_p](comp, after_outfeed_itoken, - shapes=flat_results_aval, partitions=None) + shapes=non_empty_flat_results_aval, + partitions=None) expecting_infeed = True - next_itoken = xops.GetTupleElement(results_and_token, len(flat_results_aval)) - results = [xops.GetTupleElement(results_and_token, i) for i in range(len(flat_results_aval))] + next_itoken = xops.GetTupleElement(results_and_token, len(non_empty_flat_results_aval)) + non_empty_results = [xops.GetTupleElement(results_and_token, i) + for i in range(len(non_empty_flat_results_aval))] + results = [ + empty_results.pop(0) if _aval_is_empty(result_aval) else non_empty_results.pop(0) + for result_aval in flat_results_aval] else: - results = [] + results = empty_results next_itoken = current_itoken + assert len(results) == len(flat_results_aval) + assert expecting_infeed == send_infeed return xops.Tuple(comp, results + [next_token, next_itoken]) @@ -877,7 +901,10 @@ def _outside_call_run_callback( raise TypeError(msg) if send_infeed: - device.transfer_to_infeed(tuple(canonical_flat_results)) + # Do not send the 0-sized arrays + non_empty_canonical_flat_results = tuple(filter(lambda r: not _aval_is_empty(r), + canonical_flat_results)) + device.transfer_to_infeed(non_empty_canonical_flat_results) return canonical_flat_results except Exception as e: @@ -910,6 +937,9 @@ def _add_transform(params: Dict, name: str, *transform_params) -> Dict: params, transforms=(params.get("transforms", ()) + (new_transform,))) +def _aval_is_empty(aval) -> bool: + return np.prod(aval.shape) == 0 + # TODO(necula): there must be a better way to do this. # The AttributeError is for regular values, the KeyError is for ConcreteArray def _instantiate_zeros(arg, tan): diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 8d5a657d2..0ea2beffd 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -330,6 +330,18 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): 3""", testing_stream.output) testing_stream.reset() + def test_tap_empty(self): + """Tap empty arrays.""" + hcb.id_print((), output_stream=testing_stream) + hcb.id_print((1., np.ones((2, 0))), what="second", output_stream=testing_stream) + hcb.barrier_wait() + assertMultiLineStrippedEqual(self, """ + ( ) + what: second + ( 1.00 + [] )""", testing_stream.output) + testing_stream.reset() + def test_tap_jit_simple(self): jit_fun1 = api.jit(lambda x: 3. * hcb.id_print( 2. * x, what="here", output_stream=testing_stream)) @@ -1740,6 +1752,58 @@ class HostCallbackCallTest(jtu.JaxTestCase): res_inside = fun(2, use_outside=False) self.assertAllClose(res_inside, fun(2, use_outside=True)) + def test_call_empty_arg(self): + """Call with empty array.""" + result = np.ones((2,), dtype=np.float32) + def f_outside(_): + return result + def fun(x): + return x + hcb.call(f_outside, (), + result_shape=api.ShapeDtypeStruct(result.shape, result.dtype)) + self.assertAllClose(2. + result, fun(2.)) + + def test_call_empty_result(self): + """Call returning empty array.""" + result_shape = (2, 0) + def f_outside(_): + return np.ones(result_shape, dtype=np.float32) + def fun(x): + return x + hcb.call(f_outside, 1., + result_shape=api.ShapeDtypeStruct(result_shape, np.float32)) + self.assertAllClose(f_outside(0.), fun(2.)) + + def test_call_empty_result_inside_pytree(self): + """Call returning a tuple with an empty array and a non-empty one.""" + result_shape_0 = (2, 0) + result_shape_2 = (0,) + def f_outside(_): + return (np.ones(result_shape_0, dtype=np.float32), + np.ones((1,), dtype=np.float32), + np.ones(result_shape_2, dtype=np.float32)) + def fun(x): + res = hcb.call(f_outside, 1., + result_shape=(api.ShapeDtypeStruct(result_shape_0, np.float32), + api.ShapeDtypeStruct((1,), np.float32), + api.ShapeDtypeStruct(result_shape_2, np.float32))) + self.assertEqual(result_shape_0, res[0].shape) + self.assertEqual(result_shape_2, res[2].shape) + return x + res[1] + self.assertAllClose(2 + np.ones((1,), dtype=np.float32), fun(2.)) + + def test_call_empty_result_all_pytree(self): + """Call returning a tuple of empty arrays.""" + result_shape = (2, 0) + def f_outside(_): + return (np.ones(result_shape, dtype=np.float32), + np.ones(result_shape, dtype=np.float32)) + def fun(x): + res = hcb.call(f_outside, 1., + result_shape=(api.ShapeDtypeStruct(result_shape, np.float32), + api.ShapeDtypeStruct(result_shape, np.float32))) + return x + res[0] + res[1] + self.assertAllClose(np.ones(result_shape, dtype=np.float32), + fun(2.)) + def test_call_no_result(self): def f_outside(arg): self.call_log_testing_stream(lambda x: None, arg,