mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add early support in pjit for single device shardings. Also lift the restriction of needing the mesh context manager when config.jax_array
is enabled.
PiperOrigin-RevId: 465712981
This commit is contained in:
parent
81b6263ed0
commit
c02359b924
@ -295,8 +295,13 @@ def pjit(fun: Callable,
|
||||
resource_env = pxla.thread_resources.env
|
||||
pjit_mesh = resource_env.physical_mesh
|
||||
if pjit_mesh.empty:
|
||||
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
|
||||
"it's defined at the call site?")
|
||||
if config.jax_array:
|
||||
# Don't enforce requiring a mesh when `jax_array` flag is enabled. But
|
||||
# if mesh is not empty then pjit will respect it.
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
|
||||
"it's defined at the call site?")
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False)
|
||||
@ -955,6 +960,7 @@ def _pjit_batcher_for_sharding(
|
||||
return OpShardingSharding(s._device_assignment, new_op)
|
||||
else:
|
||||
assert isinstance(s, OpShardingSharding)
|
||||
assert not mesh.empty
|
||||
parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0]
|
||||
parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val)
|
||||
mps = MeshPspecSharding._from_parsed_pspec(mesh, parsed_pspec)
|
||||
@ -1613,14 +1619,18 @@ def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[Parti
|
||||
|
||||
def _get_op_sharding_from_executable(
|
||||
executable) -> Tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]:
|
||||
input_op_shardings: List[xc.OpSharding] = []
|
||||
for s in executable.hlo_modules()[0].spmd_parameters_shardings:
|
||||
input_op_shardings.extend(_get_op_sharding(s))
|
||||
in_op_shardings: List[xc.OpSharding] = []
|
||||
parameter_shardings_from_xla = executable.hlo_modules()[0].spmd_parameters_shardings
|
||||
if parameter_shardings_from_xla is not None:
|
||||
for s in parameter_shardings_from_xla:
|
||||
in_op_shardings.extend(_get_op_sharding(s))
|
||||
|
||||
output_op_shardings: Sequence[xc.OpSharding] = _get_op_sharding(
|
||||
executable.hlo_modules()[0].spmd_output_sharding)
|
||||
out_op_shardings: List[xc.OpSharding] = []
|
||||
output_shardings_from_xla = executable.hlo_modules()[0].spmd_output_sharding
|
||||
if output_shardings_from_xla is not None:
|
||||
out_op_shardings = _get_op_sharding(output_shardings_from_xla) # type: ignore
|
||||
|
||||
return input_op_shardings, output_op_shardings
|
||||
return in_op_shardings, out_op_shardings
|
||||
|
||||
|
||||
def _get_ppspec_from_executable(executable, mesh) -> Tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
|
||||
|
@ -205,6 +205,13 @@ class MeshPspecSharding(XLACompatibleSharding):
|
||||
return sharding_spec.sharding_proto(special_axes=special_axes)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_replicated_op_sharding():
|
||||
proto = xc.OpSharding()
|
||||
proto.type = xc.OpSharding.Type.REPLICATED
|
||||
return proto
|
||||
|
||||
|
||||
class SingleDeviceSharding(XLACompatibleSharding):
|
||||
|
||||
def __init__(self, device: Device):
|
||||
@ -237,9 +244,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
return [self._device]
|
||||
|
||||
def _to_xla_op_sharding(self, num_dimensions: int) -> Optional[xc.OpSharding]:
|
||||
proto = xc.OpSharding()
|
||||
proto.type = xc.OpSharding.Type.REPLICATED
|
||||
return proto
|
||||
return _get_replicated_op_sharding()
|
||||
|
||||
|
||||
class PmapSharding(XLACompatibleSharding):
|
||||
@ -338,6 +343,5 @@ class OpShardingSharding(XLACompatibleSharding):
|
||||
|
||||
@classmethod
|
||||
def get_replicated(cls, device_assignment):
|
||||
proto = xc.OpSharding()
|
||||
proto.type = xc.OpSharding.Type.REPLICATED
|
||||
proto = _get_replicated_op_sharding()
|
||||
return cls(device_assignment, proto)
|
||||
|
@ -2630,11 +2630,21 @@ def _get_input_metadata(
|
||||
return shardings, input_indices, input_avals
|
||||
|
||||
|
||||
def _get_op_sharding_shardings_from_executable(xla_executable, device_assignment):
|
||||
def _get_op_sharding_shardings_from_executable(
|
||||
xla_executable, device_assignment, num_in_avals, num_out_avals):
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.sharding import OpShardingSharding
|
||||
from jax.experimental.sharding import OpShardingSharding, SingleDeviceSharding
|
||||
|
||||
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
|
||||
|
||||
# When the device assignment only has 1 device, SPMD partitioner will not run.
|
||||
# Hence the op shardings will not be set on the `hlo_module`. In that case,
|
||||
# just return SingleDeviceShardings since we know the computation is running
|
||||
# only on 1 device.
|
||||
if not in_op_shardings and not out_op_shardings and len(device_assignment) == 1:
|
||||
return ([SingleDeviceSharding(device_assignment[0]) for _ in range(num_in_avals)],
|
||||
[SingleDeviceSharding(device_assignment[0]) for _ in range(num_out_avals)])
|
||||
|
||||
return ([OpShardingSharding(device_assignment, i) for i in in_op_shardings],
|
||||
[OpShardingSharding(device_assignment, o) for o in out_op_shardings])
|
||||
|
||||
@ -2746,7 +2756,8 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
elif out_shardings and all(_is_unspecified(o) for o in out_shardings):
|
||||
assert mesh is None
|
||||
in_shardings, out_shardings = _get_op_sharding_shardings_from_executable(
|
||||
xla_executable, first_sharding._device_assignment)
|
||||
xla_executable, first_sharding._device_assignment,
|
||||
len(global_in_avals), len(global_out_avals))
|
||||
|
||||
in_shardings, input_indices, input_avals = _get_input_metadata(
|
||||
global_in_avals, in_shardings, in_is_global) # type: ignore
|
||||
|
@ -1443,6 +1443,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
('fully_sharded_output', P('x', 'y'), (2, 4)),
|
||||
('fully_replicated_output', P(None), (8, 8)),
|
||||
)
|
||||
@jax._src.config.jax_array(True)
|
||||
def test_pjit_array_single_output(self, out_axis_resources, shard_shape):
|
||||
global_input_shape = (8, 2)
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
@ -1450,20 +1451,45 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)
|
||||
|
||||
with jax._src.config.jax_array(True):
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x @ x.T, out_axis_resources=MeshPspecSharding(
|
||||
global_mesh, out_axis_resources))
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
f = pjit(lambda x: x @ x.T, out_axis_resources=MeshPspecSharding(
|
||||
global_mesh, out_axis_resources))
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
|
||||
out = f(input_array)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertEqual(out.shape, (8, 8))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
|
||||
for s in out.addressable_shards:
|
||||
self.assertLen(s.data._arrays, 1)
|
||||
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
|
||||
self.assertArraysEqual(out._value, expected_matrix_mul)
|
||||
out = f(input_array)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertEqual(out.shape, (8, 8))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
|
||||
for s in out.addressable_shards:
|
||||
self.assertLen(s.data._arrays, 1)
|
||||
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
|
||||
self.assertArraysEqual(out._value, expected_matrix_mul)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('fully_sharded_output', P('x', 'y'), (2, 4)),
|
||||
('fully_replicated_output', P(None), (8, 8)),
|
||||
)
|
||||
@jax._src.config.jax_array(True)
|
||||
def test_pjit_array_single_output_with_mesh_context_manager(
|
||||
self, out_axis_resources, shard_shape):
|
||||
global_input_shape = (8, 2)
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh_axes = P('x', 'y')
|
||||
|
||||
input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)
|
||||
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x @ x.T, out_axis_resources=MeshPspecSharding(
|
||||
global_mesh, out_axis_resources))
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
|
||||
out = f(input_array)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertEqual(out.shape, (8, 8))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
|
||||
for s in out.addressable_shards:
|
||||
self.assertLen(s.data._arrays, 1)
|
||||
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
|
||||
self.assertArraysEqual(out._value, expected_matrix_mul)
|
||||
|
||||
def test_non_array_input_error(self):
|
||||
input_shape = (8, 2)
|
||||
@ -1498,6 +1524,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
self.assertArraysEqual(out._value, input_data)
|
||||
|
||||
@jax._src.config.jax_array(True)
|
||||
def test_unspecified_out_axis_resources(self):
|
||||
|
||||
def _checks(out, input_data):
|
||||
@ -1516,21 +1543,20 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)
|
||||
|
||||
with jax._src.config.jax_array(True):
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x)
|
||||
f = pjit(lambda x: x)
|
||||
|
||||
out = f(input_array)
|
||||
_checks(out, input_data)
|
||||
out = f(input_array)
|
||||
_checks(out, input_data)
|
||||
|
||||
out2 = f(out)
|
||||
_checks(out2, input_data)
|
||||
out2 = f(out)
|
||||
_checks(out2, input_data)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('mesh1', (4, 2), (2, 1), (2, 2), (1, 2), (8, 2)),
|
||||
('mesh2', (2, 2), (4, 1), (4, 2), (2, 2), (8, 2)),
|
||||
('mesh3', (2, 1), (4, 2), (4, 2), (4, 2), (8, 2)),
|
||||
)
|
||||
@jax._src.config.jax_array(True)
|
||||
def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape,
|
||||
s2_shape, s3_shape, s4_shape):
|
||||
# Disable on SE runtime type because XLA sharding propagation is not
|
||||
@ -1549,37 +1575,35 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
spec4 = P(None)
|
||||
a4, _ = create_array(global_input_shape, global_mesh, spec4)
|
||||
|
||||
with jax._src.config.jax_array(True):
|
||||
with global_mesh:
|
||||
@pjit
|
||||
def f(tree):
|
||||
return tree
|
||||
out_tree = f((a1, (a2, (a3, a4))))
|
||||
(out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree)
|
||||
@pjit
|
||||
def f(tree):
|
||||
return tree
|
||||
out_tree = f((a1, (a2, (a3, a4))))
|
||||
(out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree)
|
||||
|
||||
self.assertIsInstance(out1, array.Array)
|
||||
self.assertEqual(out1.shape, (8, 2))
|
||||
self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape)
|
||||
for s in out1.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
self.assertIsInstance(out1, array.Array)
|
||||
self.assertEqual(out1.shape, (8, 2))
|
||||
self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape)
|
||||
for s in out1.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
|
||||
self.assertIsInstance(out2, array.Array)
|
||||
self.assertEqual(out2.shape, (8, 2))
|
||||
self.assertEqual(out2.addressable_shards[0].data.shape, s2_shape)
|
||||
for s in out2.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
self.assertIsInstance(out2, array.Array)
|
||||
self.assertEqual(out2.shape, (8, 2))
|
||||
self.assertEqual(out2.addressable_shards[0].data.shape, s2_shape)
|
||||
for s in out2.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
|
||||
self.assertIsInstance(out3, array.Array)
|
||||
self.assertEqual(out3.shape, (8, 2))
|
||||
self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape)
|
||||
for s in out3.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
self.assertIsInstance(out3, array.Array)
|
||||
self.assertEqual(out3.shape, (8, 2))
|
||||
self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape)
|
||||
for s in out3.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
|
||||
self.assertIsInstance(out4, array.Array)
|
||||
self.assertEqual(out4.shape, (8, 2))
|
||||
self.assertEqual(out4.addressable_shards[0].data.shape, s4_shape)
|
||||
for s in out4.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data)
|
||||
self.assertIsInstance(out4, array.Array)
|
||||
self.assertEqual(out4.shape, (8, 2))
|
||||
self.assertEqual(out4.addressable_shards[0].data.shape, s4_shape)
|
||||
for s in out4.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data)
|
||||
|
||||
def test_in_axis_resources_mismatch_error(self):
|
||||
global_input_shape = (8, 2)
|
||||
@ -1736,6 +1760,34 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
compiled = f.lower(jax.ShapedArray(input_shape, jnp.float32)).compile()
|
||||
compiled(a1) # no error
|
||||
|
||||
@jax._src.config.jax_array(True)
|
||||
def test_pjit_single_device_sharding_add(self):
|
||||
a = jnp.array([1, 2, 3], dtype=jnp.float32)
|
||||
b = jnp.array([4, 5, 6], dtype=jnp.float32)
|
||||
|
||||
@pjit
|
||||
def add(x, y):
|
||||
return x + y
|
||||
out = add(a, b)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertArraysEqual(out, a + b)
|
||||
|
||||
out2 = add(out, out)
|
||||
self.assertIsInstance(out2, array.Array)
|
||||
self.assertArraysEqual(out2, 2 * (a + b))
|
||||
|
||||
@jax._src.config.jax_array(True)
|
||||
def test_pjit_single_device_sharding_mul(self):
|
||||
a = jnp.arange(16).reshape((8, 2))
|
||||
|
||||
@pjit
|
||||
def mul(x):
|
||||
return x @ x.T
|
||||
|
||||
out = mul(a)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertArraysEqual(out, a @ a.T)
|
||||
|
||||
|
||||
def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
|
Loading…
x
Reference in New Issue
Block a user