From 973171bb6d76e1ae25a00b47e2eed3e1e0f204ca Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 1 Jun 2021 14:32:59 +0300 Subject: [PATCH] [jax2tf] Add support for pjit. --- jax/experimental/jax2tf/jax2tf.py | 79 +++++++++-- .../jax2tf/tests/sharding_test.py | 134 ++++++++++++++++-- jax/experimental/pjit.py | 9 +- jax/test_util.py | 44 +++++- tests/pjit_test.py | 84 ++++------- tests/xmap_test.py | 88 ++++-------- 6 files changed, 305 insertions(+), 133 deletions(-) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 790d49c17..328d374b5 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -31,6 +31,8 @@ from jax._src.lax import lax from jax._src.lax import linalg as lax_linalg import jax._src.random from jax.api_util import flatten_fun +from jax.experimental import maps +from jax.experimental import pjit from jax.interpreters import ad from jax.interpreters import pxla from jax.interpreters import sharded_jit @@ -397,9 +399,6 @@ def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal) -> Sequence[TfVal]: return tuple(v for v, _ in out_with_avals) -### tracer - - def _aval_to_tf_shape(aval: core.AbstractValue) -> Tuple[Optional[int], ...]: """Generate a TF shape, possibly containing None for polymorphic dimensions.""" return tuple( @@ -729,8 +728,9 @@ class TensorFlowTrace(core.Trace): def process_call(self, call_primitive: core.Primitive, f: lu.WrappedFun, tracers: Sequence[TensorFlowTracer], params): assert call_primitive.multiple_results - vals: Sequence[TfVal] = [t.val for t in tracers] - f = _interpret_subtrace(f, self.main, tuple(t.aval for t in tracers)) + vals: Sequence[TfVal] = tuple(t.val for t in tracers) + avals: Sequence[core.AbstractValue] = tuple(t.aval for t in tracers) + f = _interpret_subtrace(f, self.main, avals) with core.new_sublevel(): if call_primitive == core.named_call_p: with tf.name_scope(_sanitize_scope_name(params["name"])): @@ -813,6 +813,8 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): for unexpected in xla.call_translations: # Call primitives are inlined + if unexpected is pjit.pjit_p: + continue tf_impl[unexpected] = functools.partial(_unexpected_primitive, unexpected) # Primitives that are not yet implemented must be explicitly declared here. @@ -2494,8 +2496,8 @@ def split_to_logical_devices(tensor: TfVal, Returns: an annotated tensor. """ - # This corresponds to the sharding annotations in - # xla_bridge._sharding_to_proto. + # TODO: this is only for sharded_jit. Either remove, or implement in terms + # of _shard_values. if partition_dimensions is None: return xla_sharding.replicate(tensor, use_sharding_op=True) num_partition_splits = np.prod(partition_dimensions) @@ -2504,6 +2506,24 @@ def split_to_logical_devices(tensor: TfVal, return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) +def _shard_value(mesh: maps.Mesh, + val: TfVal, + aval: core.AbstractValue, + axis_resources: pjit.ParsedPartitionSpec) -> TfVal: + """Apply sharding to a TfVal.""" + sharding_proto: xla_client.OpSharding = pjit.get_sharding_proto_aval( + aval, axis_resources, mesh) + # To use xla_sharding.py, we must have a xla_data_pb2.OpSharding. + xla_sharding_proto: xla_data_pb2.OpSharding = ( + xla_data_pb2.OpSharding( + type=int(sharding_proto.type), + tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions, + tile_assignment_devices=sharding_proto.tile_assignment_devices, + replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim)) + return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor( + val, use_sharding_op=True) + + def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal], in_parts: Sequence[pxla.PartitionsOrReplicated], out_parts_thunk, @@ -2520,12 +2540,51 @@ def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal], return sharded_vals_out -def _sharding_constraint(arg: TfVal, *, - partitions: pxla.PartitionsOrReplicated): +def _sharded_jit_sharding_constraint(arg: TfVal, *, + partitions: pxla.PartitionsOrReplicated, + _in_avals: Sequence[core.ShapedArray], + _out_aval: core.ShapedArray): + del _in_avals, _out_aval return split_to_logical_devices(arg, partitions) -tf_impl[sharded_jit.sharding_constraint_p] = _sharding_constraint +tf_impl_with_avals[sharded_jit.sharding_constraint_p] = _sharded_jit_sharding_constraint + + +def _pjit(*args: TfVal, + jaxpr: core.ClosedJaxpr, + in_axis_resources: Sequence[pjit.ParsedPartitionSpec], + out_axis_resources: Sequence[pjit.ParsedPartitionSpec], + resource_env: maps.ResourceEnv, + donated_invars, + name: str, + _in_avals: Sequence[core.ShapedArray], + _out_aval: core.ShapedArray) -> TfVal: + del donated_invars, name + # TODO: add `name` to the name stack + shard_value_for_mesh = functools.partial(_shard_value, resource_env.physical_mesh) + # Apply sharding annotation to the arguments + sharded_args: Sequence[TfVal] = tuple( + util.safe_map(shard_value_for_mesh, args, _in_avals, in_axis_resources)) + results = _interpret_jaxpr(jaxpr, *sharded_args) + sharded_results: Sequence[TfVal] = tuple( + util.safe_map(shard_value_for_mesh, results, _out_aval, out_axis_resources)) + return tuple(sharded_results) + + +tf_impl_with_avals[pjit.pjit_p] = _pjit + + +def _pjit_sharding_constraint(arg: TfVal, *, + axis_resources: pjit.ParsedPartitionSpec, + resource_env: maps.ResourceEnv, + _in_avals: Sequence[core.ShapedArray], + _out_aval: core.ShapedArray, + **kwargs) -> TfVal: + return _shard_value(resource_env.physical_mesh, arg, _in_avals[0], axis_resources) + + +tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint def _register_checkpoint_pytrees(): diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index eb4b0a4e3..8fa835831 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -13,19 +13,24 @@ # limitations under the License. """Tests for the jax2tf conversion of sharded_jit.""" +import contextlib +import functools import logging import re -from typing import Sequence +from typing import Any, Generator, List, Sequence, Tuple import unittest from absl.testing import absltest import jax from jax import test_util as jtu +from jax import util from jax.config import config from jax.experimental import jax2tf from jax.experimental.jax2tf.tests import tf_test_util +from jax.experimental import maps +from jax.experimental import pjit from jax.interpreters import sharded_jit from jax.interpreters.sharded_jit import PartitionSpec as P import jax.numpy as jnp @@ -36,6 +41,13 @@ import tensorflow as tf # type: ignore[import] config.parse_flags_with_absl() +def setUpModule(): + jtu.set_spmd_lowering_flag(True) + +def tearDownModule(): + jtu.restore_spmd_lowering_flag() + + LOG_HLO = True class ShardedJitHloTest(tf_test_util.JaxToTfTestCase): @@ -46,7 +58,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase): def _check_sharding_annotations(self, f_jax, - args, + args: Sequence[Any], *, expected: Sequence[str], expected_opt: Sequence[str], @@ -57,6 +69,9 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase): We currently check the unoptimized HLO against `expected` on CPU and TPU, and we check the optimized HLO against `expected_opt` on TPU only and only for JAX. + + See `self.AssertShardingAnnotations` for documentation of `expected` + and `expected_opt`. """ if jtu.device_under_test() == "gpu": raise unittest.SkipTest("Sharding HLO tests not useful for GPU") @@ -93,20 +108,20 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase): experimental_get_compiler_ir(*args)(stage="hlo", device_name=device_name) if LOG_HLO: - logging.info(f"[{self._testMethodName}] got TF OPT HLO {tf_hlo}") + logging.info(f"[{self._testMethodName}] got TF HLO {tf_hlo}") self.AssertShardingAnnotations("TF before optimizations", tf_hlo, expected) tf_optimized_hlo = tf.function(f_tf, jit_compile=True).\ experimental_get_compiler_ir(*args)(stage="optimized_hlo", device_name=device_name) if LOG_HLO: - logging.info(f"[{self._testMethodName}] XXX got TF OPT HLO " + logging.info(f"[{self._testMethodName}] got TF optimized HLO " f"for {device_name}: {tf_optimized_hlo}") def AssertShardingAnnotations(self, what: str, hlo: str, expected: Sequence[str]): """Args: - what: either 'JAX' or 'TF' + what: either 'JAX' or 'TF', used for messages only. hlo: the text for the HLO module expected: a sequence of regexps that must occur in the hlo text. Each regexp must match a line, in order. @@ -128,7 +143,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase): f"!!! Not found[{next_expected_idx}] {expected[next_expected_idx]}") raise self.failureException("\n".join(failure_msg)) - def test_in_out(self): + def test_sharded_jit_in_out(self): """Test input and output sharding annotations.""" sharded_jax_func = sharded_jit.sharded_jit( jnp.dot, in_parts=(P(1, 2), P(2, 1)), out_parts=P(1, 2)) @@ -152,7 +167,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase): ], num_partitions=2) - def test_with_sharding_constraint(self): + def test_sharded_jit_with_sharding_constraint(self): """A sharding constraint in the middle.""" def jax_func(x, y): @@ -180,7 +195,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase): ], num_partitions=2) - def test_replicated(self): + def test_sharded_jit_replicated(self): """A replicated input and output.""" sharded_jax_func = sharded_jit.sharded_jit( @@ -203,6 +218,109 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase): ], num_partitions=2) + @jtu.with_mesh([('x', 2)]) + def test_pjit_basic1D(self): + + @functools.partial(pjit.pjit, + in_axis_resources=(P('x'), P('x')), + out_axis_resources=None) + def jax_func(x, y): + return x + y + + shape = (8, 10) + x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + hlo = jax.xla_computation(jax_func)(x, x).as_hlo_text() + print(f"HLO is {hlo}") + print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}") + self._check_sharding_annotations( + jax_func, [x, x], + expected=[ + r"f32\[8,10\].*sharding={devices=\[2,1\]", + r"f32\[8,10\].*sharding={replicated", # output + ], + expected_opt=[ + # TODO(necula): relax ordering + r"f32\[4,10\].*sharding={devices=\[2,1\]", + r"f32\[6,4\].*sharding={devices=\[1,2\]", + ], + num_partitions=2) + + @jtu.with_mesh([('x', 2), ('y', 2)]) + def test_pjit_basic2D(self): + @functools.partial(pjit.pjit, + in_axis_resources=(P(None, 'x', 'y'), P('y')), + out_axis_resources=P('x')) + def jax_func(x, y): + return x @ y + + x_shape = (8, 6, 4) + y_shape = (4, 2) + x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape) + self._check_sharding_annotations( + jax_func, [x, y], + expected=[ + r"f32\[8,6,4\].*sharding={devices=\[1,2,2\]0,1,2,3", # arg0 + r"f32\[4,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate", # arg1 + r"f32\[8,6,2\].*sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate", # output + ], + expected_opt=[ + # TODO(necula): relax ordering + r"f32\[4,10\].*sharding={devices=\[2,1\]", + r"f32\[6,4\].*sharding={devices=\[1,2\]", + ], + num_partitions=2) + + @jtu.with_mesh([('x', 2), ('y', 2)]) + def test_pjit_TwoMeshAxisSharding(self): + @functools.partial(pjit.pjit, + in_axis_resources=P(('x', 'y'),), + out_axis_resources=P(('x', 'y'),)) + def jax_func(x, y): + return x @ y + + x_shape = (24, 8) + y_shape = (8, 2) + x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape) + self._check_sharding_annotations( + jax_func, [x, y], + expected=[ + r"f32\[24,8\].*sharding={devices=\[4,1\]0,1,2,3", # x + r"f32\[8,2\].*sharding={devices=\[4,1\]0,1,2,3", # y + r"f32\[24,2\].*sharding={devices=\[4,1\]0,1,2,3", # output + ], + expected_opt=[ + # TODO(necula): relax ordering + r"f32\[4,10\].*sharding={devices=\[2,1\]", + r"f32\[6,4\].*sharding={devices=\[1,2\]", + ], + num_partitions=2) + + @jtu.with_mesh([('x', 2), ('y', 1)]) + def test_pjit_ShardingConstraint(self): + @functools.partial(pjit.pjit, in_axis_resources=None, + out_axis_resources=None) + def jax_func(x): + y = jnp.tile(x, (2, 1)) + y = pjit.with_sharding_constraint(y, P('x', 'y')) + return y * 2 + + shape = (12, 8) + x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + self._check_sharding_annotations( + jax_func, [x], + expected=[ + r"f32\[12,8\].*sharding={replicated}", # x + r"f32\[24,8\].*sharding={devices=\[2,1\]0,1", # y + r"f32\[24,8\].*sharding={replicated}", # output + ], + expected_opt=[ + # TODO(necula): relax ordering + r"f32\[4,10\].*sharding={devices=\[2,1\]", + r"f32\[6,4\].*sharding={devices=\[1,2\]", + ], + num_partitions=2) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 0766df952..250325f79 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -590,13 +590,20 @@ def get_array_mapping(axis_resources: ParsedPartitionSpec) -> pxla.ArrayMapping: for i, axes in enumerate(axis_resources) for axis in axes) -def get_sharding_proto(c, xla_op, axis_resources, mesh): +def get_sharding_proto(c, xla_op, axis_resources: ParsedPartitionSpec, + mesh: maps.Mesh) -> xc.OpSharding: xla_shape = c.GetShape(xla_op) if xla_shape.is_token(): aval = core.abstract_token assert axis_resources is REPLICATED else: aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.element_type()) + return get_sharding_proto_aval(aval, axis_resources, mesh) + + +def get_sharding_proto_aval(aval: core.AbstractValue, + axis_resources: ParsedPartitionSpec, + mesh: maps.Mesh) -> xc.OpSharding: array_mapping = get_array_mapping(axis_resources) sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)( aval, array_mapping) diff --git a/jax/test_util.py b/jax/test_util.py index 19afbe8bf..e073a2088 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -17,7 +17,7 @@ import functools import re import os import textwrap -from typing import Dict, Sequence, Union +from typing import Dict, List, Generator, Sequence, Tuple, Union import unittest import warnings import zlib @@ -33,10 +33,12 @@ from . import core from ._src import dtypes as _dtypes from . import lax from ._src.config import flags, bool_env, config -from ._src.util import partial, prod +from ._src.util import partial, prod, unzip2 from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce from .lib import xla_bridge from .interpreters import xla +from .experimental import maps +from .experimental.maps import mesh FLAGS = flags.FLAGS @@ -1012,6 +1014,44 @@ def ignore_warning(**kw): warnings.filterwarnings("ignore", **kw) yield +# -------------------- Mesh parametrization helpers -------------------- + +MeshSpec = List[Tuple[str, int]] + +@contextmanager +def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]: + """Test utility for setting up meshes given mesh data from `schedules`.""" + # This is similar to the `with_mesh` function above, but isn't a decorator. + axis_names, shape = unzip2(named_shape) + size = prod(shape) + local_devices = list(api.local_devices()) + if len(local_devices) < size: + raise unittest.SkipTest(f"Test requires {size} local devices") + mesh_devices = np.array(local_devices[:size]).reshape(shape) + with mesh(mesh_devices, axis_names): + yield + +def with_mesh_from_kwargs(f): + return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs) + +def with_and_without_mesh(f): + return parameterized.named_parameters( + {"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources} + for name, mesh, axis_resources in ( + ('', (), ()), + ('Mesh', (('x', 2),), (('i', 'x'),)) + ))(with_mesh_from_kwargs(f)) + +old_spmd_lowering_flag = False +def set_spmd_lowering_flag(val: bool): + global old_spmd_lowering_flag + maps.make_xmap_callable.cache_clear() + old_spmd_lowering_flag = maps.EXPERIMENTAL_SPMD_LOWERING + maps.EXPERIMENTAL_SPMD_LOWERING = val + +def restore_spmd_lowering_flag(): + maps.make_xmap_callable.cache_clear() + maps.EXPERIMENTAL_SPMD_LOWERING = old_spmd_lowering_flag class _cached_property: null = object() diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d028ab2ca..e4862dfbe 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -39,40 +39,16 @@ config.parse_flags_with_absl() def setUpModule(): - global old_lowering_flag - jax.experimental.maps.make_xmap_callable.cache_clear() - old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING - jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True + jtu.set_spmd_lowering_flag(True) def tearDownModule(): - jax.experimental.maps.make_xmap_callable.cache_clear() - jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = old_lowering_flag - - -# TODO(skye): move into test_util and dedup with xmap_test.py -MeshSpec = List[Tuple[str, int]] - -@contextmanager -def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]: - """Test utility for setting up meshes given mesh data from `schedules`.""" - # This is similar to the `with_mesh` function above, but isn't a decorator. - axis_names, shape = unzip2(named_shape) - size = prod(shape) - local_devices = list(jax.local_devices()) - if len(local_devices) < size: - raise SkipTest(f"Test requires {size} local devices") - mesh_devices = np.array(local_devices[:size]).reshape(shape) - with mesh(mesh_devices, axis_names): - yield - -def with_mesh_from_kwargs(f): - return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs) + jtu.restore_spmd_lowering_flag() # TODO(skye): make the buffer donation utils part of JaxTestCase class PJitTest(jtu.BufferDonationTestCase): - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testBasic1D(self): @partial(pjit, in_axis_resources=(P('x'), P('x')), @@ -90,7 +66,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertAllClose(actual.device_buffers[0].to_py(), expected, check_dtypes=False) - @with_mesh([('x', 2), ('y', 2)]) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testBasic2D(self): @partial(pjit, in_axis_resources=(P(None, 'x', 'y'), P('y')), @@ -118,7 +94,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertAllClose(actual.device_buffers[3].to_py(), split1, check_dtypes=False) - @with_mesh([('x', 2), ('y', 2)]) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testTwoMeshAxisSharding(self): @partial(pjit, in_axis_resources=P(('x', 'y'),), @@ -144,7 +120,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertAllClose(actual.device_buffers[3].to_py(), splits[3], check_dtypes=False) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testBufferDonation(self): @partial(pjit, in_axis_resources=P('x'), @@ -162,7 +138,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertNotDeleted(y) self.assertDeleted(x) - @with_mesh([('x', 2), ('y', 1)]) + @jtu.with_mesh([('x', 2), ('y', 1)]) def testShardingConstraint(self): @partial(pjit, in_axis_resources=None, out_axis_resources=None) def f(x): @@ -186,7 +162,7 @@ class PJitTest(jtu.BufferDonationTestCase): # Annotation from pjit self.assertIn("sharding={replicated}", hlo.as_hlo_text()) - @with_mesh([('x', 2), ('y', 1)]) + @jtu.with_mesh([('x', 2), ('y', 1)]) def testShardingConstraintPyTree(self): @partial(pjit, in_axis_resources=None, out_axis_resources=None) def f(x): @@ -232,7 +208,7 @@ class PJitTest(jtu.BufferDonationTestCase): should_be_tracing = False pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x) - @with_mesh([('x', 2), ('y', 1)]) + @jtu.with_mesh([('x', 2), ('y', 1)]) def testNested(self): # Add a constant captured by the nested pjit to make things more complicated h = jnp.arange(4) @@ -243,7 +219,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertAllClose(y, jnp.sin(x).sum() + h.sum()) self.assertTrue(hasattr(y, "sharding_spec")) - @with_mesh([('x', 2), ('y', 1)]) + @jtu.with_mesh([('x', 2), ('y', 1)]) def testJVP(self): # Add a constant captured by the nested pjit to make things more complicated h = jnp.arange(4) @@ -252,7 +228,7 @@ class PJitTest(jtu.BufferDonationTestCase): jtu.check_grads(g, (jnp.arange(16, dtype=jnp.float32).reshape((4, 4)),), order=2, modes=["fwd"], eps=1) - @with_mesh([('x', 2), ('y', 1)]) + @jtu.with_mesh([('x', 2), ('y', 1)]) def testEvalJaxpr(self): x, y = jnp.arange(4), jnp.arange(5) f = pjit(lambda x, y: x.sum() + jnp.sin(y), @@ -263,13 +239,13 @@ class PJitTest(jtu.BufferDonationTestCase): r, = f_eval(x, y) self.assertAllClose(r, x.sum() + jnp.sin(y)) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testNonArrayArg(self): self.assertEqual(pjit(lambda x: x + 2, in_axis_resources=None, out_axis_resources=None)(1), 3) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testNonHashableAxisResources(self): x = jnp.arange(4) y = pjit(lambda x: {'b': x['a'] + 2}, @@ -277,7 +253,7 @@ class PJitTest(jtu.BufferDonationTestCase): out_axis_resources={'b': P('x')})({'a': x}) self.assertAllClose(y, {'b': x + 2}) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testGradOfConstraint(self): # Make sure that we can compute grads through sharding constraints h = lambda x: jnp.sin(with_sharding_constraint(x, P('x'))).sum() @@ -286,7 +262,7 @@ class PJitTest(jtu.BufferDonationTestCase): x = jnp.arange(8, dtype=jnp.float32) self.assertAllClose(f(x), jnp.cos(x)) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testNoopPartitionSpecs(self): noops = [P(), P(None), P(()), P((), None), P(None, None, ())] x = jnp.arange(8).reshape((2, 2, 2)) @@ -294,7 +270,7 @@ class PJitTest(jtu.BufferDonationTestCase): y = pjit(lambda x: x * 2, in_axis_resources=spec, out_axis_resources=spec)(x) self.assertAllClose(y, x * 2) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testVmapModifiesAxisResources(self): h = pjit(lambda x, y: (x + y, x, y), in_axis_resources=P('x'), out_axis_resources=None) x = jnp.arange(4) @@ -310,7 +286,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertEqual(y_sync, SpecSync.IN_SYNC) self.assertEqual(z_sync, SpecSync.DIM_PERMUTE) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testVMap(self): f = pjit(lambda x, y: (x + y, x), in_axis_resources=P('x'), out_axis_resources=P('x')) x = jnp.arange(4) @@ -402,7 +378,7 @@ def check_1d_2d_mesh(f, set_mesh): ("2", (("x", 2),), "x"), ("2x1", (("x", 2), ("y", 1)), ("x", "y")), ("2x2", (("x", 2), ("y", 2)), ("x", "y")), - ))(with_mesh_from_kwargs(f) if set_mesh else f) + ))(jtu.with_mesh_from_kwargs(f) if set_mesh else f) def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") @@ -444,7 +420,7 @@ class PJitErrorTest(jtu.JaxTestCase): in_axis_resources=None, out_axis_resources=None)(x) @check_1d_2d_mesh(set_mesh=False) - @with_mesh([('z', 1)]) + @jtu.with_mesh([('z', 1)]) def testUndefinedResourcesArgs(self, mesh, resources): x = jnp.ones((2, 2)) spec = P(resources,) @@ -454,7 +430,7 @@ class PJitErrorTest(jtu.JaxTestCase): pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x) @check_1d_2d_mesh(set_mesh=False) - @with_mesh([('z', 1)]) + @jtu.with_mesh([('z', 1)]) def testUndefinedResourcesOuts(self, mesh, resources): x = jnp.ones((2, 2)) spec = P(resources,) @@ -464,7 +440,7 @@ class PJitErrorTest(jtu.JaxTestCase): pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x) @check_1d_2d_mesh(set_mesh=False) - @with_mesh([('z', 1)]) + @jtu.with_mesh([('z', 1)]) def testUndefinedResourcesConstraint(self, mesh, resources): x = jnp.ones((2, 2)) spec = P(resources,) @@ -475,7 +451,7 @@ class PJitErrorTest(jtu.JaxTestCase): pjit(lambda x: with_sharding_constraint(x, spec), in_axis_resources=None, out_axis_resources=None)(x) - @with_mesh([('x', 2), ('y', 1)]) + @jtu.with_mesh([('x', 2), ('y', 1)]) def testRankTooLowArgs(self): x = jnp.arange(2) spec = P('x', 'y') @@ -484,7 +460,7 @@ class PJitErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, error): pjit(lambda x: x.sum(), in_axis_resources=spec, out_axis_resources=None)(x) - @with_mesh([('x', 2), ('y', 1)]) + @jtu.with_mesh([('x', 2), ('y', 1)]) def testRankTooLowOuts(self): x = jnp.arange(2) spec = P('x', 'y') @@ -493,7 +469,7 @@ class PJitErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, error): pjit(lambda x: x.sum(), in_axis_resources=None, out_axis_resources=spec)(x) - @with_mesh([('x', 2), ('y', 1)]) + @jtu.with_mesh([('x', 2), ('y', 1)]) def testRankTooLowConstraint(self): x = jnp.arange(2) spec = P('x', 'y') @@ -504,7 +480,7 @@ class PJitErrorTest(jtu.JaxTestCase): pjit(lambda x: with_sharding_constraint(x, spec), in_axis_resources=None, out_axis_resources=None)(x) - @with_mesh([('x', 2), ('y', 1)]) + @jtu.with_mesh([('x', 2), ('y', 1)]) def testRepeatedInResources(self): x = jnp.arange(2) for spec in [P('x', 'x'), P('x', ('y', 'x'))]: @@ -514,7 +490,7 @@ class PJitErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, error): pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x) - @with_mesh([('x', 2), ('y', 1)]) + @jtu.with_mesh([('x', 2), ('y', 1)]) def testRepeatedOutResources(self): x = jnp.arange(2) for spec in [P('x', 'x'), P('x', ('y', 'x'))]: @@ -524,7 +500,7 @@ class PJitErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, error): pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testInputShardsXMapAxis(self): spec = P('x') f = xmap(pjit(lambda x: x + 2, in_axis_resources=spec, out_axis_resources=None), @@ -537,7 +513,7 @@ class PJitErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(JAXTypeError, error): f(x) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testOutputShardsXMapAxis(self): spec = P('x') f = xmap(pjit(lambda x: x + 2, in_axis_resources=None, out_axis_resources=spec), @@ -550,7 +526,7 @@ class PJitErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(JAXTypeError, error): f(x) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testConstraintShardsXMapAxis(self): spec = P('x') f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec), @@ -563,7 +539,7 @@ class PJitErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(JAXTypeError, error): f(x) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testCatchesInnerXMapErrors(self): f = pjit(xmap(lambda x, y: x, in_axes=(['i'], ['j']), out_axes=['i', 'j'], axis_resources={'i': 'x', 'j': 'x'}), diff --git a/tests/xmap_test.py b/tests/xmap_test.py index bc8224ca8..19cd309a4 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -71,34 +71,6 @@ def tearDownModule(): os.environ["XLA_FLAGS"] = prev_xla_flags xla_bridge.get_backend.cache_clear() -# -------------------- Mesh parametrization helpers -------------------- - -MeshSpec = List[Tuple[str, int]] - -@contextmanager -def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]: - """Test utility for setting up meshes given mesh data from `schedules`.""" - # This is similar to the `with_mesh` function above, but isn't a decorator. - axis_names, shape = unzip2(named_shape) - size = prod(shape) - local_devices = list(jax.local_devices()) - if len(local_devices) < size: - raise SkipTest(f"Test requires {size} local devices") - mesh_devices = np.array(local_devices[:size]).reshape(shape) - with mesh(mesh_devices, axis_names): - yield - -def with_mesh_from_kwargs(f): - return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs) - -def with_and_without_mesh(f): - return parameterized.named_parameters( - {"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources} - for name, mesh, axis_resources in ( - ('', (), ()), - ('Mesh', (('x', 2),), (('i', 'x'),)) - ))(with_mesh_from_kwargs(f)) - # -------------------- Itertools helpers -------------------- @@ -135,7 +107,7 @@ def ensure_bdim(x, axis_name, bdim): AxisResources = Dict[str, Union[str, Tuple[str, ...]]] def schedules(sizes: Dict[str, int] - ) -> Generator[Tuple[AxisResources, MeshSpec], None, None]: + ) -> Generator[Tuple[AxisResources, jtu.MeshSpec], None, None]: """Test utility generating xmap parallel schedules from logical names & sizes. Args: @@ -275,7 +247,7 @@ class XMapTest(XMapTestCase): self.assertAllClose(c, a * 2) self.assertAllClose(d, b * 4) - @with_mesh([('x', 2), ('y', 2)]) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testCollectiveReduce(self): fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4), in_axes=[['a', 'b', ...], {0: 'c'}], @@ -289,7 +261,7 @@ class XMapTest(XMapTestCase): self.assertAllClose(c, (a * 2).sum(0)) self.assertAllClose(d, b * 4) - @with_mesh([('x', 2), ('y', 2)]) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testCollectivePermute2D(self): perm = np.array([3, 1, 2, 0]) x = jnp.arange(4).reshape((2, 2)) @@ -313,7 +285,7 @@ class XMapTest(XMapTestCase): in_axes=['i', ...], out_axes=['i', ...])(x) self.assertAllClose(result, x + x[jnp.newaxis].T) - @with_mesh([('x', 2), ('y', 2)]) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testOneLogicalTwoMeshAxesBasic(self): def f(v): return lax.psum(v * 2, 'a'), v * 4 @@ -325,7 +297,7 @@ class XMapTest(XMapTestCase): self.assertAllClose(ans, (v * 2).sum(0)) self.assertAllClose(ans2, v.T * 4) - @with_mesh([('x', 2), ('y', 2)]) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testOneLogicalTwoMeshAxesSharding(self): def f(v): return v * 4 @@ -346,7 +318,7 @@ class XMapTest(XMapTestCase): pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))), (pxla.ShardedAxis(1), pxla.ShardedAxis(0)))) - @with_mesh([('x', 2), ('y', 2)]) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testSkipFirstMeshDim(self): def run(axis_resources): return xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...], @@ -379,7 +351,7 @@ class XMapTest(XMapTestCase): ('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))), ('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))), )) - @with_mesh_from_kwargs + @jtu.with_mesh_from_kwargs def testNestedMesh(self, mesh, axis_resources): @partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}), axis_resources=dict([axis_resources[0]])) @@ -405,7 +377,7 @@ class XMapTest(XMapTestCase): # Make sure that there are non-partial sharding specs in the HLO self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}") - @with_and_without_mesh + @jtu.with_and_without_mesh def testMultipleCalls(self, mesh, axis_resources): def f(x, y): assert x.shape == y.shape == (3, 5) @@ -420,7 +392,7 @@ class XMapTest(XMapTestCase): for i in range(10): self.assertAllClose(f_mapped(x, x), expected) - @with_and_without_mesh + @jtu.with_and_without_mesh @jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU. def testBufferDonation(self, mesh, axis_resources): shard = lambda x: x @@ -443,7 +415,7 @@ class XMapTest(XMapTestCase): xmap(lambda x: lax.fori_loop(0, 10, lambda _, x: lax.psum(x, 'i'), x), in_axes=['i', ...], out_axes=['i', ...])(x) - @with_and_without_mesh + @jtu.with_and_without_mesh def testAxisSizes(self, mesh, axis_resources): result = xmap(lambda: lax.axis_index('i'), in_axes=(), out_axes=['i', ...], @@ -551,7 +523,7 @@ class XMapTest(XMapTestCase): y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100 jtu.check_grads(f, (x, y), order=2, modes=['fwd']) - @with_and_without_mesh + @jtu.with_and_without_mesh def testNamedShape(self, mesh, axis_resources): x = np.arange(4,) y = 2 @@ -563,7 +535,7 @@ class XMapTest(XMapTestCase): self.assertEqual(z.aval.named_shape, {}) self.assertEqual(w.aval.named_shape, {}) - @with_and_without_mesh + @jtu.with_and_without_mesh def testBroadcast(self, mesh, axis_resources): x = jnp.asarray(2.0) f = xmap(lambda x: x, in_axes={}, out_axes=['i'], @@ -590,7 +562,7 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest): raise SkipTest super().setUp() - @with_mesh([('x', 2), ('y', 2), ('z', 2)]) + @jtu.with_mesh([('x', 2), ('y', 2), ('z', 2)]) def testNestedMeshSPMD(self): h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))), in_axes={0: 'c'}, out_axes=({1: 'c'}, {}), @@ -670,7 +642,7 @@ class NamedRandomTest(XMapTestCase): ((f"_mesh={mesh}_resources={sorted(axis_resources.items())}", {"axis_resources": tuple(axis_resources.items()), "mesh": tuple(mesh)}) for axis_resources, mesh in schedules({'i': 4, 'j': 6})), subset=True) - @with_mesh_from_kwargs + @jtu.with_mesh_from_kwargs def testSamplerResourceIndependence(self, distr_sample, axis_resources, mesh): def sample(axis_resources): return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)), @@ -743,7 +715,7 @@ class NewPrimitiveTest(XMapTestCase): x_explode = x.reshape((3, 3, 3)) self.assertAllClose(pgather(x, idx, 0), pgather(x_explode, idx, (0, 1))) - @with_and_without_mesh + @jtu.with_and_without_mesh def testGather(self, mesh, axis_resources): if axis_resources and not jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING: raise SkipTest("pgather over mesh axes without SPMD lowering not implemented") @@ -851,7 +823,7 @@ def gen_axis_names(): def schedules_from_pdot_spec( spec: PdotTestSpec, lhs_shape: Tuple[int], rhs_shape: Tuple[int] - ) -> Generator[Tuple[AxisResources, MeshSpec], None, None]: + ) -> Generator[Tuple[AxisResources, jtu.MeshSpec], None, None]: logical_sizes = { name: shape[ax] for shape, in_axes in [(lhs_shape, spec.lhs_in_axes), @@ -862,7 +834,7 @@ def schedules_from_pdot_spec( class PDotTests(XMapTestCase): - @with_mesh([('r1', 2)]) + @jtu.with_mesh([('r1', 2)]) def testPdotBasic(self): def f(x, y): return lax.pdot(x, y, 'i') @@ -880,7 +852,7 @@ class PDotTests(XMapTestCase): self.assertAllClose(z, jnp.dot(x, y)) - @with_mesh([('r1', 2)]) + @jtu.with_mesh([('r1', 2)]) def testPdotBatching(self): def f(x, y): return lax.pdot(x, y, 'i') @@ -898,7 +870,7 @@ class PDotTests(XMapTestCase): self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y)) - @with_mesh([('r1', 2)]) + @jtu.with_mesh([('r1', 2)]) def testPdotBatchingShardUncontractedDim(self): def f(x, y): return lax.pdot(x, y, 'i') @@ -945,7 +917,7 @@ class PDotTests(XMapTestCase): out_axes=[*pdot_spec.batch_names, ...], axis_resources=axis_resources) - with with_mesh(mesh_data): + with jtu.with_mesh(mesh_data): result = fun(lhs, rhs) expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums) @@ -989,7 +961,7 @@ class PDotTests(XMapTestCase): out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes), axis_resources=axis_resources) - with with_mesh(mesh_data): + with jtu.with_mesh(mesh_data): lhs_bar, rhs_bar = fun(lhs, rhs, out_bar) tol = 1e-1 if jtu.device_under_test() == "tpu" else None @@ -1062,7 +1034,7 @@ class PDotTests(XMapTestCase): class XMapErrorTest(jtu.JaxTestCase): - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testRepeatedAxisResource(self): def f(v): return v * 4 @@ -1070,7 +1042,7 @@ class XMapErrorTest(jtu.JaxTestCase): fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...], axis_resources={'a': ('x', 'x')}) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testNestedDifferentResources(self): @partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'}) def f(x): @@ -1089,7 +1061,7 @@ class XMapErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, "Failed to infer size of axes: i."): xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...])({}) - @with_mesh([('x', 2), ('y', 2)]) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testAxesNotDivisibleByResources(self): with self.assertRaisesRegex(ValueError, r"Size of axis i \(5\) is not divisible.*" r"\(\('x', 'y'\), 4 in total\)"): @@ -1154,7 +1126,7 @@ class XMapErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(TypeError, error): fm(x, y) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testResourceConflictArgs(self): fm = xmap(lambda x: lax.psum(x, ('a', 'b')), in_axes=['a', 'b'], out_axes=[], @@ -1166,7 +1138,7 @@ class XMapErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(JAXTypeError, error): fm(x) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testResourceConflictInner(self): fm = xmap(lambda x, y: x + y, in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...], @@ -1178,7 +1150,7 @@ class XMapErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(JAXTypeError, error): fm(x, y) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testResourceConflictOut(self): fm = xmap(lambda x, y: x, in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...], @@ -1191,7 +1163,7 @@ class XMapErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(JAXTypeError, error): fm(x, y) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testResourceConflictNestArgs(self): f = xmap(lambda x: x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'}) h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'}) @@ -1202,7 +1174,7 @@ class XMapErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(JAXTypeError, error): h(x) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testResourceConflictNestInner(self): f = xmap(lambda x: lax.axis_index('i') + x, in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'}) @@ -1214,7 +1186,7 @@ class XMapErrorTest(jtu.JaxTestCase): with self.assertRaisesRegex(JAXTypeError, error): h(x) - @with_mesh([('x', 2)]) + @jtu.with_mesh([('x', 2)]) def testResourceConflictNestOut(self): f = xmap(lambda x: x, in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})