mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #6866 from gnecula:tf_pjit
PiperOrigin-RevId: 376989780
This commit is contained in:
commit
8e6101c6a1
@ -11,6 +11,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
## jax 0.2.14 (unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master).
|
||||
* New features:
|
||||
* The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
|
||||
|
||||
* Breaking changes:
|
||||
|
||||
|
@ -475,8 +475,8 @@ in [savedmodel_test.py](https://github.com/google/jax/blob/master/jax/experiment
|
||||
|
||||
### Missing converter features
|
||||
|
||||
There is currently no support for replicated (e.g. `pmap`) or multi-device
|
||||
(e.g. `sharded_jit`) functions. The collective operations are not yet handled.
|
||||
There is currently no support for `pmap` or`xmap`, nor for the collective
|
||||
operations. There is support for `sharded_jit` and `pjit`.
|
||||
|
||||
### No SavedModel fine-tuning
|
||||
|
||||
|
@ -31,6 +31,8 @@ from jax._src.lax import lax
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
import jax._src.random
|
||||
from jax.api_util import flatten_fun
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import sharded_jit
|
||||
@ -397,9 +399,6 @@ def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal) -> Sequence[TfVal]:
|
||||
return tuple(v for v, _ in out_with_avals)
|
||||
|
||||
|
||||
### tracer
|
||||
|
||||
|
||||
def _aval_to_tf_shape(aval: core.AbstractValue) -> Tuple[Optional[int], ...]:
|
||||
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
|
||||
return tuple(
|
||||
@ -729,8 +728,9 @@ class TensorFlowTrace(core.Trace):
|
||||
def process_call(self, call_primitive: core.Primitive, f: lu.WrappedFun,
|
||||
tracers: Sequence[TensorFlowTracer], params):
|
||||
assert call_primitive.multiple_results
|
||||
vals: Sequence[TfVal] = [t.val for t in tracers]
|
||||
f = _interpret_subtrace(f, self.main, tuple(t.aval for t in tracers))
|
||||
vals: Sequence[TfVal] = tuple(t.val for t in tracers)
|
||||
avals: Sequence[core.AbstractValue] = tuple(t.aval for t in tracers)
|
||||
f = _interpret_subtrace(f, self.main, avals)
|
||||
with core.new_sublevel():
|
||||
if call_primitive == core.named_call_p:
|
||||
with tf.name_scope(_sanitize_scope_name(params["name"])):
|
||||
@ -813,6 +813,8 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
|
||||
|
||||
|
||||
for unexpected in xla.call_translations: # Call primitives are inlined
|
||||
if unexpected is pjit.pjit_p:
|
||||
continue
|
||||
tf_impl[unexpected] = functools.partial(_unexpected_primitive, unexpected)
|
||||
|
||||
# Primitives that are not yet implemented must be explicitly declared here.
|
||||
@ -2494,8 +2496,8 @@ def split_to_logical_devices(tensor: TfVal,
|
||||
Returns:
|
||||
an annotated tensor.
|
||||
"""
|
||||
# This corresponds to the sharding annotations in
|
||||
# xla_bridge._sharding_to_proto.
|
||||
# TODO: this is only for sharded_jit. Either remove, or implement in terms
|
||||
# of _shard_values.
|
||||
if partition_dimensions is None:
|
||||
return xla_sharding.replicate(tensor, use_sharding_op=True)
|
||||
num_partition_splits = np.prod(partition_dimensions)
|
||||
@ -2504,6 +2506,24 @@ def split_to_logical_devices(tensor: TfVal,
|
||||
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
|
||||
|
||||
|
||||
def _shard_value(mesh: maps.Mesh,
|
||||
val: TfVal,
|
||||
aval: core.AbstractValue,
|
||||
axis_resources: pjit.ParsedPartitionSpec) -> TfVal:
|
||||
"""Apply sharding to a TfVal."""
|
||||
sharding_proto: xla_client.OpSharding = pjit.get_aval_sharding_proto(
|
||||
aval, axis_resources, mesh)
|
||||
# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
|
||||
xla_sharding_proto: xla_data_pb2.OpSharding = (
|
||||
xla_data_pb2.OpSharding(
|
||||
type=int(sharding_proto.type),
|
||||
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
|
||||
tile_assignment_devices=sharding_proto.tile_assignment_devices,
|
||||
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim))
|
||||
return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor(
|
||||
val, use_sharding_op=True)
|
||||
|
||||
|
||||
def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal],
|
||||
in_parts: Sequence[pxla.PartitionsOrReplicated],
|
||||
out_parts_thunk,
|
||||
@ -2520,12 +2540,51 @@ def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal],
|
||||
return sharded_vals_out
|
||||
|
||||
|
||||
def _sharding_constraint(arg: TfVal, *,
|
||||
partitions: pxla.PartitionsOrReplicated):
|
||||
def _sharded_jit_sharding_constraint(arg: TfVal, *,
|
||||
partitions: pxla.PartitionsOrReplicated,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray):
|
||||
del _in_avals, _out_aval
|
||||
return split_to_logical_devices(arg, partitions)
|
||||
|
||||
|
||||
tf_impl[sharded_jit.sharding_constraint_p] = _sharding_constraint
|
||||
tf_impl_with_avals[sharded_jit.sharding_constraint_p] = _sharded_jit_sharding_constraint
|
||||
|
||||
|
||||
def _pjit(*args: TfVal,
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
in_axis_resources: Sequence[pjit.ParsedPartitionSpec],
|
||||
out_axis_resources: Sequence[pjit.ParsedPartitionSpec],
|
||||
resource_env: maps.ResourceEnv,
|
||||
donated_invars,
|
||||
name: str,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray) -> TfVal:
|
||||
del donated_invars, name
|
||||
# TODO: add `name` to the name stack
|
||||
shard_value_for_mesh = functools.partial(_shard_value, resource_env.physical_mesh)
|
||||
# Apply sharding annotation to the arguments
|
||||
sharded_args: Sequence[TfVal] = tuple(
|
||||
util.safe_map(shard_value_for_mesh, args, _in_avals, in_axis_resources))
|
||||
results = _interpret_jaxpr(jaxpr, *sharded_args)
|
||||
sharded_results: Sequence[TfVal] = tuple(
|
||||
util.safe_map(shard_value_for_mesh, results, _out_aval, out_axis_resources))
|
||||
return tuple(sharded_results)
|
||||
|
||||
|
||||
tf_impl_with_avals[pjit.pjit_p] = _pjit
|
||||
|
||||
|
||||
def _pjit_sharding_constraint(arg: TfVal, *,
|
||||
axis_resources: pjit.ParsedPartitionSpec,
|
||||
resource_env: maps.ResourceEnv,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray,
|
||||
**kwargs) -> TfVal:
|
||||
return _shard_value(resource_env.physical_mesh, arg, _in_avals[0], axis_resources)
|
||||
|
||||
|
||||
tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint
|
||||
|
||||
|
||||
def _register_checkpoint_pytrees():
|
||||
|
@ -13,9 +13,10 @@
|
||||
# limitations under the License.
|
||||
"""Tests for the jax2tf conversion of sharded_jit."""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import re
|
||||
from typing import Sequence
|
||||
from typing import Any, Sequence
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -25,6 +26,7 @@ from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax.interpreters import sharded_jit
|
||||
from jax.interpreters.sharded_jit import PartitionSpec as P
|
||||
@ -36,6 +38,13 @@ import tensorflow as tf # type: ignore[import]
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
def setUpModule():
|
||||
jtu.set_spmd_lowering_flag(True)
|
||||
|
||||
def tearDownModule():
|
||||
jtu.restore_spmd_lowering_flag()
|
||||
|
||||
|
||||
LOG_HLO = True
|
||||
|
||||
class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
@ -46,7 +55,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def _check_sharding_annotations(self,
|
||||
f_jax,
|
||||
args,
|
||||
args: Sequence[Any],
|
||||
*,
|
||||
expected: Sequence[str],
|
||||
expected_opt: Sequence[str],
|
||||
@ -57,6 +66,9 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
We currently check the unoptimized HLO against `expected` on CPU and TPU,
|
||||
and we check the optimized HLO against `expected_opt` on TPU only and
|
||||
only for JAX.
|
||||
|
||||
See `self.AssertShardingAnnotations` for documentation of `expected`
|
||||
and `expected_opt`.
|
||||
"""
|
||||
if jtu.device_under_test() == "gpu":
|
||||
raise unittest.SkipTest("Sharding HLO tests not useful for GPU")
|
||||
@ -93,20 +105,20 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
experimental_get_compiler_ir(*args)(stage="hlo",
|
||||
device_name=device_name)
|
||||
if LOG_HLO:
|
||||
logging.info(f"[{self._testMethodName}] got TF OPT HLO {tf_hlo}")
|
||||
logging.info(f"[{self._testMethodName}] got TF HLO {tf_hlo}")
|
||||
self.AssertShardingAnnotations("TF before optimizations", tf_hlo, expected)
|
||||
tf_optimized_hlo = tf.function(f_tf, jit_compile=True).\
|
||||
experimental_get_compiler_ir(*args)(stage="optimized_hlo",
|
||||
device_name=device_name)
|
||||
if LOG_HLO:
|
||||
logging.info(f"[{self._testMethodName}] XXX got TF OPT HLO "
|
||||
logging.info(f"[{self._testMethodName}] got TF optimized HLO "
|
||||
f"for {device_name}: {tf_optimized_hlo}")
|
||||
|
||||
def AssertShardingAnnotations(self, what: str, hlo: str,
|
||||
expected: Sequence[str]):
|
||||
"""Args:
|
||||
|
||||
what: either 'JAX' or 'TF'
|
||||
what: either 'JAX' or 'TF', used for messages only.
|
||||
hlo: the text for the HLO module
|
||||
expected: a sequence of regexps that must occur in the hlo text. Each
|
||||
regexp must match a line, in order.
|
||||
@ -128,7 +140,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
f"!!! Not found[{next_expected_idx}] {expected[next_expected_idx]}")
|
||||
raise self.failureException("\n".join(failure_msg))
|
||||
|
||||
def test_in_out(self):
|
||||
def test_sharded_jit_in_out(self):
|
||||
"""Test input and output sharding annotations."""
|
||||
sharded_jax_func = sharded_jit.sharded_jit(
|
||||
jnp.dot, in_parts=(P(1, 2), P(2, 1)), out_parts=P(1, 2))
|
||||
@ -152,7 +164,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
],
|
||||
num_partitions=2)
|
||||
|
||||
def test_with_sharding_constraint(self):
|
||||
def test_sharded_jit_with_sharding_constraint(self):
|
||||
"""A sharding constraint in the middle."""
|
||||
|
||||
def jax_func(x, y):
|
||||
@ -180,7 +192,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
],
|
||||
num_partitions=2)
|
||||
|
||||
def test_replicated(self):
|
||||
def test_sharded_jit_replicated(self):
|
||||
"""A replicated input and output."""
|
||||
|
||||
sharded_jax_func = sharded_jit.sharded_jit(
|
||||
@ -203,6 +215,116 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
],
|
||||
num_partitions=2)
|
||||
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_pjit_basic1D(self):
|
||||
|
||||
@functools.partial(pjit.pjit,
|
||||
in_axis_resources=(P("x"), P("x")),
|
||||
out_axis_resources=None)
|
||||
def jax_func(x, y):
|
||||
return x + y
|
||||
|
||||
shape = (8, 10)
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
hlo = jax.xla_computation(jax_func)(x, x).as_hlo_text()
|
||||
print(f"HLO is {hlo}")
|
||||
print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}")
|
||||
self._check_sharding_annotations(
|
||||
jax_func, [x, x],
|
||||
expected=[
|
||||
r"f32\[8,10\].*sharding={devices=\[2,1\]", # x and y
|
||||
r"f32\[8,10\].*sharding={replicated", # output
|
||||
],
|
||||
expected_opt=[
|
||||
r"f32\[4,10\].*sharding={devices=\[2,1\]", # x and y
|
||||
# TODO: why don't we see "sharding={replicated"
|
||||
r"f32\[8,10\]", # output
|
||||
],
|
||||
num_partitions=2)
|
||||
|
||||
@jtu.with_mesh([("x", 2), ("y", 2)])
|
||||
def test_pjit_basic2D(self):
|
||||
@functools.partial(pjit.pjit,
|
||||
in_axis_resources=(P(None, "x", "y"), P("y")),
|
||||
out_axis_resources=P("x"))
|
||||
def jax_func(x, y):
|
||||
return x @ y
|
||||
|
||||
x_shape = (8, 6, 4)
|
||||
y_shape = (4, 2)
|
||||
x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
|
||||
y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape)
|
||||
self._check_sharding_annotations(
|
||||
jax_func, [x, y],
|
||||
expected=[
|
||||
r"f32\[8,6,4\].*sharding={devices=\[1,2,2\]0,1,2,3", # x
|
||||
r"f32\[4,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate", # y
|
||||
r"f32\[8,6,2\].*sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate", # output
|
||||
],
|
||||
expected_opt=[
|
||||
# TODO: relax ordering
|
||||
r"f32\[2,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate", # y
|
||||
r"f32\[8,3,2\].*sharding={devices=\[1,2,2\]0,1,2,3", # x
|
||||
# TODO: why we cannot see sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate?
|
||||
r"bf16\[4,6,2\]", # output
|
||||
],
|
||||
num_partitions=4)
|
||||
|
||||
@jtu.with_mesh([("x", 2), ("y", 2)])
|
||||
def test_pjit_TwoMeshAxisSharding(self):
|
||||
@functools.partial(pjit.pjit,
|
||||
in_axis_resources=P(("x", "y"),),
|
||||
out_axis_resources=P(("x", "y"),))
|
||||
def jax_func(x, y):
|
||||
return x @ y
|
||||
|
||||
x_shape = (24, 8)
|
||||
y_shape = (8, 2)
|
||||
x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
|
||||
y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape)
|
||||
self._check_sharding_annotations(
|
||||
jax_func, [x, y],
|
||||
expected=[
|
||||
r"f32\[24,8\].*sharding={devices=\[4,1\]0,1,2,3", # x
|
||||
r"f32\[8,2\].*sharding={devices=\[4,1\]0,1,2,3", # y
|
||||
r"f32\[24,2\].*sharding={devices=\[4,1\]0,1,2,3", # output
|
||||
],
|
||||
expected_opt=[
|
||||
# TODO: relax ordering
|
||||
r"f32\[2,2\].*sharding={devices=\[4,1\]0,1,2,3", # y
|
||||
r"f32\[6,8\].*sharding={devices=\[4,1\]0,1,2,3", # x
|
||||
# TODO: why we cannot see .*sharding={devices=\[4,1\]0,1,2,3
|
||||
r"f32\[1,6,2\]", # output
|
||||
],
|
||||
num_partitions=4)
|
||||
|
||||
@jtu.with_mesh([("x", 2), ("y", 1)])
|
||||
def test_pjit_ShardingConstraint(self):
|
||||
@functools.partial(pjit.pjit, in_axis_resources=None,
|
||||
out_axis_resources=None)
|
||||
def jax_func(x): # x: f32[12, 8]
|
||||
y = jnp.tile(x, (2, 1)) # y: f32[24, 8]
|
||||
y = pjit.with_sharding_constraint(y, P("x", "y"))
|
||||
return y[0:y.shape[0] // 4] # res: f32[6, 8]
|
||||
|
||||
shape = (12, 8)
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
self._check_sharding_annotations(
|
||||
jax_func, [x],
|
||||
expected=[
|
||||
r"f32\[12,8\].*sharding={replicated}", # x
|
||||
r"f32\[24,8\].*sharding={devices=\[2,1\]0,1", # y
|
||||
r"f32\[6,8\].*sharding={replicated}", # output
|
||||
],
|
||||
expected_opt=[
|
||||
r"f32\[12,8\].*sharding={replicated}", # x
|
||||
# TODO: why can't we see "sharding={devices=\[2,1\]0,1"
|
||||
r"f32\[1,12,8\]", # y
|
||||
# TODO: why can't we see "sharding={replicated}" ?
|
||||
r"f32\[6,8\]", # output
|
||||
],
|
||||
num_partitions=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -590,13 +590,20 @@ def get_array_mapping(axis_resources: ParsedPartitionSpec) -> pxla.ArrayMapping:
|
||||
for i, axes in enumerate(axis_resources)
|
||||
for axis in axes)
|
||||
|
||||
def get_sharding_proto(c, xla_op, axis_resources, mesh):
|
||||
def get_sharding_proto(c, xla_op, axis_resources: ParsedPartitionSpec,
|
||||
mesh: maps.Mesh) -> xc.OpSharding:
|
||||
xla_shape = c.GetShape(xla_op)
|
||||
if xla_shape.is_token():
|
||||
aval = core.abstract_token
|
||||
assert axis_resources is REPLICATED
|
||||
else:
|
||||
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.element_type())
|
||||
return get_aval_sharding_proto(aval, axis_resources, mesh)
|
||||
|
||||
|
||||
def get_aval_sharding_proto(aval: core.AbstractValue,
|
||||
axis_resources: ParsedPartitionSpec,
|
||||
mesh: maps.Mesh) -> xc.OpSharding:
|
||||
array_mapping = get_array_mapping(axis_resources)
|
||||
sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(
|
||||
aval, array_mapping)
|
||||
|
@ -17,7 +17,7 @@ import functools
|
||||
import re
|
||||
import os
|
||||
import textwrap
|
||||
from typing import Dict, Sequence, Union
|
||||
from typing import Dict, List, Generator, Sequence, Tuple, Union
|
||||
import unittest
|
||||
import warnings
|
||||
import zlib
|
||||
@ -33,10 +33,12 @@ from . import core
|
||||
from ._src import dtypes as _dtypes
|
||||
from . import lax
|
||||
from ._src.config import flags, bool_env, config
|
||||
from ._src.util import partial, prod
|
||||
from ._src.util import partial, prod, unzip2
|
||||
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
||||
from .lib import xla_bridge
|
||||
from .interpreters import xla
|
||||
from .experimental import maps
|
||||
from .experimental.maps import mesh
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
@ -1012,6 +1014,44 @@ def ignore_warning(**kw):
|
||||
warnings.filterwarnings("ignore", **kw)
|
||||
yield
|
||||
|
||||
# -------------------- Mesh parametrization helpers --------------------
|
||||
|
||||
MeshSpec = List[Tuple[str, int]]
|
||||
|
||||
@contextmanager
|
||||
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
||||
"""Test utility for setting up meshes given mesh data from `schedules`."""
|
||||
# This is similar to the `with_mesh` function above, but isn't a decorator.
|
||||
axis_names, shape = unzip2(named_shape)
|
||||
size = prod(shape)
|
||||
local_devices = list(api.local_devices())
|
||||
if len(local_devices) < size:
|
||||
raise unittest.SkipTest(f"Test requires {size} local devices")
|
||||
mesh_devices = np.array(local_devices[:size]).reshape(shape)
|
||||
with mesh(mesh_devices, axis_names):
|
||||
yield
|
||||
|
||||
def with_mesh_from_kwargs(f):
|
||||
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
|
||||
|
||||
def with_and_without_mesh(f):
|
||||
return parameterized.named_parameters(
|
||||
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
||||
for name, mesh, axis_resources in (
|
||||
('', (), ()),
|
||||
('Mesh', (('x', 2),), (('i', 'x'),))
|
||||
))(with_mesh_from_kwargs(f))
|
||||
|
||||
old_spmd_lowering_flag = False
|
||||
def set_spmd_lowering_flag(val: bool):
|
||||
global old_spmd_lowering_flag
|
||||
maps.make_xmap_callable.cache_clear()
|
||||
old_spmd_lowering_flag = maps.EXPERIMENTAL_SPMD_LOWERING
|
||||
maps.EXPERIMENTAL_SPMD_LOWERING = val
|
||||
|
||||
def restore_spmd_lowering_flag():
|
||||
maps.make_xmap_callable.cache_clear()
|
||||
maps.EXPERIMENTAL_SPMD_LOWERING = old_spmd_lowering_flag
|
||||
|
||||
class _cached_property:
|
||||
null = object()
|
||||
|
@ -12,10 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Generator, List, Tuple
|
||||
from unittest import SkipTest
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -32,47 +30,23 @@ from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.maps import xmap, mesh
|
||||
from jax.experimental.pjit import pjit, pjit_p, with_sharding_constraint, SpecSync
|
||||
from jax.interpreters import pxla
|
||||
from jax._src.util import unzip2, prod, curry
|
||||
from jax._src.util import prod, curry
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def setUpModule():
|
||||
global old_lowering_flag
|
||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
||||
old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
|
||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
|
||||
jtu.set_spmd_lowering_flag(True)
|
||||
|
||||
def tearDownModule():
|
||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = old_lowering_flag
|
||||
|
||||
|
||||
# TODO(skye): move into test_util and dedup with xmap_test.py
|
||||
MeshSpec = List[Tuple[str, int]]
|
||||
|
||||
@contextmanager
|
||||
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
||||
"""Test utility for setting up meshes given mesh data from `schedules`."""
|
||||
# This is similar to the `with_mesh` function above, but isn't a decorator.
|
||||
axis_names, shape = unzip2(named_shape)
|
||||
size = prod(shape)
|
||||
local_devices = list(jax.local_devices())
|
||||
if len(local_devices) < size:
|
||||
raise SkipTest(f"Test requires {size} local devices")
|
||||
mesh_devices = np.array(local_devices[:size]).reshape(shape)
|
||||
with mesh(mesh_devices, axis_names):
|
||||
yield
|
||||
|
||||
def with_mesh_from_kwargs(f):
|
||||
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
|
||||
jtu.restore_spmd_lowering_flag()
|
||||
|
||||
|
||||
# TODO(skye): make the buffer donation utils part of JaxTestCase
|
||||
class PJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testBasic1D(self):
|
||||
@partial(pjit,
|
||||
in_axis_resources=(P('x'), P('x')),
|
||||
@ -90,7 +64,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
|
||||
check_dtypes=False)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 2)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testBasic2D(self):
|
||||
@partial(pjit,
|
||||
in_axis_resources=(P(None, 'x', 'y'), P('y')),
|
||||
@ -118,7 +92,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(actual.device_buffers[3].to_py(), split1,
|
||||
check_dtypes=False)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 2)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testTwoMeshAxisSharding(self):
|
||||
@partial(pjit,
|
||||
in_axis_resources=P(('x', 'y'),),
|
||||
@ -144,7 +118,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(actual.device_buffers[3].to_py(), splits[3],
|
||||
check_dtypes=False)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testBufferDonation(self):
|
||||
@partial(pjit,
|
||||
in_axis_resources=P('x'),
|
||||
@ -162,7 +136,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertNotDeleted(y)
|
||||
self.assertDeleted(x)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 1)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testShardingConstraint(self):
|
||||
@partial(pjit, in_axis_resources=None, out_axis_resources=None)
|
||||
def f(x):
|
||||
@ -186,7 +160,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
# Annotation from pjit
|
||||
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
|
||||
|
||||
@with_mesh([('x', 2), ('y', 1)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testShardingConstraintPyTree(self):
|
||||
@partial(pjit, in_axis_resources=None, out_axis_resources=None)
|
||||
def f(x):
|
||||
@ -232,7 +206,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
should_be_tracing = False
|
||||
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 1)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testNested(self):
|
||||
# Add a constant captured by the nested pjit to make things more complicated
|
||||
h = jnp.arange(4)
|
||||
@ -243,7 +217,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(y, jnp.sin(x).sum() + h.sum())
|
||||
self.assertTrue(hasattr(y, "sharding_spec"))
|
||||
|
||||
@with_mesh([('x', 2), ('y', 1)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testJVP(self):
|
||||
# Add a constant captured by the nested pjit to make things more complicated
|
||||
h = jnp.arange(4)
|
||||
@ -252,7 +226,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
jtu.check_grads(g, (jnp.arange(16, dtype=jnp.float32).reshape((4, 4)),),
|
||||
order=2, modes=["fwd"], eps=1)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 1)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testEvalJaxpr(self):
|
||||
x, y = jnp.arange(4), jnp.arange(5)
|
||||
f = pjit(lambda x, y: x.sum() + jnp.sin(y),
|
||||
@ -263,13 +237,13 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
r, = f_eval(x, y)
|
||||
self.assertAllClose(r, x.sum() + jnp.sin(y))
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testNonArrayArg(self):
|
||||
self.assertEqual(pjit(lambda x: x + 2,
|
||||
in_axis_resources=None,
|
||||
out_axis_resources=None)(1), 3)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testNonHashableAxisResources(self):
|
||||
x = jnp.arange(4)
|
||||
y = pjit(lambda x: {'b': x['a'] + 2},
|
||||
@ -277,7 +251,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
out_axis_resources={'b': P('x')})({'a': x})
|
||||
self.assertAllClose(y, {'b': x + 2})
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testGradOfConstraint(self):
|
||||
# Make sure that we can compute grads through sharding constraints
|
||||
h = lambda x: jnp.sin(with_sharding_constraint(x, P('x'))).sum()
|
||||
@ -286,7 +260,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
x = jnp.arange(8, dtype=jnp.float32)
|
||||
self.assertAllClose(f(x), jnp.cos(x))
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testNoopPartitionSpecs(self):
|
||||
noops = [P(), P(None), P(()), P((), None), P(None, None, ())]
|
||||
x = jnp.arange(8).reshape((2, 2, 2))
|
||||
@ -294,7 +268,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
y = pjit(lambda x: x * 2, in_axis_resources=spec, out_axis_resources=spec)(x)
|
||||
self.assertAllClose(y, x * 2)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testVmapModifiesAxisResources(self):
|
||||
h = pjit(lambda x, y: (x + y, x, y), in_axis_resources=P('x'), out_axis_resources=None)
|
||||
x = jnp.arange(4)
|
||||
@ -310,7 +284,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(y_sync, SpecSync.IN_SYNC)
|
||||
self.assertEqual(z_sync, SpecSync.DIM_PERMUTE)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testVMap(self):
|
||||
f = pjit(lambda x, y: (x + y, x), in_axis_resources=P('x'), out_axis_resources=P('x'))
|
||||
x = jnp.arange(4)
|
||||
@ -402,7 +376,7 @@ def check_1d_2d_mesh(f, set_mesh):
|
||||
("2", (("x", 2),), "x"),
|
||||
("2x1", (("x", 2), ("y", 1)), ("x", "y")),
|
||||
("2x2", (("x", 2), ("y", 2)), ("x", "y")),
|
||||
))(with_mesh_from_kwargs(f) if set_mesh else f)
|
||||
))(jtu.with_mesh_from_kwargs(f) if set_mesh else f)
|
||||
|
||||
def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
@ -444,7 +418,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
in_axis_resources=None, out_axis_resources=None)(x)
|
||||
|
||||
@check_1d_2d_mesh(set_mesh=False)
|
||||
@with_mesh([('z', 1)])
|
||||
@jtu.with_mesh([('z', 1)])
|
||||
def testUndefinedResourcesArgs(self, mesh, resources):
|
||||
x = jnp.ones((2, 2))
|
||||
spec = P(resources,)
|
||||
@ -454,7 +428,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x)
|
||||
|
||||
@check_1d_2d_mesh(set_mesh=False)
|
||||
@with_mesh([('z', 1)])
|
||||
@jtu.with_mesh([('z', 1)])
|
||||
def testUndefinedResourcesOuts(self, mesh, resources):
|
||||
x = jnp.ones((2, 2))
|
||||
spec = P(resources,)
|
||||
@ -464,7 +438,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
|
||||
|
||||
@check_1d_2d_mesh(set_mesh=False)
|
||||
@with_mesh([('z', 1)])
|
||||
@jtu.with_mesh([('z', 1)])
|
||||
def testUndefinedResourcesConstraint(self, mesh, resources):
|
||||
x = jnp.ones((2, 2))
|
||||
spec = P(resources,)
|
||||
@ -475,7 +449,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
pjit(lambda x: with_sharding_constraint(x, spec),
|
||||
in_axis_resources=None, out_axis_resources=None)(x)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 1)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testRankTooLowArgs(self):
|
||||
x = jnp.arange(2)
|
||||
spec = P('x', 'y')
|
||||
@ -484,7 +458,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
pjit(lambda x: x.sum(), in_axis_resources=spec, out_axis_resources=None)(x)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 1)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testRankTooLowOuts(self):
|
||||
x = jnp.arange(2)
|
||||
spec = P('x', 'y')
|
||||
@ -493,7 +467,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
pjit(lambda x: x.sum(), in_axis_resources=None, out_axis_resources=spec)(x)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 1)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testRankTooLowConstraint(self):
|
||||
x = jnp.arange(2)
|
||||
spec = P('x', 'y')
|
||||
@ -504,7 +478,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
pjit(lambda x: with_sharding_constraint(x, spec),
|
||||
in_axis_resources=None, out_axis_resources=None)(x)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 1)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testRepeatedInResources(self):
|
||||
x = jnp.arange(2)
|
||||
for spec in [P('x', 'x'), P('x', ('y', 'x'))]:
|
||||
@ -514,7 +488,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 1)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testRepeatedOutResources(self):
|
||||
x = jnp.arange(2)
|
||||
for spec in [P('x', 'x'), P('x', ('y', 'x'))]:
|
||||
@ -524,7 +498,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testInputShardsXMapAxis(self):
|
||||
spec = P('x')
|
||||
f = xmap(pjit(lambda x: x + 2, in_axis_resources=spec, out_axis_resources=None),
|
||||
@ -537,7 +511,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
f(x)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testOutputShardsXMapAxis(self):
|
||||
spec = P('x')
|
||||
f = xmap(pjit(lambda x: x + 2, in_axis_resources=None, out_axis_resources=spec),
|
||||
@ -550,7 +524,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
f(x)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testConstraintShardsXMapAxis(self):
|
||||
spec = P('x')
|
||||
f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec),
|
||||
@ -563,7 +537,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
f(x)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testCatchesInnerXMapErrors(self):
|
||||
f = pjit(xmap(lambda x, y: x, in_axes=(['i'], ['j']), out_axes=['i', 'j'],
|
||||
axis_resources={'i': 'x', 'j': 'x'}),
|
||||
|
@ -71,34 +71,6 @@ def tearDownModule():
|
||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
# -------------------- Mesh parametrization helpers --------------------
|
||||
|
||||
MeshSpec = List[Tuple[str, int]]
|
||||
|
||||
@contextmanager
|
||||
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
||||
"""Test utility for setting up meshes given mesh data from `schedules`."""
|
||||
# This is similar to the `with_mesh` function above, but isn't a decorator.
|
||||
axis_names, shape = unzip2(named_shape)
|
||||
size = prod(shape)
|
||||
local_devices = list(jax.local_devices())
|
||||
if len(local_devices) < size:
|
||||
raise SkipTest(f"Test requires {size} local devices")
|
||||
mesh_devices = np.array(local_devices[:size]).reshape(shape)
|
||||
with mesh(mesh_devices, axis_names):
|
||||
yield
|
||||
|
||||
def with_mesh_from_kwargs(f):
|
||||
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
|
||||
|
||||
def with_and_without_mesh(f):
|
||||
return parameterized.named_parameters(
|
||||
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
||||
for name, mesh, axis_resources in (
|
||||
('', (), ()),
|
||||
('Mesh', (('x', 2),), (('i', 'x'),))
|
||||
))(with_mesh_from_kwargs(f))
|
||||
|
||||
|
||||
# -------------------- Itertools helpers --------------------
|
||||
|
||||
@ -135,7 +107,7 @@ def ensure_bdim(x, axis_name, bdim):
|
||||
AxisResources = Dict[str, Union[str, Tuple[str, ...]]]
|
||||
|
||||
def schedules(sizes: Dict[str, int]
|
||||
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]:
|
||||
) -> Generator[Tuple[AxisResources, jtu.MeshSpec], None, None]:
|
||||
"""Test utility generating xmap parallel schedules from logical names & sizes.
|
||||
|
||||
Args:
|
||||
@ -275,7 +247,7 @@ class XMapTest(XMapTestCase):
|
||||
self.assertAllClose(c, a * 2)
|
||||
self.assertAllClose(d, b * 4)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 2)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testCollectiveReduce(self):
|
||||
fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
|
||||
in_axes=[['a', 'b', ...], {0: 'c'}],
|
||||
@ -289,7 +261,7 @@ class XMapTest(XMapTestCase):
|
||||
self.assertAllClose(c, (a * 2).sum(0))
|
||||
self.assertAllClose(d, b * 4)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 2)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testCollectivePermute2D(self):
|
||||
perm = np.array([3, 1, 2, 0])
|
||||
x = jnp.arange(4).reshape((2, 2))
|
||||
@ -313,7 +285,7 @@ class XMapTest(XMapTestCase):
|
||||
in_axes=['i', ...], out_axes=['i', ...])(x)
|
||||
self.assertAllClose(result, x + x[jnp.newaxis].T)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 2)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testOneLogicalTwoMeshAxesBasic(self):
|
||||
def f(v):
|
||||
return lax.psum(v * 2, 'a'), v * 4
|
||||
@ -325,7 +297,7 @@ class XMapTest(XMapTestCase):
|
||||
self.assertAllClose(ans, (v * 2).sum(0))
|
||||
self.assertAllClose(ans2, v.T * 4)
|
||||
|
||||
@with_mesh([('x', 2), ('y', 2)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testOneLogicalTwoMeshAxesSharding(self):
|
||||
def f(v):
|
||||
return v * 4
|
||||
@ -346,7 +318,7 @@ class XMapTest(XMapTestCase):
|
||||
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
|
||||
(pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
|
||||
|
||||
@with_mesh([('x', 2), ('y', 2)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testSkipFirstMeshDim(self):
|
||||
def run(axis_resources):
|
||||
return xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
|
||||
@ -379,7 +351,7 @@ class XMapTest(XMapTestCase):
|
||||
('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))),
|
||||
('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))),
|
||||
))
|
||||
@with_mesh_from_kwargs
|
||||
@jtu.with_mesh_from_kwargs
|
||||
def testNestedMesh(self, mesh, axis_resources):
|
||||
@partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
|
||||
axis_resources=dict([axis_resources[0]]))
|
||||
@ -405,7 +377,7 @@ class XMapTest(XMapTestCase):
|
||||
# Make sure that there are non-partial sharding specs in the HLO
|
||||
self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}")
|
||||
|
||||
@with_and_without_mesh
|
||||
@jtu.with_and_without_mesh
|
||||
def testMultipleCalls(self, mesh, axis_resources):
|
||||
def f(x, y):
|
||||
assert x.shape == y.shape == (3, 5)
|
||||
@ -420,7 +392,7 @@ class XMapTest(XMapTestCase):
|
||||
for i in range(10):
|
||||
self.assertAllClose(f_mapped(x, x), expected)
|
||||
|
||||
@with_and_without_mesh
|
||||
@jtu.with_and_without_mesh
|
||||
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
||||
def testBufferDonation(self, mesh, axis_resources):
|
||||
shard = lambda x: x
|
||||
@ -443,7 +415,7 @@ class XMapTest(XMapTestCase):
|
||||
xmap(lambda x: lax.fori_loop(0, 10, lambda _, x: lax.psum(x, 'i'), x),
|
||||
in_axes=['i', ...], out_axes=['i', ...])(x)
|
||||
|
||||
@with_and_without_mesh
|
||||
@jtu.with_and_without_mesh
|
||||
def testAxisSizes(self, mesh, axis_resources):
|
||||
result = xmap(lambda: lax.axis_index('i'),
|
||||
in_axes=(), out_axes=['i', ...],
|
||||
@ -551,7 +523,7 @@ class XMapTest(XMapTestCase):
|
||||
y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100
|
||||
jtu.check_grads(f, (x, y), order=2, modes=['fwd'])
|
||||
|
||||
@with_and_without_mesh
|
||||
@jtu.with_and_without_mesh
|
||||
def testNamedShape(self, mesh, axis_resources):
|
||||
x = np.arange(4,)
|
||||
y = 2
|
||||
@ -563,7 +535,7 @@ class XMapTest(XMapTestCase):
|
||||
self.assertEqual(z.aval.named_shape, {})
|
||||
self.assertEqual(w.aval.named_shape, {})
|
||||
|
||||
@with_and_without_mesh
|
||||
@jtu.with_and_without_mesh
|
||||
def testBroadcast(self, mesh, axis_resources):
|
||||
x = jnp.asarray(2.0)
|
||||
f = xmap(lambda x: x, in_axes={}, out_axes=['i'],
|
||||
@ -590,7 +562,7 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest):
|
||||
raise SkipTest
|
||||
super().setUp()
|
||||
|
||||
@with_mesh([('x', 2), ('y', 2), ('z', 2)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 2), ('z', 2)])
|
||||
def testNestedMeshSPMD(self):
|
||||
h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))),
|
||||
in_axes={0: 'c'}, out_axes=({1: 'c'}, {}),
|
||||
@ -670,7 +642,7 @@ class NamedRandomTest(XMapTestCase):
|
||||
((f"_mesh={mesh}_resources={sorted(axis_resources.items())}",
|
||||
{"axis_resources": tuple(axis_resources.items()), "mesh": tuple(mesh)})
|
||||
for axis_resources, mesh in schedules({'i': 4, 'j': 6})), subset=True)
|
||||
@with_mesh_from_kwargs
|
||||
@jtu.with_mesh_from_kwargs
|
||||
def testSamplerResourceIndependence(self, distr_sample, axis_resources, mesh):
|
||||
def sample(axis_resources):
|
||||
return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)),
|
||||
@ -743,7 +715,7 @@ class NewPrimitiveTest(XMapTestCase):
|
||||
x_explode = x.reshape((3, 3, 3))
|
||||
self.assertAllClose(pgather(x, idx, 0), pgather(x_explode, idx, (0, 1)))
|
||||
|
||||
@with_and_without_mesh
|
||||
@jtu.with_and_without_mesh
|
||||
def testGather(self, mesh, axis_resources):
|
||||
if axis_resources and not jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING:
|
||||
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
|
||||
@ -851,7 +823,7 @@ def gen_axis_names():
|
||||
|
||||
def schedules_from_pdot_spec(
|
||||
spec: PdotTestSpec, lhs_shape: Tuple[int], rhs_shape: Tuple[int]
|
||||
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]:
|
||||
) -> Generator[Tuple[AxisResources, jtu.MeshSpec], None, None]:
|
||||
logical_sizes = {
|
||||
name: shape[ax]
|
||||
for shape, in_axes in [(lhs_shape, spec.lhs_in_axes),
|
||||
@ -862,7 +834,7 @@ def schedules_from_pdot_spec(
|
||||
|
||||
class PDotTests(XMapTestCase):
|
||||
|
||||
@with_mesh([('r1', 2)])
|
||||
@jtu.with_mesh([('r1', 2)])
|
||||
def testPdotBasic(self):
|
||||
def f(x, y):
|
||||
return lax.pdot(x, y, 'i')
|
||||
@ -880,7 +852,7 @@ class PDotTests(XMapTestCase):
|
||||
|
||||
self.assertAllClose(z, jnp.dot(x, y))
|
||||
|
||||
@with_mesh([('r1', 2)])
|
||||
@jtu.with_mesh([('r1', 2)])
|
||||
def testPdotBatching(self):
|
||||
def f(x, y):
|
||||
return lax.pdot(x, y, 'i')
|
||||
@ -898,7 +870,7 @@ class PDotTests(XMapTestCase):
|
||||
|
||||
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
|
||||
|
||||
@with_mesh([('r1', 2)])
|
||||
@jtu.with_mesh([('r1', 2)])
|
||||
def testPdotBatchingShardUncontractedDim(self):
|
||||
def f(x, y):
|
||||
return lax.pdot(x, y, 'i')
|
||||
@ -945,7 +917,7 @@ class PDotTests(XMapTestCase):
|
||||
out_axes=[*pdot_spec.batch_names, ...],
|
||||
axis_resources=axis_resources)
|
||||
|
||||
with with_mesh(mesh_data):
|
||||
with jtu.with_mesh(mesh_data):
|
||||
result = fun(lhs, rhs)
|
||||
|
||||
expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums)
|
||||
@ -989,7 +961,7 @@ class PDotTests(XMapTestCase):
|
||||
out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes),
|
||||
axis_resources=axis_resources)
|
||||
|
||||
with with_mesh(mesh_data):
|
||||
with jtu.with_mesh(mesh_data):
|
||||
lhs_bar, rhs_bar = fun(lhs, rhs, out_bar)
|
||||
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
@ -1062,7 +1034,7 @@ class PDotTests(XMapTestCase):
|
||||
|
||||
class XMapErrorTest(jtu.JaxTestCase):
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testRepeatedAxisResource(self):
|
||||
def f(v):
|
||||
return v * 4
|
||||
@ -1070,7 +1042,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': ('x', 'x')})
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testNestedDifferentResources(self):
|
||||
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
|
||||
def f(x):
|
||||
@ -1089,7 +1061,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "Failed to infer size of axes: i."):
|
||||
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...])({})
|
||||
|
||||
@with_mesh([('x', 2), ('y', 2)])
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testAxesNotDivisibleByResources(self):
|
||||
with self.assertRaisesRegex(ValueError, r"Size of axis i \(5\) is not divisible.*"
|
||||
r"\(\('x', 'y'\), 4 in total\)"):
|
||||
@ -1154,7 +1126,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, error):
|
||||
fm(x, y)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testResourceConflictArgs(self):
|
||||
fm = xmap(lambda x: lax.psum(x, ('a', 'b')),
|
||||
in_axes=['a', 'b'], out_axes=[],
|
||||
@ -1166,7 +1138,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
fm(x)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testResourceConflictInner(self):
|
||||
fm = xmap(lambda x, y: x + y,
|
||||
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
|
||||
@ -1178,7 +1150,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
fm(x, y)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testResourceConflictOut(self):
|
||||
fm = xmap(lambda x, y: x,
|
||||
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
|
||||
@ -1191,7 +1163,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
fm(x, y)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testResourceConflictNestArgs(self):
|
||||
f = xmap(lambda x: x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'})
|
||||
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
|
||||
@ -1202,7 +1174,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
h(x)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testResourceConflictNestInner(self):
|
||||
f = xmap(lambda x: lax.axis_index('i') + x,
|
||||
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
||||
@ -1214,7 +1186,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
h(x)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testResourceConflictNestOut(self):
|
||||
f = xmap(lambda x: x,
|
||||
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
||||
|
Loading…
x
Reference in New Issue
Block a user