[host_callback] Add support for pjit of host_callback.

Currently, all XLA side-effect ops inside a sharded computation must have
explicit sharding. This includes the outfeed and infeed used by host_callback.

The implementation here uses AssignDevice sharding for both the outfeed and the
infeed. This means that before outfeed, the devices will do an all_gather and
the first device will make the outfeed. The host callback will receive a single
outfeed with the entire array, and is supposed to return the entire array. This
gets sent to the same device that issued to outfeed, which is responsible to
send the respective slices to the other participating devices.

PiperOrigin-RevId: 370711606
This commit is contained in:
George Necula 2021-04-27 10:29:39 -07:00 committed by jax authors
parent aa9d350aaf
commit 342d62436f
6 changed files with 404 additions and 99 deletions

View File

@ -134,6 +134,7 @@ pytype_library(
srcs_version = "PY3",
deps = [
":jax",
":pjit",
],
)

View File

@ -173,11 +173,14 @@ following function definition::
def power3(x):
y = x * x
# Print both 'x' and 'x^2'
_, y = id_print((x, y), what="x,x^2") # Must pack multiple arguments
return y * x
power3(3.)
# what: x,x^2 : [3., 9.]
# what: x,x^2 : (3., 9.)
(You can see these examples tested in `host_callback_test.HostCallbackIdTapTest.test_tap_transforms`.)
During JAX transformations the special parameter ``transforms`` is added to
contain a list of transformation descriptors in the form
@ -189,15 +192,14 @@ batched dimensions (one entry per argument, ``None`` denotes an argument that
was broadcast)::
jax.vmap(power3)(np.arange(3.))
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : [[0, 1, 2], [0, 1,
4]]
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : ([0, 1, 2], [0, 1,
4])
For :func:`jax.jvp` there will be two callbacks, one with the values of
the primals and one with the tangents::
For :func:`jax.jvp` there will be one callback with a pair, consisting of
the values of the primals and those of the tangents::
jax.jvp(power3, (3.,), (0.1,))
# what: x,x^2: [3., 9.]
# transforms: ['jvp'] what: x,x^2 : [0.1, 0.6]
# transforms: ['jvp'] what: x,x^2 : ( (3., 9.), (0.1, 0.6) )
For :func:`jax.vjp` or :func:`jax.grad` there will be one callback with the
values of the adjoints for the arguments. You may also see a callback with
@ -205,19 +207,48 @@ the values of the primals from the forward pass, if those values are needed for
the backward pass::
jax.grad(power3)(3.)
# what=x,x^2: [3., 9.] # from forward pass, since y is used in backward pass
# transforms: ['jvp', 'transpose'] what: x,x^2 : [0., 3.] # from backward pass, adjoints of _, y
# what=x,x^2: (3., 9.) # from forward pass, since y is used in backward pass
# transforms: ['jvp', 'transpose'] what: x,x^2 : (0., 3.) # from backward pass, adjoints of _, y
And here is an example of composed transforms. For vmap of grad, we see first
a callback with the vmap of the forward pass (with just the 'batch' transform),
and another callback with the vmap of the adjoints of the arguments. Note that
the first argument is replicated (`batch_dims` is None)::
jax.vmap(jax.grad(power3))(np.array([2., 3.]))
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
# ( [2. 3.]
# [4. 9.] )
# transforms: ['jvp', 'transpose', ('batch', {'batch_dims': (None, 0)})] what: x,x^2
# ( 0.
# [2. 3.] )
In presence of :func:`jax.pmap` the code will run on multiple devices and
each device will tap its values independently.
It may be helpful to use the ``tap_with_device`` option for :func:`id_print`
or :func:`id_tap`, so that you see which device is sending which data::
jax.pmap(power3, devices=jax.devices()[0:2])(np.array([3., 4.])
# device=cpu:0 what=x,x^2: [3., 9.] # from the first device
# device=cpu:1 what=x,x^2: [4., 16.] # from the second device
jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.])
# device=cpu:0 what=x,x^2: (3., 9.) # from the first device
# device=cpu:1 what=x,x^2: (4., 16.) # from the second device
See documentation for :func:`id_tap` and :func:`id_print`.
When using the experimental :func:`pjit.pjit` the code will run on multiple
devices on different shards of the input. The current implementation of
host callbacks will ensure that a single device will collect and outfeed
the entire operand, in a single callback. The callback function is supposed
to return the entire array, which will then be sent in a single infeed to the
same device that issued the outfeed. This device is then responsible for
sending the required shards to the other devices::
with maps.mesh(jax.local_devices()[:2], ["d"]):
pjit.pjit(power3, in_axis_resources=(P("d"),),
out_axis_resources=(P("d"),))(np.array([3., 4.]))
# device=TPU:0 what=x,x^2: ( [3., 4.],
# [9., 16.] )
See documentation for :func:`id_tap`, :func:`id_print`, and :func:`call`.
For more usage example, see tests/host_callback_test.py.
Low-level details and debugging
@ -354,7 +385,9 @@ from jax.config import config
from jax import custom_derivatives
from jax._src import dtypes
from jax import lax
from jax.experimental import pjit
from jax.lib import pytree
from jax.lib import xla_bridge as xb
from jax.lib import xla_client
from jax.lib import xla_extension
from jax.interpreters import ad, xla, batching, masking, pxla
@ -780,7 +813,7 @@ outside_call_p.def_impl(_outside_call_impl)
def _outside_call_translation_rule(
comp: XlaComputationBuilder, *args_op: XlaOp, **params):
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
assert params["has_token"]
assert params["has_token"] # type: ignore[key-error]
current_token = args_op[-2]
current_itoken = args_op[-1]
# TODO: expose shape.is_token
@ -790,8 +823,9 @@ def _outside_call_translation_rule(
"The last two arguments must be tokens")
args_to_outfeed = args_op[:-2]
identity = params["identity"]
flat_results_aval = params["flat_results_aval"] if not identity else []
identity = params["identity"] # type: ignore[key-error]
flat_results_aval = params["flat_results_aval"] if not identity else [
] # type: ignore[key-error]
# Many platforms refuse to infeed empty arrays. We generate constants
# instead.
non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)),
@ -814,14 +848,36 @@ def _outside_call_translation_rule(
]
if non_empty_flat_results_aval:
after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token])
# We shard the infeed as AssignedDevice(0). This must match the
# outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support
# this kind of sharding, we use a custom translation for infeed.
array_sharding_proto = xla_client.OpSharding()
array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL
array_sharding_proto.tile_assignment_dimensions = [1]
array_sharding_proto.tile_assignment_devices = [0]
results_and_token = xla.translations[lax.infeed_p](comp, after_outfeed_itoken,
shapes=non_empty_flat_results_aval,
partitions=None)
token_sharding_proto = xla_client.OpSharding()
token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED
infeed_sharding_proto = xb.tuple_sharding_proto(
[array_sharding_proto] * len(non_empty_flat_results_aval) +
[token_sharding_proto])
shape = tuple(shape.with_major_to_minor_layout_if_absent()
for x in non_empty_flat_results_aval
for shape in xla.aval_to_xla_shapes(x))
build_infeed = functools.partial(xops.InfeedWithToken,
after_outfeed_itoken,
xla_client.Shape.tuple_shape(shape))
outs_and_token = xb.with_sharding_proto(comp, infeed_sharding_proto,
build_infeed)
outs = xops.GetTupleElement(outs_and_token, 0)
next_itoken = xops.GetTupleElement(outs_and_token, 1)
non_empty_results = [
xops.GetTupleElement(outs, i)
for i in range(len(non_empty_flat_results_aval))
]
expecting_infeed = True
next_itoken = xops.GetTupleElement(results_and_token, len(non_empty_flat_results_aval))
non_empty_results = [xops.GetTupleElement(results_and_token, i)
for i in range(len(non_empty_flat_results_aval))]
results = [
empty_results.pop(0) if _aval_is_empty(result_aval) else non_empty_results.pop(0)
for result_aval in flat_results_aval]
@ -889,7 +945,10 @@ def _outside_call_run_callback(
canonical_flat_results = tuple(util.safe_map(xla.canonicalize_dtype, actual_flat_results))
actual_flat_results_aval = _values_to_avals(canonical_flat_results)
logging.vlog(2, f"Outside call consumer {callback} result {res} : {flat_results_aval}. Sending to infeed.")
logging.vlog(
2,
f"Outside call consumer {callback} result {res} : {flat_results_aval}. Sending to infeed for device {device}."
)
if not all(ea.strip_weak_type() == ra.strip_weak_type()
for ea, ra in util.safe_zip(flat_results_aval,
@ -1318,13 +1377,26 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
eqns.append(
core.new_jaxpr_eqn(
eqn.invars + [input_token_var, input_itoken_var],
eqn.outvars + [output_token_var, output_itoken_var],
eqn.primitive,
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
dict(
eqn.params,
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
),
eqn.source_info))
), eqn.source_info))
elif eqn.primitive is pjit.pjit_call_p:
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
eqns.append(
core.new_jaxpr_eqn(
eqn.invars + [input_token_var, input_itoken_var],
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
dict(
eqn.params,
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
donated_invars=eqn.params["donated_invars"] + (False, False),
in_axis_resources=(eqn.params["in_axis_resources"] +
(pjit.REPLICATED, pjit.REPLICATED)),
out_axis_resources=(eqn.params["out_axis_resources"] +
(pjit.REPLICATED, pjit.REPLICATED)),
), eqn.source_info))
else:
raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")

View File

@ -265,9 +265,11 @@ def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
out_nodes = xla.jaxpr_subcomp(
subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, "pjit")), *args)
out_nodes = [xb.set_sharding_proto(subc, out,
get_sharding_proto(c, n, axis_resources, mesh))
for out, axis_resources in safe_zip(out_nodes, out_axis_resources)]
out_nodes = [
xb.set_sharding_proto(subc, out,
get_sharding_proto(subc, out, axis_resources, mesh))
for out, axis_resources in safe_zip(out_nodes, out_axis_resources)
]
subc = subc.build(xops.Tuple(subc, out_nodes))
return xops.Call(c, subc, list(in_nodes))
@ -405,7 +407,11 @@ def get_array_mapping(axis_resources: ParsedPartitionSpec) -> pxla.ArrayMapping:
def get_sharding_proto(c, xla_op, axis_resources, mesh):
xla_shape = c.GetShape(xla_op)
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.element_type())
if xla_shape.is_token():
aval = core.abstract_token
assert axis_resources is REPLICATED
else:
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.element_type())
array_mapping = get_array_mapping(axis_resources)
sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(
aval, array_mapping)

View File

@ -1585,8 +1585,11 @@ def mesh_sharding_specs(axis_sizes, axis_names):
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
# NOTE: This takes in the non-sharded avals!
def mk_sharding_spec(aval, aval_axes):
sharding = [_UNSHARDED_INSTANCE] * len(aval.shape)
mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()]
if aval is core.abstract_token:
assert not aval_axes
return ShardingSpec([], mesh_mapping)
sharding = [_UNSHARDED_INSTANCE] * len(aval.shape)
next_sharded_axis = 0
aval_shape = list(aval.shape)
# NOTE: sorted is stable, which is important when multiple resources

View File

@ -30,6 +30,9 @@ from jax._src import api
from jax.config import config
from jax import dtypes
from jax.experimental import host_callback as hcb
from jax.experimental import PartitionSpec as P
from jax.experimental import maps
from jax.experimental import pjit
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
@ -65,6 +68,7 @@ class _TestingOutputStream(object):
m = re.match(r'.*device: (\S+)', s)
if m:
by_device.append((m.group(1), []))
assert by_device, f"output does not include 'device:': {self._output}"
by_device[-1][1].append(s)
sorted_by_device = sorted(by_device, key=lambda x: x[0])
@ -99,7 +103,7 @@ def maybe_print(do_print: bool, arg, what: str, tap_with_device: Optional[bool]
return arg
def devices():
def local_devices():
# Tests require using not more than 2 devices.
return api.local_devices()[:2]
@ -162,6 +166,37 @@ def helper_print_optimized_hlo(fun, *args):
backend.compile(c).hlo_modules()[0].to_string()))
def helper_log_ir(name,
f_jax,
*args,
num_partitions=None,
strip_metadata=False):
print(f"Jaxpr[{name}]: {jax.make_jaxpr(f_jax)(*args)}")
jax_comp = jax.xla_computation(f_jax)(*args)
print(f"HLO[{name}]: {jax_comp.as_hlo_text()}")
backend = jax.lib.xla_bridge.get_backend()
if num_partitions is not None:
num_replicas = 1
device_assignment = np.arange(num_partitions * num_replicas)
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
use_spmd_partitioning = num_partitions > 1
compile_options = jax.lib.xla_bridge.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
device_assignment=device_assignment,
use_spmd_partitioning=use_spmd_partitioning,
)
else:
compile_options = None
jax_optimized_hlo = backend.compile(
jax_comp, compile_options).hlo_modules()[0].to_string()
if strip_metadata:
jax_optimized_hlo = re.sub(r", metadata.*", "", jax_optimized_hlo)
print(f"Optimized HLO[{name}] for "
f"platform {backend.platform}: {jax_optimized_hlo}")
prev_xla_flags = None
@ -189,13 +224,13 @@ def assertMultiDeviceOutputEqual(tst: jtu.JaxTestCase,
device_under_test is not a CPU, then we change the names
"""
expected = expected_2CPUs
if len(devices()) == 1:
if len(local_devices()) == 1:
start_device_1 = expected.find('device: cpu:1')
if start_device_1 >= 0:
expected = expected[0:start_device_1]
def replace_device_name(m) -> str:
return str(devices()[int(m.group(1))])
return str(local_devices()[int(m.group(1))])
expected = re.sub(r'cpu:(\d+)', replace_device_name, expected)
what = testing_stream.output_sorted_by_device
@ -474,19 +509,21 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
def test_tap_jit_devices(self):
"""Running on multiple devices."""
logging.info(f"{self._testMethodName}: has devices {devices()}")
logging.info(f"{self._testMethodName}: has devices {local_devices()}")
def func(x, device_id):
x1 = hcb.id_print(x, dev=str(device_id), output_stream=testing_stream)
x2 = hcb.id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream)
return x2
for d in devices():
for d in local_devices():
self.assertEqual(112, api.jit(func, device=d, static_argnums=1)(111, d.id))
hcb.barrier_wait()
logging.info(f"{self._testMethodName}: found output {testing_stream.output}")
self.assertEqual(len(devices()), len(re.findall(r"111", testing_stream.output)))
self.assertEqual(len(devices()), len(re.findall(r"112", testing_stream.output)))
self.assertEqual(
len(local_devices()), len(re.findall(r"111", testing_stream.output)))
self.assertEqual(
len(local_devices()), len(re.findall(r"112", testing_stream.output)))
testing_stream.reset()
@parameterized.named_parameters(
@ -1145,81 +1182,99 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
testing_stream.reset()
def test_tap_transforms(self):
def power(x, n):
x, n = hcb.id_print((x, n), output_stream=testing_stream)
return x * x * n * x
def f(x, n):
return x * power(x + 1., n)
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
_, y = hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
return y * x
x = 3.
print("impl = ", f(x, 2.))
print(f"impl = {power3(3.)}")
hcb.barrier_wait()
expected = """
( 4.
2. )"""
what: x,x^2
( 3.
9. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print("jvp = ", api.jvp(lambda x: f(x, 2.), (x,), (1.,)))
print(f"vmap = {jax.vmap(power3)(np.arange(3.))}")
hcb.barrier_wait()
expected = """
transforms: ['jvp']
( ( 4.
2. )
( 1.
0. ) )"""
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [0. 1. 2.]
[0. 1. 4.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print("grad = ", api.grad(f)(x, 2.))
print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}")
hcb.barrier_wait()
expected = """
( 4.
2. )
transforms: ['jvp', 'transpose']
( 288.
192. )"""
transforms: ['jvp'] what: x,x^2
( ( 3.
9. )
( 0.1
0.6 ) )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
xv = np.array([3., 4.])
print("vmap o grad = ", api.vmap(api.grad(f))(xv, np.array([2., 3.])))
print(f"grad = {jax.grad(power3)(3.)}")
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})]
( [4. 5.]
[2. 3.] )
transforms: ['jvp', 'transpose', ('batch', {'batch_dims': (0, 0)})]
( [288. 900.]
[192. 500.] )"""
what: x,x^2
( 3.
9. )
transforms: ['jvp', 'transpose'] what: x,x^2
( 0.
3. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}")
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.]
[4. 9.] )
transforms: ['jvp', 'transpose', ('batch', {'batch_dims': (None, 0)})] what: x,x^2
( 0.
[2. 3.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
def test_tap_pmap(self):
xv = jnp.arange(len(devices()), dtype=jnp.int32)
if len(local_devices()) < 2:
raise SkipTest("test requires at least 2 devices")
def fun1(x, do_print=False): # x: i32
return maybe_print(do_print, x * 2, "x * 2", tap_with_device=True)
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
_, y = hcb.id_print((x, y),
what="x,x^2",
output_stream=testing_stream,
tap_with_device=True)
return y * x
pmap_fun1 = api.pmap(partial(fun1, do_print=True), devices=devices())
res = pmap_fun1(xv)
pmap_power3 = api.pmap(power3, devices=local_devices())
xv = np.array([3, 4], dtype=np.int32)
res = pmap_power3(xv)
hcb.barrier_wait()
expected_res = api.pmap(partial(fun1, do_print=False),
devices=devices())(xv)
self.assertAllClose(expected_res, res, check_dtypes=False)
self.assertAllClose(xv * xv * xv, res, check_dtypes=False)
# Assertion text is for 2 devices (also works for 1 device)
assertMultiDeviceOutputEqual(self, """
device: cpu:0 what: x * 2
0
device: cpu:1 what: x * 2
2""")
assertMultiDeviceOutputEqual(
self, """
device: cpu:0 what: x,x^2
( 3
9 )
device: cpu:1 what: x,x^2
( 4
16 )""")
testing_stream.reset()
def test_tap_pmap_vmap(self):
# A matrix M[ij] = i * 10 + j
nr_devices = len(devices())
nr_devices = len(local_devices())
shape = (nr_devices, 3)
matrix = np.fromfunction(lambda i, j: 10. * i + j, shape,
dtype=np.int32)
@ -1227,13 +1282,14 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
def fun1(x, do_print=False): # x: i32
return maybe_print(do_print, x * 2, "x * 2", tap_with_device=True)
pmap_vmap_fun1 = api.pmap(api.vmap(partial(fun1, do_print=True)),
devices=devices())
pmap_vmap_fun1 = api.pmap(
api.vmap(partial(fun1, do_print=True)), devices=local_devices())
res = pmap_vmap_fun1(matrix)
hcb.barrier_wait()
expected_res = api.pmap(api.vmap(partial(fun1, do_print=False)),
devices=devices())(matrix)
expected_res = api.pmap(
api.vmap(partial(fun1, do_print=False)), devices=local_devices())(
matrix)
self.assertAllClose(expected_res, res, check_dtypes=False)
# Assertion text is for 2 devices (also works for 1 device)
assertMultiDeviceOutputEqual(self, """
@ -1245,7 +1301,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
def test_tap_pmap_pmap_vmap(self):
# A matrix M[ijk] = i * 100 + j * 10 + k
nr_devices = len(devices())
nr_devices = len(local_devices())
if nr_devices % 2 != 0:
raise SkipTest("test works only on even number of devices")
@ -1257,12 +1313,15 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
y = maybe_print(do_print, x * 2., "x * 2", tap_with_device=True)
return y ** 2
pmap_fun1 = api.pmap(api.pmap(api.vmap(partial(fun1, do_print=True))),
devices=devices())
pmap_fun1 = api.pmap(
api.pmap(api.vmap(partial(fun1, do_print=True))),
devices=local_devices())
res = pmap_fun1(matrix)
hcb.barrier_wait()
expected_res = api.pmap(api.pmap(api.vmap(partial(fun1, do_print=False))),
devices=devices())(matrix)
expected_res = api.pmap(
api.pmap(api.vmap(partial(fun1, do_print=False))),
devices=local_devices())(
matrix)
self.assertAllClose(expected_res, res, check_dtypes=False)
# Assertion text is for 2 devices (also works for 1 device)
assertMultiDeviceOutputEqual(self, """
@ -1276,7 +1335,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
def test_tap_pmap_pmap_extra(self):
"""pmap of a pmap surrounded by extra code."""
# A matrix M[ij] = i * 10 + j
nr_devices = len(devices())
nr_devices = len(local_devices())
if nr_devices != 2:
raise SkipTest("test works only on 2 devices")
shape = (2, 1, 3)
@ -1312,7 +1371,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
def test_tap_jvp_pmap_vmap(self):
# A matrix M[ijk] = i * 100 + j * 10 * k
nr_devices = len(devices())
nr_devices = len(local_devices())
shape = (nr_devices, 2, 3)
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
dtype=np.float32)
@ -1343,7 +1402,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
def test_tap_vmap_pmap(self):
# A matrix M[ijk] = i * 100 + j * 10 * k
nr_devices = len(devices())
nr_devices = len(local_devices())
shape = (2, nr_devices, 3)
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
dtype=np.float32)
@ -1371,7 +1430,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
def test_tap_jit_pmap_extra(self):
"""jit of a pmap surrounded by extra code."""
# A matrix M[ij] = i * 10 + j
nr_devices = len(devices())
nr_devices = len(local_devices())
assert nr_devices in (1, 2)
shape = (nr_devices, 3)
matrix = np.fromfunction(lambda i, j: 10. * i + j, shape,
@ -1387,7 +1446,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
res = api.jit(partial(fun, do_print=True))(matrix)
self.assertAllClose(fun(matrix, do_print=False), res, check_dtypes=False)
hcb.barrier_wait()
if len(devices()) == 2:
if len(local_devices()) == 2:
assertMultiDeviceOutputEqual(self, """
device: cpu:0 what: before
[[ 1.00 2.00 3.00]
@ -1406,7 +1465,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
[[ 3.00 5.00 7.00]
[23.00 25.00 27.00]]""")
else:
assert len(devices()) == 1
assert len(local_devices()) == 1
assertMultiDeviceOutputEqual(self, """
device: cpu:0 what: before
[[1.00 2.00 3.00]]
@ -1420,7 +1479,7 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
def test_tap_cond_pmap(self):
raise SkipTest("cond of pmap does not work in JAX. Issue #5178.")
# A matrix M[ij] = i * 10 + j
nr_devices = len(devices())
nr_devices = len(local_devices())
shape = (nr_devices, 3)
matrix = np.fromfunction(lambda i, j: 10. * i + j, shape,
dtype=np.float32)
@ -1439,6 +1498,52 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
TBD""", testing_stream.output)
testing_stream.reset()
@jtu.skip_on_devices("cpu", "gpu")
# TODO(necula): file XLA:GPU bug for the 'Sharding' CustomCall
def test_tap_pjit(self):
devices = np.array(local_devices())
nr_devices = len(devices)
if nr_devices < 2:
raise SkipTest("test requires at least 2 devices")
print(f"test_tap_pjit is running on devices {devices}.")
# x: i32[D, 3] = [[0, 1, 2], [10, 11, 12], ...]
# y: i32[3, 4]
x = jnp.arange(100, dtype=jnp.int32).reshape((10, 10))[:nr_devices, :3]
y = jnp.ones((3, 4), np.int32)
@partial(jax.named_call, name="fun1") # for xprof debugging
def fun1(x, do_print=False):
z = jnp.dot(x, y)
return maybe_print(do_print, z, "z", tap_with_device=True)
res0 = fun1(x, do_print=False)
pjit_fun1 = pjit.pjit(
partial(fun1, do_print=True),
in_axis_resources=(P("d"),),
out_axis_resources=P("d"))
with maps.mesh(devices, ["d"]):
# Print the internal IR
helper_log_ir(
f"{self._testMethodName}.pjit",
pjit_fun1,
x,
num_partitions=nr_devices)
res = pjit_fun1(x)
self.assertAllClose(res0, res)
hcb.barrier_wait("before check")
# Assertion text is for 2 devices (also works for 1 device)
# Note that a single call is made.
assertMultiDeviceOutputEqual(
self, """
device: cpu:0 what: z
[[ 3 3 3 3]
[33 33 33 33]]""")
testing_stream.reset()
def test_tap_tap_scan_custom_jvp(self):
"""custom JVP, inside scan.
This exercises the custom_jvp_call_jaxpr primitives."""
@ -1933,7 +2038,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
result_shape=x,
call_with_device=True)
xv = jnp.arange(len(devices()), dtype=jnp.int32)
xv = jnp.arange(len(local_devices()), dtype=jnp.int32)
res = api.pmap(fun)(xv)
self.assertAllClose(api.pmap(lambda x: x * 6)(xv), res)
# Assertion text is for 2 devices (also works for 1 device)
@ -1954,6 +2059,53 @@ class HostCallbackCallTest(jtu.JaxTestCase):
"batching rules are implemented only for id_tap, not for call"):
api.vmap(fun)(np.ones((2, 3)))
@jtu.skip_on_devices("cpu", "gpu")
# TODO(necula): file XLA:GPU bug for the 'Sharding' CustomCall
def test_call_pjit(self):
devices = np.array(local_devices())
nr_devices = len(devices)
if nr_devices < 2:
raise SkipTest("test requires at least 2 devices")
print(f"test_call_pjit is running on devices {devices}.")
# x: i32[D, 3] = [[0, 1, 2], [10, 11, 12], ...]
# y: i32[3, 4]
x = jnp.arange(100, dtype=jnp.int32).reshape((10, 10))[:nr_devices, :3]
y = jnp.ones((3, 4), np.int32)
def callback_x5_func(x, device=None):
testing_stream.write(f"device: {device}\n Called with {x}")
return x * np.array(5, np.int32)
def fun(x):
xy = jnp.dot(x, y)
return hcb.call(
callback_x5_func, xy, result_shape=xy, call_with_device=True)
pjit_fun = pjit.pjit(
fun, in_axis_resources=(P("d"),), out_axis_resources=P("d"))
with maps.mesh(devices, ["d"]):
# Print the internal IR
helper_log_ir(
f"{self._testMethodName}.pjit",
pjit_fun,
x,
num_partitions=nr_devices)
res = pjit_fun(x)
expected_res = jnp.dot(x, y) * np.array(5, np.int32)
self.assertAllClose(expected_res, res, check_dtypes=False)
hcb.barrier_wait("before assertion")
# Assertion text is for 2 devices (also works for 1 device)
assertMultiDeviceOutputEqual(
self, """
device: cpu:0
Called with [[ 3 3 3 3]
[33 33 33 33]]""")
testing_stream.reset()
def test_call_error_bad_result_shape(self):
with self.assertRaisesRegex(
ValueError,

View File

@ -14,6 +14,7 @@
from contextlib import contextmanager
from functools import partial
import logging
from typing import Generator, List, Tuple
from unittest import SkipTest
@ -25,6 +26,7 @@ import jax
import jax.numpy as jnp
from jax import test_util as jtu
from jax.errors import JAXTypeError
from jax import lax
# TODO(skye): do we still wanna call this PartitionSpec?
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import xmap, mesh
@ -278,6 +280,75 @@ class PJitTest(jtu.BufferDonationTestCase):
# TODO(skye): add more unit tests once API is more finalized
def testInfeed(self):
devices = np.array(jax.local_devices())
nr_devices = len(devices)
shape = (nr_devices * 3, nr_devices * 5)
def f_for_jit(x):
token = lax.create_token(x)
(y,), token = lax.infeed(
token, shape=(jax.ShapedArray(x.shape, np.float32),))
(z,), token = lax.infeed(
token, shape=(jax.ShapedArray(x.shape, np.float32),))
(w,), token = lax.infeed(
token, shape=(jax.ShapedArray(x.shape, np.float32),))
return x + y + z + w
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
# JIT
logging.info('Making jit call')
res0 = jax.jit(f_for_jit)(x)
y = x * 2.
z = x * 3.
w = x * 4.
logging.info('Transfering to infeed for the jit call')
d = devices[0]
d.transfer_to_infeed((y,))
d.transfer_to_infeed((z,))
d.transfer_to_infeed((w,))
self.assertAllClose(res0, x + y + z + w, check_dtypes=True)
# PJIT
def f_for_pjit(x):
token = lax.create_token(x)
# A replicated infeed
(y,), token = lax.infeed(
token,
shape=(jax.ShapedArray(x.shape, np.float32),),
partitions=(None,))
# An infeed sharded on first axis
(z,), token = lax.infeed(
token,
shape=(jax.ShapedArray(x.shape, np.float32),),
partitions=(P(nr_devices, 1),))
# An infeed sharded on second axis
(w,), token = lax.infeed(
token,
shape=(jax.ShapedArray(x.shape, np.float32),),
partitions=(P(1, nr_devices),))
return x + y + z + w
with mesh(devices, ['d']):
logging.info('Making pjit call')
res = pjit(
f_for_pjit, in_axis_resources=(P('d'),), out_axis_resources=P('d'))(
x)
logging.info('Transfering to infeed for the pjit call')
for didx, d in enumerate(devices):
# Transfer the whole array to all devices for replicated.
d.transfer_to_infeed((y,))
# For sharded infeed, transfer only the needed slices to each device.
d.transfer_to_infeed((z[3 * didx:3 * didx + 3, :]))
d.transfer_to_infeed((w[:, 5 * didx:5 * didx + 5],))
self.assertAllClose(res0, res, check_dtypes=True)
@curry
def check_1d_2d_mesh(f, set_mesh):
return parameterized.named_parameters(