Merge pull request #6866 from gnecula:tf_pjit

PiperOrigin-RevId: 376989780
This commit is contained in:
jax authors 2021-06-01 22:50:12 -07:00
commit 8e6101c6a1
8 changed files with 314 additions and 139 deletions

View File

@ -11,6 +11,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.14 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master).
* New features:
* The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
* Breaking changes:

View File

@ -1,7 +1,7 @@
# JAX and TensorFlow interoperation (jax2tf/call_tf)
This package provides experimental support for interoperation between JAX and TensorFlow.
There are two interoperation directions:
There are two interoperation directions:
- `jax2tf.convert`: for using JAX functions in a TensorFlow context, e.g.,
for eager or graph execution, or for saving as a TensorFlow SavedModel; and
@ -475,8 +475,8 @@ in [savedmodel_test.py](https://github.com/google/jax/blob/master/jax/experiment
### Missing converter features
There is currently no support for replicated (e.g. `pmap`) or multi-device
(e.g. `sharded_jit`) functions. The collective operations are not yet handled.
There is currently no support for `pmap` or`xmap`, nor for the collective
operations. There is support for `sharded_jit` and `pjit`.
### No SavedModel fine-tuning

View File

@ -31,6 +31,8 @@ from jax._src.lax import lax
from jax._src.lax import linalg as lax_linalg
import jax._src.random
from jax.api_util import flatten_fun
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import pxla
from jax.interpreters import sharded_jit
@ -397,9 +399,6 @@ def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal) -> Sequence[TfVal]:
return tuple(v for v, _ in out_with_avals)
### tracer
def _aval_to_tf_shape(aval: core.AbstractValue) -> Tuple[Optional[int], ...]:
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
return tuple(
@ -729,8 +728,9 @@ class TensorFlowTrace(core.Trace):
def process_call(self, call_primitive: core.Primitive, f: lu.WrappedFun,
tracers: Sequence[TensorFlowTracer], params):
assert call_primitive.multiple_results
vals: Sequence[TfVal] = [t.val for t in tracers]
f = _interpret_subtrace(f, self.main, tuple(t.aval for t in tracers))
vals: Sequence[TfVal] = tuple(t.val for t in tracers)
avals: Sequence[core.AbstractValue] = tuple(t.aval for t in tracers)
f = _interpret_subtrace(f, self.main, avals)
with core.new_sublevel():
if call_primitive == core.named_call_p:
with tf.name_scope(_sanitize_scope_name(params["name"])):
@ -813,6 +813,8 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
for unexpected in xla.call_translations: # Call primitives are inlined
if unexpected is pjit.pjit_p:
continue
tf_impl[unexpected] = functools.partial(_unexpected_primitive, unexpected)
# Primitives that are not yet implemented must be explicitly declared here.
@ -2494,8 +2496,8 @@ def split_to_logical_devices(tensor: TfVal,
Returns:
an annotated tensor.
"""
# This corresponds to the sharding annotations in
# xla_bridge._sharding_to_proto.
# TODO: this is only for sharded_jit. Either remove, or implement in terms
# of _shard_values.
if partition_dimensions is None:
return xla_sharding.replicate(tensor, use_sharding_op=True)
num_partition_splits = np.prod(partition_dimensions)
@ -2504,6 +2506,24 @@ def split_to_logical_devices(tensor: TfVal,
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
def _shard_value(mesh: maps.Mesh,
val: TfVal,
aval: core.AbstractValue,
axis_resources: pjit.ParsedPartitionSpec) -> TfVal:
"""Apply sharding to a TfVal."""
sharding_proto: xla_client.OpSharding = pjit.get_aval_sharding_proto(
aval, axis_resources, mesh)
# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
xla_sharding_proto: xla_data_pb2.OpSharding = (
xla_data_pb2.OpSharding(
type=int(sharding_proto.type),
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
tile_assignment_devices=sharding_proto.tile_assignment_devices,
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim))
return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor(
val, use_sharding_op=True)
def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal],
in_parts: Sequence[pxla.PartitionsOrReplicated],
out_parts_thunk,
@ -2520,12 +2540,51 @@ def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal],
return sharded_vals_out
def _sharding_constraint(arg: TfVal, *,
partitions: pxla.PartitionsOrReplicated):
def _sharded_jit_sharding_constraint(arg: TfVal, *,
partitions: pxla.PartitionsOrReplicated,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
del _in_avals, _out_aval
return split_to_logical_devices(arg, partitions)
tf_impl[sharded_jit.sharding_constraint_p] = _sharding_constraint
tf_impl_with_avals[sharded_jit.sharding_constraint_p] = _sharded_jit_sharding_constraint
def _pjit(*args: TfVal,
jaxpr: core.ClosedJaxpr,
in_axis_resources: Sequence[pjit.ParsedPartitionSpec],
out_axis_resources: Sequence[pjit.ParsedPartitionSpec],
resource_env: maps.ResourceEnv,
donated_invars,
name: str,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray) -> TfVal:
del donated_invars, name
# TODO: add `name` to the name stack
shard_value_for_mesh = functools.partial(_shard_value, resource_env.physical_mesh)
# Apply sharding annotation to the arguments
sharded_args: Sequence[TfVal] = tuple(
util.safe_map(shard_value_for_mesh, args, _in_avals, in_axis_resources))
results = _interpret_jaxpr(jaxpr, *sharded_args)
sharded_results: Sequence[TfVal] = tuple(
util.safe_map(shard_value_for_mesh, results, _out_aval, out_axis_resources))
return tuple(sharded_results)
tf_impl_with_avals[pjit.pjit_p] = _pjit
def _pjit_sharding_constraint(arg: TfVal, *,
axis_resources: pjit.ParsedPartitionSpec,
resource_env: maps.ResourceEnv,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray,
**kwargs) -> TfVal:
return _shard_value(resource_env.physical_mesh, arg, _in_avals[0], axis_resources)
tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint
def _register_checkpoint_pytrees():

View File

@ -13,9 +13,10 @@
# limitations under the License.
"""Tests for the jax2tf conversion of sharded_jit."""
import functools
import logging
import re
from typing import Sequence
from typing import Any, Sequence
import unittest
from absl.testing import absltest
@ -25,6 +26,7 @@ from jax import test_util as jtu
from jax.config import config
from jax.experimental import jax2tf
from jax.experimental import pjit
from jax.experimental.jax2tf.tests import tf_test_util
from jax.interpreters import sharded_jit
from jax.interpreters.sharded_jit import PartitionSpec as P
@ -36,6 +38,13 @@ import tensorflow as tf # type: ignore[import]
config.parse_flags_with_absl()
def setUpModule():
jtu.set_spmd_lowering_flag(True)
def tearDownModule():
jtu.restore_spmd_lowering_flag()
LOG_HLO = True
class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
@ -46,7 +55,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
def _check_sharding_annotations(self,
f_jax,
args,
args: Sequence[Any],
*,
expected: Sequence[str],
expected_opt: Sequence[str],
@ -57,6 +66,9 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
We currently check the unoptimized HLO against `expected` on CPU and TPU,
and we check the optimized HLO against `expected_opt` on TPU only and
only for JAX.
See `self.AssertShardingAnnotations` for documentation of `expected`
and `expected_opt`.
"""
if jtu.device_under_test() == "gpu":
raise unittest.SkipTest("Sharding HLO tests not useful for GPU")
@ -93,20 +105,20 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
experimental_get_compiler_ir(*args)(stage="hlo",
device_name=device_name)
if LOG_HLO:
logging.info(f"[{self._testMethodName}] got TF OPT HLO {tf_hlo}")
logging.info(f"[{self._testMethodName}] got TF HLO {tf_hlo}")
self.AssertShardingAnnotations("TF before optimizations", tf_hlo, expected)
tf_optimized_hlo = tf.function(f_tf, jit_compile=True).\
experimental_get_compiler_ir(*args)(stage="optimized_hlo",
device_name=device_name)
if LOG_HLO:
logging.info(f"[{self._testMethodName}] XXX got TF OPT HLO "
logging.info(f"[{self._testMethodName}] got TF optimized HLO "
f"for {device_name}: {tf_optimized_hlo}")
def AssertShardingAnnotations(self, what: str, hlo: str,
expected: Sequence[str]):
"""Args:
what: either 'JAX' or 'TF'
what: either 'JAX' or 'TF', used for messages only.
hlo: the text for the HLO module
expected: a sequence of regexps that must occur in the hlo text. Each
regexp must match a line, in order.
@ -128,7 +140,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
f"!!! Not found[{next_expected_idx}] {expected[next_expected_idx]}")
raise self.failureException("\n".join(failure_msg))
def test_in_out(self):
def test_sharded_jit_in_out(self):
"""Test input and output sharding annotations."""
sharded_jax_func = sharded_jit.sharded_jit(
jnp.dot, in_parts=(P(1, 2), P(2, 1)), out_parts=P(1, 2))
@ -152,7 +164,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
],
num_partitions=2)
def test_with_sharding_constraint(self):
def test_sharded_jit_with_sharding_constraint(self):
"""A sharding constraint in the middle."""
def jax_func(x, y):
@ -180,7 +192,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
],
num_partitions=2)
def test_replicated(self):
def test_sharded_jit_replicated(self):
"""A replicated input and output."""
sharded_jax_func = sharded_jit.sharded_jit(
@ -203,6 +215,116 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
],
num_partitions=2)
@jtu.with_mesh([("x", 2)])
def test_pjit_basic1D(self):
@functools.partial(pjit.pjit,
in_axis_resources=(P("x"), P("x")),
out_axis_resources=None)
def jax_func(x, y):
return x + y
shape = (8, 10)
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
hlo = jax.xla_computation(jax_func)(x, x).as_hlo_text()
print(f"HLO is {hlo}")
print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}")
self._check_sharding_annotations(
jax_func, [x, x],
expected=[
r"f32\[8,10\].*sharding={devices=\[2,1\]", # x and y
r"f32\[8,10\].*sharding={replicated", # output
],
expected_opt=[
r"f32\[4,10\].*sharding={devices=\[2,1\]", # x and y
# TODO: why don't we see "sharding={replicated"
r"f32\[8,10\]", # output
],
num_partitions=2)
@jtu.with_mesh([("x", 2), ("y", 2)])
def test_pjit_basic2D(self):
@functools.partial(pjit.pjit,
in_axis_resources=(P(None, "x", "y"), P("y")),
out_axis_resources=P("x"))
def jax_func(x, y):
return x @ y
x_shape = (8, 6, 4)
y_shape = (4, 2)
x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape)
self._check_sharding_annotations(
jax_func, [x, y],
expected=[
r"f32\[8,6,4\].*sharding={devices=\[1,2,2\]0,1,2,3", # x
r"f32\[4,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate", # y
r"f32\[8,6,2\].*sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate", # output
],
expected_opt=[
# TODO: relax ordering
r"f32\[2,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate", # y
r"f32\[8,3,2\].*sharding={devices=\[1,2,2\]0,1,2,3", # x
# TODO: why we cannot see sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate?
r"bf16\[4,6,2\]", # output
],
num_partitions=4)
@jtu.with_mesh([("x", 2), ("y", 2)])
def test_pjit_TwoMeshAxisSharding(self):
@functools.partial(pjit.pjit,
in_axis_resources=P(("x", "y"),),
out_axis_resources=P(("x", "y"),))
def jax_func(x, y):
return x @ y
x_shape = (24, 8)
y_shape = (8, 2)
x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape)
self._check_sharding_annotations(
jax_func, [x, y],
expected=[
r"f32\[24,8\].*sharding={devices=\[4,1\]0,1,2,3", # x
r"f32\[8,2\].*sharding={devices=\[4,1\]0,1,2,3", # y
r"f32\[24,2\].*sharding={devices=\[4,1\]0,1,2,3", # output
],
expected_opt=[
# TODO: relax ordering
r"f32\[2,2\].*sharding={devices=\[4,1\]0,1,2,3", # y
r"f32\[6,8\].*sharding={devices=\[4,1\]0,1,2,3", # x
# TODO: why we cannot see .*sharding={devices=\[4,1\]0,1,2,3
r"f32\[1,6,2\]", # output
],
num_partitions=4)
@jtu.with_mesh([("x", 2), ("y", 1)])
def test_pjit_ShardingConstraint(self):
@functools.partial(pjit.pjit, in_axis_resources=None,
out_axis_resources=None)
def jax_func(x): # x: f32[12, 8]
y = jnp.tile(x, (2, 1)) # y: f32[24, 8]
y = pjit.with_sharding_constraint(y, P("x", "y"))
return y[0:y.shape[0] // 4] # res: f32[6, 8]
shape = (12, 8)
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
self._check_sharding_annotations(
jax_func, [x],
expected=[
r"f32\[12,8\].*sharding={replicated}", # x
r"f32\[24,8\].*sharding={devices=\[2,1\]0,1", # y
r"f32\[6,8\].*sharding={replicated}", # output
],
expected_opt=[
r"f32\[12,8\].*sharding={replicated}", # x
# TODO: why can't we see "sharding={devices=\[2,1\]0,1"
r"f32\[1,12,8\]", # y
# TODO: why can't we see "sharding={replicated}" ?
r"f32\[6,8\]", # output
],
num_partitions=2)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -590,13 +590,20 @@ def get_array_mapping(axis_resources: ParsedPartitionSpec) -> pxla.ArrayMapping:
for i, axes in enumerate(axis_resources)
for axis in axes)
def get_sharding_proto(c, xla_op, axis_resources, mesh):
def get_sharding_proto(c, xla_op, axis_resources: ParsedPartitionSpec,
mesh: maps.Mesh) -> xc.OpSharding:
xla_shape = c.GetShape(xla_op)
if xla_shape.is_token():
aval = core.abstract_token
assert axis_resources is REPLICATED
else:
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.element_type())
return get_aval_sharding_proto(aval, axis_resources, mesh)
def get_aval_sharding_proto(aval: core.AbstractValue,
axis_resources: ParsedPartitionSpec,
mesh: maps.Mesh) -> xc.OpSharding:
array_mapping = get_array_mapping(axis_resources)
sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(
aval, array_mapping)

View File

@ -17,7 +17,7 @@ import functools
import re
import os
import textwrap
from typing import Dict, Sequence, Union
from typing import Dict, List, Generator, Sequence, Tuple, Union
import unittest
import warnings
import zlib
@ -33,10 +33,12 @@ from . import core
from ._src import dtypes as _dtypes
from . import lax
from ._src.config import flags, bool_env, config
from ._src.util import partial, prod
from ._src.util import partial, prod, unzip2
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
from .lib import xla_bridge
from .interpreters import xla
from .experimental import maps
from .experimental.maps import mesh
FLAGS = flags.FLAGS
@ -1012,6 +1014,44 @@ def ignore_warning(**kw):
warnings.filterwarnings("ignore", **kw)
yield
# -------------------- Mesh parametrization helpers --------------------
MeshSpec = List[Tuple[str, int]]
@contextmanager
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
"""Test utility for setting up meshes given mesh data from `schedules`."""
# This is similar to the `with_mesh` function above, but isn't a decorator.
axis_names, shape = unzip2(named_shape)
size = prod(shape)
local_devices = list(api.local_devices())
if len(local_devices) < size:
raise unittest.SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape)
with mesh(mesh_devices, axis_names):
yield
def with_mesh_from_kwargs(f):
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
def with_and_without_mesh(f):
return parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
for name, mesh, axis_resources in (
('', (), ()),
('Mesh', (('x', 2),), (('i', 'x'),))
))(with_mesh_from_kwargs(f))
old_spmd_lowering_flag = False
def set_spmd_lowering_flag(val: bool):
global old_spmd_lowering_flag
maps.make_xmap_callable.cache_clear()
old_spmd_lowering_flag = maps.EXPERIMENTAL_SPMD_LOWERING
maps.EXPERIMENTAL_SPMD_LOWERING = val
def restore_spmd_lowering_flag():
maps.make_xmap_callable.cache_clear()
maps.EXPERIMENTAL_SPMD_LOWERING = old_spmd_lowering_flag
class _cached_property:
null = object()

View File

@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from functools import partial
import logging
from typing import Generator, List, Tuple
from unittest import SkipTest
from absl.testing import absltest
@ -32,47 +30,23 @@ from jax.experimental import PartitionSpec as P
from jax.experimental.maps import xmap, mesh
from jax.experimental.pjit import pjit, pjit_p, with_sharding_constraint, SpecSync
from jax.interpreters import pxla
from jax._src.util import unzip2, prod, curry
from jax._src.util import prod, curry
from jax.config import config
config.parse_flags_with_absl()
def setUpModule():
global old_lowering_flag
jax.experimental.maps.make_xmap_callable.cache_clear()
old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
jtu.set_spmd_lowering_flag(True)
def tearDownModule():
jax.experimental.maps.make_xmap_callable.cache_clear()
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = old_lowering_flag
# TODO(skye): move into test_util and dedup with xmap_test.py
MeshSpec = List[Tuple[str, int]]
@contextmanager
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
"""Test utility for setting up meshes given mesh data from `schedules`."""
# This is similar to the `with_mesh` function above, but isn't a decorator.
axis_names, shape = unzip2(named_shape)
size = prod(shape)
local_devices = list(jax.local_devices())
if len(local_devices) < size:
raise SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape)
with mesh(mesh_devices, axis_names):
yield
def with_mesh_from_kwargs(f):
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
jtu.restore_spmd_lowering_flag()
# TODO(skye): make the buffer donation utils part of JaxTestCase
class PJitTest(jtu.BufferDonationTestCase):
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testBasic1D(self):
@partial(pjit,
in_axis_resources=(P('x'), P('x')),
@ -90,7 +64,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
check_dtypes=False)
@with_mesh([('x', 2), ('y', 2)])
@jtu.with_mesh([('x', 2), ('y', 2)])
def testBasic2D(self):
@partial(pjit,
in_axis_resources=(P(None, 'x', 'y'), P('y')),
@ -118,7 +92,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual.device_buffers[3].to_py(), split1,
check_dtypes=False)
@with_mesh([('x', 2), ('y', 2)])
@jtu.with_mesh([('x', 2), ('y', 2)])
def testTwoMeshAxisSharding(self):
@partial(pjit,
in_axis_resources=P(('x', 'y'),),
@ -144,7 +118,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual.device_buffers[3].to_py(), splits[3],
check_dtypes=False)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testBufferDonation(self):
@partial(pjit,
in_axis_resources=P('x'),
@ -162,7 +136,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertNotDeleted(y)
self.assertDeleted(x)
@with_mesh([('x', 2), ('y', 1)])
@jtu.with_mesh([('x', 2), ('y', 1)])
def testShardingConstraint(self):
@partial(pjit, in_axis_resources=None, out_axis_resources=None)
def f(x):
@ -186,7 +160,7 @@ class PJitTest(jtu.BufferDonationTestCase):
# Annotation from pjit
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
@with_mesh([('x', 2), ('y', 1)])
@jtu.with_mesh([('x', 2), ('y', 1)])
def testShardingConstraintPyTree(self):
@partial(pjit, in_axis_resources=None, out_axis_resources=None)
def f(x):
@ -232,7 +206,7 @@ class PJitTest(jtu.BufferDonationTestCase):
should_be_tracing = False
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
@with_mesh([('x', 2), ('y', 1)])
@jtu.with_mesh([('x', 2), ('y', 1)])
def testNested(self):
# Add a constant captured by the nested pjit to make things more complicated
h = jnp.arange(4)
@ -243,7 +217,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(y, jnp.sin(x).sum() + h.sum())
self.assertTrue(hasattr(y, "sharding_spec"))
@with_mesh([('x', 2), ('y', 1)])
@jtu.with_mesh([('x', 2), ('y', 1)])
def testJVP(self):
# Add a constant captured by the nested pjit to make things more complicated
h = jnp.arange(4)
@ -252,7 +226,7 @@ class PJitTest(jtu.BufferDonationTestCase):
jtu.check_grads(g, (jnp.arange(16, dtype=jnp.float32).reshape((4, 4)),),
order=2, modes=["fwd"], eps=1)
@with_mesh([('x', 2), ('y', 1)])
@jtu.with_mesh([('x', 2), ('y', 1)])
def testEvalJaxpr(self):
x, y = jnp.arange(4), jnp.arange(5)
f = pjit(lambda x, y: x.sum() + jnp.sin(y),
@ -263,13 +237,13 @@ class PJitTest(jtu.BufferDonationTestCase):
r, = f_eval(x, y)
self.assertAllClose(r, x.sum() + jnp.sin(y))
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testNonArrayArg(self):
self.assertEqual(pjit(lambda x: x + 2,
in_axis_resources=None,
out_axis_resources=None)(1), 3)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testNonHashableAxisResources(self):
x = jnp.arange(4)
y = pjit(lambda x: {'b': x['a'] + 2},
@ -277,7 +251,7 @@ class PJitTest(jtu.BufferDonationTestCase):
out_axis_resources={'b': P('x')})({'a': x})
self.assertAllClose(y, {'b': x + 2})
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testGradOfConstraint(self):
# Make sure that we can compute grads through sharding constraints
h = lambda x: jnp.sin(with_sharding_constraint(x, P('x'))).sum()
@ -286,7 +260,7 @@ class PJitTest(jtu.BufferDonationTestCase):
x = jnp.arange(8, dtype=jnp.float32)
self.assertAllClose(f(x), jnp.cos(x))
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testNoopPartitionSpecs(self):
noops = [P(), P(None), P(()), P((), None), P(None, None, ())]
x = jnp.arange(8).reshape((2, 2, 2))
@ -294,7 +268,7 @@ class PJitTest(jtu.BufferDonationTestCase):
y = pjit(lambda x: x * 2, in_axis_resources=spec, out_axis_resources=spec)(x)
self.assertAllClose(y, x * 2)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testVmapModifiesAxisResources(self):
h = pjit(lambda x, y: (x + y, x, y), in_axis_resources=P('x'), out_axis_resources=None)
x = jnp.arange(4)
@ -310,7 +284,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertEqual(y_sync, SpecSync.IN_SYNC)
self.assertEqual(z_sync, SpecSync.DIM_PERMUTE)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testVMap(self):
f = pjit(lambda x, y: (x + y, x), in_axis_resources=P('x'), out_axis_resources=P('x'))
x = jnp.arange(4)
@ -402,7 +376,7 @@ def check_1d_2d_mesh(f, set_mesh):
("2", (("x", 2),), "x"),
("2x1", (("x", 2), ("y", 1)), ("x", "y")),
("2x2", (("x", 2), ("y", 2)), ("x", "y")),
))(with_mesh_from_kwargs(f) if set_mesh else f)
))(jtu.with_mesh_from_kwargs(f) if set_mesh else f)
def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
@ -444,7 +418,7 @@ class PJitErrorTest(jtu.JaxTestCase):
in_axis_resources=None, out_axis_resources=None)(x)
@check_1d_2d_mesh(set_mesh=False)
@with_mesh([('z', 1)])
@jtu.with_mesh([('z', 1)])
def testUndefinedResourcesArgs(self, mesh, resources):
x = jnp.ones((2, 2))
spec = P(resources,)
@ -454,7 +428,7 @@ class PJitErrorTest(jtu.JaxTestCase):
pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x)
@check_1d_2d_mesh(set_mesh=False)
@with_mesh([('z', 1)])
@jtu.with_mesh([('z', 1)])
def testUndefinedResourcesOuts(self, mesh, resources):
x = jnp.ones((2, 2))
spec = P(resources,)
@ -464,7 +438,7 @@ class PJitErrorTest(jtu.JaxTestCase):
pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
@check_1d_2d_mesh(set_mesh=False)
@with_mesh([('z', 1)])
@jtu.with_mesh([('z', 1)])
def testUndefinedResourcesConstraint(self, mesh, resources):
x = jnp.ones((2, 2))
spec = P(resources,)
@ -475,7 +449,7 @@ class PJitErrorTest(jtu.JaxTestCase):
pjit(lambda x: with_sharding_constraint(x, spec),
in_axis_resources=None, out_axis_resources=None)(x)
@with_mesh([('x', 2), ('y', 1)])
@jtu.with_mesh([('x', 2), ('y', 1)])
def testRankTooLowArgs(self):
x = jnp.arange(2)
spec = P('x', 'y')
@ -484,7 +458,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x.sum(), in_axis_resources=spec, out_axis_resources=None)(x)
@with_mesh([('x', 2), ('y', 1)])
@jtu.with_mesh([('x', 2), ('y', 1)])
def testRankTooLowOuts(self):
x = jnp.arange(2)
spec = P('x', 'y')
@ -493,7 +467,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x.sum(), in_axis_resources=None, out_axis_resources=spec)(x)
@with_mesh([('x', 2), ('y', 1)])
@jtu.with_mesh([('x', 2), ('y', 1)])
def testRankTooLowConstraint(self):
x = jnp.arange(2)
spec = P('x', 'y')
@ -504,7 +478,7 @@ class PJitErrorTest(jtu.JaxTestCase):
pjit(lambda x: with_sharding_constraint(x, spec),
in_axis_resources=None, out_axis_resources=None)(x)
@with_mesh([('x', 2), ('y', 1)])
@jtu.with_mesh([('x', 2), ('y', 1)])
def testRepeatedInResources(self):
x = jnp.arange(2)
for spec in [P('x', 'x'), P('x', ('y', 'x'))]:
@ -514,7 +488,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x)
@with_mesh([('x', 2), ('y', 1)])
@jtu.with_mesh([('x', 2), ('y', 1)])
def testRepeatedOutResources(self):
x = jnp.arange(2)
for spec in [P('x', 'x'), P('x', ('y', 'x'))]:
@ -524,7 +498,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testInputShardsXMapAxis(self):
spec = P('x')
f = xmap(pjit(lambda x: x + 2, in_axis_resources=spec, out_axis_resources=None),
@ -537,7 +511,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error):
f(x)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testOutputShardsXMapAxis(self):
spec = P('x')
f = xmap(pjit(lambda x: x + 2, in_axis_resources=None, out_axis_resources=spec),
@ -550,7 +524,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error):
f(x)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testConstraintShardsXMapAxis(self):
spec = P('x')
f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec),
@ -563,7 +537,7 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error):
f(x)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testCatchesInnerXMapErrors(self):
f = pjit(xmap(lambda x, y: x, in_axes=(['i'], ['j']), out_axes=['i', 'j'],
axis_resources={'i': 'x', 'j': 'x'}),

View File

@ -71,34 +71,6 @@ def tearDownModule():
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
# -------------------- Mesh parametrization helpers --------------------
MeshSpec = List[Tuple[str, int]]
@contextmanager
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
"""Test utility for setting up meshes given mesh data from `schedules`."""
# This is similar to the `with_mesh` function above, but isn't a decorator.
axis_names, shape = unzip2(named_shape)
size = prod(shape)
local_devices = list(jax.local_devices())
if len(local_devices) < size:
raise SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape)
with mesh(mesh_devices, axis_names):
yield
def with_mesh_from_kwargs(f):
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
def with_and_without_mesh(f):
return parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
for name, mesh, axis_resources in (
('', (), ()),
('Mesh', (('x', 2),), (('i', 'x'),))
))(with_mesh_from_kwargs(f))
# -------------------- Itertools helpers --------------------
@ -135,7 +107,7 @@ def ensure_bdim(x, axis_name, bdim):
AxisResources = Dict[str, Union[str, Tuple[str, ...]]]
def schedules(sizes: Dict[str, int]
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]:
) -> Generator[Tuple[AxisResources, jtu.MeshSpec], None, None]:
"""Test utility generating xmap parallel schedules from logical names & sizes.
Args:
@ -275,7 +247,7 @@ class XMapTest(XMapTestCase):
self.assertAllClose(c, a * 2)
self.assertAllClose(d, b * 4)
@with_mesh([('x', 2), ('y', 2)])
@jtu.with_mesh([('x', 2), ('y', 2)])
def testCollectiveReduce(self):
fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
in_axes=[['a', 'b', ...], {0: 'c'}],
@ -289,7 +261,7 @@ class XMapTest(XMapTestCase):
self.assertAllClose(c, (a * 2).sum(0))
self.assertAllClose(d, b * 4)
@with_mesh([('x', 2), ('y', 2)])
@jtu.with_mesh([('x', 2), ('y', 2)])
def testCollectivePermute2D(self):
perm = np.array([3, 1, 2, 0])
x = jnp.arange(4).reshape((2, 2))
@ -313,7 +285,7 @@ class XMapTest(XMapTestCase):
in_axes=['i', ...], out_axes=['i', ...])(x)
self.assertAllClose(result, x + x[jnp.newaxis].T)
@with_mesh([('x', 2), ('y', 2)])
@jtu.with_mesh([('x', 2), ('y', 2)])
def testOneLogicalTwoMeshAxesBasic(self):
def f(v):
return lax.psum(v * 2, 'a'), v * 4
@ -325,7 +297,7 @@ class XMapTest(XMapTestCase):
self.assertAllClose(ans, (v * 2).sum(0))
self.assertAllClose(ans2, v.T * 4)
@with_mesh([('x', 2), ('y', 2)])
@jtu.with_mesh([('x', 2), ('y', 2)])
def testOneLogicalTwoMeshAxesSharding(self):
def f(v):
return v * 4
@ -346,7 +318,7 @@ class XMapTest(XMapTestCase):
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
(pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
@with_mesh([('x', 2), ('y', 2)])
@jtu.with_mesh([('x', 2), ('y', 2)])
def testSkipFirstMeshDim(self):
def run(axis_resources):
return xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
@ -379,7 +351,7 @@ class XMapTest(XMapTestCase):
('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))),
('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))),
))
@with_mesh_from_kwargs
@jtu.with_mesh_from_kwargs
def testNestedMesh(self, mesh, axis_resources):
@partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
axis_resources=dict([axis_resources[0]]))
@ -405,7 +377,7 @@ class XMapTest(XMapTestCase):
# Make sure that there are non-partial sharding specs in the HLO
self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}")
@with_and_without_mesh
@jtu.with_and_without_mesh
def testMultipleCalls(self, mesh, axis_resources):
def f(x, y):
assert x.shape == y.shape == (3, 5)
@ -420,7 +392,7 @@ class XMapTest(XMapTestCase):
for i in range(10):
self.assertAllClose(f_mapped(x, x), expected)
@with_and_without_mesh
@jtu.with_and_without_mesh
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def testBufferDonation(self, mesh, axis_resources):
shard = lambda x: x
@ -443,7 +415,7 @@ class XMapTest(XMapTestCase):
xmap(lambda x: lax.fori_loop(0, 10, lambda _, x: lax.psum(x, 'i'), x),
in_axes=['i', ...], out_axes=['i', ...])(x)
@with_and_without_mesh
@jtu.with_and_without_mesh
def testAxisSizes(self, mesh, axis_resources):
result = xmap(lambda: lax.axis_index('i'),
in_axes=(), out_axes=['i', ...],
@ -551,7 +523,7 @@ class XMapTest(XMapTestCase):
y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100
jtu.check_grads(f, (x, y), order=2, modes=['fwd'])
@with_and_without_mesh
@jtu.with_and_without_mesh
def testNamedShape(self, mesh, axis_resources):
x = np.arange(4,)
y = 2
@ -563,7 +535,7 @@ class XMapTest(XMapTestCase):
self.assertEqual(z.aval.named_shape, {})
self.assertEqual(w.aval.named_shape, {})
@with_and_without_mesh
@jtu.with_and_without_mesh
def testBroadcast(self, mesh, axis_resources):
x = jnp.asarray(2.0)
f = xmap(lambda x: x, in_axes={}, out_axes=['i'],
@ -590,7 +562,7 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest):
raise SkipTest
super().setUp()
@with_mesh([('x', 2), ('y', 2), ('z', 2)])
@jtu.with_mesh([('x', 2), ('y', 2), ('z', 2)])
def testNestedMeshSPMD(self):
h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))),
in_axes={0: 'c'}, out_axes=({1: 'c'}, {}),
@ -670,7 +642,7 @@ class NamedRandomTest(XMapTestCase):
((f"_mesh={mesh}_resources={sorted(axis_resources.items())}",
{"axis_resources": tuple(axis_resources.items()), "mesh": tuple(mesh)})
for axis_resources, mesh in schedules({'i': 4, 'j': 6})), subset=True)
@with_mesh_from_kwargs
@jtu.with_mesh_from_kwargs
def testSamplerResourceIndependence(self, distr_sample, axis_resources, mesh):
def sample(axis_resources):
return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)),
@ -743,7 +715,7 @@ class NewPrimitiveTest(XMapTestCase):
x_explode = x.reshape((3, 3, 3))
self.assertAllClose(pgather(x, idx, 0), pgather(x_explode, idx, (0, 1)))
@with_and_without_mesh
@jtu.with_and_without_mesh
def testGather(self, mesh, axis_resources):
if axis_resources and not jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING:
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
@ -851,7 +823,7 @@ def gen_axis_names():
def schedules_from_pdot_spec(
spec: PdotTestSpec, lhs_shape: Tuple[int], rhs_shape: Tuple[int]
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]:
) -> Generator[Tuple[AxisResources, jtu.MeshSpec], None, None]:
logical_sizes = {
name: shape[ax]
for shape, in_axes in [(lhs_shape, spec.lhs_in_axes),
@ -862,7 +834,7 @@ def schedules_from_pdot_spec(
class PDotTests(XMapTestCase):
@with_mesh([('r1', 2)])
@jtu.with_mesh([('r1', 2)])
def testPdotBasic(self):
def f(x, y):
return lax.pdot(x, y, 'i')
@ -880,7 +852,7 @@ class PDotTests(XMapTestCase):
self.assertAllClose(z, jnp.dot(x, y))
@with_mesh([('r1', 2)])
@jtu.with_mesh([('r1', 2)])
def testPdotBatching(self):
def f(x, y):
return lax.pdot(x, y, 'i')
@ -898,7 +870,7 @@ class PDotTests(XMapTestCase):
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
@with_mesh([('r1', 2)])
@jtu.with_mesh([('r1', 2)])
def testPdotBatchingShardUncontractedDim(self):
def f(x, y):
return lax.pdot(x, y, 'i')
@ -945,7 +917,7 @@ class PDotTests(XMapTestCase):
out_axes=[*pdot_spec.batch_names, ...],
axis_resources=axis_resources)
with with_mesh(mesh_data):
with jtu.with_mesh(mesh_data):
result = fun(lhs, rhs)
expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums)
@ -989,7 +961,7 @@ class PDotTests(XMapTestCase):
out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes),
axis_resources=axis_resources)
with with_mesh(mesh_data):
with jtu.with_mesh(mesh_data):
lhs_bar, rhs_bar = fun(lhs, rhs, out_bar)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
@ -1062,7 +1034,7 @@ class PDotTests(XMapTestCase):
class XMapErrorTest(jtu.JaxTestCase):
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testRepeatedAxisResource(self):
def f(v):
return v * 4
@ -1070,7 +1042,7 @@ class XMapErrorTest(jtu.JaxTestCase):
fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': ('x', 'x')})
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testNestedDifferentResources(self):
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
def f(x):
@ -1089,7 +1061,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, "Failed to infer size of axes: i."):
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...])({})
@with_mesh([('x', 2), ('y', 2)])
@jtu.with_mesh([('x', 2), ('y', 2)])
def testAxesNotDivisibleByResources(self):
with self.assertRaisesRegex(ValueError, r"Size of axis i \(5\) is not divisible.*"
r"\(\('x', 'y'\), 4 in total\)"):
@ -1154,7 +1126,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, error):
fm(x, y)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testResourceConflictArgs(self):
fm = xmap(lambda x: lax.psum(x, ('a', 'b')),
in_axes=['a', 'b'], out_axes=[],
@ -1166,7 +1138,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error):
fm(x)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testResourceConflictInner(self):
fm = xmap(lambda x, y: x + y,
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
@ -1178,7 +1150,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error):
fm(x, y)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testResourceConflictOut(self):
fm = xmap(lambda x, y: x,
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
@ -1191,7 +1163,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error):
fm(x, y)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testResourceConflictNestArgs(self):
f = xmap(lambda x: x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'})
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
@ -1202,7 +1174,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error):
h(x)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testResourceConflictNestInner(self):
f = xmap(lambda x: lax.axis_index('i') + x,
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
@ -1214,7 +1186,7 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error):
h(x)
@with_mesh([('x', 2)])
@jtu.with_mesh([('x', 2)])
def testResourceConflictNestOut(self):
f = xmap(lambda x: x,
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})