Merge pull request #6866 from gnecula:tf_pjit

PiperOrigin-RevId: 376989780
This commit is contained in:
jax authors 2021-06-01 22:50:12 -07:00
commit 8e6101c6a1
8 changed files with 314 additions and 139 deletions

View File

@ -11,6 +11,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.14 (unreleased) ## jax 0.2.14 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master). * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master).
* New features: * New features:
* The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
* Breaking changes: * Breaking changes:

View File

@ -475,8 +475,8 @@ in [savedmodel_test.py](https://github.com/google/jax/blob/master/jax/experiment
### Missing converter features ### Missing converter features
There is currently no support for replicated (e.g. `pmap`) or multi-device There is currently no support for `pmap` or`xmap`, nor for the collective
(e.g. `sharded_jit`) functions. The collective operations are not yet handled. operations. There is support for `sharded_jit` and `pjit`.
### No SavedModel fine-tuning ### No SavedModel fine-tuning

View File

@ -31,6 +31,8 @@ from jax._src.lax import lax
from jax._src.lax import linalg as lax_linalg from jax._src.lax import linalg as lax_linalg
import jax._src.random import jax._src.random
from jax.api_util import flatten_fun from jax.api_util import flatten_fun
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import ad from jax.interpreters import ad
from jax.interpreters import pxla from jax.interpreters import pxla
from jax.interpreters import sharded_jit from jax.interpreters import sharded_jit
@ -397,9 +399,6 @@ def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal) -> Sequence[TfVal]:
return tuple(v for v, _ in out_with_avals) return tuple(v for v, _ in out_with_avals)
### tracer
def _aval_to_tf_shape(aval: core.AbstractValue) -> Tuple[Optional[int], ...]: def _aval_to_tf_shape(aval: core.AbstractValue) -> Tuple[Optional[int], ...]:
"""Generate a TF shape, possibly containing None for polymorphic dimensions.""" """Generate a TF shape, possibly containing None for polymorphic dimensions."""
return tuple( return tuple(
@ -729,8 +728,9 @@ class TensorFlowTrace(core.Trace):
def process_call(self, call_primitive: core.Primitive, f: lu.WrappedFun, def process_call(self, call_primitive: core.Primitive, f: lu.WrappedFun,
tracers: Sequence[TensorFlowTracer], params): tracers: Sequence[TensorFlowTracer], params):
assert call_primitive.multiple_results assert call_primitive.multiple_results
vals: Sequence[TfVal] = [t.val for t in tracers] vals: Sequence[TfVal] = tuple(t.val for t in tracers)
f = _interpret_subtrace(f, self.main, tuple(t.aval for t in tracers)) avals: Sequence[core.AbstractValue] = tuple(t.aval for t in tracers)
f = _interpret_subtrace(f, self.main, avals)
with core.new_sublevel(): with core.new_sublevel():
if call_primitive == core.named_call_p: if call_primitive == core.named_call_p:
with tf.name_scope(_sanitize_scope_name(params["name"])): with tf.name_scope(_sanitize_scope_name(params["name"])):
@ -813,6 +813,8 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
for unexpected in xla.call_translations: # Call primitives are inlined for unexpected in xla.call_translations: # Call primitives are inlined
if unexpected is pjit.pjit_p:
continue
tf_impl[unexpected] = functools.partial(_unexpected_primitive, unexpected) tf_impl[unexpected] = functools.partial(_unexpected_primitive, unexpected)
# Primitives that are not yet implemented must be explicitly declared here. # Primitives that are not yet implemented must be explicitly declared here.
@ -2494,8 +2496,8 @@ def split_to_logical_devices(tensor: TfVal,
Returns: Returns:
an annotated tensor. an annotated tensor.
""" """
# This corresponds to the sharding annotations in # TODO: this is only for sharded_jit. Either remove, or implement in terms
# xla_bridge._sharding_to_proto. # of _shard_values.
if partition_dimensions is None: if partition_dimensions is None:
return xla_sharding.replicate(tensor, use_sharding_op=True) return xla_sharding.replicate(tensor, use_sharding_op=True)
num_partition_splits = np.prod(partition_dimensions) num_partition_splits = np.prod(partition_dimensions)
@ -2504,6 +2506,24 @@ def split_to_logical_devices(tensor: TfVal,
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
def _shard_value(mesh: maps.Mesh,
val: TfVal,
aval: core.AbstractValue,
axis_resources: pjit.ParsedPartitionSpec) -> TfVal:
"""Apply sharding to a TfVal."""
sharding_proto: xla_client.OpSharding = pjit.get_aval_sharding_proto(
aval, axis_resources, mesh)
# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
xla_sharding_proto: xla_data_pb2.OpSharding = (
xla_data_pb2.OpSharding(
type=int(sharding_proto.type),
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
tile_assignment_devices=sharding_proto.tile_assignment_devices,
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim))
return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor(
val, use_sharding_op=True)
def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal], def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal],
in_parts: Sequence[pxla.PartitionsOrReplicated], in_parts: Sequence[pxla.PartitionsOrReplicated],
out_parts_thunk, out_parts_thunk,
@ -2520,12 +2540,51 @@ def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal],
return sharded_vals_out return sharded_vals_out
def _sharding_constraint(arg: TfVal, *, def _sharded_jit_sharding_constraint(arg: TfVal, *,
partitions: pxla.PartitionsOrReplicated): partitions: pxla.PartitionsOrReplicated,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
del _in_avals, _out_aval
return split_to_logical_devices(arg, partitions) return split_to_logical_devices(arg, partitions)
tf_impl[sharded_jit.sharding_constraint_p] = _sharding_constraint tf_impl_with_avals[sharded_jit.sharding_constraint_p] = _sharded_jit_sharding_constraint
def _pjit(*args: TfVal,
jaxpr: core.ClosedJaxpr,
in_axis_resources: Sequence[pjit.ParsedPartitionSpec],
out_axis_resources: Sequence[pjit.ParsedPartitionSpec],
resource_env: maps.ResourceEnv,
donated_invars,
name: str,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray) -> TfVal:
del donated_invars, name
# TODO: add `name` to the name stack
shard_value_for_mesh = functools.partial(_shard_value, resource_env.physical_mesh)
# Apply sharding annotation to the arguments
sharded_args: Sequence[TfVal] = tuple(
util.safe_map(shard_value_for_mesh, args, _in_avals, in_axis_resources))
results = _interpret_jaxpr(jaxpr, *sharded_args)
sharded_results: Sequence[TfVal] = tuple(
util.safe_map(shard_value_for_mesh, results, _out_aval, out_axis_resources))
return tuple(sharded_results)
tf_impl_with_avals[pjit.pjit_p] = _pjit
def _pjit_sharding_constraint(arg: TfVal, *,
axis_resources: pjit.ParsedPartitionSpec,
resource_env: maps.ResourceEnv,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray,
**kwargs) -> TfVal:
return _shard_value(resource_env.physical_mesh, arg, _in_avals[0], axis_resources)
tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint
def _register_checkpoint_pytrees(): def _register_checkpoint_pytrees():

View File

@ -13,9 +13,10 @@
# limitations under the License. # limitations under the License.
"""Tests for the jax2tf conversion of sharded_jit.""" """Tests for the jax2tf conversion of sharded_jit."""
import functools
import logging import logging
import re import re
from typing import Sequence from typing import Any, Sequence
import unittest import unittest
from absl.testing import absltest from absl.testing import absltest
@ -25,6 +26,7 @@ from jax import test_util as jtu
from jax.config import config from jax.config import config
from jax.experimental import jax2tf from jax.experimental import jax2tf
from jax.experimental import pjit
from jax.experimental.jax2tf.tests import tf_test_util from jax.experimental.jax2tf.tests import tf_test_util
from jax.interpreters import sharded_jit from jax.interpreters import sharded_jit
from jax.interpreters.sharded_jit import PartitionSpec as P from jax.interpreters.sharded_jit import PartitionSpec as P
@ -36,6 +38,13 @@ import tensorflow as tf # type: ignore[import]
config.parse_flags_with_absl() config.parse_flags_with_absl()
def setUpModule():
jtu.set_spmd_lowering_flag(True)
def tearDownModule():
jtu.restore_spmd_lowering_flag()
LOG_HLO = True LOG_HLO = True
class ShardedJitHloTest(tf_test_util.JaxToTfTestCase): class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
@ -46,7 +55,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
def _check_sharding_annotations(self, def _check_sharding_annotations(self,
f_jax, f_jax,
args, args: Sequence[Any],
*, *,
expected: Sequence[str], expected: Sequence[str],
expected_opt: Sequence[str], expected_opt: Sequence[str],
@ -57,6 +66,9 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
We currently check the unoptimized HLO against `expected` on CPU and TPU, We currently check the unoptimized HLO against `expected` on CPU and TPU,
and we check the optimized HLO against `expected_opt` on TPU only and and we check the optimized HLO against `expected_opt` on TPU only and
only for JAX. only for JAX.
See `self.AssertShardingAnnotations` for documentation of `expected`
and `expected_opt`.
""" """
if jtu.device_under_test() == "gpu": if jtu.device_under_test() == "gpu":
raise unittest.SkipTest("Sharding HLO tests not useful for GPU") raise unittest.SkipTest("Sharding HLO tests not useful for GPU")
@ -93,20 +105,20 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
experimental_get_compiler_ir(*args)(stage="hlo", experimental_get_compiler_ir(*args)(stage="hlo",
device_name=device_name) device_name=device_name)
if LOG_HLO: if LOG_HLO:
logging.info(f"[{self._testMethodName}] got TF OPT HLO {tf_hlo}") logging.info(f"[{self._testMethodName}] got TF HLO {tf_hlo}")
self.AssertShardingAnnotations("TF before optimizations", tf_hlo, expected) self.AssertShardingAnnotations("TF before optimizations", tf_hlo, expected)
tf_optimized_hlo = tf.function(f_tf, jit_compile=True).\ tf_optimized_hlo = tf.function(f_tf, jit_compile=True).\
experimental_get_compiler_ir(*args)(stage="optimized_hlo", experimental_get_compiler_ir(*args)(stage="optimized_hlo",
device_name=device_name) device_name=device_name)
if LOG_HLO: if LOG_HLO:
logging.info(f"[{self._testMethodName}] XXX got TF OPT HLO " logging.info(f"[{self._testMethodName}] got TF optimized HLO "
f"for {device_name}: {tf_optimized_hlo}") f"for {device_name}: {tf_optimized_hlo}")
def AssertShardingAnnotations(self, what: str, hlo: str, def AssertShardingAnnotations(self, what: str, hlo: str,
expected: Sequence[str]): expected: Sequence[str]):
"""Args: """Args:
what: either 'JAX' or 'TF' what: either 'JAX' or 'TF', used for messages only.
hlo: the text for the HLO module hlo: the text for the HLO module
expected: a sequence of regexps that must occur in the hlo text. Each expected: a sequence of regexps that must occur in the hlo text. Each
regexp must match a line, in order. regexp must match a line, in order.
@ -128,7 +140,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
f"!!! Not found[{next_expected_idx}] {expected[next_expected_idx]}") f"!!! Not found[{next_expected_idx}] {expected[next_expected_idx]}")
raise self.failureException("\n".join(failure_msg)) raise self.failureException("\n".join(failure_msg))
def test_in_out(self): def test_sharded_jit_in_out(self):
"""Test input and output sharding annotations.""" """Test input and output sharding annotations."""
sharded_jax_func = sharded_jit.sharded_jit( sharded_jax_func = sharded_jit.sharded_jit(
jnp.dot, in_parts=(P(1, 2), P(2, 1)), out_parts=P(1, 2)) jnp.dot, in_parts=(P(1, 2), P(2, 1)), out_parts=P(1, 2))
@ -152,7 +164,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
], ],
num_partitions=2) num_partitions=2)
def test_with_sharding_constraint(self): def test_sharded_jit_with_sharding_constraint(self):
"""A sharding constraint in the middle.""" """A sharding constraint in the middle."""
def jax_func(x, y): def jax_func(x, y):
@ -180,7 +192,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
], ],
num_partitions=2) num_partitions=2)
def test_replicated(self): def test_sharded_jit_replicated(self):
"""A replicated input and output.""" """A replicated input and output."""
sharded_jax_func = sharded_jit.sharded_jit( sharded_jax_func = sharded_jit.sharded_jit(
@ -203,6 +215,116 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
], ],
num_partitions=2) num_partitions=2)
@jtu.with_mesh([("x", 2)])
def test_pjit_basic1D(self):
@functools.partial(pjit.pjit,
in_axis_resources=(P("x"), P("x")),
out_axis_resources=None)
def jax_func(x, y):
return x + y
shape = (8, 10)
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
hlo = jax.xla_computation(jax_func)(x, x).as_hlo_text()
print(f"HLO is {hlo}")
print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}")
self._check_sharding_annotations(
jax_func, [x, x],
expected=[
r"f32\[8,10\].*sharding={devices=\[2,1\]", # x and y
r"f32\[8,10\].*sharding={replicated", # output
],
expected_opt=[
r"f32\[4,10\].*sharding={devices=\[2,1\]", # x and y
# TODO: why don't we see "sharding={replicated"
r"f32\[8,10\]", # output
],
num_partitions=2)
@jtu.with_mesh([("x", 2), ("y", 2)])
def test_pjit_basic2D(self):
@functools.partial(pjit.pjit,
in_axis_resources=(P(None, "x", "y"), P("y")),
out_axis_resources=P("x"))
def jax_func(x, y):
return x @ y
x_shape = (8, 6, 4)
y_shape = (4, 2)
x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape)
self._check_sharding_annotations(
jax_func, [x, y],
expected=[
r"f32\[8,6,4\].*sharding={devices=\[1,2,2\]0,1,2,3", # x
r"f32\[4,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate", # y
r"f32\[8,6,2\].*sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate", # output
],
expected_opt=[
# TODO: relax ordering
r"f32\[2,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate", # y
r"f32\[8,3,2\].*sharding={devices=\[1,2,2\]0,1,2,3", # x
# TODO: why we cannot see sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate?
r"bf16\[4,6,2\]", # output
],
num_partitions=4)
@jtu.with_mesh([("x", 2), ("y", 2)])
def test_pjit_TwoMeshAxisSharding(self):
@functools.partial(pjit.pjit,
in_axis_resources=P(("x", "y"),),
out_axis_resources=P(("x", "y"),))
def jax_func(x, y):
return x @ y
x_shape = (24, 8)
y_shape = (8, 2)
x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape)
self._check_sharding_annotations(
jax_func, [x, y],
expected=[
r"f32\[24,8\].*sharding={devices=\[4,1\]0,1,2,3", # x
r"f32\[8,2\].*sharding={devices=\[4,1\]0,1,2,3", # y
r"f32\[24,2\].*sharding={devices=\[4,1\]0,1,2,3", # output
],
expected_opt=[
# TODO: relax ordering
r"f32\[2,2\].*sharding={devices=\[4,1\]0,1,2,3", # y
r"f32\[6,8\].*sharding={devices=\[4,1\]0,1,2,3", # x
# TODO: why we cannot see .*sharding={devices=\[4,1\]0,1,2,3
r"f32\[1,6,2\]", # output
],
num_partitions=4)
@jtu.with_mesh([("x", 2), ("y", 1)])
def test_pjit_ShardingConstraint(self):
@functools.partial(pjit.pjit, in_axis_resources=None,
out_axis_resources=None)
def jax_func(x): # x: f32[12, 8]
y = jnp.tile(x, (2, 1)) # y: f32[24, 8]
y = pjit.with_sharding_constraint(y, P("x", "y"))
return y[0:y.shape[0] // 4] # res: f32[6, 8]
shape = (12, 8)
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
self._check_sharding_annotations(
jax_func, [x],
expected=[
r"f32\[12,8\].*sharding={replicated}", # x
r"f32\[24,8\].*sharding={devices=\[2,1\]0,1", # y
r"f32\[6,8\].*sharding={replicated}", # output
],
expected_opt=[
r"f32\[12,8\].*sharding={replicated}", # x
# TODO: why can't we see "sharding={devices=\[2,1\]0,1"
r"f32\[1,12,8\]", # y
# TODO: why can't we see "sharding={replicated}" ?
r"f32\[6,8\]", # output
],
num_partitions=2)
if __name__ == "__main__": if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader()) absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -590,13 +590,20 @@ def get_array_mapping(axis_resources: ParsedPartitionSpec) -> pxla.ArrayMapping:
for i, axes in enumerate(axis_resources) for i, axes in enumerate(axis_resources)
for axis in axes) for axis in axes)
def get_sharding_proto(c, xla_op, axis_resources, mesh): def get_sharding_proto(c, xla_op, axis_resources: ParsedPartitionSpec,
mesh: maps.Mesh) -> xc.OpSharding:
xla_shape = c.GetShape(xla_op) xla_shape = c.GetShape(xla_op)
if xla_shape.is_token(): if xla_shape.is_token():
aval = core.abstract_token aval = core.abstract_token
assert axis_resources is REPLICATED assert axis_resources is REPLICATED
else: else:
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.element_type()) aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.element_type())
return get_aval_sharding_proto(aval, axis_resources, mesh)
def get_aval_sharding_proto(aval: core.AbstractValue,
axis_resources: ParsedPartitionSpec,
mesh: maps.Mesh) -> xc.OpSharding:
array_mapping = get_array_mapping(axis_resources) array_mapping = get_array_mapping(axis_resources)
sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)( sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(
aval, array_mapping) aval, array_mapping)

View File

@ -17,7 +17,7 @@ import functools
import re import re
import os import os
import textwrap import textwrap
from typing import Dict, Sequence, Union from typing import Dict, List, Generator, Sequence, Tuple, Union
import unittest import unittest
import warnings import warnings
import zlib import zlib
@ -33,10 +33,12 @@ from . import core
from ._src import dtypes as _dtypes from ._src import dtypes as _dtypes
from . import lax from . import lax
from ._src.config import flags, bool_env, config from ._src.config import flags, bool_env, config
from ._src.util import partial, prod from ._src.util import partial, prod, unzip2
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
from .lib import xla_bridge from .lib import xla_bridge
from .interpreters import xla from .interpreters import xla
from .experimental import maps
from .experimental.maps import mesh
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
@ -1012,6 +1014,44 @@ def ignore_warning(**kw):
warnings.filterwarnings("ignore", **kw) warnings.filterwarnings("ignore", **kw)
yield yield
# -------------------- Mesh parametrization helpers --------------------
MeshSpec = List[Tuple[str, int]]
@contextmanager
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
"""Test utility for setting up meshes given mesh data from `schedules`."""
# This is similar to the `with_mesh` function above, but isn't a decorator.
axis_names, shape = unzip2(named_shape)
size = prod(shape)
local_devices = list(api.local_devices())
if len(local_devices) < size:
raise unittest.SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape)
with mesh(mesh_devices, axis_names):
yield
def with_mesh_from_kwargs(f):
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
def with_and_without_mesh(f):
return parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
for name, mesh, axis_resources in (
('', (), ()),
('Mesh', (('x', 2),), (('i', 'x'),))
))(with_mesh_from_kwargs(f))
old_spmd_lowering_flag = False
def set_spmd_lowering_flag(val: bool):
global old_spmd_lowering_flag
maps.make_xmap_callable.cache_clear()
old_spmd_lowering_flag = maps.EXPERIMENTAL_SPMD_LOWERING
maps.EXPERIMENTAL_SPMD_LOWERING = val
def restore_spmd_lowering_flag():
maps.make_xmap_callable.cache_clear()
maps.EXPERIMENTAL_SPMD_LOWERING = old_spmd_lowering_flag
class _cached_property: class _cached_property:
null = object() null = object()

View File

@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import contextmanager
from functools import partial from functools import partial
import logging import logging
from typing import Generator, List, Tuple
from unittest import SkipTest from unittest import SkipTest
from absl.testing import absltest from absl.testing import absltest
@ -32,47 +30,23 @@ from jax.experimental import PartitionSpec as P
from jax.experimental.maps import xmap, mesh from jax.experimental.maps import xmap, mesh
from jax.experimental.pjit import pjit, pjit_p, with_sharding_constraint, SpecSync from jax.experimental.pjit import pjit, pjit_p, with_sharding_constraint, SpecSync
from jax.interpreters import pxla from jax.interpreters import pxla
from jax._src.util import unzip2, prod, curry from jax._src.util import prod, curry
from jax.config import config from jax.config import config
config.parse_flags_with_absl() config.parse_flags_with_absl()
def setUpModule(): def setUpModule():
global old_lowering_flag jtu.set_spmd_lowering_flag(True)
jax.experimental.maps.make_xmap_callable.cache_clear()
old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
def tearDownModule(): def tearDownModule():
jax.experimental.maps.make_xmap_callable.cache_clear() jtu.restore_spmd_lowering_flag()
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = old_lowering_flag
# TODO(skye): move into test_util and dedup with xmap_test.py
MeshSpec = List[Tuple[str, int]]
@contextmanager
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
"""Test utility for setting up meshes given mesh data from `schedules`."""
# This is similar to the `with_mesh` function above, but isn't a decorator.
axis_names, shape = unzip2(named_shape)
size = prod(shape)
local_devices = list(jax.local_devices())
if len(local_devices) < size:
raise SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape)
with mesh(mesh_devices, axis_names):
yield
def with_mesh_from_kwargs(f):
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
# TODO(skye): make the buffer donation utils part of JaxTestCase # TODO(skye): make the buffer donation utils part of JaxTestCase
class PJitTest(jtu.BufferDonationTestCase): class PJitTest(jtu.BufferDonationTestCase):
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testBasic1D(self): def testBasic1D(self):
@partial(pjit, @partial(pjit,
in_axis_resources=(P('x'), P('x')), in_axis_resources=(P('x'), P('x')),
@ -90,7 +64,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual.device_buffers[0].to_py(), expected, self.assertAllClose(actual.device_buffers[0].to_py(), expected,
check_dtypes=False) check_dtypes=False)
@with_mesh([('x', 2), ('y', 2)]) @jtu.with_mesh([('x', 2), ('y', 2)])
def testBasic2D(self): def testBasic2D(self):
@partial(pjit, @partial(pjit,
in_axis_resources=(P(None, 'x', 'y'), P('y')), in_axis_resources=(P(None, 'x', 'y'), P('y')),
@ -118,7 +92,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual.device_buffers[3].to_py(), split1, self.assertAllClose(actual.device_buffers[3].to_py(), split1,
check_dtypes=False) check_dtypes=False)
@with_mesh([('x', 2), ('y', 2)]) @jtu.with_mesh([('x', 2), ('y', 2)])
def testTwoMeshAxisSharding(self): def testTwoMeshAxisSharding(self):
@partial(pjit, @partial(pjit,
in_axis_resources=P(('x', 'y'),), in_axis_resources=P(('x', 'y'),),
@ -144,7 +118,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual.device_buffers[3].to_py(), splits[3], self.assertAllClose(actual.device_buffers[3].to_py(), splits[3],
check_dtypes=False) check_dtypes=False)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testBufferDonation(self): def testBufferDonation(self):
@partial(pjit, @partial(pjit,
in_axis_resources=P('x'), in_axis_resources=P('x'),
@ -162,7 +136,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertNotDeleted(y) self.assertNotDeleted(y)
self.assertDeleted(x) self.assertDeleted(x)
@with_mesh([('x', 2), ('y', 1)]) @jtu.with_mesh([('x', 2), ('y', 1)])
def testShardingConstraint(self): def testShardingConstraint(self):
@partial(pjit, in_axis_resources=None, out_axis_resources=None) @partial(pjit, in_axis_resources=None, out_axis_resources=None)
def f(x): def f(x):
@ -186,7 +160,7 @@ class PJitTest(jtu.BufferDonationTestCase):
# Annotation from pjit # Annotation from pjit
self.assertIn("sharding={replicated}", hlo.as_hlo_text()) self.assertIn("sharding={replicated}", hlo.as_hlo_text())
@with_mesh([('x', 2), ('y', 1)]) @jtu.with_mesh([('x', 2), ('y', 1)])
def testShardingConstraintPyTree(self): def testShardingConstraintPyTree(self):
@partial(pjit, in_axis_resources=None, out_axis_resources=None) @partial(pjit, in_axis_resources=None, out_axis_resources=None)
def f(x): def f(x):
@ -232,7 +206,7 @@ class PJitTest(jtu.BufferDonationTestCase):
should_be_tracing = False should_be_tracing = False
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x) pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
@with_mesh([('x', 2), ('y', 1)]) @jtu.with_mesh([('x', 2), ('y', 1)])
def testNested(self): def testNested(self):
# Add a constant captured by the nested pjit to make things more complicated # Add a constant captured by the nested pjit to make things more complicated
h = jnp.arange(4) h = jnp.arange(4)
@ -243,7 +217,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(y, jnp.sin(x).sum() + h.sum()) self.assertAllClose(y, jnp.sin(x).sum() + h.sum())
self.assertTrue(hasattr(y, "sharding_spec")) self.assertTrue(hasattr(y, "sharding_spec"))
@with_mesh([('x', 2), ('y', 1)]) @jtu.with_mesh([('x', 2), ('y', 1)])
def testJVP(self): def testJVP(self):
# Add a constant captured by the nested pjit to make things more complicated # Add a constant captured by the nested pjit to make things more complicated
h = jnp.arange(4) h = jnp.arange(4)
@ -252,7 +226,7 @@ class PJitTest(jtu.BufferDonationTestCase):
jtu.check_grads(g, (jnp.arange(16, dtype=jnp.float32).reshape((4, 4)),), jtu.check_grads(g, (jnp.arange(16, dtype=jnp.float32).reshape((4, 4)),),
order=2, modes=["fwd"], eps=1) order=2, modes=["fwd"], eps=1)
@with_mesh([('x', 2), ('y', 1)]) @jtu.with_mesh([('x', 2), ('y', 1)])
def testEvalJaxpr(self): def testEvalJaxpr(self):
x, y = jnp.arange(4), jnp.arange(5) x, y = jnp.arange(4), jnp.arange(5)
f = pjit(lambda x, y: x.sum() + jnp.sin(y), f = pjit(lambda x, y: x.sum() + jnp.sin(y),
@ -263,13 +237,13 @@ class PJitTest(jtu.BufferDonationTestCase):
r, = f_eval(x, y) r, = f_eval(x, y)
self.assertAllClose(r, x.sum() + jnp.sin(y)) self.assertAllClose(r, x.sum() + jnp.sin(y))
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testNonArrayArg(self): def testNonArrayArg(self):
self.assertEqual(pjit(lambda x: x + 2, self.assertEqual(pjit(lambda x: x + 2,
in_axis_resources=None, in_axis_resources=None,
out_axis_resources=None)(1), 3) out_axis_resources=None)(1), 3)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testNonHashableAxisResources(self): def testNonHashableAxisResources(self):
x = jnp.arange(4) x = jnp.arange(4)
y = pjit(lambda x: {'b': x['a'] + 2}, y = pjit(lambda x: {'b': x['a'] + 2},
@ -277,7 +251,7 @@ class PJitTest(jtu.BufferDonationTestCase):
out_axis_resources={'b': P('x')})({'a': x}) out_axis_resources={'b': P('x')})({'a': x})
self.assertAllClose(y, {'b': x + 2}) self.assertAllClose(y, {'b': x + 2})
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testGradOfConstraint(self): def testGradOfConstraint(self):
# Make sure that we can compute grads through sharding constraints # Make sure that we can compute grads through sharding constraints
h = lambda x: jnp.sin(with_sharding_constraint(x, P('x'))).sum() h = lambda x: jnp.sin(with_sharding_constraint(x, P('x'))).sum()
@ -286,7 +260,7 @@ class PJitTest(jtu.BufferDonationTestCase):
x = jnp.arange(8, dtype=jnp.float32) x = jnp.arange(8, dtype=jnp.float32)
self.assertAllClose(f(x), jnp.cos(x)) self.assertAllClose(f(x), jnp.cos(x))
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testNoopPartitionSpecs(self): def testNoopPartitionSpecs(self):
noops = [P(), P(None), P(()), P((), None), P(None, None, ())] noops = [P(), P(None), P(()), P((), None), P(None, None, ())]
x = jnp.arange(8).reshape((2, 2, 2)) x = jnp.arange(8).reshape((2, 2, 2))
@ -294,7 +268,7 @@ class PJitTest(jtu.BufferDonationTestCase):
y = pjit(lambda x: x * 2, in_axis_resources=spec, out_axis_resources=spec)(x) y = pjit(lambda x: x * 2, in_axis_resources=spec, out_axis_resources=spec)(x)
self.assertAllClose(y, x * 2) self.assertAllClose(y, x * 2)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testVmapModifiesAxisResources(self): def testVmapModifiesAxisResources(self):
h = pjit(lambda x, y: (x + y, x, y), in_axis_resources=P('x'), out_axis_resources=None) h = pjit(lambda x, y: (x + y, x, y), in_axis_resources=P('x'), out_axis_resources=None)
x = jnp.arange(4) x = jnp.arange(4)
@ -310,7 +284,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertEqual(y_sync, SpecSync.IN_SYNC) self.assertEqual(y_sync, SpecSync.IN_SYNC)
self.assertEqual(z_sync, SpecSync.DIM_PERMUTE) self.assertEqual(z_sync, SpecSync.DIM_PERMUTE)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testVMap(self): def testVMap(self):
f = pjit(lambda x, y: (x + y, x), in_axis_resources=P('x'), out_axis_resources=P('x')) f = pjit(lambda x, y: (x + y, x), in_axis_resources=P('x'), out_axis_resources=P('x'))
x = jnp.arange(4) x = jnp.arange(4)
@ -402,7 +376,7 @@ def check_1d_2d_mesh(f, set_mesh):
("2", (("x", 2),), "x"), ("2", (("x", 2),), "x"),
("2x1", (("x", 2), ("y", 1)), ("x", "y")), ("2x1", (("x", 2), ("y", 1)), ("x", "y")),
("2x2", (("x", 2), ("y", 2)), ("x", "y")), ("2x2", (("x", 2), ("y", 2)), ("x", "y")),
))(with_mesh_from_kwargs(f) if set_mesh else f) ))(jtu.with_mesh_from_kwargs(f) if set_mesh else f)
def spec_regex(s): def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)") return str(s).replace(r"(", r"\(").replace(r")", r"\)")
@ -444,7 +418,7 @@ class PJitErrorTest(jtu.JaxTestCase):
in_axis_resources=None, out_axis_resources=None)(x) in_axis_resources=None, out_axis_resources=None)(x)
@check_1d_2d_mesh(set_mesh=False) @check_1d_2d_mesh(set_mesh=False)
@with_mesh([('z', 1)]) @jtu.with_mesh([('z', 1)])
def testUndefinedResourcesArgs(self, mesh, resources): def testUndefinedResourcesArgs(self, mesh, resources):
x = jnp.ones((2, 2)) x = jnp.ones((2, 2))
spec = P(resources,) spec = P(resources,)
@ -454,7 +428,7 @@ class PJitErrorTest(jtu.JaxTestCase):
pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x) pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x)
@check_1d_2d_mesh(set_mesh=False) @check_1d_2d_mesh(set_mesh=False)
@with_mesh([('z', 1)]) @jtu.with_mesh([('z', 1)])
def testUndefinedResourcesOuts(self, mesh, resources): def testUndefinedResourcesOuts(self, mesh, resources):
x = jnp.ones((2, 2)) x = jnp.ones((2, 2))
spec = P(resources,) spec = P(resources,)
@ -464,7 +438,7 @@ class PJitErrorTest(jtu.JaxTestCase):
pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x) pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
@check_1d_2d_mesh(set_mesh=False) @check_1d_2d_mesh(set_mesh=False)
@with_mesh([('z', 1)]) @jtu.with_mesh([('z', 1)])
def testUndefinedResourcesConstraint(self, mesh, resources): def testUndefinedResourcesConstraint(self, mesh, resources):
x = jnp.ones((2, 2)) x = jnp.ones((2, 2))
spec = P(resources,) spec = P(resources,)
@ -475,7 +449,7 @@ class PJitErrorTest(jtu.JaxTestCase):
pjit(lambda x: with_sharding_constraint(x, spec), pjit(lambda x: with_sharding_constraint(x, spec),
in_axis_resources=None, out_axis_resources=None)(x) in_axis_resources=None, out_axis_resources=None)(x)
@with_mesh([('x', 2), ('y', 1)]) @jtu.with_mesh([('x', 2), ('y', 1)])
def testRankTooLowArgs(self): def testRankTooLowArgs(self):
x = jnp.arange(2) x = jnp.arange(2)
spec = P('x', 'y') spec = P('x', 'y')
@ -484,7 +458,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, error): with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x.sum(), in_axis_resources=spec, out_axis_resources=None)(x) pjit(lambda x: x.sum(), in_axis_resources=spec, out_axis_resources=None)(x)
@with_mesh([('x', 2), ('y', 1)]) @jtu.with_mesh([('x', 2), ('y', 1)])
def testRankTooLowOuts(self): def testRankTooLowOuts(self):
x = jnp.arange(2) x = jnp.arange(2)
spec = P('x', 'y') spec = P('x', 'y')
@ -493,7 +467,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, error): with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x.sum(), in_axis_resources=None, out_axis_resources=spec)(x) pjit(lambda x: x.sum(), in_axis_resources=None, out_axis_resources=spec)(x)
@with_mesh([('x', 2), ('y', 1)]) @jtu.with_mesh([('x', 2), ('y', 1)])
def testRankTooLowConstraint(self): def testRankTooLowConstraint(self):
x = jnp.arange(2) x = jnp.arange(2)
spec = P('x', 'y') spec = P('x', 'y')
@ -504,7 +478,7 @@ class PJitErrorTest(jtu.JaxTestCase):
pjit(lambda x: with_sharding_constraint(x, spec), pjit(lambda x: with_sharding_constraint(x, spec),
in_axis_resources=None, out_axis_resources=None)(x) in_axis_resources=None, out_axis_resources=None)(x)
@with_mesh([('x', 2), ('y', 1)]) @jtu.with_mesh([('x', 2), ('y', 1)])
def testRepeatedInResources(self): def testRepeatedInResources(self):
x = jnp.arange(2) x = jnp.arange(2)
for spec in [P('x', 'x'), P('x', ('y', 'x'))]: for spec in [P('x', 'x'), P('x', ('y', 'x'))]:
@ -514,7 +488,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, error): with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x) pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x)
@with_mesh([('x', 2), ('y', 1)]) @jtu.with_mesh([('x', 2), ('y', 1)])
def testRepeatedOutResources(self): def testRepeatedOutResources(self):
x = jnp.arange(2) x = jnp.arange(2)
for spec in [P('x', 'x'), P('x', ('y', 'x'))]: for spec in [P('x', 'x'), P('x', ('y', 'x'))]:
@ -524,7 +498,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, error): with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x) pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testInputShardsXMapAxis(self): def testInputShardsXMapAxis(self):
spec = P('x') spec = P('x')
f = xmap(pjit(lambda x: x + 2, in_axis_resources=spec, out_axis_resources=None), f = xmap(pjit(lambda x: x + 2, in_axis_resources=spec, out_axis_resources=None),
@ -537,7 +511,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error): with self.assertRaisesRegex(JAXTypeError, error):
f(x) f(x)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testOutputShardsXMapAxis(self): def testOutputShardsXMapAxis(self):
spec = P('x') spec = P('x')
f = xmap(pjit(lambda x: x + 2, in_axis_resources=None, out_axis_resources=spec), f = xmap(pjit(lambda x: x + 2, in_axis_resources=None, out_axis_resources=spec),
@ -550,7 +524,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error): with self.assertRaisesRegex(JAXTypeError, error):
f(x) f(x)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testConstraintShardsXMapAxis(self): def testConstraintShardsXMapAxis(self):
spec = P('x') spec = P('x')
f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec), f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec),
@ -563,7 +537,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error): with self.assertRaisesRegex(JAXTypeError, error):
f(x) f(x)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testCatchesInnerXMapErrors(self): def testCatchesInnerXMapErrors(self):
f = pjit(xmap(lambda x, y: x, in_axes=(['i'], ['j']), out_axes=['i', 'j'], f = pjit(xmap(lambda x, y: x, in_axes=(['i'], ['j']), out_axes=['i', 'j'],
axis_resources={'i': 'x', 'j': 'x'}), axis_resources={'i': 'x', 'j': 'x'}),

View File

@ -71,34 +71,6 @@ def tearDownModule():
os.environ["XLA_FLAGS"] = prev_xla_flags os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear() xla_bridge.get_backend.cache_clear()
# -------------------- Mesh parametrization helpers --------------------
MeshSpec = List[Tuple[str, int]]
@contextmanager
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
"""Test utility for setting up meshes given mesh data from `schedules`."""
# This is similar to the `with_mesh` function above, but isn't a decorator.
axis_names, shape = unzip2(named_shape)
size = prod(shape)
local_devices = list(jax.local_devices())
if len(local_devices) < size:
raise SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape)
with mesh(mesh_devices, axis_names):
yield
def with_mesh_from_kwargs(f):
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
def with_and_without_mesh(f):
return parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
for name, mesh, axis_resources in (
('', (), ()),
('Mesh', (('x', 2),), (('i', 'x'),))
))(with_mesh_from_kwargs(f))
# -------------------- Itertools helpers -------------------- # -------------------- Itertools helpers --------------------
@ -135,7 +107,7 @@ def ensure_bdim(x, axis_name, bdim):
AxisResources = Dict[str, Union[str, Tuple[str, ...]]] AxisResources = Dict[str, Union[str, Tuple[str, ...]]]
def schedules(sizes: Dict[str, int] def schedules(sizes: Dict[str, int]
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]: ) -> Generator[Tuple[AxisResources, jtu.MeshSpec], None, None]:
"""Test utility generating xmap parallel schedules from logical names & sizes. """Test utility generating xmap parallel schedules from logical names & sizes.
Args: Args:
@ -275,7 +247,7 @@ class XMapTest(XMapTestCase):
self.assertAllClose(c, a * 2) self.assertAllClose(c, a * 2)
self.assertAllClose(d, b * 4) self.assertAllClose(d, b * 4)
@with_mesh([('x', 2), ('y', 2)]) @jtu.with_mesh([('x', 2), ('y', 2)])
def testCollectiveReduce(self): def testCollectiveReduce(self):
fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4), fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
in_axes=[['a', 'b', ...], {0: 'c'}], in_axes=[['a', 'b', ...], {0: 'c'}],
@ -289,7 +261,7 @@ class XMapTest(XMapTestCase):
self.assertAllClose(c, (a * 2).sum(0)) self.assertAllClose(c, (a * 2).sum(0))
self.assertAllClose(d, b * 4) self.assertAllClose(d, b * 4)
@with_mesh([('x', 2), ('y', 2)]) @jtu.with_mesh([('x', 2), ('y', 2)])
def testCollectivePermute2D(self): def testCollectivePermute2D(self):
perm = np.array([3, 1, 2, 0]) perm = np.array([3, 1, 2, 0])
x = jnp.arange(4).reshape((2, 2)) x = jnp.arange(4).reshape((2, 2))
@ -313,7 +285,7 @@ class XMapTest(XMapTestCase):
in_axes=['i', ...], out_axes=['i', ...])(x) in_axes=['i', ...], out_axes=['i', ...])(x)
self.assertAllClose(result, x + x[jnp.newaxis].T) self.assertAllClose(result, x + x[jnp.newaxis].T)
@with_mesh([('x', 2), ('y', 2)]) @jtu.with_mesh([('x', 2), ('y', 2)])
def testOneLogicalTwoMeshAxesBasic(self): def testOneLogicalTwoMeshAxesBasic(self):
def f(v): def f(v):
return lax.psum(v * 2, 'a'), v * 4 return lax.psum(v * 2, 'a'), v * 4
@ -325,7 +297,7 @@ class XMapTest(XMapTestCase):
self.assertAllClose(ans, (v * 2).sum(0)) self.assertAllClose(ans, (v * 2).sum(0))
self.assertAllClose(ans2, v.T * 4) self.assertAllClose(ans2, v.T * 4)
@with_mesh([('x', 2), ('y', 2)]) @jtu.with_mesh([('x', 2), ('y', 2)])
def testOneLogicalTwoMeshAxesSharding(self): def testOneLogicalTwoMeshAxesSharding(self):
def f(v): def f(v):
return v * 4 return v * 4
@ -346,7 +318,7 @@ class XMapTest(XMapTestCase):
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))), pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
(pxla.ShardedAxis(1), pxla.ShardedAxis(0)))) (pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
@with_mesh([('x', 2), ('y', 2)]) @jtu.with_mesh([('x', 2), ('y', 2)])
def testSkipFirstMeshDim(self): def testSkipFirstMeshDim(self):
def run(axis_resources): def run(axis_resources):
return xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...], return xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
@ -379,7 +351,7 @@ class XMapTest(XMapTestCase):
('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))), ('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))),
('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))), ('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))),
)) ))
@with_mesh_from_kwargs @jtu.with_mesh_from_kwargs
def testNestedMesh(self, mesh, axis_resources): def testNestedMesh(self, mesh, axis_resources):
@partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}), @partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
axis_resources=dict([axis_resources[0]])) axis_resources=dict([axis_resources[0]]))
@ -405,7 +377,7 @@ class XMapTest(XMapTestCase):
# Make sure that there are non-partial sharding specs in the HLO # Make sure that there are non-partial sharding specs in the HLO
self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}") self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}")
@with_and_without_mesh @jtu.with_and_without_mesh
def testMultipleCalls(self, mesh, axis_resources): def testMultipleCalls(self, mesh, axis_resources):
def f(x, y): def f(x, y):
assert x.shape == y.shape == (3, 5) assert x.shape == y.shape == (3, 5)
@ -420,7 +392,7 @@ class XMapTest(XMapTestCase):
for i in range(10): for i in range(10):
self.assertAllClose(f_mapped(x, x), expected) self.assertAllClose(f_mapped(x, x), expected)
@with_and_without_mesh @jtu.with_and_without_mesh
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU. @jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def testBufferDonation(self, mesh, axis_resources): def testBufferDonation(self, mesh, axis_resources):
shard = lambda x: x shard = lambda x: x
@ -443,7 +415,7 @@ class XMapTest(XMapTestCase):
xmap(lambda x: lax.fori_loop(0, 10, lambda _, x: lax.psum(x, 'i'), x), xmap(lambda x: lax.fori_loop(0, 10, lambda _, x: lax.psum(x, 'i'), x),
in_axes=['i', ...], out_axes=['i', ...])(x) in_axes=['i', ...], out_axes=['i', ...])(x)
@with_and_without_mesh @jtu.with_and_without_mesh
def testAxisSizes(self, mesh, axis_resources): def testAxisSizes(self, mesh, axis_resources):
result = xmap(lambda: lax.axis_index('i'), result = xmap(lambda: lax.axis_index('i'),
in_axes=(), out_axes=['i', ...], in_axes=(), out_axes=['i', ...],
@ -551,7 +523,7 @@ class XMapTest(XMapTestCase):
y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100 y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100
jtu.check_grads(f, (x, y), order=2, modes=['fwd']) jtu.check_grads(f, (x, y), order=2, modes=['fwd'])
@with_and_without_mesh @jtu.with_and_without_mesh
def testNamedShape(self, mesh, axis_resources): def testNamedShape(self, mesh, axis_resources):
x = np.arange(4,) x = np.arange(4,)
y = 2 y = 2
@ -563,7 +535,7 @@ class XMapTest(XMapTestCase):
self.assertEqual(z.aval.named_shape, {}) self.assertEqual(z.aval.named_shape, {})
self.assertEqual(w.aval.named_shape, {}) self.assertEqual(w.aval.named_shape, {})
@with_and_without_mesh @jtu.with_and_without_mesh
def testBroadcast(self, mesh, axis_resources): def testBroadcast(self, mesh, axis_resources):
x = jnp.asarray(2.0) x = jnp.asarray(2.0)
f = xmap(lambda x: x, in_axes={}, out_axes=['i'], f = xmap(lambda x: x, in_axes={}, out_axes=['i'],
@ -590,7 +562,7 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest):
raise SkipTest raise SkipTest
super().setUp() super().setUp()
@with_mesh([('x', 2), ('y', 2), ('z', 2)]) @jtu.with_mesh([('x', 2), ('y', 2), ('z', 2)])
def testNestedMeshSPMD(self): def testNestedMeshSPMD(self):
h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))), h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))),
in_axes={0: 'c'}, out_axes=({1: 'c'}, {}), in_axes={0: 'c'}, out_axes=({1: 'c'}, {}),
@ -670,7 +642,7 @@ class NamedRandomTest(XMapTestCase):
((f"_mesh={mesh}_resources={sorted(axis_resources.items())}", ((f"_mesh={mesh}_resources={sorted(axis_resources.items())}",
{"axis_resources": tuple(axis_resources.items()), "mesh": tuple(mesh)}) {"axis_resources": tuple(axis_resources.items()), "mesh": tuple(mesh)})
for axis_resources, mesh in schedules({'i': 4, 'j': 6})), subset=True) for axis_resources, mesh in schedules({'i': 4, 'j': 6})), subset=True)
@with_mesh_from_kwargs @jtu.with_mesh_from_kwargs
def testSamplerResourceIndependence(self, distr_sample, axis_resources, mesh): def testSamplerResourceIndependence(self, distr_sample, axis_resources, mesh):
def sample(axis_resources): def sample(axis_resources):
return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)), return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)),
@ -743,7 +715,7 @@ class NewPrimitiveTest(XMapTestCase):
x_explode = x.reshape((3, 3, 3)) x_explode = x.reshape((3, 3, 3))
self.assertAllClose(pgather(x, idx, 0), pgather(x_explode, idx, (0, 1))) self.assertAllClose(pgather(x, idx, 0), pgather(x_explode, idx, (0, 1)))
@with_and_without_mesh @jtu.with_and_without_mesh
def testGather(self, mesh, axis_resources): def testGather(self, mesh, axis_resources):
if axis_resources and not jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING: if axis_resources and not jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING:
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented") raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
@ -851,7 +823,7 @@ def gen_axis_names():
def schedules_from_pdot_spec( def schedules_from_pdot_spec(
spec: PdotTestSpec, lhs_shape: Tuple[int], rhs_shape: Tuple[int] spec: PdotTestSpec, lhs_shape: Tuple[int], rhs_shape: Tuple[int]
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]: ) -> Generator[Tuple[AxisResources, jtu.MeshSpec], None, None]:
logical_sizes = { logical_sizes = {
name: shape[ax] name: shape[ax]
for shape, in_axes in [(lhs_shape, spec.lhs_in_axes), for shape, in_axes in [(lhs_shape, spec.lhs_in_axes),
@ -862,7 +834,7 @@ def schedules_from_pdot_spec(
class PDotTests(XMapTestCase): class PDotTests(XMapTestCase):
@with_mesh([('r1', 2)]) @jtu.with_mesh([('r1', 2)])
def testPdotBasic(self): def testPdotBasic(self):
def f(x, y): def f(x, y):
return lax.pdot(x, y, 'i') return lax.pdot(x, y, 'i')
@ -880,7 +852,7 @@ class PDotTests(XMapTestCase):
self.assertAllClose(z, jnp.dot(x, y)) self.assertAllClose(z, jnp.dot(x, y))
@with_mesh([('r1', 2)]) @jtu.with_mesh([('r1', 2)])
def testPdotBatching(self): def testPdotBatching(self):
def f(x, y): def f(x, y):
return lax.pdot(x, y, 'i') return lax.pdot(x, y, 'i')
@ -898,7 +870,7 @@ class PDotTests(XMapTestCase):
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y)) self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
@with_mesh([('r1', 2)]) @jtu.with_mesh([('r1', 2)])
def testPdotBatchingShardUncontractedDim(self): def testPdotBatchingShardUncontractedDim(self):
def f(x, y): def f(x, y):
return lax.pdot(x, y, 'i') return lax.pdot(x, y, 'i')
@ -945,7 +917,7 @@ class PDotTests(XMapTestCase):
out_axes=[*pdot_spec.batch_names, ...], out_axes=[*pdot_spec.batch_names, ...],
axis_resources=axis_resources) axis_resources=axis_resources)
with with_mesh(mesh_data): with jtu.with_mesh(mesh_data):
result = fun(lhs, rhs) result = fun(lhs, rhs)
expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums) expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums)
@ -989,7 +961,7 @@ class PDotTests(XMapTestCase):
out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes), out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes),
axis_resources=axis_resources) axis_resources=axis_resources)
with with_mesh(mesh_data): with jtu.with_mesh(mesh_data):
lhs_bar, rhs_bar = fun(lhs, rhs, out_bar) lhs_bar, rhs_bar = fun(lhs, rhs, out_bar)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None tol = 1e-1 if jtu.device_under_test() == "tpu" else None
@ -1062,7 +1034,7 @@ class PDotTests(XMapTestCase):
class XMapErrorTest(jtu.JaxTestCase): class XMapErrorTest(jtu.JaxTestCase):
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testRepeatedAxisResource(self): def testRepeatedAxisResource(self):
def f(v): def f(v):
return v * 4 return v * 4
@ -1070,7 +1042,7 @@ class XMapErrorTest(jtu.JaxTestCase):
fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...], fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': ('x', 'x')}) axis_resources={'a': ('x', 'x')})
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testNestedDifferentResources(self): def testNestedDifferentResources(self):
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'}) @partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
def f(x): def f(x):
@ -1089,7 +1061,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, "Failed to infer size of axes: i."): with self.assertRaisesRegex(ValueError, "Failed to infer size of axes: i."):
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...])({}) xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...])({})
@with_mesh([('x', 2), ('y', 2)]) @jtu.with_mesh([('x', 2), ('y', 2)])
def testAxesNotDivisibleByResources(self): def testAxesNotDivisibleByResources(self):
with self.assertRaisesRegex(ValueError, r"Size of axis i \(5\) is not divisible.*" with self.assertRaisesRegex(ValueError, r"Size of axis i \(5\) is not divisible.*"
r"\(\('x', 'y'\), 4 in total\)"): r"\(\('x', 'y'\), 4 in total\)"):
@ -1154,7 +1126,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, error): with self.assertRaisesRegex(TypeError, error):
fm(x, y) fm(x, y)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testResourceConflictArgs(self): def testResourceConflictArgs(self):
fm = xmap(lambda x: lax.psum(x, ('a', 'b')), fm = xmap(lambda x: lax.psum(x, ('a', 'b')),
in_axes=['a', 'b'], out_axes=[], in_axes=['a', 'b'], out_axes=[],
@ -1166,7 +1138,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error): with self.assertRaisesRegex(JAXTypeError, error):
fm(x) fm(x)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testResourceConflictInner(self): def testResourceConflictInner(self):
fm = xmap(lambda x, y: x + y, fm = xmap(lambda x, y: x + y,
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...], in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
@ -1178,7 +1150,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error): with self.assertRaisesRegex(JAXTypeError, error):
fm(x, y) fm(x, y)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testResourceConflictOut(self): def testResourceConflictOut(self):
fm = xmap(lambda x, y: x, fm = xmap(lambda x, y: x,
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...], in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
@ -1191,7 +1163,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error): with self.assertRaisesRegex(JAXTypeError, error):
fm(x, y) fm(x, y)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testResourceConflictNestArgs(self): def testResourceConflictNestArgs(self):
f = xmap(lambda x: x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'}) f = xmap(lambda x: x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'})
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'}) h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
@ -1202,7 +1174,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error): with self.assertRaisesRegex(JAXTypeError, error):
h(x) h(x)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testResourceConflictNestInner(self): def testResourceConflictNestInner(self):
f = xmap(lambda x: lax.axis_index('i') + x, f = xmap(lambda x: lax.axis_index('i') + x,
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'}) in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
@ -1214,7 +1186,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error): with self.assertRaisesRegex(JAXTypeError, error):
h(x) h(x)
@with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testResourceConflictNestOut(self): def testResourceConflictNestOut(self):
f = xmap(lambda x: x, f = xmap(lambda x: x,
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'}) in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})