mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[host_callback] Add support for tapping empty arrays
We make sure that both the inputs and the outputs of callbacks can contain empty arrays. Most platforms do not support empty infeed, so we ensure we do not send those.
This commit is contained in:
parent
4fc8fb57ee
commit
d323ad0f2b
@ -25,7 +25,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
for more information.
|
||||
* Python integers larger than the maximum `int64` value will now lead to an overflow
|
||||
in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`).
|
||||
|
||||
* Bug fixes:
|
||||
* `host_callback` now supports empty arrays in arguments and results ({jax-issue}`#6262`).
|
||||
|
||||
## jaxlib 0.1.65 (unreleased)
|
||||
|
||||
|
@ -15,18 +15,19 @@
|
||||
from docutils import nodes
|
||||
|
||||
def jax_issue_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
|
||||
"""Generate links to jax issues in sphinx.
|
||||
"""Generate links to jax issues or PRs in sphinx.
|
||||
|
||||
Usage::
|
||||
|
||||
:jax-issue:`1234`
|
||||
|
||||
This will output a hyperlink of the form
|
||||
`#1234 <http://github.com/google.jax/issues/1234>`_.
|
||||
`#1234 <http://github.com/google/jax/issues/1234>`_. These links work even
|
||||
for PR numbers.
|
||||
"""
|
||||
text = text.lstrip('#')
|
||||
if not text.isdigit():
|
||||
raise RuntimeError(f"Invalid content in {rawtext}: expected an issue number.")
|
||||
raise RuntimeError(f"Invalid content in {rawtext}: expected an issue or PR number.")
|
||||
url = "https://github.com/google/jax/issues/{}".format(text)
|
||||
node = nodes.reference(rawtext, '#' + text, refuri=url, **options)
|
||||
return [node], []
|
||||
|
@ -272,6 +272,8 @@ program as:
|
||||
RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ...
|
||||
```
|
||||
|
||||
To debug the underlying cause for these messages, see the Debugging section.
|
||||
|
||||
On TPU, there is currently no shape check for infeed, so we take the safer
|
||||
route to not send anything in case of errors, and let the computation hang.
|
||||
|
||||
@ -318,9 +320,14 @@ for the C++ outfeed `receiver backend
|
||||
|
||||
You should also use the ``--verbosity=2`` flag so that you see the logs from Python.
|
||||
|
||||
For example:
|
||||
For example, you can try to enable logging in the ``host_callback`` module:
|
||||
```
|
||||
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_jit_simple
|
||||
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
|
||||
```
|
||||
|
||||
If you want to enable logging in lower-level implementation modules try:
|
||||
```
|
||||
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
|
||||
```
|
||||
|
||||
(For bazel tests use --test_arg=--vmodule=...
|
||||
@ -783,30 +790,47 @@ def _outside_call_translation_rule(
|
||||
"The last two arguments must be tokens")
|
||||
|
||||
args_to_outfeed = args_op[:-2]
|
||||
send_infeed = not params["identity"] and len(params["flat_results_aval"]) > 0
|
||||
identity = params["identity"]
|
||||
flat_results_aval = params["flat_results_aval"] if not identity else []
|
||||
# Many platforms refuse to infeed empty arrays. We generate constants
|
||||
# instead.
|
||||
non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)),
|
||||
flat_results_aval))
|
||||
send_infeed = not identity and len(non_empty_flat_results_aval) > 0
|
||||
callback_id = _register_callback(
|
||||
functools.partial(_outside_call_run_callback, send_infeed=send_infeed, **params))
|
||||
next_token = _outfeed_receiver.receiver.add_outfeed(comp, current_token,
|
||||
callback_id,
|
||||
args_to_outfeed)
|
||||
expecting_infeed = False
|
||||
if params["identity"]:
|
||||
if identity:
|
||||
results = list(args_to_outfeed)
|
||||
next_itoken = current_itoken
|
||||
else:
|
||||
flat_results_aval = params["flat_results_aval"]
|
||||
if flat_results_aval:
|
||||
empty_results = [
|
||||
xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype))
|
||||
for aval in flat_results_aval
|
||||
if _aval_is_empty(aval)
|
||||
]
|
||||
if non_empty_flat_results_aval:
|
||||
after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token])
|
||||
|
||||
results_and_token = xla.translations[lax.infeed_p](comp, after_outfeed_itoken,
|
||||
shapes=flat_results_aval, partitions=None)
|
||||
shapes=non_empty_flat_results_aval,
|
||||
partitions=None)
|
||||
expecting_infeed = True
|
||||
next_itoken = xops.GetTupleElement(results_and_token, len(flat_results_aval))
|
||||
results = [xops.GetTupleElement(results_and_token, i) for i in range(len(flat_results_aval))]
|
||||
next_itoken = xops.GetTupleElement(results_and_token, len(non_empty_flat_results_aval))
|
||||
non_empty_results = [xops.GetTupleElement(results_and_token, i)
|
||||
for i in range(len(non_empty_flat_results_aval))]
|
||||
results = [
|
||||
empty_results.pop(0) if _aval_is_empty(result_aval) else non_empty_results.pop(0)
|
||||
for result_aval in flat_results_aval]
|
||||
else:
|
||||
results = []
|
||||
results = empty_results
|
||||
next_itoken = current_itoken
|
||||
|
||||
assert len(results) == len(flat_results_aval)
|
||||
|
||||
assert expecting_infeed == send_infeed
|
||||
return xops.Tuple(comp, results + [next_token, next_itoken])
|
||||
|
||||
@ -877,7 +901,10 @@ def _outside_call_run_callback(
|
||||
raise TypeError(msg)
|
||||
|
||||
if send_infeed:
|
||||
device.transfer_to_infeed(tuple(canonical_flat_results))
|
||||
# Do not send the 0-sized arrays
|
||||
non_empty_canonical_flat_results = tuple(filter(lambda r: not _aval_is_empty(r),
|
||||
canonical_flat_results))
|
||||
device.transfer_to_infeed(non_empty_canonical_flat_results)
|
||||
return canonical_flat_results
|
||||
|
||||
except Exception as e:
|
||||
@ -910,6 +937,9 @@ def _add_transform(params: Dict, name: str, *transform_params) -> Dict:
|
||||
params, transforms=(params.get("transforms", ()) + (new_transform,)))
|
||||
|
||||
|
||||
def _aval_is_empty(aval) -> bool:
|
||||
return np.prod(aval.shape) == 0
|
||||
|
||||
# TODO(necula): there must be a better way to do this.
|
||||
# The AttributeError is for regular values, the KeyError is for ConcreteArray
|
||||
def _instantiate_zeros(arg, tan):
|
||||
|
@ -330,6 +330,18 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
3""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_tap_empty(self):
|
||||
"""Tap empty arrays."""
|
||||
hcb.id_print((), output_stream=testing_stream)
|
||||
hcb.id_print((1., np.ones((2, 0))), what="second", output_stream=testing_stream)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
( )
|
||||
what: second
|
||||
( 1.00
|
||||
[] )""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_tap_jit_simple(self):
|
||||
jit_fun1 = api.jit(lambda x: 3. * hcb.id_print(
|
||||
2. * x, what="here", output_stream=testing_stream))
|
||||
@ -1740,6 +1752,58 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
res_inside = fun(2, use_outside=False)
|
||||
self.assertAllClose(res_inside, fun(2, use_outside=True))
|
||||
|
||||
def test_call_empty_arg(self):
|
||||
"""Call with empty array."""
|
||||
result = np.ones((2,), dtype=np.float32)
|
||||
def f_outside(_):
|
||||
return result
|
||||
def fun(x):
|
||||
return x + hcb.call(f_outside, (),
|
||||
result_shape=api.ShapeDtypeStruct(result.shape, result.dtype))
|
||||
self.assertAllClose(2. + result, fun(2.))
|
||||
|
||||
def test_call_empty_result(self):
|
||||
"""Call returning empty array."""
|
||||
result_shape = (2, 0)
|
||||
def f_outside(_):
|
||||
return np.ones(result_shape, dtype=np.float32)
|
||||
def fun(x):
|
||||
return x + hcb.call(f_outside, 1.,
|
||||
result_shape=api.ShapeDtypeStruct(result_shape, np.float32))
|
||||
self.assertAllClose(f_outside(0.), fun(2.))
|
||||
|
||||
def test_call_empty_result_inside_pytree(self):
|
||||
"""Call returning a tuple with an empty array and a non-empty one."""
|
||||
result_shape_0 = (2, 0)
|
||||
result_shape_2 = (0,)
|
||||
def f_outside(_):
|
||||
return (np.ones(result_shape_0, dtype=np.float32),
|
||||
np.ones((1,), dtype=np.float32),
|
||||
np.ones(result_shape_2, dtype=np.float32))
|
||||
def fun(x):
|
||||
res = hcb.call(f_outside, 1.,
|
||||
result_shape=(api.ShapeDtypeStruct(result_shape_0, np.float32),
|
||||
api.ShapeDtypeStruct((1,), np.float32),
|
||||
api.ShapeDtypeStruct(result_shape_2, np.float32)))
|
||||
self.assertEqual(result_shape_0, res[0].shape)
|
||||
self.assertEqual(result_shape_2, res[2].shape)
|
||||
return x + res[1]
|
||||
self.assertAllClose(2 + np.ones((1,), dtype=np.float32), fun(2.))
|
||||
|
||||
def test_call_empty_result_all_pytree(self):
|
||||
"""Call returning a tuple of empty arrays."""
|
||||
result_shape = (2, 0)
|
||||
def f_outside(_):
|
||||
return (np.ones(result_shape, dtype=np.float32),
|
||||
np.ones(result_shape, dtype=np.float32))
|
||||
def fun(x):
|
||||
res = hcb.call(f_outside, 1.,
|
||||
result_shape=(api.ShapeDtypeStruct(result_shape, np.float32),
|
||||
api.ShapeDtypeStruct(result_shape, np.float32)))
|
||||
return x + res[0] + res[1]
|
||||
self.assertAllClose(np.ones(result_shape, dtype=np.float32),
|
||||
fun(2.))
|
||||
|
||||
def test_call_no_result(self):
|
||||
def f_outside(arg):
|
||||
self.call_log_testing_stream(lambda x: None, arg,
|
||||
|
Loading…
x
Reference in New Issue
Block a user