Add early support in pjit for single device shardings. Also lift the restriction of needing the mesh context manager when config.jax_array is enabled.

PiperOrigin-RevId: 465712981
This commit is contained in:
Yash Katariya 2022-08-05 22:24:46 -07:00 committed by jax authors
parent 81b6263ed0
commit c02359b924
4 changed files with 140 additions and 63 deletions

View File

@ -295,8 +295,13 @@ def pjit(fun: Callable,
resource_env = pxla.thread_resources.env
pjit_mesh = resource_env.physical_mesh
if pjit_mesh.empty:
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
"it's defined at the call site?")
if config.jax_array:
# Don't enforce requiring a mesh when `jax_array` flag is enabled. But
# if mesh is not empty then pjit will respect it.
pass
else:
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
"it's defined at the call site?")
f = lu.wrap_init(fun)
f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False)
@ -955,6 +960,7 @@ def _pjit_batcher_for_sharding(
return OpShardingSharding(s._device_assignment, new_op)
else:
assert isinstance(s, OpShardingSharding)
assert not mesh.empty
parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0]
parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val)
mps = MeshPspecSharding._from_parsed_pspec(mesh, parsed_pspec)
@ -1613,14 +1619,18 @@ def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[Parti
def _get_op_sharding_from_executable(
executable) -> Tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]:
input_op_shardings: List[xc.OpSharding] = []
for s in executable.hlo_modules()[0].spmd_parameters_shardings:
input_op_shardings.extend(_get_op_sharding(s))
in_op_shardings: List[xc.OpSharding] = []
parameter_shardings_from_xla = executable.hlo_modules()[0].spmd_parameters_shardings
if parameter_shardings_from_xla is not None:
for s in parameter_shardings_from_xla:
in_op_shardings.extend(_get_op_sharding(s))
output_op_shardings: Sequence[xc.OpSharding] = _get_op_sharding(
executable.hlo_modules()[0].spmd_output_sharding)
out_op_shardings: List[xc.OpSharding] = []
output_shardings_from_xla = executable.hlo_modules()[0].spmd_output_sharding
if output_shardings_from_xla is not None:
out_op_shardings = _get_op_sharding(output_shardings_from_xla) # type: ignore
return input_op_shardings, output_op_shardings
return in_op_shardings, out_op_shardings
def _get_ppspec_from_executable(executable, mesh) -> Tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:

View File

@ -205,6 +205,13 @@ class MeshPspecSharding(XLACompatibleSharding):
return sharding_spec.sharding_proto(special_axes=special_axes)
@functools.lru_cache()
def _get_replicated_op_sharding():
proto = xc.OpSharding()
proto.type = xc.OpSharding.Type.REPLICATED
return proto
class SingleDeviceSharding(XLACompatibleSharding):
def __init__(self, device: Device):
@ -237,9 +244,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
return [self._device]
def _to_xla_op_sharding(self, num_dimensions: int) -> Optional[xc.OpSharding]:
proto = xc.OpSharding()
proto.type = xc.OpSharding.Type.REPLICATED
return proto
return _get_replicated_op_sharding()
class PmapSharding(XLACompatibleSharding):
@ -338,6 +343,5 @@ class OpShardingSharding(XLACompatibleSharding):
@classmethod
def get_replicated(cls, device_assignment):
proto = xc.OpSharding()
proto.type = xc.OpSharding.Type.REPLICATED
proto = _get_replicated_op_sharding()
return cls(device_assignment, proto)

View File

@ -2630,11 +2630,21 @@ def _get_input_metadata(
return shardings, input_indices, input_avals
def _get_op_sharding_shardings_from_executable(xla_executable, device_assignment):
def _get_op_sharding_shardings_from_executable(
xla_executable, device_assignment, num_in_avals, num_out_avals):
from jax.experimental import pjit
from jax.experimental.sharding import OpShardingSharding
from jax.experimental.sharding import OpShardingSharding, SingleDeviceSharding
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
# When the device assignment only has 1 device, SPMD partitioner will not run.
# Hence the op shardings will not be set on the `hlo_module`. In that case,
# just return SingleDeviceShardings since we know the computation is running
# only on 1 device.
if not in_op_shardings and not out_op_shardings and len(device_assignment) == 1:
return ([SingleDeviceSharding(device_assignment[0]) for _ in range(num_in_avals)],
[SingleDeviceSharding(device_assignment[0]) for _ in range(num_out_avals)])
return ([OpShardingSharding(device_assignment, i) for i in in_op_shardings],
[OpShardingSharding(device_assignment, o) for o in out_op_shardings])
@ -2746,7 +2756,8 @@ class MeshExecutable(stages.XlaExecutable):
elif out_shardings and all(_is_unspecified(o) for o in out_shardings):
assert mesh is None
in_shardings, out_shardings = _get_op_sharding_shardings_from_executable(
xla_executable, first_sharding._device_assignment)
xla_executable, first_sharding._device_assignment,
len(global_in_avals), len(global_out_avals))
in_shardings, input_indices, input_avals = _get_input_metadata(
global_in_avals, in_shardings, in_is_global) # type: ignore

View File

@ -1443,6 +1443,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
('fully_sharded_output', P('x', 'y'), (2, 4)),
('fully_replicated_output', P(None), (8, 8)),
)
@jax._src.config.jax_array(True)
def test_pjit_array_single_output(self, out_axis_resources, shard_shape):
global_input_shape = (8, 2)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
@ -1450,20 +1451,45 @@ class ArrayPjitTest(jtu.JaxTestCase):
input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)
with jax._src.config.jax_array(True):
with global_mesh:
f = pjit(lambda x: x @ x.T, out_axis_resources=MeshPspecSharding(
global_mesh, out_axis_resources))
expected_matrix_mul = input_data @ input_data.T
f = pjit(lambda x: x @ x.T, out_axis_resources=MeshPspecSharding(
global_mesh, out_axis_resources))
expected_matrix_mul = input_data @ input_data.T
out = f(input_array)
self.assertIsInstance(out, array.Array)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
for s in out.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
self.assertArraysEqual(out._value, expected_matrix_mul)
out = f(input_array)
self.assertIsInstance(out, array.Array)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
for s in out.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
self.assertArraysEqual(out._value, expected_matrix_mul)
@parameterized.named_parameters(
('fully_sharded_output', P('x', 'y'), (2, 4)),
('fully_replicated_output', P(None), (8, 8)),
)
@jax._src.config.jax_array(True)
def test_pjit_array_single_output_with_mesh_context_manager(
self, out_axis_resources, shard_shape):
global_input_shape = (8, 2)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mesh_axes = P('x', 'y')
input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)
with global_mesh:
f = pjit(lambda x: x @ x.T, out_axis_resources=MeshPspecSharding(
global_mesh, out_axis_resources))
expected_matrix_mul = input_data @ input_data.T
out = f(input_array)
self.assertIsInstance(out, array.Array)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
for s in out.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
self.assertArraysEqual(out._value, expected_matrix_mul)
def test_non_array_input_error(self):
input_shape = (8, 2)
@ -1498,6 +1524,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertArraysEqual(out._value, input_data)
@jax._src.config.jax_array(True)
def test_unspecified_out_axis_resources(self):
def _checks(out, input_data):
@ -1516,21 +1543,20 @@ class ArrayPjitTest(jtu.JaxTestCase):
input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)
with jax._src.config.jax_array(True):
with global_mesh:
f = pjit(lambda x: x)
f = pjit(lambda x: x)
out = f(input_array)
_checks(out, input_data)
out = f(input_array)
_checks(out, input_data)
out2 = f(out)
_checks(out2, input_data)
out2 = f(out)
_checks(out2, input_data)
@parameterized.named_parameters(
('mesh1', (4, 2), (2, 1), (2, 2), (1, 2), (8, 2)),
('mesh2', (2, 2), (4, 1), (4, 2), (2, 2), (8, 2)),
('mesh3', (2, 1), (4, 2), (4, 2), (4, 2), (8, 2)),
)
@jax._src.config.jax_array(True)
def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape,
s2_shape, s3_shape, s4_shape):
# Disable on SE runtime type because XLA sharding propagation is not
@ -1549,37 +1575,35 @@ class ArrayPjitTest(jtu.JaxTestCase):
spec4 = P(None)
a4, _ = create_array(global_input_shape, global_mesh, spec4)
with jax._src.config.jax_array(True):
with global_mesh:
@pjit
def f(tree):
return tree
out_tree = f((a1, (a2, (a3, a4))))
(out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree)
@pjit
def f(tree):
return tree
out_tree = f((a1, (a2, (a3, a4))))
(out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree)
self.assertIsInstance(out1, array.Array)
self.assertEqual(out1.shape, (8, 2))
self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape)
for s in out1.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertIsInstance(out1, array.Array)
self.assertEqual(out1.shape, (8, 2))
self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape)
for s in out1.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertIsInstance(out2, array.Array)
self.assertEqual(out2.shape, (8, 2))
self.assertEqual(out2.addressable_shards[0].data.shape, s2_shape)
for s in out2.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertIsInstance(out2, array.Array)
self.assertEqual(out2.shape, (8, 2))
self.assertEqual(out2.addressable_shards[0].data.shape, s2_shape)
for s in out2.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertIsInstance(out3, array.Array)
self.assertEqual(out3.shape, (8, 2))
self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape)
for s in out3.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertIsInstance(out3, array.Array)
self.assertEqual(out3.shape, (8, 2))
self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape)
for s in out3.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertIsInstance(out4, array.Array)
self.assertEqual(out4.shape, (8, 2))
self.assertEqual(out4.addressable_shards[0].data.shape, s4_shape)
for s in out4.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data)
self.assertIsInstance(out4, array.Array)
self.assertEqual(out4.shape, (8, 2))
self.assertEqual(out4.addressable_shards[0].data.shape, s4_shape)
for s in out4.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data)
def test_in_axis_resources_mismatch_error(self):
global_input_shape = (8, 2)
@ -1736,6 +1760,34 @@ class ArrayPjitTest(jtu.JaxTestCase):
compiled = f.lower(jax.ShapedArray(input_shape, jnp.float32)).compile()
compiled(a1) # no error
@jax._src.config.jax_array(True)
def test_pjit_single_device_sharding_add(self):
a = jnp.array([1, 2, 3], dtype=jnp.float32)
b = jnp.array([4, 5, 6], dtype=jnp.float32)
@pjit
def add(x, y):
return x + y
out = add(a, b)
self.assertIsInstance(out, array.Array)
self.assertArraysEqual(out, a + b)
out2 = add(out, out)
self.assertIsInstance(out2, array.Array)
self.assertArraysEqual(out2, 2 * (a + b))
@jax._src.config.jax_array(True)
def test_pjit_single_device_sharding_mul(self):
a = jnp.arange(16).reshape((8, 2))
@pjit
def mul(x):
return x @ x.T
out = mul(a)
self.assertIsInstance(out, array.Array)
self.assertArraysEqual(out, a @ a.T)
def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")