From 342d62436f9e54bfc4934cce54005548ff9a3459 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 27 Apr 2021 10:29:39 -0700 Subject: [PATCH] [host_callback] Add support for pjit of host_callback. Currently, all XLA side-effect ops inside a sharded computation must have explicit sharding. This includes the outfeed and infeed used by host_callback. The implementation here uses AssignDevice sharding for both the outfeed and the infeed. This means that before outfeed, the devices will do an all_gather and the first device will make the outfeed. The host callback will receive a single outfeed with the entire array, and is supposed to return the entire array. This gets sent to the same device that issued to outfeed, which is responsible to send the respective slices to the other participating devices. PiperOrigin-RevId: 370711606 --- jax/BUILD | 1 + jax/experimental/host_callback.py | 126 ++++++++++--- jax/experimental/pjit.py | 14 +- jax/interpreters/pxla.py | 5 +- tests/host_callback_test.py | 286 +++++++++++++++++++++++------- tests/pjit_test.py | 71 ++++++++ 6 files changed, 404 insertions(+), 99 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index af138cd06..f93a6acb3 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -134,6 +134,7 @@ pytype_library( srcs_version = "PY3", deps = [ ":jax", + ":pjit", ], ) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index f524781dc..b7a9271d0 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -173,11 +173,14 @@ following function definition:: def power3(x): y = x * x + # Print both 'x' and 'x^2' _, y = id_print((x, y), what="x,x^2") # Must pack multiple arguments return y * x power3(3.) - # what: x,x^2 : [3., 9.] + # what: x,x^2 : (3., 9.) + +(You can see these examples tested in `host_callback_test.HostCallbackIdTapTest.test_tap_transforms`.) During JAX transformations the special parameter ``transforms`` is added to contain a list of transformation descriptors in the form @@ -189,15 +192,14 @@ batched dimensions (one entry per argument, ``None`` denotes an argument that was broadcast):: jax.vmap(power3)(np.arange(3.)) - # transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : [[0, 1, 2], [0, 1, - 4]] + # transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : ([0, 1, 2], [0, 1, + 4]) -For :func:`jax.jvp` there will be two callbacks, one with the values of -the primals and one with the tangents:: +For :func:`jax.jvp` there will be one callback with a pair, consisting of +the values of the primals and those of the tangents:: jax.jvp(power3, (3.,), (0.1,)) - # what: x,x^2: [3., 9.] - # transforms: ['jvp'] what: x,x^2 : [0.1, 0.6] + # transforms: ['jvp'] what: x,x^2 : ( (3., 9.), (0.1, 0.6) ) For :func:`jax.vjp` or :func:`jax.grad` there will be one callback with the values of the adjoints for the arguments. You may also see a callback with @@ -205,19 +207,48 @@ the values of the primals from the forward pass, if those values are needed for the backward pass:: jax.grad(power3)(3.) - # what=x,x^2: [3., 9.] # from forward pass, since y is used in backward pass - # transforms: ['jvp', 'transpose'] what: x,x^2 : [0., 3.] # from backward pass, adjoints of _, y + # what=x,x^2: (3., 9.) # from forward pass, since y is used in backward pass + # transforms: ['jvp', 'transpose'] what: x,x^2 : (0., 3.) # from backward pass, adjoints of _, y + +And here is an example of composed transforms. For vmap of grad, we see first +a callback with the vmap of the forward pass (with just the 'batch' transform), +and another callback with the vmap of the adjoints of the arguments. Note that +the first argument is replicated (`batch_dims` is None):: + + jax.vmap(jax.grad(power3))(np.array([2., 3.])) + # transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 + # ( [2. 3.] + # [4. 9.] ) + # transforms: ['jvp', 'transpose', ('batch', {'batch_dims': (None, 0)})] what: x,x^2 + # ( 0. + # [2. 3.] ) In presence of :func:`jax.pmap` the code will run on multiple devices and each device will tap its values independently. It may be helpful to use the ``tap_with_device`` option for :func:`id_print` or :func:`id_tap`, so that you see which device is sending which data:: - jax.pmap(power3, devices=jax.devices()[0:2])(np.array([3., 4.]) - # device=cpu:0 what=x,x^2: [3., 9.] # from the first device - # device=cpu:1 what=x,x^2: [4., 16.] # from the second device + jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.]) + # device=cpu:0 what=x,x^2: (3., 9.) # from the first device + # device=cpu:1 what=x,x^2: (4., 16.) # from the second device -See documentation for :func:`id_tap` and :func:`id_print`. + +When using the experimental :func:`pjit.pjit` the code will run on multiple +devices on different shards of the input. The current implementation of +host callbacks will ensure that a single device will collect and outfeed +the entire operand, in a single callback. The callback function is supposed +to return the entire array, which will then be sent in a single infeed to the +same device that issued the outfeed. This device is then responsible for +sending the required shards to the other devices:: + + with maps.mesh(jax.local_devices()[:2], ["d"]): + pjit.pjit(power3, in_axis_resources=(P("d"),), + out_axis_resources=(P("d"),))(np.array([3., 4.])) + + # device=TPU:0 what=x,x^2: ( [3., 4.], + # [9., 16.] ) + +See documentation for :func:`id_tap`, :func:`id_print`, and :func:`call`. For more usage example, see tests/host_callback_test.py. Low-level details and debugging @@ -354,7 +385,9 @@ from jax.config import config from jax import custom_derivatives from jax._src import dtypes from jax import lax +from jax.experimental import pjit from jax.lib import pytree +from jax.lib import xla_bridge as xb from jax.lib import xla_client from jax.lib import xla_extension from jax.interpreters import ad, xla, batching, masking, pxla @@ -780,7 +813,7 @@ outside_call_p.def_impl(_outside_call_impl) def _outside_call_translation_rule( comp: XlaComputationBuilder, *args_op: XlaOp, **params): # We expect the current tokens at the end, inserted by _rewrite_jaxpr. - assert params["has_token"] + assert params["has_token"] # type: ignore[key-error] current_token = args_op[-2] current_itoken = args_op[-1] # TODO: expose shape.is_token @@ -790,8 +823,9 @@ def _outside_call_translation_rule( "The last two arguments must be tokens") args_to_outfeed = args_op[:-2] - identity = params["identity"] - flat_results_aval = params["flat_results_aval"] if not identity else [] + identity = params["identity"] # type: ignore[key-error] + flat_results_aval = params["flat_results_aval"] if not identity else [ + ] # type: ignore[key-error] # Many platforms refuse to infeed empty arrays. We generate constants # instead. non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)), @@ -814,14 +848,36 @@ def _outside_call_translation_rule( ] if non_empty_flat_results_aval: after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token]) + # We shard the infeed as AssignedDevice(0). This must match the + # outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support + # this kind of sharding, we use a custom translation for infeed. + array_sharding_proto = xla_client.OpSharding() + array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL + array_sharding_proto.tile_assignment_dimensions = [1] + array_sharding_proto.tile_assignment_devices = [0] - results_and_token = xla.translations[lax.infeed_p](comp, after_outfeed_itoken, - shapes=non_empty_flat_results_aval, - partitions=None) + token_sharding_proto = xla_client.OpSharding() + token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED + infeed_sharding_proto = xb.tuple_sharding_proto( + [array_sharding_proto] * len(non_empty_flat_results_aval) + + [token_sharding_proto]) + + shape = tuple(shape.with_major_to_minor_layout_if_absent() + for x in non_empty_flat_results_aval + for shape in xla.aval_to_xla_shapes(x)) + + build_infeed = functools.partial(xops.InfeedWithToken, + after_outfeed_itoken, + xla_client.Shape.tuple_shape(shape)) + outs_and_token = xb.with_sharding_proto(comp, infeed_sharding_proto, + build_infeed) + outs = xops.GetTupleElement(outs_and_token, 0) + next_itoken = xops.GetTupleElement(outs_and_token, 1) + non_empty_results = [ + xops.GetTupleElement(outs, i) + for i in range(len(non_empty_flat_results_aval)) + ] expecting_infeed = True - next_itoken = xops.GetTupleElement(results_and_token, len(non_empty_flat_results_aval)) - non_empty_results = [xops.GetTupleElement(results_and_token, i) - for i in range(len(non_empty_flat_results_aval))] results = [ empty_results.pop(0) if _aval_is_empty(result_aval) else non_empty_results.pop(0) for result_aval in flat_results_aval] @@ -889,7 +945,10 @@ def _outside_call_run_callback( canonical_flat_results = tuple(util.safe_map(xla.canonicalize_dtype, actual_flat_results)) actual_flat_results_aval = _values_to_avals(canonical_flat_results) - logging.vlog(2, f"Outside call consumer {callback} result {res} : {flat_results_aval}. Sending to infeed.") + logging.vlog( + 2, + f"Outside call consumer {callback} result {res} : {flat_results_aval}. Sending to infeed for device {device}." + ) if not all(ea.strip_weak_type() == ra.strip_weak_type() for ea, ra in util.safe_zip(flat_results_aval, @@ -1318,13 +1377,26 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], eqns.append( core.new_jaxpr_eqn( eqn.invars + [input_token_var, input_itoken_var], - eqn.outvars + [output_token_var, output_itoken_var], - eqn.primitive, + eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive, dict( eqn.params, call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True), - ), - eqn.source_info)) + ), eqn.source_info)) + elif eqn.primitive is pjit.pjit_call_p: + call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) + eqns.append( + core.new_jaxpr_eqn( + eqn.invars + [input_token_var, input_itoken_var], + eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive, + dict( + eqn.params, + call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True), + donated_invars=eqn.params["donated_invars"] + (False, False), + in_axis_resources=(eqn.params["in_axis_resources"] + + (pjit.REPLICATED, pjit.REPLICATED)), + out_axis_resources=(eqn.params["out_axis_resources"] + + (pjit.REPLICATED, pjit.REPLICATED)), + ), eqn.source_info)) else: raise NotImplementedError(f"outfeed rewrite {eqn.primitive}") diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 9e1f03344..4c9795274 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -265,9 +265,11 @@ def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name, out_nodes = xla.jaxpr_subcomp( subc, call_jaxpr, backend, axis_env, (), extend_name_stack(name_stack, wrap_name(name, "pjit")), *args) - out_nodes = [xb.set_sharding_proto(subc, out, - get_sharding_proto(c, n, axis_resources, mesh)) - for out, axis_resources in safe_zip(out_nodes, out_axis_resources)] + out_nodes = [ + xb.set_sharding_proto(subc, out, + get_sharding_proto(subc, out, axis_resources, mesh)) + for out, axis_resources in safe_zip(out_nodes, out_axis_resources) + ] subc = subc.build(xops.Tuple(subc, out_nodes)) return xops.Call(c, subc, list(in_nodes)) @@ -405,7 +407,11 @@ def get_array_mapping(axis_resources: ParsedPartitionSpec) -> pxla.ArrayMapping: def get_sharding_proto(c, xla_op, axis_resources, mesh): xla_shape = c.GetShape(xla_op) - aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.element_type()) + if xla_shape.is_token(): + aval = core.abstract_token + assert axis_resources is REPLICATED + else: + aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.element_type()) array_mapping = get_array_mapping(axis_resources) sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)( aval, array_mapping) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 6a4c2dd4f..e8bbc35f7 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -1585,8 +1585,11 @@ def mesh_sharding_specs(axis_sizes, axis_names): mesh_axis_pos = {name: i for i, name in enumerate(axis_names)} # NOTE: This takes in the non-sharded avals! def mk_sharding_spec(aval, aval_axes): - sharding = [_UNSHARDED_INSTANCE] * len(aval.shape) mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()] + if aval is core.abstract_token: + assert not aval_axes + return ShardingSpec([], mesh_mapping) + sharding = [_UNSHARDED_INSTANCE] * len(aval.shape) next_sharded_axis = 0 aval_shape = list(aval.shape) # NOTE: sorted is stable, which is important when multiple resources diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 6bc010afb..c73c014e8 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -30,6 +30,9 @@ from jax._src import api from jax.config import config from jax import dtypes from jax.experimental import host_callback as hcb +from jax.experimental import PartitionSpec as P +from jax.experimental import maps +from jax.experimental import pjit from jax import lax from jax import numpy as jnp from jax import test_util as jtu @@ -65,6 +68,7 @@ class _TestingOutputStream(object): m = re.match(r'.*device: (\S+)', s) if m: by_device.append((m.group(1), [])) + assert by_device, f"output does not include 'device:': {self._output}" by_device[-1][1].append(s) sorted_by_device = sorted(by_device, key=lambda x: x[0]) @@ -99,7 +103,7 @@ def maybe_print(do_print: bool, arg, what: str, tap_with_device: Optional[bool] return arg -def devices(): +def local_devices(): # Tests require using not more than 2 devices. return api.local_devices()[:2] @@ -162,6 +166,37 @@ def helper_print_optimized_hlo(fun, *args): backend.compile(c).hlo_modules()[0].to_string())) +def helper_log_ir(name, + f_jax, + *args, + num_partitions=None, + strip_metadata=False): + print(f"Jaxpr[{name}]: {jax.make_jaxpr(f_jax)(*args)}") + jax_comp = jax.xla_computation(f_jax)(*args) + print(f"HLO[{name}]: {jax_comp.as_hlo_text()}") + + backend = jax.lib.xla_bridge.get_backend() + if num_partitions is not None: + num_replicas = 1 + device_assignment = np.arange(num_partitions * num_replicas) + device_assignment = np.reshape(device_assignment, (-1, num_partitions)) + use_spmd_partitioning = num_partitions > 1 + compile_options = jax.lib.xla_bridge.get_compile_options( + num_replicas=num_replicas, + num_partitions=num_partitions, + device_assignment=device_assignment, + use_spmd_partitioning=use_spmd_partitioning, + ) + else: + compile_options = None + jax_optimized_hlo = backend.compile( + jax_comp, compile_options).hlo_modules()[0].to_string() + if strip_metadata: + jax_optimized_hlo = re.sub(r", metadata.*", "", jax_optimized_hlo) + print(f"Optimized HLO[{name}] for " + f"platform {backend.platform}: {jax_optimized_hlo}") + + prev_xla_flags = None @@ -189,13 +224,13 @@ def assertMultiDeviceOutputEqual(tst: jtu.JaxTestCase, device_under_test is not a CPU, then we change the names """ expected = expected_2CPUs - if len(devices()) == 1: + if len(local_devices()) == 1: start_device_1 = expected.find('device: cpu:1') if start_device_1 >= 0: expected = expected[0:start_device_1] def replace_device_name(m) -> str: - return str(devices()[int(m.group(1))]) + return str(local_devices()[int(m.group(1))]) expected = re.sub(r'cpu:(\d+)', replace_device_name, expected) what = testing_stream.output_sorted_by_device @@ -474,19 +509,21 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): def test_tap_jit_devices(self): """Running on multiple devices.""" - logging.info(f"{self._testMethodName}: has devices {devices()}") + logging.info(f"{self._testMethodName}: has devices {local_devices()}") def func(x, device_id): x1 = hcb.id_print(x, dev=str(device_id), output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream) return x2 - for d in devices(): + for d in local_devices(): self.assertEqual(112, api.jit(func, device=d, static_argnums=1)(111, d.id)) hcb.barrier_wait() logging.info(f"{self._testMethodName}: found output {testing_stream.output}") - self.assertEqual(len(devices()), len(re.findall(r"111", testing_stream.output))) - self.assertEqual(len(devices()), len(re.findall(r"112", testing_stream.output))) + self.assertEqual( + len(local_devices()), len(re.findall(r"111", testing_stream.output))) + self.assertEqual( + len(local_devices()), len(re.findall(r"112", testing_stream.output))) testing_stream.reset() @parameterized.named_parameters( @@ -1145,81 +1182,99 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): testing_stream.reset() def test_tap_transforms(self): - def power(x, n): - x, n = hcb.id_print((x, n), output_stream=testing_stream) - return x * x * n * x - def f(x, n): - return x * power(x + 1., n) + def power3(x): + y = x * x + # Print both 'x' and 'x^2'. Must pack as a tuple. + _, y = hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) + return y * x - x = 3. - print("impl = ", f(x, 2.)) + print(f"impl = {power3(3.)}") hcb.barrier_wait() expected = """ - ( 4. - 2. )""" + what: x,x^2 + ( 3. + 9. )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() - print("jvp = ", api.jvp(lambda x: f(x, 2.), (x,), (1.,))) + print(f"vmap = {jax.vmap(power3)(np.arange(3.))}") hcb.barrier_wait() expected = """ - transforms: ['jvp'] - ( ( 4. - 2. ) - ( 1. - 0. ) )""" + transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 + ( [0. 1. 2.] + [0. 1. 4.] )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() - print("grad = ", api.grad(f)(x, 2.)) + print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}") hcb.barrier_wait() expected = """ - ( 4. - 2. ) - transforms: ['jvp', 'transpose'] - ( 288. - 192. )""" + transforms: ['jvp'] what: x,x^2 + ( ( 3. + 9. ) + ( 0.1 + 0.6 ) )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() - xv = np.array([3., 4.]) - print("vmap o grad = ", api.vmap(api.grad(f))(xv, np.array([2., 3.]))) + print(f"grad = {jax.grad(power3)(3.)}") hcb.barrier_wait() expected = """ - transforms: [('batch', {'batch_dims': (0, 0)})] - ( [4. 5.] - [2. 3.] ) - transforms: ['jvp', 'transpose', ('batch', {'batch_dims': (0, 0)})] - ( [288. 900.] - [192. 500.] )""" + what: x,x^2 + ( 3. + 9. ) + transforms: ['jvp', 'transpose'] what: x,x^2 + ( 0. + 3. )""" + self.assertMultiLineStrippedEqual(expected, testing_stream.output) + testing_stream.reset() + + print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}") + hcb.barrier_wait() + expected = """ + transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 + ( [2. 3.] + [4. 9.] ) + transforms: ['jvp', 'transpose', ('batch', {'batch_dims': (None, 0)})] what: x,x^2 + ( 0. + [2. 3.] )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() def test_tap_pmap(self): - xv = jnp.arange(len(devices()), dtype=jnp.int32) + if len(local_devices()) < 2: + raise SkipTest("test requires at least 2 devices") - def fun1(x, do_print=False): # x: i32 - return maybe_print(do_print, x * 2, "x * 2", tap_with_device=True) + def power3(x): + y = x * x + # Print both 'x' and 'x^2'. Must pack as a tuple. + _, y = hcb.id_print((x, y), + what="x,x^2", + output_stream=testing_stream, + tap_with_device=True) + return y * x - pmap_fun1 = api.pmap(partial(fun1, do_print=True), devices=devices()) - res = pmap_fun1(xv) + pmap_power3 = api.pmap(power3, devices=local_devices()) + xv = np.array([3, 4], dtype=np.int32) + res = pmap_power3(xv) hcb.barrier_wait() - expected_res = api.pmap(partial(fun1, do_print=False), - devices=devices())(xv) - self.assertAllClose(expected_res, res, check_dtypes=False) + self.assertAllClose(xv * xv * xv, res, check_dtypes=False) # Assertion text is for 2 devices (also works for 1 device) - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 what: x * 2 - 0 - device: cpu:1 what: x * 2 - 2""") + assertMultiDeviceOutputEqual( + self, """ + device: cpu:0 what: x,x^2 + ( 3 + 9 ) + device: cpu:1 what: x,x^2 + ( 4 + 16 )""") testing_stream.reset() def test_tap_pmap_vmap(self): # A matrix M[ij] = i * 10 + j - nr_devices = len(devices()) + nr_devices = len(local_devices()) shape = (nr_devices, 3) matrix = np.fromfunction(lambda i, j: 10. * i + j, shape, dtype=np.int32) @@ -1227,13 +1282,14 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): def fun1(x, do_print=False): # x: i32 return maybe_print(do_print, x * 2, "x * 2", tap_with_device=True) - pmap_vmap_fun1 = api.pmap(api.vmap(partial(fun1, do_print=True)), - devices=devices()) + pmap_vmap_fun1 = api.pmap( + api.vmap(partial(fun1, do_print=True)), devices=local_devices()) res = pmap_vmap_fun1(matrix) hcb.barrier_wait() - expected_res = api.pmap(api.vmap(partial(fun1, do_print=False)), - devices=devices())(matrix) + expected_res = api.pmap( + api.vmap(partial(fun1, do_print=False)), devices=local_devices())( + matrix) self.assertAllClose(expected_res, res, check_dtypes=False) # Assertion text is for 2 devices (also works for 1 device) assertMultiDeviceOutputEqual(self, """ @@ -1245,7 +1301,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): def test_tap_pmap_pmap_vmap(self): # A matrix M[ijk] = i * 100 + j * 10 + k - nr_devices = len(devices()) + nr_devices = len(local_devices()) if nr_devices % 2 != 0: raise SkipTest("test works only on even number of devices") @@ -1257,12 +1313,15 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): y = maybe_print(do_print, x * 2., "x * 2", tap_with_device=True) return y ** 2 - pmap_fun1 = api.pmap(api.pmap(api.vmap(partial(fun1, do_print=True))), - devices=devices()) + pmap_fun1 = api.pmap( + api.pmap(api.vmap(partial(fun1, do_print=True))), + devices=local_devices()) res = pmap_fun1(matrix) hcb.barrier_wait() - expected_res = api.pmap(api.pmap(api.vmap(partial(fun1, do_print=False))), - devices=devices())(matrix) + expected_res = api.pmap( + api.pmap(api.vmap(partial(fun1, do_print=False))), + devices=local_devices())( + matrix) self.assertAllClose(expected_res, res, check_dtypes=False) # Assertion text is for 2 devices (also works for 1 device) assertMultiDeviceOutputEqual(self, """ @@ -1276,7 +1335,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): def test_tap_pmap_pmap_extra(self): """pmap of a pmap surrounded by extra code.""" # A matrix M[ij] = i * 10 + j - nr_devices = len(devices()) + nr_devices = len(local_devices()) if nr_devices != 2: raise SkipTest("test works only on 2 devices") shape = (2, 1, 3) @@ -1312,7 +1371,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): def test_tap_jvp_pmap_vmap(self): # A matrix M[ijk] = i * 100 + j * 10 * k - nr_devices = len(devices()) + nr_devices = len(local_devices()) shape = (nr_devices, 2, 3) matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape, dtype=np.float32) @@ -1343,7 +1402,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): def test_tap_vmap_pmap(self): # A matrix M[ijk] = i * 100 + j * 10 * k - nr_devices = len(devices()) + nr_devices = len(local_devices()) shape = (2, nr_devices, 3) matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape, dtype=np.float32) @@ -1371,7 +1430,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): def test_tap_jit_pmap_extra(self): """jit of a pmap surrounded by extra code.""" # A matrix M[ij] = i * 10 + j - nr_devices = len(devices()) + nr_devices = len(local_devices()) assert nr_devices in (1, 2) shape = (nr_devices, 3) matrix = np.fromfunction(lambda i, j: 10. * i + j, shape, @@ -1387,7 +1446,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): res = api.jit(partial(fun, do_print=True))(matrix) self.assertAllClose(fun(matrix, do_print=False), res, check_dtypes=False) hcb.barrier_wait() - if len(devices()) == 2: + if len(local_devices()) == 2: assertMultiDeviceOutputEqual(self, """ device: cpu:0 what: before [[ 1.00 2.00 3.00] @@ -1406,7 +1465,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): [[ 3.00 5.00 7.00] [23.00 25.00 27.00]]""") else: - assert len(devices()) == 1 + assert len(local_devices()) == 1 assertMultiDeviceOutputEqual(self, """ device: cpu:0 what: before [[1.00 2.00 3.00]] @@ -1420,7 +1479,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): def test_tap_cond_pmap(self): raise SkipTest("cond of pmap does not work in JAX. Issue #5178.") # A matrix M[ij] = i * 10 + j - nr_devices = len(devices()) + nr_devices = len(local_devices()) shape = (nr_devices, 3) matrix = np.fromfunction(lambda i, j: 10. * i + j, shape, dtype=np.float32) @@ -1439,6 +1498,52 @@ class HostCallbackIdTapTest(jtu.JaxTestCase): TBD""", testing_stream.output) testing_stream.reset() + @jtu.skip_on_devices("cpu", "gpu") + # TODO(necula): file XLA:GPU bug for the 'Sharding' CustomCall + def test_tap_pjit(self): + devices = np.array(local_devices()) + nr_devices = len(devices) + if nr_devices < 2: + raise SkipTest("test requires at least 2 devices") + + print(f"test_tap_pjit is running on devices {devices}.") + # x: i32[D, 3] = [[0, 1, 2], [10, 11, 12], ...] + # y: i32[3, 4] + x = jnp.arange(100, dtype=jnp.int32).reshape((10, 10))[:nr_devices, :3] + y = jnp.ones((3, 4), np.int32) + + @partial(jax.named_call, name="fun1") # for xprof debugging + def fun1(x, do_print=False): + z = jnp.dot(x, y) + return maybe_print(do_print, z, "z", tap_with_device=True) + + res0 = fun1(x, do_print=False) + pjit_fun1 = pjit.pjit( + partial(fun1, do_print=True), + in_axis_resources=(P("d"),), + out_axis_resources=P("d")) + + with maps.mesh(devices, ["d"]): + # Print the internal IR + helper_log_ir( + f"{self._testMethodName}.pjit", + pjit_fun1, + x, + num_partitions=nr_devices) + res = pjit_fun1(x) + + self.assertAllClose(res0, res) + hcb.barrier_wait("before check") + + # Assertion text is for 2 devices (also works for 1 device) + # Note that a single call is made. + assertMultiDeviceOutputEqual( + self, """ + device: cpu:0 what: z + [[ 3 3 3 3] + [33 33 33 33]]""") + testing_stream.reset() + def test_tap_tap_scan_custom_jvp(self): """custom JVP, inside scan. This exercises the custom_jvp_call_jaxpr primitives.""" @@ -1933,7 +2038,7 @@ class HostCallbackCallTest(jtu.JaxTestCase): result_shape=x, call_with_device=True) - xv = jnp.arange(len(devices()), dtype=jnp.int32) + xv = jnp.arange(len(local_devices()), dtype=jnp.int32) res = api.pmap(fun)(xv) self.assertAllClose(api.pmap(lambda x: x * 6)(xv), res) # Assertion text is for 2 devices (also works for 1 device) @@ -1954,6 +2059,53 @@ class HostCallbackCallTest(jtu.JaxTestCase): "batching rules are implemented only for id_tap, not for call"): api.vmap(fun)(np.ones((2, 3))) + @jtu.skip_on_devices("cpu", "gpu") + # TODO(necula): file XLA:GPU bug for the 'Sharding' CustomCall + def test_call_pjit(self): + devices = np.array(local_devices()) + nr_devices = len(devices) + if nr_devices < 2: + raise SkipTest("test requires at least 2 devices") + + print(f"test_call_pjit is running on devices {devices}.") + # x: i32[D, 3] = [[0, 1, 2], [10, 11, 12], ...] + # y: i32[3, 4] + x = jnp.arange(100, dtype=jnp.int32).reshape((10, 10))[:nr_devices, :3] + y = jnp.ones((3, 4), np.int32) + + def callback_x5_func(x, device=None): + testing_stream.write(f"device: {device}\n Called with {x}") + return x * np.array(5, np.int32) + + def fun(x): + xy = jnp.dot(x, y) + return hcb.call( + callback_x5_func, xy, result_shape=xy, call_with_device=True) + + pjit_fun = pjit.pjit( + fun, in_axis_resources=(P("d"),), out_axis_resources=P("d")) + with maps.mesh(devices, ["d"]): + # Print the internal IR + helper_log_ir( + f"{self._testMethodName}.pjit", + pjit_fun, + x, + num_partitions=nr_devices) + + res = pjit_fun(x) + + expected_res = jnp.dot(x, y) * np.array(5, np.int32) + self.assertAllClose(expected_res, res, check_dtypes=False) + + hcb.barrier_wait("before assertion") + # Assertion text is for 2 devices (also works for 1 device) + assertMultiDeviceOutputEqual( + self, """ + device: cpu:0 + Called with [[ 3 3 3 3] + [33 33 33 33]]""") + testing_stream.reset() + def test_call_error_bad_result_shape(self): with self.assertRaisesRegex( ValueError, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0efc33061..b516d1c0d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -14,6 +14,7 @@ from contextlib import contextmanager from functools import partial +import logging from typing import Generator, List, Tuple from unittest import SkipTest @@ -25,6 +26,7 @@ import jax import jax.numpy as jnp from jax import test_util as jtu from jax.errors import JAXTypeError +from jax import lax # TODO(skye): do we still wanna call this PartitionSpec? from jax.experimental import PartitionSpec as P from jax.experimental.maps import xmap, mesh @@ -278,6 +280,75 @@ class PJitTest(jtu.BufferDonationTestCase): # TODO(skye): add more unit tests once API is more finalized + def testInfeed(self): + devices = np.array(jax.local_devices()) + nr_devices = len(devices) + shape = (nr_devices * 3, nr_devices * 5) + + def f_for_jit(x): + token = lax.create_token(x) + (y,), token = lax.infeed( + token, shape=(jax.ShapedArray(x.shape, np.float32),)) + (z,), token = lax.infeed( + token, shape=(jax.ShapedArray(x.shape, np.float32),)) + (w,), token = lax.infeed( + token, shape=(jax.ShapedArray(x.shape, np.float32),)) + + return x + y + z + w + + x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + # JIT + logging.info('Making jit call') + res0 = jax.jit(f_for_jit)(x) + y = x * 2. + z = x * 3. + w = x * 4. + + logging.info('Transfering to infeed for the jit call') + d = devices[0] + d.transfer_to_infeed((y,)) + d.transfer_to_infeed((z,)) + d.transfer_to_infeed((w,)) + self.assertAllClose(res0, x + y + z + w, check_dtypes=True) + + # PJIT + def f_for_pjit(x): + token = lax.create_token(x) + # A replicated infeed + (y,), token = lax.infeed( + token, + shape=(jax.ShapedArray(x.shape, np.float32),), + partitions=(None,)) + # An infeed sharded on first axis + (z,), token = lax.infeed( + token, + shape=(jax.ShapedArray(x.shape, np.float32),), + partitions=(P(nr_devices, 1),)) + # An infeed sharded on second axis + (w,), token = lax.infeed( + token, + shape=(jax.ShapedArray(x.shape, np.float32),), + partitions=(P(1, nr_devices),)) + return x + y + z + w + + with mesh(devices, ['d']): + logging.info('Making pjit call') + res = pjit( + f_for_pjit, in_axis_resources=(P('d'),), out_axis_resources=P('d'))( + x) + + logging.info('Transfering to infeed for the pjit call') + for didx, d in enumerate(devices): + # Transfer the whole array to all devices for replicated. + d.transfer_to_infeed((y,)) + # For sharded infeed, transfer only the needed slices to each device. + d.transfer_to_infeed((z[3 * didx:3 * didx + 3, :])) + d.transfer_to_infeed((w[:, 5 * didx:5 * didx + 5],)) + + self.assertAllClose(res0, res, check_dtypes=True) + + @curry def check_1d_2d_mesh(f, set_mesh): return parameterized.named_parameters(